diff --git a/aepsych/factory/factory.py b/aepsych/factory/factory.py index c8c581ed7..39521f001 100644 --- a/aepsych/factory/factory.py +++ b/aepsych/factory/factory.py @@ -13,9 +13,10 @@ from aepsych.config import Config from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad -from aepsych.utils import get_dim from scipy.stats import norm +from ..kernels.pairwisekernel import PairwiseKernel + """AEPsych factory functions. These functions generate a gpytorch Mean and Kernel objects from aepsych.config.Config configurations, including setting lengthscale @@ -36,7 +37,9 @@ def default_mean_covar_factory( - config: Optional[Config] = None, dim: Optional[int] = None + config: Optional[Config] = None, + dim: Optional[int] = None, + stimuli_per_trial: int = 1, ) -> Tuple[gpytorch.means.ConstantMean, gpytorch.kernels.ScaleKernel]: """Default factory for generic GP models @@ -55,6 +58,8 @@ def default_mean_covar_factory( dim is not None ), "Either config or dim must be provided!" + assert stimuli_per_trial in (1, 2), "stimuli_per_trial must be 1 or 2!" + fixed_mean = False lengthscale_prior = "gamma" outputscale_prior = "box" @@ -136,6 +141,9 @@ def default_mean_covar_factory( outputscale_prior=os_prior, ) + if stimuli_per_trial == 2: + covar = PairwiseKernel(covar) + return mean, covar diff --git a/aepsych/generators/__init__.py b/aepsych/generators/__init__.py index cb9bd8cfb..30fcc6f37 100644 --- a/aepsych/generators/__init__.py +++ b/aepsych/generators/__init__.py @@ -13,8 +13,6 @@ from .monotonic_rejection_generator import MonotonicRejectionGenerator from .monotonic_thompson_sampler_generator import MonotonicThompsonSamplerGenerator from .optimize_acqf_generator import OptimizeAcqfGenerator -from .pairwise_optimize_acqf_generator import PairwiseOptimizeAcqfGenerator -from .pairwise_sobol_generator import PairwiseSobolGenerator from .random_generator import RandomGenerator from .semi_p import IntensityAwareSemiPGenerator from .sobol_generator import SobolGenerator @@ -28,8 +26,6 @@ "SobolGenerator", "EpsilonGreedyGenerator", "ManualGenerator", - "PairwiseOptimizeAcqfGenerator", - "PairwiseSobolGenerator", "IntensityAwareSemiPGenerator", "AcqfThompsonSamplerGenerator" ] diff --git a/aepsych/generators/base.py b/aepsych/generators/base.py index e4e33e7db..0cda2b4a2 100644 --- a/aepsych/generators/base.py +++ b/aepsych/generators/base.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import abc from inspect import signature -from typing import Any, Dict, Generic, Protocol, runtime_checkable, TypeVar, Optional +from typing import Any, Dict, Generic, Optional, Protocol, runtime_checkable, TypeVar import re import torch @@ -13,10 +13,10 @@ from aepsych.models.base import AEPsychMixin from botorch.acquisition import ( AcquisitionFunction, - NoisyExpectedImprovement, - qNoisyExpectedImprovement, LogNoisyExpectedImprovement, + NoisyExpectedImprovement, qLogNoisyExpectedImprovement, + qNoisyExpectedImprovement, ) @@ -40,7 +40,6 @@ class AEPsychGenerator(abc.ABC, Generic[AEPsychModelType]): qLogNoisyExpectedImprovement, LogNoisyExpectedImprovement, ] - stimuli_per_trial = 1 max_asks: Optional[int] = None def __init__( diff --git a/aepsych/generators/optimize_acqf_generator.py b/aepsych/generators/optimize_acqf_generator.py index 77ab1a42d..301fe0e1f 100644 --- a/aepsych/generators/optimize_acqf_generator.py +++ b/aepsych/generators/optimize_acqf_generator.py @@ -14,15 +14,15 @@ from aepsych.generators.base import AEPsychGenerator from aepsych.models.base import ModelProtocol from aepsych.utils_logging import getLogger -from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption -from botorch.optim import optimize_acqf from botorch.acquisition import ( AcquisitionFunction, - NoisyExpectedImprovement, - qNoisyExpectedImprovement, LogNoisyExpectedImprovement, + NoisyExpectedImprovement, qLogNoisyExpectedImprovement, + qNoisyExpectedImprovement, ) +from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption +from botorch.optim import optimize_acqf logger = getLogger() @@ -44,7 +44,6 @@ def __init__( restarts: int = 10, samps: int = 1000, max_gen_time: Optional[float] = None, - stimuli_per_trial: int = 1, ) -> None: """Initialize OptimizeAcqfGenerator. Args: @@ -63,7 +62,6 @@ def __init__( self.restarts = restarts self.samps = samps self.max_gen_time = max_gen_time - self.stimuli_per_trial = stimuli_per_trial def _instantiate_acquisition_fn(self, model: ModelProtocol): if self.acqf == AnalyticExpectedUtilityOfBestOption: @@ -83,17 +81,7 @@ def gen(self, num_points: int, model: ModelProtocol, **gen_options) -> torch.Ten np.ndarray: Next set of point(s) to evaluate, [num_points x dim]. """ - if self.stimuli_per_trial == 2: - qbatch_points = self._gen( - num_points=num_points * 2, model=model, **gen_options - ) - - # output of super() is (q, dim) but the contract is (num_points, dim, 2) - # so we need to split q into q and pairs and then move the pair dim to the end - return qbatch_points.reshape(num_points, 2, -1).swapaxes(-1, -2) - - else: - return self._gen(num_points=num_points, model=model, **gen_options) + return self._gen(num_points=num_points, model=model, **gen_options) def _gen( self, num_points: int, model: ModelProtocol, **gen_options @@ -124,7 +112,6 @@ def from_config(cls, config: Config): classname = cls.__name__ acqf = config.getobj(classname, "acqf", fallback=None) extra_acqf_args = cls._get_acqf_options(acqf, config) - stimuli_per_trial = config.getint(classname, "stimuli_per_trial") restarts = config.getint(classname, "restarts", fallback=10) samps = config.getint(classname, "samps", fallback=1000) max_gen_time = config.getfloat(classname, "max_gen_time", fallback=None) @@ -135,5 +122,4 @@ def from_config(cls, config: Config): restarts=restarts, samps=samps, max_gen_time=max_gen_time, - stimuli_per_trial=stimuli_per_trial, ) diff --git a/aepsych/generators/pairwise_optimize_acqf_generator.py b/aepsych/generators/pairwise_optimize_acqf_generator.py deleted file mode 100644 index 4bbe46156..000000000 --- a/aepsych/generators/pairwise_optimize_acqf_generator.py +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import warnings - -from aepsych.config import Config -from aepsych.generators import OptimizeAcqfGenerator - - -class PairwiseOptimizeAcqfGenerator(OptimizeAcqfGenerator): - """Deprecated. Use OptimizeAcqfGenerator instead.""" - - stimuli_per_trial = 2 - - @classmethod - def from_config(cls, config: Config): - warnings.warn( - "PairwiseOptimizeAcqfGenerator is deprecated. Use OptimizeAcqfGenerator instead.", - DeprecationWarning, - ) - return super().from_config(config) diff --git a/aepsych/generators/pairwise_sobol_generator.py b/aepsych/generators/pairwise_sobol_generator.py deleted file mode 100644 index 6e5f6d205..000000000 --- a/aepsych/generators/pairwise_sobol_generator.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import warnings - -from aepsych.config import Config - -from .sobol_generator import SobolGenerator - - -class PairwiseSobolGenerator(SobolGenerator): - """Deprecated. Use SobolGenerator instead.""" - - stimuli_per_trial = 2 - - @classmethod - def from_config(cls, config: Config): - warnings.warn( - "PairwiseSobolGenerator is deprecated. Use SobolGenerator instead.", - DeprecationWarning, - ) - return super().from_config(config) diff --git a/aepsych/generators/sobol_generator.py b/aepsych/generators/sobol_generator.py index ce54150f3..53e48e388 100644 --- a/aepsych/generators/sobol_generator.py +++ b/aepsych/generators/sobol_generator.py @@ -28,7 +28,6 @@ def __init__( ub: Union[np.ndarray, torch.Tensor], dim: Optional[int] = None, seed: Optional[int] = None, - stimuli_per_trial: int = 1, ): """Iniatialize SobolGenerator. Args: @@ -38,13 +37,8 @@ def __init__( seed (int, optional): Random seed. """ self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim) - self.lb = self.lb.repeat(stimuli_per_trial) - self.ub = self.ub.repeat(stimuli_per_trial) - self.stimuli_per_trial = stimuli_per_trial self.seed = seed - self.engine = SobolEngine( - dimension=self.dim * stimuli_per_trial, scramble=True, seed=self.seed - ) + self.engine = SobolEngine(dimension=self.dim, scramble=True, seed=self.seed) def gen( self, @@ -59,16 +53,7 @@ def gen( """ grid = self.engine.draw(num_points) grid = self.lb + (self.ub - self.lb) * grid - if self.stimuli_per_trial == 1: - return grid - - return torch.tensor( - np.moveaxis( - grid.reshape(num_points, self.stimuli_per_trial, -1).numpy(), - -1, - -self.stimuli_per_trial, - ) - ) + return grid @classmethod def from_config(cls, config: Config): @@ -78,8 +63,10 @@ def from_config(cls, config: Config): ub = config.gettensor(classname, "ub") dim = config.getint(classname, "dim", fallback=None) seed = config.getint(classname, "seed", fallback=None) - stimuli_per_trial = config.getint(classname, "stimuli_per_trial") return cls( - lb=lb, ub=ub, dim=dim, seed=seed, stimuli_per_trial=stimuli_per_trial + lb=lb, + ub=ub, + dim=dim, + seed=seed, ) diff --git a/aepsych/kernels/__init__.py b/aepsych/kernels/__init__.py index 8b2df349c..30439fbe1 100644 --- a/aepsych/kernels/__init__.py +++ b/aepsych/kernels/__init__.py @@ -4,3 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. + +from .pairwisekernel import PairwiseKernel +from .rbf_partial_grad import RBFKernelPartialObsGrad + +__all__ = ["PairwiseKernel", "RBFKernelPartialObsGrad"] diff --git a/aepsych/kernels/pairwisekernel.py b/aepsych/kernels/pairwisekernel.py new file mode 100644 index 000000000..18f4976cc --- /dev/null +++ b/aepsych/kernels/pairwisekernel.py @@ -0,0 +1,85 @@ +import torch +from gpytorch.kernels import Kernel +from linear_operator import to_linear_operator + + +class PairwiseKernel(Kernel): + """ + Wrapper to convert a kernel K on R^k to a kernel K' on R^{2k}, modeling + functions of the form g(a, b) = f(a) - f(b), where f ~ GP(mu, K). + + Since g is a linear combination of Gaussians, it follows that g ~ GP(0, K') + where K'((a,b), (c,d)) = K(a,c) - K(a, d) - K(b, c) + K(b, d). + + """ + + def __init__(self, latent_kernel, is_partial_obs=False, **kwargs): + super(PairwiseKernel, self).__init__(**kwargs) + + self.latent_kernel = latent_kernel + self.is_partial_obs = is_partial_obs + + def forward(self, x1, x2, diag=False, **params): + r""" + TODO: make last_batch_dim work properly + + d must be 2*k for integer k, k is the dimension of the latent space + Args: + :attr:`x1` (Tensor `n x d` or `b x n x d`): + First set of data + :attr:`x2` (Tensor `m x d` or `b x m x d`): + Second set of data + :attr:`diag` (bool): + Should the Kernel compute the whole kernel, or just the diag? + + Returns: + :class:`Tensor` or :class:`gpytorch.lazy.LazyTensor`. + The exact size depends on the kernel's evaluation mode: + + * `full_covar`: `n x m` or `b x n x m` + * `diag`: `n` or `b x n` + """ + if self.is_partial_obs: + d = x1.shape[-1] - 1 + assert d == x2.shape[-1] - 1, "tensors not the same dimension" + assert d % 2 == 0, "dimension must be even" + + k = int(d / 2) + + # special handling for kernels that (also) do funky + # things with the input dimension + deriv_idx_1 = x1[..., -1][:, None] + deriv_idx_2 = x2[..., -1][:, None] + + a = torch.cat((x1[..., :k], deriv_idx_1), dim=1) + b = torch.cat((x1[..., k:-1], deriv_idx_1), dim=1) + c = torch.cat((x2[..., :k], deriv_idx_2), dim=1) + d = torch.cat((x2[..., k:-1], deriv_idx_2), dim=1) + + else: + d = x1.shape[-1] + + assert d == x2.shape[-1], "tensors not the same dimension" + assert d % 2 == 0, "dimension must be even" + + k = int(d / 2) + + a = x1[..., :k] + b = x1[..., k:] + c = x2[..., :k] + d = x2[..., k:] + + if not diag: + return ( + to_linear_operator(self.latent_kernel(a, c, diag=diag, **params)) + + to_linear_operator(self.latent_kernel(b, d, diag=diag, **params)) + - to_linear_operator(self.latent_kernel(b, c, diag=diag, **params)) + - to_linear_operator(self.latent_kernel(a, d, diag=diag, **params)) + ) + else: + return ( + self.latent_kernel(a, c, diag=diag, **params) + + self.latent_kernel(b, d, diag=diag, **params) + - self.latent_kernel(b, c, diag=diag, **params) + - self.latent_kernel(a, d, diag=diag, **params) + ) diff --git a/aepsych/models/__init__.py b/aepsych/models/__init__.py index 1380c8a8d..00b66219b 100644 --- a/aepsych/models/__init__.py +++ b/aepsych/models/__init__.py @@ -14,7 +14,6 @@ from .monotonic_rejection_gp import MonotonicRejectionGP from .multitask_regression import IndependentMultitaskGPRModel, MultitaskGPRModel from .ordinal_gp import OrdinalGPModel -from .pairwise_probit import PairwiseProbitModel from .semi_p import ( HadamardSemiPModel, semi_p_posterior_transform, @@ -26,7 +25,6 @@ "GPClassificationModel", "MonotonicRejectionGP", "GPRegressionModel", - "PairwiseProbitModel", "OrdinalGPModel", "MonotonicProjectionGP", "MultitaskGPRModel", diff --git a/aepsych/models/gp_classification.py b/aepsych/models/gp_classification.py index 7f8785536..b2a25b89d 100644 --- a/aepsych/models/gp_classification.py +++ b/aepsych/models/gp_classification.py @@ -42,7 +42,6 @@ class GPClassificationModel(AEPsychMixin, ApproximateGP): _batch_size = 1 _num_outputs = 1 - stimuli_per_trial = 1 outcome_type = "binary" def __init__( @@ -56,6 +55,7 @@ def __init__( inducing_size: int = 100, max_fit_time: Optional[float] = None, inducing_point_method: str = "auto", + stimuli_per_trial: int = 1, ): """Initialize the GP Classification model @@ -77,11 +77,20 @@ def __init__( If "pivoted_chol", selects points based on the pivoted Cholesky heuristic. If "kmeans++", selects points by performing kmeans++ clustering on the training data. If "auto", tries to determine the best method automatically. + stimuli_per_trial (int): Number of stimuli that will be presented each trial. Currently only 1 or 2 are + supported. If 2, covar_module should use PairwiseKernel, and the dimensionality will be interpreted as + 2*d, where the first d dimensions correspond to the dimensions of the first stimulus, and the second d + dimensions correspond to the dimensions of the second stimulus. """ self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim) self.max_fit_time = max_fit_time self.inducing_size = inducing_size + assert stimuli_per_trial in (1, 2), "stimuli_per_trial must be 1 or 2!" + if stimuli_per_trial == 2: + assert self.dim % 2 == 0, "Dimensionality does not match stimuli_per_trial!" + self.stimuli_per_trial = stimuli_per_trial + if likelihood is None: likelihood = BernoulliLikelihood() @@ -104,7 +113,9 @@ def __init__( super().__init__(variational_strategy) if mean_module is None or covar_module is None: - default_mean, default_covar = default_mean_covar_factory(dim=self.dim) + default_mean, default_covar = default_mean_covar_factory( + dim=self.dim, stimuli_per_trial=self.stimuli_per_trial + ) self.mean_module = mean_module or default_mean self.covar_module = covar_module or default_covar @@ -155,6 +166,8 @@ def from_config(cls, config: Config) -> GPClassificationModel: else: likelihood = None # fall back to __init__ default + stimuli_per_trial = config.getint(classname, "stimuli_per_trial", fallback=1) + return cls( lb=lb, ub=ub, @@ -165,6 +178,7 @@ def from_config(cls, config: Config) -> GPClassificationModel: max_fit_time=max_fit_time, inducing_point_method=inducing_point_method, likelihood=likelihood, + stimuli_per_trial=stimuli_per_trial, ) def _reset_hyperparameters(self): diff --git a/aepsych/models/pairwise_probit.py b/aepsych/models/pairwise_probit.py deleted file mode 100644 index f0c6a05f5..000000000 --- a/aepsych/models/pairwise_probit.py +++ /dev/null @@ -1,199 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -import time -from typing import Any, Dict, Optional, Union - -import gpytorch -import numpy as np -import torch -from aepsych.config import Config -from aepsych.factory import default_mean_covar_factory -from aepsych.models.base import AEPsychMixin -from aepsych.utils import _process_bounds, promote_0d -from aepsych.utils_logging import getLogger -from botorch.fit import fit_gpytorch_mll -from botorch.models import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood -from botorch.models.transforms.input import Normalize -from scipy.stats import norm - -logger = getLogger() - - -class PairwiseProbitModel(PairwiseGP, AEPsychMixin): - _num_outputs = 1 - stimuli_per_trial = 2 - outcome_type = "binary" - - def _pairs_to_comparisons(self, x, y): - """ - Takes x, y structured as pairs and judgments and - returns pairs and comparisons as PairwiseGP requires - """ - # This needs to take a unique over the feature dim by flattening - # over pairs but not instances/batches. This is actually tensor - # matricization over the feature dimension but awkward in numpy - unique_coords = torch.unique( - torch.transpose(x, 1, 0).reshape(self.dim, -1), dim=1 - ) - - def _get_index_of_equal_row(arr, x, dim=0): - return torch.all(torch.eq(arr, x[:, None]), dim=dim).nonzero().item() - - comparisons = [] - for pair, judgement in zip(x, y): - comparison = ( - _get_index_of_equal_row(unique_coords, pair[..., 0]), - _get_index_of_equal_row(unique_coords, pair[..., 1]), - ) - if judgement == 0: - comparisons.append(comparison) - else: - comparisons.append(comparison[::-1]) - return unique_coords.T, torch.LongTensor(comparisons) - - def __init__( - self, - lb: Union[np.ndarray, torch.Tensor], - ub: Union[np.ndarray, torch.Tensor], - dim: Optional[int] = None, - covar_module: Optional[gpytorch.kernels.Kernel] = None, - max_fit_time: Optional[float] = None, - ): - self.lb, self.ub, dim = _process_bounds(lb, ub, dim) - - self.max_fit_time = max_fit_time - - bounds = torch.stack((self.lb, self.ub)) - input_transform = Normalize(d=dim, bounds=bounds) - if covar_module is None: - config = Config( - config_dict={ - "default_mean_covar_factory": { - "lb": str(self.lb.tolist()), - "ub": str(self.ub.tolist()), - } - } - ) # type: ignore - _, covar_module = default_mean_covar_factory(config) - - super().__init__( - datapoints=None, - comparisons=None, - covar_module=covar_module, - jitter=1e-3, - input_transform=input_transform, - ) - - self.dim = dim # The Pairwise constructor sets self.dim = None. - - def fit( - self, - train_x: torch.Tensor, - train_y: torch.Tensor, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ): - self.train() - mll = PairwiseLaplaceMarginalLogLikelihood(self.likelihood, self) - datapoints, comparisons = self._pairs_to_comparisons(train_x, train_y) - self.set_train_data(datapoints, comparisons) - - optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs.copy() - max_fit_time = kwargs.pop("max_fit_time", self.max_fit_time) - if max_fit_time is not None: - # figure out how long evaluating a single samp - starttime = time.time() - _ = mll(self(datapoints), comparisons) - single_eval_time = time.time() - starttime - n_eval = int(max_fit_time / single_eval_time) - optimizer_kwargs["maxfun"] = n_eval - logger.info(f"fit maxfun is {n_eval}") - - logger.info("Starting fit...") - starttime = time.time() - fit_gpytorch_mll(mll, **kwargs, **optimizer_kwargs) - logger.info(f"Fit done, time={time.time()-starttime}") - - def update( - self, train_x: torch.Tensor, train_y: torch.Tensor, warmstart: bool = True - ): - """Perform a warm-start update of the model from previous fit.""" - self.fit(train_x, train_y) - - def predict( - self, x, probability_space=False, num_samples=1000, rereference="x_min" - ): - if rereference is not None: - samps = self.sample(x, num_samples, rereference) - fmean, fvar = samps.mean(0).squeeze(), samps.var(0).squeeze() - else: - post = self.posterior(x) - fmean, fvar = post.mean.squeeze(), post.variance.squeeze() - - if probability_space: - return ( - promote_0d(norm.cdf(fmean)), - promote_0d(norm.cdf(fvar)), - ) - else: - return fmean, fvar - - def predict_probability( - self, x, probability_space=False, num_samples=1000, rereference="x_min" - ): - return self.predict( - x, probability_space=True, num_samples=num_samples, rereference=rereference - ) - - def sample(self, x, num_samples, rereference="x_min"): - if len(x.shape) < 2: - x = x.reshape(-1, 1) - if rereference is None: - return self.posterior(x).rsample(torch.Size([num_samples])) - - if rereference == "x_min": - x_ref = self.lb - elif rereference == "x_max": - x_ref = self.ub - elif rereference == "f_max": - x_ref = torch.Tensor(self.get_max()[1]) - elif rereference == "f_min": - x_ref = torch.Tensor(self.get_min()[1]) - else: - raise RuntimeError( - f"Unknown rereference type {rereference}! Options: x_min, x_max, f_min, f_max." - ) - - x_stack = torch.vstack([x, x_ref]) - samps = self.posterior(x_stack).rsample(torch.Size([num_samples])) - samps, samps_ref = torch.split(samps, [samps.shape[1] - 1, 1], dim=1) - if rereference == "x_min" or rereference == "f_min": - return samps - samps_ref - else: - return -samps + samps_ref - - @classmethod - def from_config(cls, config): - - classname = cls.__name__ - - mean_covar_factory = config.getobj( - "PairwiseProbitModel", - "mean_covar_factory", - fallback=default_mean_covar_factory, - ) - - # no way of passing mean into PairwiseGP right now - _, covar = mean_covar_factory(config) - - lb = config.gettensor(classname, "lb") - ub = config.gettensor(classname, "ub") - dim = lb.shape[0] - - max_fit_time = config.getfloat(classname, "max_fit_time", fallback=None) - - return cls(lb=lb, ub=ub, dim=dim, covar_module=covar, max_fit_time=max_fit_time) diff --git a/aepsych/strategy.py b/aepsych/strategy.py index 704dd09fd..d1f6ef76a 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -19,10 +19,7 @@ from aepsych.generators.base import AEPsychGenerator from aepsych.generators.sobol_generator import SobolGenerator from aepsych.models.base import ModelProtocol -from aepsych.utils import ( - _process_bounds, - make_scaled_sobol, -) +from aepsych.utils import _process_bounds, make_scaled_sobol from aepsych.utils_logging import getLogger from botorch.exceptions.errors import ModelFittingError @@ -147,13 +144,7 @@ def __init__( self.min_total_tells = min_total_tells self.stimuli_per_trial = stimuli_per_trial self.outcome_types = outcome_types - - if self.stimuli_per_trial == 1: - self.event_shape: Tuple[int, ...] = (self.dim,) - - if self.stimuli_per_trial == 2: - self.event_shape = (self.dim, self.stimuli_per_trial) - + self.event_shape: Tuple[int, ...] = (self.dim,) self.model = model self.refit_every = refit_every self._model_is_fresh = False diff --git a/configs/pairwise_opt_example.ini b/configs/pairwise_opt_example.ini index 8fb15aa74..b305e32c6 100644 --- a/configs/pairwise_opt_example.ini +++ b/configs/pairwise_opt_example.ini @@ -8,9 +8,9 @@ ## The common section includes global server parameters and parameters ## reused in multiple other classes [common] -parnames = [par1, par2] # names of the parameters -lb = [0, 0] # lower bounds of the parameters, in the same order as above -ub = [1, 1] # upper bounds of parameter, in the same order as above +parnames = [par1_1, par2_1, par3_1, par1_2, par2_2, par3_2] # names of the parameters +lb = [0, 0, 0, 0, 0, 0] # lower bounds of the parameters, in the same order as above +ub = [1, 1, 1, 1, 1, 1] # upper bounds of parameter, in the same order as above stimuli_per_trial = 2 # the number of stimuli shown in each trial; 1 for single, or 2 for pairwise experiments outcome_types = [binary] # the type of response given by the participant; can only be [binary] for pairwise for now strategy_names = [init_strat, opt_strat] # The strategies that will be used, corresponding to the named sections below @@ -42,21 +42,13 @@ acqf = qLogNoisyExpectedImprovement # The model, which must conform to the stimuli_per_trial and outcome_types settings above. # Use GPClassificationModel or GPRegressionModel for single or PairwiseProbitModel for pairwise. -model = PairwiseProbitModel +model = GPClassificationModel ## Below this section are configurations of all the classes defined in the section above, ## matching the API in the code. -## Acquisition function settings; we recommend not changing this. -[PairwiseMCPosteriorVariance] -# The transformation of the latent function before threshold estimation. ProbitObjective -# lets us search where the probability is uncertain (vs where there is high variance -# in the function itself, which might still lead to low variance on the probability -# after the probit transform). -objective = ProbitObjective - ## This configures the PairwiseGP model -[PairwiseProbitModel] +[GPClassificationModel] # Number of inducing points for approximate inference. 100 is fine for 2d and overkill for 1d; # for larger dimensions, scale this up. inducing_size = 100 diff --git a/examples/minimal_pairwise_example.ipynb b/examples/minimal_pairwise_example.ipynb index dcbb30794..e36e3d6e1 100644 --- a/examples/minimal_pairwise_example.ipynb +++ b/examples/minimal_pairwise_example.ipynb @@ -2,7 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, + "metadata": {}, + "outputs": [], "source": [ "import numpy as np\n", "import torch\n", @@ -10,13 +12,13 @@ "from aepsych_prerelease.server import AEPsychServer\n", "from scipy.special import expit, logit\n", "from scipy.stats import bernoulli" - ], - "outputs": [], - "metadata": {} + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, + "metadata": {}, + "outputs": [], "source": [ "# Define the 75% lse to be where par1_1 - par1_2 + par2_1 - par2_2 = 1\n", "def get_response_probability(params):\n", @@ -24,13 +26,13 @@ " b = logit(0.75) - m\n", " p = expit(m * params.sum(1) + b)\n", " return p" - ], - "outputs": [], - "metadata": {} + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, + "metadata": {}, + "outputs": [], "source": [ "# Simulate participant responses; returns 1 if the participant detected the stimulus or 0 if they did not.\n", "def simulate_response(trial_params):\n", @@ -42,13 +44,22 @@ " response = bernoulli.rvs(p)\n", "\n", " return response" - ], - "outputs": [], - "metadata": {} + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-07-09 14:40:03,632 [INFO ] Found DB at pairwise_example.db, appending!\n", + "2024-07-09 14:40:03,689 [INFO ] Received msg [setup]\n" + ] + } + ], "source": [ "# Fix random seeds\n", "np.random.seed(0)\n", @@ -58,13 +69,208 @@ "server = AEPsychServer(database_path=\"pairwise_example.db\")\n", "client = AEPsychClient(server=server)\n", "client.configure(config_path=\"../configs/pairwise_opt_example.ini\")" - ], - "outputs": [], - "metadata": {} + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-07-09 14:40:08,526 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:21,648 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:47,334 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:47,359 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:47,485 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:47,499 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:47,627 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:47,643 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:47,788 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:47,805 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:47,932 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:47,948 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:48,092 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:48,115 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:48,257 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:48,272 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:48,385 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:48,399 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:48,545 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:48,563 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:48,695 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:48,697 [INFO ] Starting fitting (no warm start)...\n", + "2024-07-09 14:40:48,724 [INFO ] Starting fit...\n", + "2024-07-09 14:40:50,053 [INFO ] Fit done, time=1.3274188041687012\n", + "2024-07-09 14:40:50,056 [INFO ] Fitting done, took 1.359386920928955\n", + "2024-07-09 14:40:50,079 [INFO ] Starting gen...\n", + "2024-07-09 14:40:50,776 [INFO ] Gen done, time=0.6950039863586426\n", + "2024-07-09 14:40:50,803 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:50,966 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:50,968 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:40:50,983 [INFO ] Starting fit...\n", + "2024-07-09 14:40:51,302 [INFO ] Fit done, time=0.3174622058868408\n", + "2024-07-09 14:40:51,307 [INFO ] Fitting done, took 0.33973193168640137\n", + "2024-07-09 14:40:51,329 [INFO ] Starting gen...\n", + "2024-07-09 14:40:52,033 [INFO ] Gen done, time=0.7015540599822998\n", + "2024-07-09 14:40:52,052 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:52,204 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:52,206 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:40:52,223 [INFO ] Starting fit...\n", + "2024-07-09 14:40:52,395 [INFO ] Fit done, time=0.16601300239562988\n", + "2024-07-09 14:40:52,396 [INFO ] Fitting done, took 0.19062399864196777\n", + "2024-07-09 14:40:52,406 [INFO ] Starting gen...\n", + "2024-07-09 14:40:53,034 [INFO ] Gen done, time=0.6267471313476562\n", + "2024-07-09 14:40:53,062 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:53,207 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:53,209 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:40:53,230 [INFO ] Starting fit...\n", + "2024-07-09 14:40:53,457 [INFO ] Fit done, time=0.2250969409942627\n", + "2024-07-09 14:40:53,459 [INFO ] Fitting done, took 0.24954485893249512\n", + "2024-07-09 14:40:53,472 [INFO ] Starting gen...\n", + "2024-07-09 14:40:54,775 [INFO ] Gen done, time=1.3010618686676025\n", + "2024-07-09 14:40:54,794 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:54,964 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:54,966 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:40:54,984 [INFO ] Starting fit...\n", + "2024-07-09 14:40:55,412 [INFO ] Fit done, time=0.4265940189361572\n", + "2024-07-09 14:40:55,415 [INFO ] Fitting done, took 0.449174165725708\n", + "2024-07-09 14:40:55,431 [INFO ] Starting gen...\n", + "2024-07-09 14:40:56,299 [INFO ] Gen done, time=0.8664309978485107\n", + "2024-07-09 14:40:56,318 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:56,502 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:56,503 [INFO ] Starting fitting (no warm start)...\n", + "2024-07-09 14:40:56,519 [INFO ] Starting fit...\n", + "/Users/craigsanders/opt/anaconda3/envs/aepsych/lib/python3.9/site-packages/botorch/optim/fit.py:102: OptimizationWarning: `scipy_minimize` terminated with status 3, displaying original message from `scipy.optimize.minimize`: ABNORMAL_TERMINATION_IN_LNSRCH\n", + " warn(\n", + "2024-07-09 14:40:58,481 [INFO ] Fit done, time=1.9604310989379883\n", + "2024-07-09 14:40:58,483 [INFO ] Fitting done, took 1.9796431064605713\n", + "2024-07-09 14:40:58,502 [INFO ] Starting gen...\n", + "2024-07-09 14:40:59,020 [INFO ] Gen done, time=0.5169198513031006\n", + "2024-07-09 14:40:59,038 [INFO ] Received msg [tell]\n", + "2024-07-09 14:40:59,174 [INFO ] Received msg [ask]\n", + "2024-07-09 14:40:59,175 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:40:59,191 [INFO ] Starting fit...\n", + "2024-07-09 14:40:59,547 [INFO ] Fit done, time=0.3543732166290283\n", + "2024-07-09 14:40:59,549 [INFO ] Fitting done, took 0.37344884872436523\n", + "2024-07-09 14:40:59,559 [INFO ] Starting gen...\n", + "2024-07-09 14:41:01,032 [INFO ] Gen done, time=1.4708058834075928\n", + "2024-07-09 14:41:01,067 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:01,327 [INFO ] Received msg [ask]\n", + "2024-07-09 14:41:01,330 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:41:01,353 [INFO ] Starting fit...\n", + "2024-07-09 14:41:01,843 [INFO ] Fit done, time=0.48758888244628906\n", + "2024-07-09 14:41:01,844 [INFO ] Fitting done, took 0.5147519111633301\n", + "2024-07-09 14:41:01,857 [INFO ] Starting gen...\n", + "2024-07-09 14:41:02,621 [INFO ] Gen done, time=0.761469841003418\n", + "2024-07-09 14:41:02,657 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:02,983 [INFO ] Received msg [ask]\n", + "2024-07-09 14:41:02,985 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:41:03,020 [INFO ] Starting fit...\n", + "2024-07-09 14:41:03,337 [INFO ] Fit done, time=0.31476712226867676\n", + "2024-07-09 14:41:03,354 [INFO ] Fitting done, took 0.36829710006713867\n", + "2024-07-09 14:41:03,369 [INFO ] Starting gen...\n", + "2024-07-09 14:41:04,476 [INFO ] Gen done, time=1.104773998260498\n", + "2024-07-09 14:41:04,512 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:04,779 [INFO ] Received msg [ask]\n", + "2024-07-09 14:41:04,783 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:41:04,825 [INFO ] Starting fit...\n", + "2024-07-09 14:41:05,403 [INFO ] Fit done, time=0.5752570629119873\n", + "2024-07-09 14:41:05,405 [INFO ] Fitting done, took 0.6217248439788818\n", + "2024-07-09 14:41:05,420 [INFO ] Starting gen...\n", + "2024-07-09 14:41:06,344 [INFO ] Gen done, time=0.9227659702301025\n", + "2024-07-09 14:41:06,405 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:06,663 [INFO ] Received msg [ask]\n", + "2024-07-09 14:41:06,665 [INFO ] Starting fitting (no warm start)...\n", + "2024-07-09 14:41:06,682 [INFO ] Starting fit...\n", + "2024-07-09 14:41:07,239 [INFO ] Fit done, time=0.5554029941558838\n", + "2024-07-09 14:41:07,254 [INFO ] Fitting done, took 0.5888071060180664\n", + "2024-07-09 14:41:07,317 [INFO ] Starting gen...\n", + "2024-07-09 14:41:08,826 [INFO ] Gen done, time=1.5064280033111572\n", + "2024-07-09 14:41:08,865 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:09,120 [INFO ] Received msg [ask]\n", + "2024-07-09 14:41:09,122 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:41:09,146 [INFO ] Starting fit...\n", + "2024-07-09 14:41:09,611 [INFO ] Fit done, time=0.4568631649017334\n", + "2024-07-09 14:41:09,614 [INFO ] Fitting done, took 0.4918668270111084\n", + "2024-07-09 14:41:09,627 [INFO ] Starting gen...\n", + "2024-07-09 14:41:10,517 [INFO ] Gen done, time=0.8875226974487305\n", + "2024-07-09 14:41:10,554 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:10,911 [INFO ] Received msg [ask]\n", + "2024-07-09 14:41:10,913 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:41:10,932 [INFO ] Starting fit...\n", + "2024-07-09 14:41:11,531 [INFO ] Fit done, time=0.5971672534942627\n", + "2024-07-09 14:41:11,534 [INFO ] Fitting done, took 0.6207730770111084\n", + "2024-07-09 14:41:11,545 [INFO ] Starting gen...\n", + "2024-07-09 14:41:12,601 [INFO ] Gen done, time=1.0476117134094238\n", + "2024-07-09 14:41:12,685 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:13,159 [INFO ] Received msg [ask]\n", + "2024-07-09 14:41:13,160 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:41:13,181 [INFO ] Starting fit...\n", + "2024-07-09 14:41:13,295 [INFO ] Fit done, time=0.11230731010437012\n", + "2024-07-09 14:41:13,304 [INFO ] Fitting done, took 0.14335203170776367\n", + "2024-07-09 14:41:13,317 [INFO ] Starting gen...\n", + "2024-07-09 14:41:15,162 [INFO ] Gen done, time=1.843230962753296\n", + "2024-07-09 14:41:15,185 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:15,462 [INFO ] Received msg [ask]\n", + "2024-07-09 14:41:15,464 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:41:15,479 [INFO ] Starting fit...\n", + "2024-07-09 14:41:15,843 [INFO ] Fit done, time=0.3621230125427246\n", + "2024-07-09 14:41:15,854 [INFO ] Fitting done, took 0.3900420665740967\n", + "2024-07-09 14:41:15,866 [INFO ] Starting gen...\n", + "2024-07-09 14:41:17,817 [INFO ] Gen done, time=1.9494099617004395\n", + "2024-07-09 14:41:17,855 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:18,058 [INFO ] Received msg [ask]\n", + "2024-07-09 14:41:18,059 [INFO ] Starting fitting (no warm start)...\n", + "2024-07-09 14:41:18,077 [INFO ] Starting fit...\n", + "2024-07-09 14:41:18,230 [INFO ] Fit done, time=0.15170717239379883\n", + "2024-07-09 14:41:18,234 [INFO ] Fitting done, took 0.17423677444458008\n", + "2024-07-09 14:41:18,247 [INFO ] Starting gen...\n", + "2024-07-09 14:41:19,977 [INFO ] Gen done, time=1.7231180667877197\n", + "2024-07-09 14:41:20,012 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:20,212 [INFO ] Received msg [ask]\n", + "2024-07-09 14:41:20,214 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:41:20,232 [INFO ] Starting fit...\n", + "/Users/craigsanders/opt/anaconda3/envs/aepsych/lib/python3.9/site-packages/botorch/optim/fit.py:102: OptimizationWarning: `scipy_minimize` terminated with status 3, displaying original message from `scipy.optimize.minimize`: ABNORMAL_TERMINATION_IN_LNSRCH\n", + " warn(\n", + "2024-07-09 14:41:21,918 [INFO ] Fit done, time=1.6844849586486816\n", + "2024-07-09 14:41:21,919 [INFO ] Fitting done, took 1.7053167819976807\n", + "2024-07-09 14:41:21,929 [INFO ] Starting gen...\n", + "2024-07-09 14:41:23,582 [INFO ] Gen done, time=1.6515789031982422\n", + "2024-07-09 14:41:23,618 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:23,824 [INFO ] Received msg [ask]\n", + "2024-07-09 14:41:23,825 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:41:23,844 [INFO ] Starting fit...\n", + "2024-07-09 14:41:24,203 [INFO ] Fit done, time=0.3493356704711914\n", + "2024-07-09 14:41:24,205 [INFO ] Fitting done, took 0.37955403327941895\n", + "2024-07-09 14:41:24,216 [INFO ] Starting gen...\n", + "2024-07-09 14:41:25,891 [INFO ] Gen done, time=1.6728451251983643\n", + "2024-07-09 14:41:25,920 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:26,173 [INFO ] Received msg [ask]\n", + "2024-07-09 14:41:26,175 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:41:26,192 [INFO ] Starting fit...\n", + "2024-07-09 14:41:26,440 [INFO ] Fit done, time=0.23659110069274902\n", + "2024-07-09 14:41:26,455 [INFO ] Fitting done, took 0.27997803688049316\n", + "2024-07-09 14:41:26,476 [INFO ] Starting gen...\n", + "2024-07-09 14:41:28,868 [INFO ] Gen done, time=2.389058828353882\n", + "2024-07-09 14:41:28,907 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:29,108 [INFO ] Received msg [ask]\n", + "2024-07-09 14:41:29,109 [INFO ] Starting fitting (warm start)...\n", + "2024-07-09 14:41:29,125 [INFO ] Starting fit...\n", + "2024-07-09 14:41:29,343 [INFO ] Fit done, time=0.21627116203308105\n", + "2024-07-09 14:41:29,354 [INFO ] Fitting done, took 0.2449021339416504\n", + "2024-07-09 14:41:29,367 [INFO ] Starting gen...\n", + "2024-07-09 14:41:32,660 [INFO ] Gen done, time=3.290616035461426\n", + "2024-07-09 14:41:32,683 [INFO ] Received msg [tell]\n", + "2024-07-09 14:41:32,963 [INFO ] Recording strat because the experiment is complete.\n", + "2024-07-09 14:41:33,061 [INFO ] Received msg [exit]\n", + "2024-07-09 14:41:33,063 [INFO ] Got termination message!\n", + "2024-07-09 14:41:33,064 [INFO ] Dumping strats to DB due to Normal termination.\n" + ] + } + ], "source": [ "# Do the ask/tell loop\n", "finished = False\n", @@ -82,17 +288,29 @@ "\n", "# Finish the experiment\n", "client.finalize()" - ], - "outputs": [], - "metadata": {} + ] } ], "metadata": { - "orig_nbformat": 4, + "kernelspec": { + "display_name": "aepsych", + "language": "python", + "name": "python3" + }, "language_info": { - "name": "python" - } + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/examples/pairwise_example.db b/examples/pairwise_example.db new file mode 100644 index 000000000..cc935fff2 Binary files /dev/null and b/examples/pairwise_example.db differ diff --git a/examples/untracked/databases/default.db b/examples/untracked/databases/default.db new file mode 100644 index 000000000..61551e67a Binary files /dev/null and b/examples/untracked/databases/default.db differ diff --git a/examples/untracked/minimal_classification_example.py b/examples/untracked/minimal_classification_example.py new file mode 100644 index 000000000..06463b1b5 --- /dev/null +++ b/examples/untracked/minimal_classification_example.py @@ -0,0 +1,59 @@ +import time + +import numpy as np +import torch + +from scipy.stats import bernoulli + +from scipy.special import expit, logit +from aepsych.server import AEPsychServer +from aepsych_client import AEPsychClient +from aepsych.plotting import plot_strat + + +# Define the 75% to be where par1 + par2 = 1 +def get_response_probability(params): + m = 10 + b = logit(0.5) - m + p = expit(m * params.sum(1) + b) + return p + + +# Simulate participant responses; returns 1 if the participant detected the stimulus or 0 if they did not. +def simulate_response(trial_params): + params = np.array([[trial_params[par][0] for par in trial_params]]) + prob = get_response_probability(params) + response = bernoulli.rvs(prob) + + return response + + +# Fix random seeds +np.random.seed(0) +torch.manual_seed(0) + +# Create a server object configured to run a 2d threshold experiment +server = AEPsychServer() +client = AEPsychClient(server=server) +client.configure( + config_path="/Users/craigsanders/fbsource/fbcode/frl/ae/aepsych/configs/single_lse_example.ini" +) + +is_finished = False +while not is_finished: + # Ask the server what the next parameter values to test should be. + starttime = time.time() + trial_params = client.ask() + print(f"Ask time={time.time()-starttime}") + + # Simulate a participant response. + outcome = simulate_response(trial_params["config"]) + # time.sleep(2) + + # Tell the server what happened so that it can update its model. + client.tell(config=trial_params["config"], outcome=outcome) + is_finished = trial_params["is_finished"] + +# print(client.query("max")) +# Plot the results +plot_strat(server.strat, target_level=0.5) diff --git a/examples/untracked/minimal_pairwise_example.py b/examples/untracked/minimal_pairwise_example.py new file mode 100644 index 000000000..13f220651 --- /dev/null +++ b/examples/untracked/minimal_pairwise_example.py @@ -0,0 +1,57 @@ +import numpy as np +import torch +from aepsych_client import AEPsychClient +from aepsych_prerelease.server import AEPsychServer +from scipy.special import expit, logit +from scipy.stats import bernoulli + + +def get_response_probability(params): + m = 10 + b = logit(0.75) - m + p = expit(m * params.sum(1) + b) + return p + + +# Simulate participant responses; returns 1 if the participant detected the stimulus or 0 if they did not. +def simulate_response(trial_params): + params = np.array( + [ + [ + trial_params[f"{par}_1"][0] - trial_params[f"{par}_2"][0] + for par in ["par1", "par2", "par3"] + ] + ] + ) + + p = get_response_probability(params) + response = bernoulli.rvs(p) + + return response + + +# Fix random seeds +np.random.seed(0) +torch.manual_seed(0) + +# Configure the client/server to do pairwise optimization +server = AEPsychServer(database_path="pairwise_example.db") +client = AEPsychClient(server=server) +client.configure(config_path="../../configs/pairwise_opt_example.ini") + +# Do the ask/tell loop +finished = False +while not finished: + # Ask the server what the next parameter values to test should be. + response = client.ask() + trial_params = response["config"] + finished = response["is_finished"] + + # Simulate a participant response. + outcome = simulate_response(trial_params) + + # Tell the server what happened so that it can update its model. + client.tell(config=trial_params, outcome=outcome) + +# Finish the experiment +client.finalize() diff --git a/examples/untracked/pairwise_example.db b/examples/untracked/pairwise_example.db new file mode 100644 index 000000000..0a66e114b Binary files /dev/null and b/examples/untracked/pairwise_example.db differ diff --git a/tests/models/test_pairwise_probit.py b/tests/models/test_pairwise_probit.py deleted file mode 100644 index 0bafa8004..000000000 --- a/tests/models/test_pairwise_probit.py +++ /dev/null @@ -1,803 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import unittest -import uuid - -import numpy as np -import numpy.testing as npt -import torch -from aepsych import server, utils_logging -from aepsych.acquisition.objective import ProbitObjective -from aepsych.benchmark.test_functions import f_1d, f_2d, f_pairwise, new_novel_det -from aepsych.config import Config -from aepsych.generators import OptimizeAcqfGenerator, SobolGenerator -from aepsych.models import PairwiseProbitModel -from aepsych.server.message_handlers.handle_ask import ask -from aepsych.server.message_handlers.handle_setup import configure -from aepsych.server.message_handlers.handle_tell import tell -from aepsych.strategy import SequentialStrategy, Strategy -from botorch.acquisition import qUpperConfidenceBound -from botorch.acquisition.active_learning import PairwiseMCPosteriorVariance -from scipy.stats import bernoulli, norm, pearsonr - - -class PairwiseProbitModelStrategyTest(unittest.TestCase): - def test_pairs_to_comparisons(self): - def ptc_numpy(x, y, dim): - """ - old numpy impl of pairs to comparisons - """ - - # This needs to take a unique over the feature dim by flattening - # over pairs but not instances/batches. This is actually tensor - # matricization over the feature dimension but awkward in numpy - unique_coords = np.unique(np.moveaxis(x, 1, 0).reshape(dim, -1), axis=1) - - def _get_index_of_equal_row(arr, x, axis=0): - return np.argwhere(np.all(np.equal(arr, x[:, None]), axis=axis)).item() - - comparisons = [] - for pair, judgement in zip(x, y): - comparison = ( - _get_index_of_equal_row(unique_coords, pair[..., 0]), - _get_index_of_equal_row(unique_coords, pair[..., 1]), - ) - if judgement == 0: - comparisons.append(comparison) - else: - comparisons.append(comparison[::-1]) - return torch.Tensor(unique_coords.T), torch.LongTensor(comparisons) - - x = np.random.normal(size=(10, 1, 2)) - y = np.random.choice((0, 1), size=10) - - datapoints1, comparisons1 = ptc_numpy(x, y, 1) - - pbo = PairwiseProbitModel(lb=[-10], ub=[10]) - datapoints2, comparisons2 = pbo._pairs_to_comparisons( - torch.Tensor(x), torch.Tensor(y) - ) - npt.assert_equal(datapoints1.numpy(), datapoints2.numpy()) - npt.assert_equal(comparisons1.numpy(), comparisons2.numpy()) - - x = np.random.normal(size=(10, 2, 2)) - y = np.random.choice((0, 1), size=10) - - datapoints1, comparisons1 = ptc_numpy(x, y, 2) - - pbo = PairwiseProbitModel(lb=[-10], ub=[10], dim=2) - datapoints2, comparisons2 = pbo._pairs_to_comparisons( - torch.Tensor(x), torch.Tensor(y) - ) - npt.assert_equal(datapoints1.numpy(), datapoints2.numpy()) - npt.assert_equal(comparisons1.numpy(), comparisons2.numpy()) - - def test_pairwise_probit_batched(self): - """ - test our 1d gaussian bump example - """ - seed = 1 - torch.manual_seed(seed) - np.random.seed(seed) - n_init = 20 - n_opt = 1 - lb = [-4.0, 1e-5] - ub = [-1e-5, 4.0] - extra_acqf_args = {"beta": 3.84} - model_list = [ - Strategy( - lb=lb, - ub=ub, - generator=SobolGenerator(lb=lb, ub=ub, seed=seed, stimuli_per_trial=2), - min_asks=n_init, - stimuli_per_trial=2, - outcome_types=["binary"], - ), - Strategy( - lb=lb, - ub=ub, - model=PairwiseProbitModel(lb=lb, ub=ub), - generator=OptimizeAcqfGenerator( - acqf=qUpperConfidenceBound, - acqf_kwargs=extra_acqf_args, - stimuli_per_trial=2, - ), - min_asks=n_opt, - stimuli_per_trial=2, - outcome_types=["binary"], - ), - ] - - strat = SequentialStrategy(model_list) - - while not strat.finished: - next_pair = strat.gen(num_points=3) - # next_pair is batch x dim x pair, - # this checks that we have the reshapes - # right - self.assertTrue((next_pair[:, 0, :] < 0).all()) - self.assertTrue((next_pair[:, 1, :] > 0).all()) - strat.add_data( - next_pair, - bernoulli.rvs( - f_pairwise(f_1d, next_pair.sum(1), noise_scale=0.1).squeeze() - ), - ) - - xgrid = strat.model.dim_grid(gridsize=10) - - zhat, _ = strat.predict(xgrid) - # true max is 0, very loose test - self.assertTrue(xgrid[torch.argmax(zhat, 0)].sum().detach().numpy() < 0.5) - - def test_pairwise_memorize(self): - """ - can we memorize a simple function - """ - seed = 1 - torch.manual_seed(seed) - np.random.seed(seed) - lb = [-1, -1] - ub = [1, 1] - gen = SobolGenerator(lb=lb, ub=ub, seed=seed, stimuli_per_trial=2) - x = torch.Tensor(gen.gen(num_points=20)) - # "noiseless" new_novel_det (just take the mean instead of sampling) - y = torch.Tensor(f_pairwise(new_novel_det, x) > 0.5).int() - model = PairwiseProbitModel(lb=lb, ub=ub) - model.fit(x[:18], y[:18]) - with torch.no_grad(): - f0, _ = model.predict(x[18:, ..., 0]) - f1, _ = model.predict(x[18:, ..., 1]) - pred_diff = norm.cdf(f1 - f0) - pred = pred_diff > 0.5 - npt.assert_allclose(pred, y[18:]) - - def test_pairwise_memorize_rescaled(self): - """ - can we memorize a simple function (with rescaled inputs) - """ - seed = 1 - torch.manual_seed(seed) - np.random.seed(seed) - lb = [-1000, 0] - ub = [0, 1e-5] - gen = SobolGenerator(lb=lb, ub=ub, seed=seed, stimuli_per_trial=2) - x = torch.Tensor(gen.gen(num_points=20)) - # "noiseless" new_novel_det (just take the mean instead of sampling) - xrescaled = x.clone() - xrescaled[:, 0, :] = xrescaled[:, 0, :] / 500 + 1 - xrescaled[:, 1, :] = xrescaled[:, 1, :] / 5e-6 - 1 - y = torch.Tensor(f_pairwise(new_novel_det, xrescaled) > 0.5).int() - model = PairwiseProbitModel(lb=lb, ub=ub) - model.fit(x[:18], y[:18]) - with torch.no_grad(): - f0, _ = model.predict(x[18:, ..., 0]) - f1, _ = model.predict(x[18:, ..., 1]) - pred_diff = norm.cdf(f1 - f0) - pred = pred_diff > 0.5 - npt.assert_allclose(pred, y[18:]) - - def test_1d_pairwise_probit(self): - """ - test our 1d gaussian bump example - """ - seed = 1 - torch.manual_seed(seed) - np.random.seed(seed) - n_init = 50 - n_opt = 1 - lb = -4.0 - ub = 4.0 - extra_acqf_args = {"beta": 3.84} - model_list = [ - Strategy( - lb=lb, - ub=ub, - generator=SobolGenerator(lb=lb, ub=ub, seed=seed, stimuli_per_trial=2), - min_asks=n_init, - stimuli_per_trial=2, - outcome_types=["binary"], - ), - Strategy( - lb=lb, - ub=ub, - model=PairwiseProbitModel(lb=lb, ub=ub), - generator=OptimizeAcqfGenerator( - acqf=qUpperConfidenceBound, - acqf_kwargs=extra_acqf_args, - stimuli_per_trial=2, - ), - min_asks=n_opt, - stimuli_per_trial=2, - outcome_types=["binary"], - ), - ] - - strat = SequentialStrategy(model_list) - - for _i in range(n_init + n_opt): - next_pair = strat.gen() - strat.add_data( - next_pair, [bernoulli.rvs(f_pairwise(f_1d, next_pair, noise_scale=0.1))] - ) - - x = torch.linspace(-4, 4, 100) - - zhat, _ = strat.predict(x) - # true max is 0, very loose test - self.assertTrue(np.abs(x[np.argmax(zhat.detach().numpy())]) < 0.5) - - def test_1d_pairwise_probit_pure_exploration(self): - seed = 1 - torch.manual_seed(seed) - np.random.seed(seed) - n_init = 50 - n_opt = 1 - lb = -2.0 - ub = 2.0 - - acqf = PairwiseMCPosteriorVariance - extra_acqf_args = {"objective": ProbitObjective()} - - model_list = [ - Strategy( - lb=lb, - ub=ub, - generator=SobolGenerator(lb=lb, ub=ub, seed=seed, stimuli_per_trial=2), - min_asks=n_init, - stimuli_per_trial=2, - outcome_types=["binary"], - ), - Strategy( - lb=lb, - ub=ub, - model=PairwiseProbitModel(lb=lb, ub=ub), - generator=OptimizeAcqfGenerator( - acqf=acqf, acqf_kwargs=extra_acqf_args, stimuli_per_trial=2 - ), - min_asks=n_opt, - stimuli_per_trial=2, - outcome_types=["binary"], - ), - ] - - strat = SequentialStrategy(model_list) - - for _i in range(n_init + n_opt): - next_pair = strat.gen() - strat.add_data( - next_pair, - [bernoulli.rvs(f_pairwise(lambda x: x, next_pair, noise_scale=0.1))], - ) - - test_gen = SobolGenerator(lb=lb, ub=ub, seed=seed + 1, stimuli_per_trial=2) - test_x = torch.Tensor(test_gen.gen(100)) - - ftrue_test = (test_x[..., 0] - test_x[..., 1]).squeeze() - - with torch.no_grad(): - fdiff_test = ( - strat.model.predict(test_x[..., 0], rereference=None)[0] - - strat.model.predict(test_x[..., 1], rereference=None)[0] - ) - - self.assertTrue(pearsonr(fdiff_test, ftrue_test)[0] >= 0.9) - - with torch.no_grad(): - fdiff_test_reref = ( - strat.model.predict(test_x[..., 0])[0] - - strat.model.predict(test_x[..., 1])[0] - ) - - self.assertTrue(pearsonr(fdiff_test_reref, ftrue_test)[0] >= 0.9) - - def test_2d_pairwise_probit(self): - seed = 1 - torch.manual_seed(seed) - np.random.seed(seed) - n_init = 20 - n_opt = 1 - lb = np.r_[-1, -1] - ub = np.r_[1, 1] - extra_acqf_args = {"beta": 3.84} - - model_list = [ - Strategy( - lb=lb, - ub=ub, - generator=SobolGenerator(lb=lb, ub=ub, seed=seed, stimuli_per_trial=2), - min_asks=n_init, - stimuli_per_trial=2, - outcome_types=["binary"], - ), - Strategy( - lb=lb, - ub=ub, - model=PairwiseProbitModel(lb=lb, ub=ub), - generator=OptimizeAcqfGenerator( - acqf=qUpperConfidenceBound, - acqf_kwargs=extra_acqf_args, - stimuli_per_trial=2, - ), - min_asks=n_opt, - stimuli_per_trial=2, - outcome_types=["binary"], - ), - ] - - strat = SequentialStrategy(model_list) - - for _i in range(n_init + n_opt): - next_pair = strat.gen() - strat.add_data( - next_pair, [bernoulli.rvs(f_pairwise(f_2d, next_pair, noise_scale=0.1))] - ) - - xy = np.mgrid[-1:1:30j, -1:1:30j].reshape(2, -1).T - - zhat, _ = strat.predict(torch.Tensor(xy)) - - # true min is at 0,0 - self.assertTrue(np.all(np.abs(xy[np.argmax(zhat.detach().numpy())]) < 0.2)) - - def test_2d_pairwise_probit_pure_exploration(self): - seed = 1 - torch.manual_seed(seed) - np.random.seed(seed) - n_init = 20 - n_opt = 1 - lb = np.r_[-1, -1] - ub = np.r_[1, 1] - acqf = PairwiseMCPosteriorVariance - extra_acqf_args = {"objective": ProbitObjective()} - - model_list = [ - Strategy( - lb=lb, - ub=ub, - generator=SobolGenerator(lb=lb, ub=ub, seed=seed, stimuli_per_trial=2), - min_asks=n_init, - stimuli_per_trial=2, - outcome_types=["binary"], - ), - Strategy( - lb=lb, - ub=ub, - model=PairwiseProbitModel(lb=lb, ub=ub), - generator=OptimizeAcqfGenerator( - acqf=acqf, acqf_kwargs=extra_acqf_args, stimuli_per_trial=2 - ), - min_asks=n_opt, - stimuli_per_trial=2, - outcome_types=["binary"], - ), - ] - - strat = SequentialStrategy(model_list) - - for _i in range(n_init + n_opt): - next_pair = strat.gen() - strat.add_data( - next_pair, [bernoulli.rvs(f_pairwise(new_novel_det, next_pair))] - ) - - xy = np.mgrid[-1:1:30j, -1:1:30j].reshape(2, -1).T - - zhat, _ = strat.predict(torch.Tensor(xy)) - - ztrue = new_novel_det(xy) - - corr = pearsonr(zhat.detach().flatten(), ztrue.flatten())[0] - self.assertTrue(corr > 0.80) - - def test_sobolmodel_pairwise(self): - # test that SobolModel correctly gets bounds - - sobol_x = np.zeros((10, 3, 2)) - mod = Strategy( - lb=[1, 2, 3], - ub=[2, 3, 4], - min_asks=10, - stimuli_per_trial=2, - outcome_types=["binary"], - generator=SobolGenerator( - lb=[1, 2, 3], ub=[2, 3, 4], seed=12345, stimuli_per_trial=2 - ), - ) - - for i in range(10): - sobol_x[i, ...] = mod.gen() - - self.assertTrue(np.all(sobol_x[:, 0, :] > 1)) - self.assertTrue(np.all(sobol_x[:, 1, :] > 2)) - self.assertTrue(np.all(sobol_x[:, 2, :] > 3)) - self.assertTrue(np.all(sobol_x[:, 0, :] < 2)) - self.assertTrue(np.all(sobol_x[:, 1, :] < 3)) - self.assertTrue(np.all(sobol_x[:, 2, :] < 4)) - - def test_hyperparam_consistency(self): - # verify that creating the model `from_config` or with `__init__` has the same hyperparams - - m1 = PairwiseProbitModel(lb=[1, 2], ub=[3, 4]) - - m2 = PairwiseProbitModel.from_config( - config=Config(config_dict={"common": {"lb": "[1,2]", "ub": "[3,4]"}}) - ) - - self.assertTrue(isinstance(m1.covar_module, type(m2.covar_module))) - self.assertTrue( - isinstance(m1.covar_module.base_kernel, type(m2.covar_module.base_kernel)) - ) - self.assertTrue(isinstance(m1.mean_module, type(m2.mean_module))) - m1priors = list(m1.covar_module.named_priors()) - m2priors = list(m2.covar_module.named_priors()) - for p1, p2 in zip(m1priors, m2priors): - name1, parent1, prior1, paramtransforms1, priortransforms1 = p1 - name2, parent2, prior2, paramtransforms2, priortransforms2 = p2 - self.assertTrue(name1 == name2) - self.assertTrue(isinstance(parent1, type(parent2))) - self.assertTrue(isinstance(prior1, type(prior2))) - # no obvious way to test paramtransform equivalence - - -class PairwiseProbitModelServerTest(unittest.TestCase): - def setUp(self): - # setup logger - server.logger = utils_logging.getLogger(logging.DEBUG, "logs") - # random datebase path name without dashes - database_path = "./{}.db".format(str(uuid.uuid4().hex)) - self.s = server.AEPsychServer(database_path=database_path) - - def tearDown(self): - self.s.cleanup() - - # cleanup the db - if self.s.db is not None: - self.s.db.delete_db() - - def test_1d_pairwise_server(self): - seed = 123 - torch.manual_seed(seed) - np.random.seed(seed) - n_init = 50 - n_opt = 2 - config_str = f""" - [common] - lb = [-4] - ub = [4] - stimuli_per_trial = 2 - outcome_types =[binary] - parnames = [x] - strategy_names = [init_strat, opt_strat] - acqf = PairwiseMCPosteriorVariance - - [init_strat] - min_asks = {n_init} - generator = PairwiseSobolGenerator - - [opt_strat] - model = PairwiseProbitModel - min_asks = {n_opt} - generator = OptimizeAcqfGenerator - - [PairwiseProbitModel] - mean_covar_factory = default_mean_covar_factory - - [PairwiseMCPosteriorVariance] - objective = ProbitObjective - - [OptimizeAcqfGenerator] - restarts = 10 - samps = 1000 - """ - - server = self.s - configure( - server, - config_str=config_str, - ) - - for _i in range(n_init + n_opt): - next_config = ask(server) - next_y = bernoulli.rvs(f_pairwise(f_1d, next_config["x"], noise_scale=0.1)) - tell(server, config=next_config, outcome=next_y) - - x = torch.linspace(-4, 4, 100) - zhat, _ = server.strat.predict(x) - self.assertTrue(np.abs(x[np.argmax(zhat.detach().numpy())]) < 0.5) - - def test_2d_pairwise_server(self): - seed = 1 - torch.manual_seed(seed) - np.random.seed(seed) - n_init = 50 - n_opt = 1 - config_str = f""" - [common] - lb = [-1, -1] - ub = [1, 1] - stimuli_per_trial=2 - outcome_types=[binary] - parnames = [x, y] - strategy_names = [init_strat, opt_strat] - acqf = PairwiseMCPosteriorVariance - - [init_strat] - min_asks = {n_init} - generator = PairwiseSobolGenerator - - [opt_strat] - min_asks = {n_opt} - model = PairwiseProbitModel - generator = OptimizeAcqfGenerator - - [PairwiseProbitModel] - mean_covar_factory = default_mean_covar_factory - - [PairwiseMCPosteriorVariance] - objective = ProbitObjective - - [OptimizeAcqfGenerator] - restarts = 10 - samps = 1000 - """ - - server = self.s - configure( - server, - config_str=config_str, - ) - for _i in range(n_init + n_opt): - next_config = ask(server) - next_pair = np.c_[next_config["x"], next_config["y"]].T - next_y = bernoulli.rvs(f_pairwise(f_2d, next_pair, noise_scale=0.1)) - tell(server, config=next_config, outcome=next_y) - - xy = np.mgrid[-1:1:30j, -1:1:30j].reshape(2, -1).T - - zhat, _ = server.strat.predict(torch.Tensor(xy)) - - # true min is at 0,0 - self.assertTrue(np.all(np.abs(xy[np.argmax(zhat.detach().numpy())]) < 0.2)) - - def test_serialization_1d(self): - seed = 1 - torch.manual_seed(seed) - np.random.seed(seed) - n_init = 3 - n_opt = 1 - config_str = f""" - [common] - lb = [-4] - ub = [4] - stimuli_per_trial=2 - outcome_types=[binary] - parnames = [x] - strategy_names = [init_strat, opt_strat] - acqf = PairwiseMCPosteriorVariance - - [init_strat] - min_asks = {n_init} - generator = PairwiseSobolGenerator - - [opt_strat] - model = PairwiseProbitModel - min_asks = {n_opt} - generator = OptimizeAcqfGenerator - - [PairwiseProbitModel] - mean_covar_factory = default_mean_covar_factory - - [PairwiseMCPosteriorVariance] - objective = ProbitObjective - - [OptimizeAcqfGenerator] - restarts = 10 - samps = 1000 - """ - - server = self.s - configure(server, config_str=config_str) - - for _i in range(n_init + n_opt): - next_config = ask(server) - next_y = bernoulli.rvs(f_pairwise(f_1d, next_config["x"])) - tell(server, config=next_config, outcome=next_y) - - import dill - - # just make sure it works - try: - s = dill.dumps(server) - server2 = dill.loads(s) - self.assertEqual(len(server2._strats), len(server._strats)) - for strat1, strat2 in zip(server._strats, server2._strats): - self.assertEqual(type(strat1), type(strat2)) - self.assertEqual(type(strat1.model), type(strat2.model)) - self.assertTrue(torch.equal(strat1.x, strat2.x)) - self.assertTrue(torch.equal(strat1.y, strat2.y)) - - except Exception: - self.fail() - - def test_serialization_2d(self): - seed = 1 - torch.manual_seed(seed) - np.random.seed(seed) - n_init = 3 - n_opt = 1 - - config_str = f""" - [common] - lb = [-1, -1] - ub = [1, 1] - stimuli_per_trial=2 - outcome_types=[binary] - parnames = [x, y] - strategy_names = [init_strat, opt_strat] - acqf = PairwiseMCPosteriorVariance - - [init_strat] - min_asks = {n_init} - generator = PairwiseSobolGenerator - - [opt_strat] - model = PairwiseProbitModel - min_asks = {n_opt} - generator = PairwiseOptimizeAcqfGenerator - - [PairwiseProbitModel] - mean_covar_factory = default_mean_covar_factory - - [PairwiseMCPosteriorVariance] - objective = ProbitObjective - - [PairwiseOptimizeAcqfGenerator] - restarts = 10 - samps = 1000 - """ - - server = self.s - - configure(server, config_str=config_str) - - for _i in range(n_init + n_opt): - next_config = ask(server) - next_pair = np.c_[next_config["x"], next_config["y"]].T - next_y = bernoulli.rvs(f_pairwise(f_2d, next_pair)) - tell(server, config=next_config, outcome=next_y) - - import dill - - # just make sure it works - try: - s = dill.dumps(server) - server2 = dill.loads(s) - self.assertEqual(len(server2._strats), len(server._strats)) - for strat1, strat2 in zip(server._strats, server2._strats): - self.assertEqual(type(strat1), type(strat2)) - self.assertEqual(type(strat1.model), type(strat2.model)) - self.assertTrue(torch.equal(strat1.x, strat2.x)) - self.assertTrue(torch.equal(strat1.y, strat2.y)) - except Exception: - self.fail() - - def test_config_to_tensor(self): - config_str = """ - [common] - lb = [-1] - ub = [1] - stimuli_per_trial=2 - outcome_types=[binary] - parnames = [x] - strategy_names = [init_strat, opt_strat] - acqf = PairwiseMCPosteriorVariance - - [init_strat] - min_asks = 1 - generator = PairwiseSobolGenerator - - [opt_strat] - model = PairwiseProbitModel - min_asks = 1 - generator = OptimizeAcqfGenerator - - [PairwiseProbitModel] - mean_covar_factory = default_mean_covar_factory - - [PairwiseMCPosteriorVariance] - objective = ProbitObjective - - [OptimizeAcqfGenerator] - restarts = 10 - samps = 1000 - """ - server = self.s - - configure(server, config_str=config_str) - - conf = ask(server) - - self.assertTrue(server._config_to_tensor(conf).shape == (1, 2)) - - config_str = """ - [common] - lb = [-1, -1] - ub = [1, 1] - stimuli_per_trial=2 - outcome_types=[binary] - parnames = [x, y] - strategy_names = [init_strat, opt_strat] - acqf = PairwiseMCPosteriorVariance - - [init_strat] - min_asks = 1 - generator = PairwiseSobolGenerator - - [opt_strat] - model = PairwiseProbitModel - min_asks = 1 - generator = OptimizeAcqfGenerator - - [PairwiseProbitModel] - mean_covar_factory = default_mean_covar_factory - - [PairwiseMCPosteriorVariance] - objective = ProbitObjective - - [OptimizeAcqfGenerator] - restarts = 10 - samps = 1000 - """ - - configure(server, config_str=config_str) - - conf = ask(server) - - self.assertTrue(server._config_to_tensor(conf).shape == (2, 2)) - - config_str = """ - [common] - lb = [-1, -1, -1] - ub = [1, 1, 1] - stimuli_per_trial=2 - outcome_types=[binary] - parnames = [x, y, z] - strategy_names = [init_strat, opt_strat] - acqf = PairwiseMCPosteriorVariance - - [init_strat] - min_asks = 1 - generator = PairwiseSobolGenerator - - [opt_strat] - model = PairwiseProbitModel - min_asks = 1 - generator = OptimizeAcqfGenerator - - [PairwiseProbitModel] - mean_covar_factory = default_mean_covar_factory - - [PairwiseMCPosteriorVariance] - objective = ProbitObjective - - [OptimizeAcqfGenerator] - restarts = 10 - samps = 1000 - """ - - configure(server, config_str=config_str) - - conf = ask(server) - - self.assertTrue(server._config_to_tensor(conf).shape == (3, 2)) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_pairwise_kernel.py b/tests/test_pairwise_kernel.py new file mode 100644 index 000000000..04e0a8f88 --- /dev/null +++ b/tests/test_pairwise_kernel.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +import unittest + +import numpy as np +import numpy.testing as npt +import torch +from aepsych.kernels.pairwisekernel import PairwiseKernel +from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad +from gpytorch.kernels import RBFKernel + + +class PairwiseKernelTest(unittest.TestCase): + """ + Basic tests that PairwiseKernel is working + """ + + def setUp(self): + self.latent_kernel = RBFKernel() + self.kernel = PairwiseKernel(self.latent_kernel) + + def test_kernelgrad_pairwise(self): + kernel = PairwiseKernel(RBFKernelPartialObsGrad(), is_partial_obs=True) + x1 = torch.rand(torch.Size([2, 4])) + x2 = torch.rand(torch.Size([2, 4])) + + x1 = torch.cat((x1, torch.zeros(2, 1)), dim=1) + x2 = torch.cat((x2, torch.zeros(2, 1)), dim=1) + + deriv_idx_1 = x1[..., -1][:, None] + deriv_idx_2 = x2[..., -1][:, None] + + a = torch.cat((x1[..., :2], deriv_idx_1), dim=1) + b = torch.cat((x1[..., 2:-1], deriv_idx_1), dim=1) + c = torch.cat((x2[..., :2], deriv_idx_2), dim=1) + d = torch.cat((x2[..., 2:-1], deriv_idx_2), dim=1) + + c12 = kernel.forward(x1, x2).evaluate().detach().numpy() + pwc = ( + ( + kernel.latent_kernel.forward(a, c) + - kernel.latent_kernel.forward(a, d) + - kernel.latent_kernel.forward(b, c) + + kernel.latent_kernel.forward(b, d) + ) + .detach() + .numpy() + ) + npt.assert_allclose(c12, pwc, atol=1e-6) + + def test_dim_check(self): + """ + Test that we get expected errors. + """ + x1 = torch.zeros(torch.Size([3])) + x2 = torch.zeros(torch.Size([3])) + x3 = torch.zeros(torch.Size([2])) + x4 = torch.zeros(torch.Size([4])) + + self.assertRaises(AssertionError, self.kernel.forward, x1=x1, x2=x2) + + self.assertRaises(AssertionError, self.kernel.forward, x1=x3, x2=x4) + + def test_covar(self): + """ + Test that we get expected covariances + """ + np.random.seed(1) + torch.manual_seed(1) + + x1 = torch.rand(torch.Size([2, 4])) + x2 = torch.rand(torch.Size([2, 4])) + a = x1[..., :2] + b = x1[..., 2:] + c = x2[..., :2] + d = x2[..., 2:] + c12 = self.kernel.forward(x1, x2).evaluate().detach().numpy() + pwc = ( + ( + self.latent_kernel.forward(a, c) + - self.latent_kernel.forward(a, d) + - self.latent_kernel.forward(b, c) + + self.latent_kernel.forward(b, d) + ) + .detach() + .numpy() + ) + npt.assert_allclose(c12, pwc, atol=1e-6) + + shape = np.array(c12.shape) + npt.assert_equal(shape, np.array([2, 2])) + + x3 = torch.rand(torch.Size([3, 4])) + x4 = torch.rand(torch.Size([6, 4])) + a = x3[..., :2] + b = x3[..., 2:] + c = x4[..., :2] + d = x4[..., 2:] + c34 = self.kernel.forward(x3, x4).evaluate().detach().numpy() + pwc = ( + ( + self.latent_kernel.forward(a, c) + - self.latent_kernel.forward(a, d) + - self.latent_kernel.forward(b, c) + + self.latent_kernel.forward(b, d) + ) + .detach() + .numpy() + ) + npt.assert_allclose(c34, pwc, atol=1e-6) + + shape = np.array(c34.shape) + npt.assert_equal(shape, np.array([3, 6])) + + def test_latent_diag(self): + """ + g(a, a) = 0 for all a, so K((a, a), (a, a)) = 0 + """ + + np.random.seed(1) + torch.manual_seed(1) + a = torch.rand(torch.Size([2, 2])) + + # should get 0 variance on pairs (a,a) + diag = torch.cat((a, a), dim=1) + diagv = self.kernel.forward(diag, diag).evaluate().detach().numpy() + npt.assert_allclose(diagv, 0.0) + + def test_diag(self): + """ + make sure the diagonal is the right shape + """ + np.random.seed(1) + torch.manual_seed(1) + + x1 = torch.rand(torch.Size([2, 2, 4])) + x2 = torch.rand(torch.Size([2, 2, 4])) + + diag = self.kernel(x1, x2, diag=True) + shape = np.array(diag.shape) + npt.assert_equal(shape, np.array([2, 2])) + + x1 = torch.rand(torch.Size([2, 4])) + x2 = torch.rand(torch.Size([2, 4])) + + diag = self.kernel(x1, x2, diag=True) + shape = np.array(diag.shape) + npt.assert_equal(shape, np.array([2])) + + +if __name__ == "__main__": + unittest.main()