Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: ADVI - arguments start and start_sigma have inconsistent keys for transformed random variables #7534

Open
trendelkampschroer opened this issue Oct 11, 2024 · 5 comments
Labels
bug VI Variational Inference

Comments

@trendelkampschroer
Copy link

trendelkampschroer commented Oct 11, 2024

Describe the issue:

There are a couple of issues with the current design:
i) The keys of start and start_sigma are inconsistent as can be seen in example below.
ii) I can specify arbitrary (str-valued) keys for start_sigma, if they are not in the model, then the default zero initialisation is used for the rho variable in the variational approximation.
iii) The np.log(np.expm1(np.abs(sigma))) transformation from sigma to rho in method create_shared_parameters from class MeanFieldGroup is also somewhat surprising.
iv) It is also unclear how the values generated from advi.approx.mean.eval and advi.approx.std.eval relate to the variational parameters mu and rho returned by method create_shared_params and/or with the free random variables beta and sigma of the model. For example I think that tracker["mean"][-1][-1] (cf. example below) is the variational parameter corresponding to the variable sigma_log__.
v) In pm.fit arguments start and start_sigma are ignored if an instance of ADVI is passed as method.

Reproduceable code example:

Minimal example below 


import numpy as np
import pandas as pd
import pymc as pm


def generate_data(num_samples: int) -> pd.DataFrame:
    rng = np.random.default_rng(seed=42)
    beta = 1.0
    sigma = 10.0
    x = rng.normal(loc=0.0, scale=1.0, size=num_samples)
    y = beta * x + sigma * rng.normal(size=num_samples)
    return pd.DataFrame({"x": x, "y": y})


def make_model(frame: pd.DataFrame) -> pm.Model:
    with pm.Model() as model:
        # Data
        x = pm.Data("x", frame["x"])
        y = pm.Data("y", frame["y"])

        # Prior
        beta = pm.Normal("beta", sigma=10.0)
        sigma = pm.HalfNormal("sigma", sigma=20.0)

        # Linear model
        mu = beta * x

        # Likelihood
        pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)
    return model

if __name__ == "__main__":
    frame = generate_data(10000)
    model = make_model(frame)
    with model:
        advi = pm.ADVI(
            start={"beta": 1.0, "sigma": 10.0},
            start_sigma={"beta": 0.5, "sigma_log__": 0.5}
        )
       tracker = Tracker(
            mean=advi.approx.mean.eval,
            std=advi.approx.std.eval
        )
        approx = pm.fit(
            n=1_000_000,
            method=advi,
            callbacks=[
                CheckParametersConvergence(diff="relative", tolerance=1e-3),
                CheckParametersConvergence(diff="absolute", tolerance=1e-3),
                tracker
            ],
        )

This is from class MeanFieldGroup

    def create_shared_params(self, start=None, start_sigma=None):
        # NOTE: `Group._prepare_start` uses `self.model.free_RVs` to identify free variables and
        # `DictToArrayBijection` to turn them into a flat array, while `Approximation.rslice` assumes that the free
        # variables are given by `self.group` and that the mapping between original variables and flat array is given
        # by `self.ordering`. In the cases I looked into these turn out to be the same, but there may be edge cases or
        # future code changes that break this assumption.
        start = self._prepare_start(start)
        rho1 = np.zeros((self.ddim,))

        if start_sigma is not None:
            for name, slice_, *_ in self.ordering.values():
                sigma = start_sigma.get(name)
                if sigma is not None:
                    rho1[slice_] = np.log(np.expm1(np.abs(sigma)))
        rho = rho1

        return {
            "mu": pytensor.shared(pm.floatX(start), "mu"),
            "rho": pytensor.shared(pm.floatX(rho), "rho"),
        }

I think this simple example is already hitting the problematic edge case mentioned in the NOTE. Also the line sigma = start_sigma.get(name) is problematic as passing start_sigma with wrong keys will never raise an error.

This is in contrast to parameter start where a wrong key will raise, e.g. using "beta_foo" instead of "beta"

Traceback (most recent call last):
  File ".../univariate.py", line 167, in <module>
    advi = pm.ADVI(
           ^^^^^^^^
  File ".../pymc/variational/inference.py", line 471, in __init__
    super().__init__(MeanField(*args, **kwargs))
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../pymc/variational/approximations.py", line 339, in __init__
    super().__init__(groups, model=kwargs.get("model"))
  File ".../pymc/variational/opvi.py", line 1229, in __init__
    rest.__init_group__(unseen_free_RVs)
  File ".../pytensor/configparser.py", line 44, in res
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File ".../pymc/variational/approximations.py", line 74, in __init_group__
    self.shared_params = self.create_shared_params(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../pymc/variational/approximations.py", line 85, in create_shared_params
    start = self._prepare_start(start)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../pymc/variational/opvi.py", line 760, in _prepare_start
    ipfn = make_initial_point_fn(
           ^^^^^^^^^^^^^^^^^^^^^^
  File ".../pymc/initial_point.py", line 134, in make_initial_point_fn
    sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../pymc/initial_point.py", line 49, in convert_str_to_rv_dict
    initvals[model[key]] = initval
             ~~~~~^^^^^
  File ".../pymc/model/core.py", line 1575, in __getitem__
    raise e
  File ".../pymc/model/core.py", line 1570, in __getitem__
    return self.named_vars[key]
           ~~~~~~~~~~~~~~~^^^^^
KeyError: 'beta_foo'

Also this block in pm.fit is problematic since it will ignore start and start_sigma if an instance of e.g. ADVI is passed as method argument to instead of the string 'advi'

 _select = dict(advi=ADVI, fullrank_advi=FullRankADVI, svgd=SVGD, asvgd=ASVGD)
    if isinstance(method, str):
        method = method.lower()
        if method in _select:
            inference = _select[method](model=model, **inf_kwargs)
        else:
            raise KeyError(f"method should be one of {set(_select.keys())} or Inference instance")
    elif isinstance(method, Inference):
        inference = method
...

Initially this made debugging of the actual issue even more contrived. Also it is not clear if I can use any instance of ADVI to get tracking for mean and std or if it has to be the same instance that is passed to pm.fit.

I'd be happy to make a PR with a fix, but I am lacking familiarity with the APIs/concepts mentioned in the NOTE, i.e.

`Group._prepare_start` uses `self.model.free_RVs` to identify free variables and
        # `DictToArrayBijection` to turn them into a flat array, while `Approximation.rslice` assumes that the free
        # variables are given by `self.group` and that the mapping between original variables and flat array is given
        # by `self.ordering`.

I'd be super glad to receive guidance so that I can work on this. I think that a fix could significantly improve the API and usability of the initialisation for ADVI.

Context for the issue:

cf. https://discourse.pymc.io/t/variational-fit-advi-initialisation/15940/4 for more context. Also thanks a lot @jessegrabowski for already very helpful suggestions.

A working example, that demonstrates the quirks of the current state is below. The example demonstrates how to use the "final state" (variational parameters mu, rho???) to initialise a fit for a model with slightly enlarged data set, but I wouldn't know how to implement this for a complicated (hierarchical) model with many (transformed) random variables. As the transformation need to be applied in order to use the "final state" from the first fit as initialisation of the second fit.

import numpy as np
import pandas as pd
import pymc as pm


def generate_data(num_samples: int) -> pd.DataFrame:
    rng = np.random.default_rng(seed=42)
    beta = 1.0
    sigma = 10.0
    x = rng.normal(loc=0.0, scale=1.0, size=num_samples)
    y = beta * x + sigma * rng.normal(size=num_samples)
    return pd.DataFrame({"x": x, "y": y})


def make_model(frame: pd.DataFrame) -> pm.Model:
    with pm.Model() as model:
        # Data
        x = pm.Data("x", frame["x"])
        y = pm.Data("y", frame["y"])

        # Prior
        beta = pm.Normal("beta", sigma=10.0)
        sigma = pm.HalfNormal("sigma", sigma=20.0)

        # Linear model
        mu = beta * x

        # Likelihood
        pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)
    return model


if __name__ == "__main__":
    num_samples = 10_000
    frame = generate_data(num_samples=num_samples)
    model = make_model(frame)
    with model:
        advi = pm.ADVI(
        )
        tracker = Tracker(
            mean=advi.approx.mean.eval,
            std=advi.approx.std.eval
        )
        t0 = time.time()
        approx = pm.fit(
            n=1_000_000,
            method=advi,
            callbacks=[
                CheckParametersConvergence(diff="relative", tolerance=1e-3),
                CheckParametersConvergence(diff="absolute", tolerance=1e-3),
                tracker
            ],
        )
        t = time.time() - t0
        print(f"Time for fit is {t:.3f}s.")

    frame_new = pd.concat([frame, generate_data(100)], axis=0)
    model = make_model(frame_new)
    with model:
        advi = pm.ADVI(
            start={"beta_foo": tracker["mean"][-1][0], "sigma": np.exp(tracker["mean"][-1][1])},
            start_sigma={"beta": tracker["std"][-1][0], "sigma_log__": tracker["std"][-1][1]}
        )
        tracker = Tracker(
            mean=advi.approx.mean.eval,
            std=advi.approx.std.eval
        )
        t0 = time.time()
        approx = pm.fit(
            n=1_000_000,
            method=advi,
            callbacks=[
                CheckParametersConvergence(diff="relative", tolerance=1e-3),
                CheckParametersConvergence(diff="absolute", tolerance=1e-3),
                tracker
            ],
        )
        t = time.time() - t0
        print(f"Time for fit is {t:.3f}s.")

This is the trace for the initial ADVI fit
initial_fit

and here for the model with data updated and parameters initialised
update_fit

As we can see we can save quite a lot of computation by using the final parameters from the old fit as the initial guess for the new fit.

Error message:

No response

PyMC version information:

# packages in environment at .../.miniconda3/envs/pymc: # # Name Version Build Channel absl-py 2.1.0 pyhd8ed1ab_0 conda-forge accelerate 1.0.0 pyhd8ed1ab_0 conda-forge arviz 0.20.0 pyhd8ed1ab_0 conda-forge atk-1.0 2.38.0 hd03087b_2 conda-forge aws-c-auth 0.7.31 hc27b277_0 conda-forge aws-c-cal 0.7.4 h41dd001_1 conda-forge aws-c-common 0.9.28 hd74edd7_0 conda-forge aws-c-compression 0.2.19 h41dd001_1 conda-forge aws-c-event-stream 0.4.3 h40a8fc1_2 conda-forge aws-c-http 0.8.10 hf5a2c8c_0 conda-forge aws-c-io 0.14.18 hc3cb426_12 conda-forge aws-c-mqtt 0.10.7 h3acc7b9_0 conda-forge aws-c-s3 0.6.6 hd16c091_0 conda-forge aws-c-sdkutils 0.1.19 h41dd001_3 conda-forge aws-checksums 0.1.20 h41dd001_0 conda-forge aws-crt-cpp 0.28.3 h433f80b_6 conda-forge aws-sdk-cpp 1.11.407 h0455a66_0 conda-forge azure-core-cpp 1.13.0 hd01fc5c_0 conda-forge azure-identity-cpp 1.8.0 h13ea094_2 conda-forge azure-storage-blobs-cpp 12.12.0 hfde595f_0 conda-forge azure-storage-common-cpp 12.7.0 hcf3b6fd_1 conda-forge azure-storage-files-datalake-cpp 12.11.0 h082e32e_1 conda-forge blackjax 1.2.4 pyhd8ed1ab_0 conda-forge blas 2.124 openblas conda-forge blas-devel 3.9.0 24_osxarm64_openblas conda-forge brotli 1.1.0 hd74edd7_2 conda-forge brotli-bin 1.1.0 hd74edd7_2 conda-forge brotli-python 1.1.0 py312hde4cb15_2 conda-forge bzip2 1.0.8 h99b78c6_7 conda-forge c-ares 1.34.1 hd74edd7_0 conda-forge c-compiler 1.8.0 h2664225_0 conda-forge ca-certificates 2024.8.30 hf0a4a13_0 conda-forge cached-property 1.5.2 hd8ed1ab_1 conda-forge cached_property 1.5.2 pyha770c72_1 conda-forge cachetools 5.5.0 pyhd8ed1ab_0 conda-forge cairo 1.18.0 hb4a6bf7_3 conda-forge cctools 1010.6 hf67d63f_1 conda-forge cctools_osx-arm64 1010.6 h4208deb_1 conda-forge certifi 2024.8.30 pyhd8ed1ab_0 conda-forge cffi 1.17.1 py312h0fad829_0 conda-forge charset-normalizer 3.4.0 pyhd8ed1ab_0 conda-forge chex 0.1.87 pyhd8ed1ab_0 conda-forge clang 17.0.6 default_h360f5da_7 conda-forge clang-17 17.0.6 default_h146c034_7 conda-forge clang_impl_osx-arm64 17.0.6 he47c785_21 conda-forge clang_osx-arm64 17.0.6 h54d7cd3_21 conda-forge clangxx 17.0.6 default_h360f5da_7 conda-forge clangxx_impl_osx-arm64 17.0.6 h50f59cd_21 conda-forge clangxx_osx-arm64 17.0.6 h54d7cd3_21 conda-forge cloudpickle 3.0.0 pyhd8ed1ab_0 conda-forge colorama 0.4.6 pyhd8ed1ab_0 conda-forge compiler-rt 17.0.6 h856b3c1_2 conda-forge compiler-rt_osx-arm64 17.0.6 h832e737_2 conda-forge cons 0.4.6 pyhd8ed1ab_0 conda-forge contourpy 1.3.0 py312h6142ec9_2 conda-forge cpython 3.12.7 py312hd8ed1ab_0 conda-forge cxx-compiler 1.8.0 he8d86c4_0 conda-forge cycler 0.12.1 pyhd8ed1ab_0 conda-forge etils 1.9.4 pyhd8ed1ab_0 conda-forge etuples 0.3.9 pyhd8ed1ab_0 conda-forge expat 2.6.3 hf9b8971_0 conda-forge fastprogress 1.0.3 pyhd8ed1ab_0 conda-forge filelock 3.16.1 pyhd8ed1ab_0 conda-forge font-ttf-dejavu-sans-mono 2.37 hab24e00_0 conda-forge font-ttf-inconsolata 3.000 h77eed37_0 conda-forge font-ttf-source-code-pro 2.038 h77eed37_0 conda-forge font-ttf-ubuntu 0.83 h77eed37_3 conda-forge fontconfig 2.14.2 h82840c6_0 conda-forge fonts-conda-ecosystem 1 0 conda-forge fonts-conda-forge 1 0 conda-forge fonttools 4.54.1 py312h024a12e_0 conda-forge freetype 2.12.1 hadb7bae_2 conda-forge fribidi 1.0.10 h27ca646_0 conda-forge fsspec 2024.9.0 pyhff2d567_0 conda-forge gdk-pixbuf 2.42.12 h7ddc832_0 conda-forge gflags 2.2.2 hf9b8971_1005 conda-forge glog 0.7.1 heb240a5_0 conda-forge gmp 6.3.0 h7bae524_2 conda-forge gmpy2 2.1.5 py312h87fada9_2 conda-forge graphite2 1.3.13 hebf3989_1003 conda-forge graphviz 12.0.0 hbf8cc41_0 conda-forge gtk2 2.24.33 h91d5085_5 conda-forge gts 0.7.6 he42f4ea_4 conda-forge h2 4.1.0 pyhd8ed1ab_0 conda-forge h5netcdf 1.4.0 pyhd8ed1ab_0 conda-forge h5py 3.11.0 nompi_py312h903599c_102 conda-forge harfbuzz 9.0.0 h997cde5_1 conda-forge hdf5 1.14.3 nompi_hec07895_105 conda-forge hpack 4.0.0 pyh9f0ad1d_0 conda-forge huggingface_hub 0.25.2 pyh0610db2_0 conda-forge hyperframe 6.0.1 pyhd8ed1ab_0 conda-forge icu 75.1 hfee45f7_0 conda-forge idna 3.10 pyhd8ed1ab_0 conda-forge importlib-metadata 8.5.0 pyha770c72_0 conda-forge jax 0.4.31 pyhd8ed1ab_1 conda-forge jaxlib 0.4.31 cpu_py312h47007b3_1 conda-forge jaxopt 0.8.3 pyhd8ed1ab_0 conda-forge jinja2 3.1.4 pyhd8ed1ab_0 conda-forge joblib 1.4.2 pyhd8ed1ab_0 conda-forge kiwisolver 1.4.7 py312h6142ec9_0 conda-forge krb5 1.21.3 h237132a_0 conda-forge lcms2 2.16 ha0e7c42_0 conda-forge ld64 951.9 h39a299f_1 conda-forge ld64_osx-arm64 951.9 hc81425b_1 conda-forge lerc 4.0.0 h9a09cb3_0 conda-forge libabseil 20240116.2 cxx17_h00cdb27_1 conda-forge libaec 1.1.3 hebf3989_0 conda-forge libarrow 17.0.0 hc6a7651_16_cpu conda-forge libblas 3.9.0 24_osxarm64_openblas conda-forge libbrotlicommon 1.1.0 hd74edd7_2 conda-forge libbrotlidec 1.1.0 hd74edd7_2 conda-forge libbrotlienc 1.1.0 hd74edd7_2 conda-forge libcblas 3.9.0 24_osxarm64_openblas conda-forge libclang-cpp17 17.0.6 default_h146c034_7 conda-forge libcrc32c 1.1.2 hbdafb3b_0 conda-forge libcurl 8.10.1 h13a7ad3_0 conda-forge libcxx 19.1.1 ha82da77_0 conda-forge libcxx-devel 17.0.6 h86353a2_6 conda-forge libdeflate 1.22 hd74edd7_0 conda-forge libedit 3.1.20191231 hc8eb9b7_2 conda-forge libev 4.33 h93a5062_2 conda-forge libexpat 2.6.3 hf9b8971_0 conda-forge libffi 3.4.2 h3422bc3_5 conda-forge libgd 2.3.3 hac1b3a8_10 conda-forge libgfortran 5.0.0 13_2_0_hd922786_3 conda-forge libgfortran5 13.2.0 hf226fd6_3 conda-forge libglib 2.82.1 h4821c08_0 conda-forge libgoogle-cloud 2.29.0 hfa33a2f_0 conda-forge libgoogle-cloud-storage 2.29.0 h90fd6fa_0 conda-forge libgrpc 1.62.2 h9c18a4f_0 conda-forge libiconv 1.17 h0d3ecfb_2 conda-forge libintl 0.22.5 h8414b35_3 conda-forge libjpeg-turbo 3.0.0 hb547adb_1 conda-forge liblapack 3.9.0 24_osxarm64_openblas conda-forge liblapacke 3.9.0 24_osxarm64_openblas conda-forge libllvm14 14.0.6 hd1a9a77_4 conda-forge libllvm17 17.0.6 h5090b49_2 conda-forge libnghttp2 1.58.0 ha4dd798_1 conda-forge libopenblas 0.3.27 openmp_h517c56d_1 conda-forge libpng 1.6.44 hc14010f_0 conda-forge libprotobuf 4.25.3 hc39d83c_1 conda-forge libre2-11 2023.09.01 h7b2c953_2 conda-forge librsvg 2.58.4 h40956f1_0 conda-forge libsqlite 3.46.1 hc14010f_0 conda-forge libssh2 1.11.0 h7a5bd25_0 conda-forge libtiff 4.7.0 hfce79cd_1 conda-forge libtorch 2.4.1 cpu_generic_h123b01e_0 conda-forge libutf8proc 2.8.0 h1a8c8d9_0 conda-forge libuv 1.49.0 hd74edd7_0 conda-forge libwebp-base 1.4.0 h93a5062_0 conda-forge libxcb 1.17.0 hdb1d25a_0 conda-forge libxml2 2.12.7 h01dff8b_4 conda-forge libzlib 1.3.1 h8359307_2 conda-forge llvm-openmp 19.1.1 h6cdba0f_0 conda-forge llvm-tools 17.0.6 h5090b49_2 conda-forge llvmlite 0.43.0 py312ha9ca408_1 conda-forge logical-unification 0.4.6 pyhd8ed1ab_0 conda-forge lz4-c 1.9.4 hb7217d7_0 conda-forge macosx_deployment_target_osx-arm64 11.0 h6553868_1 conda-forge markdown-it-py 3.0.0 pyhd8ed1ab_0 conda-forge markupsafe 3.0.1 py312h906988d_1 conda-forge matplotlib 3.9.2 py312h1f38498_1 conda-forge matplotlib-base 3.9.2 py312h9bd0bc6_1 conda-forge mdurl 0.1.2 pyhd8ed1ab_0 conda-forge minikanren 1.0.3 pyhd8ed1ab_0 conda-forge ml_dtypes 0.5.0 py312hcd31e36_0 conda-forge mpc 1.3.1 h8f1351a_1 conda-forge mpfr 4.2.1 hb693164_3 conda-forge mpmath 1.3.0 pyhd8ed1ab_0 conda-forge multipledispatch 0.6.0 pyhd8ed1ab_1 conda-forge munkres 1.1.4 pyh9f0ad1d_0 conda-forge ncurses 6.5 h7bae524_1 conda-forge networkx 3.4 pyhd8ed1ab_0 conda-forge nomkl 1.0 h5ca1d4c_0 conda-forge numba 0.60.0 py312h41cea2d_0 conda-forge numpy 1.26.4 py312h8442bc7_0 conda-forge numpyro 0.15.3 pyhd8ed1ab_0 conda-forge nutpie 0.13.2 py312headafe2_0 conda-forge openblas 0.3.27 openmp_h560b219_1 conda-forge openjpeg 2.5.2 h9f1df11_0 conda-forge openssl 3.3.2 h8359307_0 conda-forge opt-einsum 3.4.0 hd8ed1ab_0 conda-forge opt_einsum 3.4.0 pyhd8ed1ab_0 conda-forge optax 0.2.3 pyhd8ed1ab_0 conda-forge orc 2.0.2 h75dedd0_0 conda-forge packaging 24.1 pyhd8ed1ab_0 conda-forge pandas 2.2.3 py312hcd31e36_1 conda-forge pango 1.54.0 h9ee27a3_2 conda-forge pcre2 10.44 h297a79d_2 conda-forge pillow 10.4.0 py312h8609ca0_1 conda-forge pip 24.2 pyh8b19718_1 conda-forge pixman 0.43.4 hebf3989_0 conda-forge psutil 6.0.0 py312h024a12e_1 conda-forge pthread-stubs 0.4 hd74edd7_1002 conda-forge pyarrow-core 17.0.0 py312he20ac61_1_cpu conda-forge pycparser 2.22 pyhd8ed1ab_0 conda-forge pygments 2.18.0 pyhd8ed1ab_0 conda-forge pymc 5.17.0 hd8ed1ab_0 conda-forge pymc-base 5.17.0 pyhd8ed1ab_0 conda-forge pyparsing 3.1.4 pyhd8ed1ab_0 conda-forge pysocks 1.7.1 pyha2e5f31_6 conda-forge pytensor 2.25.5 py312h3f593ad_0 conda-forge pytensor-base 2.25.5 py312h02baea5_0 conda-forge python 3.12.7 h739c21a_0_cpython conda-forge python-dateutil 2.9.0 pyhd8ed1ab_0 conda-forge python-graphviz 0.20.3 pyhe28f650_1 conda-forge python-tzdata 2024.2 pyhd8ed1ab_0 conda-forge python_abi 3.12 5_cp312 conda-forge pytorch 2.4.1 cpu_generic_py312h40771f0_0 conda-forge pytz 2024.1 pyhd8ed1ab_0 conda-forge pyyaml 6.0.2 py312h024a12e_1 conda-forge qhull 2020.2 h420ef59_5 conda-forge re2 2023.09.01 h4cba328_2 conda-forge readline 8.2 h92ec313_1 conda-forge requests 2.32.3 pyhd8ed1ab_0 conda-forge rich 13.9.2 pyhd8ed1ab_0 conda-forge safetensors 0.4.5 py312he431725_0 conda-forge scikit-learn 1.5.2 py312h387f99c_1 conda-forge scipy 1.14.1 py312heb3a901_0 conda-forge setuptools 75.1.0 pyhd8ed1ab_0 conda-forge sigtool 0.1.3 h44b9a77_0 conda-forge six 1.16.0 pyh6c4a22f_0 conda-forge sleef 3.7 h7783ee8_0 conda-forge snappy 1.2.1 hd02b534_0 conda-forge sympy 1.13.3 pyh2585a3b_104 conda-forge tabulate 0.9.0 pyhd8ed1ab_1 conda-forge tapi 1300.6.5 h03f4b80_0 conda-forge threadpoolctl 3.5.0 pyhc1e730c_0 conda-forge tk 8.6.13 h5083fa2_1 conda-forge toolz 1.0.0 pyhd8ed1ab_0 conda-forge tornado 6.4.1 py312h024a12e_1 conda-forge tqdm 4.66.5 pyhd8ed1ab_0 conda-forge typing-extensions 4.12.2 hd8ed1ab_0 conda-forge typing_extensions 4.12.2 pyha770c72_0 conda-forge tzdata 2024b hc8b5060_0 conda-forge urllib3 2.2.3 pyhd8ed1ab_0 conda-forge wheel 0.44.0 pyhd8ed1ab_0 conda-forge xarray 2024.9.0 pyhd8ed1ab_1 conda-forge xarray-einstats 0.8.0 pyhd8ed1ab_0 conda-forge xorg-libxau 1.0.11 hd74edd7_1 conda-forge xorg-libxdmcp 1.1.5 hd74edd7_0 conda-forge xz 5.2.6 h57fd34a_0 conda-forge yaml 0.2.5 h3422bc3_2 conda-forge zipp 3.20.2 pyhd8ed1ab_0 conda-forge zlib 1.3.1 h8359307_2 conda-forge zstandard 0.23.0 py312h15fbf35_1 conda-forge zstd 1.5.6 hb46c0d2_0 conda-forge
Copy link

welcome bot commented Oct 11, 2024

Welcome Banner]
🎉 Welcome to PyMC! 🎉 We're really excited to have your input into the project! 💖

If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

@fonnesbeck fonnesbeck added the VI Variational Inference label Oct 11, 2024
@fonnesbeck
Copy link
Member

Thanks for this. Yeah, the start values should really be associated with fit rather than the ADVI object.

@jessegrabowski
Copy link
Member

cc @ferrine

@ferrine
Copy link
Member

ferrine commented Oct 12, 2024

It's also problematic because some approximations might not have a very clear correspondence to variables: e.g. Normalizing Flows or other low rank approximations, the initialization parameters of ADVI just happen to coincide with model parameters. For every approximation I think there is a special way to initialize parameters. If we consider to rethink the API, not error messages, this should be taken into account

@trendelkampschroer
Copy link
Author

@ferrine thanks for chiming in, would it make sense to ask the user to directly specify the parameters of the variational model (mu, rho) and provide some means to convert a dictionary with initial guess in terms of the (free) model parameters to the variational parameters. This would obviously be a large breaking change to the API and the types that are used internally.

I'd like to move forward with a less invasive solution that at least ensures that start and start_sigma are consistent in terms of the keys that are used. Also if possible I'd want an API that enables me to get the final values for the variational parameters from a fitted approximation and a clear recipe of turning them back into an initial guess for a new fit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug VI Variational Inference
Projects
None yet
Development

No branches or pull requests

4 participants