Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement pairwisekernel #371

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions aepsych/factory/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from aepsych.config import Config
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
from aepsych.utils import get_dim
from scipy.stats import norm

from ..kernels.pairwisekernel import PairwiseKernel

"""AEPsych factory functions.
These functions generate a gpytorch Mean and Kernel objects from
aepsych.config.Config configurations, including setting lengthscale
Expand All @@ -36,7 +37,9 @@


def default_mean_covar_factory(
config: Optional[Config] = None, dim: Optional[int] = None
config: Optional[Config] = None,
dim: Optional[int] = None,
stimuli_per_trial: int = 1,
) -> Tuple[gpytorch.means.ConstantMean, gpytorch.kernels.ScaleKernel]:
"""Default factory for generic GP models

Expand All @@ -55,6 +58,8 @@ def default_mean_covar_factory(
dim is not None
), "Either config or dim must be provided!"

assert stimuli_per_trial in (1, 2), "stimuli_per_trial must be 1 or 2!"

fixed_mean = False
lengthscale_prior = "gamma"
outputscale_prior = "box"
Expand Down Expand Up @@ -136,6 +141,9 @@ def default_mean_covar_factory(
outputscale_prior=os_prior,
)

if stimuli_per_trial == 2:
covar = PairwiseKernel(covar)

return mean, covar


Expand Down
4 changes: 0 additions & 4 deletions aepsych/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from .monotonic_rejection_generator import MonotonicRejectionGenerator
from .monotonic_thompson_sampler_generator import MonotonicThompsonSamplerGenerator
from .optimize_acqf_generator import OptimizeAcqfGenerator
from .pairwise_optimize_acqf_generator import PairwiseOptimizeAcqfGenerator
from .pairwise_sobol_generator import PairwiseSobolGenerator
from .random_generator import RandomGenerator
from .semi_p import IntensityAwareSemiPGenerator
from .sobol_generator import SobolGenerator
Expand All @@ -28,8 +26,6 @@
"SobolGenerator",
"EpsilonGreedyGenerator",
"ManualGenerator",
"PairwiseOptimizeAcqfGenerator",
"PairwiseSobolGenerator",
"IntensityAwareSemiPGenerator",
"AcqfThompsonSamplerGenerator"
]
Expand Down
7 changes: 3 additions & 4 deletions aepsych/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
# LICENSE file in the root directory of this source tree.
import abc
from inspect import signature
from typing import Any, Dict, Generic, Protocol, runtime_checkable, TypeVar, Optional
from typing import Any, Dict, Generic, Optional, Protocol, runtime_checkable, TypeVar
import re

import torch
from aepsych.config import Config
from aepsych.models.base import AEPsychMixin
from botorch.acquisition import (
AcquisitionFunction,
NoisyExpectedImprovement,
qNoisyExpectedImprovement,
LogNoisyExpectedImprovement,
NoisyExpectedImprovement,
qLogNoisyExpectedImprovement,
qNoisyExpectedImprovement,
)


Expand All @@ -40,7 +40,6 @@ class AEPsychGenerator(abc.ABC, Generic[AEPsychModelType]):
qLogNoisyExpectedImprovement,
LogNoisyExpectedImprovement,
]
stimuli_per_trial = 1
max_asks: Optional[int] = None

def __init__(
Expand Down
24 changes: 5 additions & 19 deletions aepsych/generators/optimize_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import ModelProtocol
from aepsych.utils_logging import getLogger
from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption
from botorch.optim import optimize_acqf
from botorch.acquisition import (
AcquisitionFunction,
NoisyExpectedImprovement,
qNoisyExpectedImprovement,
LogNoisyExpectedImprovement,
NoisyExpectedImprovement,
qLogNoisyExpectedImprovement,
qNoisyExpectedImprovement,
)
from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption
from botorch.optim import optimize_acqf

logger = getLogger()

Expand All @@ -44,7 +44,6 @@ def __init__(
restarts: int = 10,
samps: int = 1000,
max_gen_time: Optional[float] = None,
stimuli_per_trial: int = 1,
) -> None:
"""Initialize OptimizeAcqfGenerator.
Args:
Expand All @@ -63,7 +62,6 @@ def __init__(
self.restarts = restarts
self.samps = samps
self.max_gen_time = max_gen_time
self.stimuli_per_trial = stimuli_per_trial

def _instantiate_acquisition_fn(self, model: ModelProtocol):
if self.acqf == AnalyticExpectedUtilityOfBestOption:
Expand All @@ -83,17 +81,7 @@ def gen(self, num_points: int, model: ModelProtocol, **gen_options) -> torch.Ten
np.ndarray: Next set of point(s) to evaluate, [num_points x dim].
"""

if self.stimuli_per_trial == 2:
qbatch_points = self._gen(
num_points=num_points * 2, model=model, **gen_options
)

# output of super() is (q, dim) but the contract is (num_points, dim, 2)
# so we need to split q into q and pairs and then move the pair dim to the end
return qbatch_points.reshape(num_points, 2, -1).swapaxes(-1, -2)

else:
return self._gen(num_points=num_points, model=model, **gen_options)
return self._gen(num_points=num_points, model=model, **gen_options)

def _gen(
self, num_points: int, model: ModelProtocol, **gen_options
Expand Down Expand Up @@ -124,7 +112,6 @@ def from_config(cls, config: Config):
classname = cls.__name__
acqf = config.getobj(classname, "acqf", fallback=None)
extra_acqf_args = cls._get_acqf_options(acqf, config)
stimuli_per_trial = config.getint(classname, "stimuli_per_trial")
restarts = config.getint(classname, "restarts", fallback=10)
samps = config.getint(classname, "samps", fallback=1000)
max_gen_time = config.getfloat(classname, "max_gen_time", fallback=None)
Expand All @@ -135,5 +122,4 @@ def from_config(cls, config: Config):
restarts=restarts,
samps=samps,
max_gen_time=max_gen_time,
stimuli_per_trial=stimuli_per_trial,
)
25 changes: 0 additions & 25 deletions aepsych/generators/pairwise_optimize_acqf_generator.py

This file was deleted.

26 changes: 0 additions & 26 deletions aepsych/generators/pairwise_sobol_generator.py

This file was deleted.

25 changes: 6 additions & 19 deletions aepsych/generators/sobol_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(
ub: Union[np.ndarray, torch.Tensor],
dim: Optional[int] = None,
seed: Optional[int] = None,
stimuli_per_trial: int = 1,
):
"""Iniatialize SobolGenerator.
Args:
Expand All @@ -38,13 +37,8 @@ def __init__(
seed (int, optional): Random seed.
"""
self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim)
self.lb = self.lb.repeat(stimuli_per_trial)
self.ub = self.ub.repeat(stimuli_per_trial)
self.stimuli_per_trial = stimuli_per_trial
self.seed = seed
self.engine = SobolEngine(
dimension=self.dim * stimuli_per_trial, scramble=True, seed=self.seed
)
self.engine = SobolEngine(dimension=self.dim, scramble=True, seed=self.seed)

def gen(
self,
Expand All @@ -59,16 +53,7 @@ def gen(
"""
grid = self.engine.draw(num_points)
grid = self.lb + (self.ub - self.lb) * grid
if self.stimuli_per_trial == 1:
return grid

return torch.tensor(
np.moveaxis(
grid.reshape(num_points, self.stimuli_per_trial, -1).numpy(),
-1,
-self.stimuli_per_trial,
)
)
return grid

@classmethod
def from_config(cls, config: Config):
Expand All @@ -78,8 +63,10 @@ def from_config(cls, config: Config):
ub = config.gettensor(classname, "ub")
dim = config.getint(classname, "dim", fallback=None)
seed = config.getint(classname, "seed", fallback=None)
stimuli_per_trial = config.getint(classname, "stimuli_per_trial")

return cls(
lb=lb, ub=ub, dim=dim, seed=seed, stimuli_per_trial=stimuli_per_trial
lb=lb,
ub=ub,
dim=dim,
seed=seed,
)
5 changes: 5 additions & 0 deletions aepsych/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .pairwisekernel import PairwiseKernel
from .rbf_partial_grad import RBFKernelPartialObsGrad

__all__ = ["PairwiseKernel", "RBFKernelPartialObsGrad"]
85 changes: 85 additions & 0 deletions aepsych/kernels/pairwisekernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch
from gpytorch.kernels import Kernel
from linear_operator import to_linear_operator


class PairwiseKernel(Kernel):
"""
Wrapper to convert a kernel K on R^k to a kernel K' on R^{2k}, modeling
functions of the form g(a, b) = f(a) - f(b), where f ~ GP(mu, K).

Since g is a linear combination of Gaussians, it follows that g ~ GP(0, K')
where K'((a,b), (c,d)) = K(a,c) - K(a, d) - K(b, c) + K(b, d).

"""

def __init__(self, latent_kernel, is_partial_obs=False, **kwargs):
super(PairwiseKernel, self).__init__(**kwargs)

self.latent_kernel = latent_kernel
self.is_partial_obs = is_partial_obs

def forward(self, x1, x2, diag=False, **params):
r"""
TODO: make last_batch_dim work properly

d must be 2*k for integer k, k is the dimension of the latent space
Args:
:attr:`x1` (Tensor `n x d` or `b x n x d`):
First set of data
:attr:`x2` (Tensor `m x d` or `b x m x d`):
Second set of data
:attr:`diag` (bool):
Should the Kernel compute the whole kernel, or just the diag?

Returns:
:class:`Tensor` or :class:`gpytorch.lazy.LazyTensor`.
The exact size depends on the kernel's evaluation mode:

* `full_covar`: `n x m` or `b x n x m`
* `diag`: `n` or `b x n`
"""
if self.is_partial_obs:
d = x1.shape[-1] - 1
assert d == x2.shape[-1] - 1, "tensors not the same dimension"
assert d % 2 == 0, "dimension must be even"

k = int(d / 2)

# special handling for kernels that (also) do funky
# things with the input dimension
deriv_idx_1 = x1[..., -1][:, None]
deriv_idx_2 = x2[..., -1][:, None]

a = torch.cat((x1[..., :k], deriv_idx_1), dim=1)
b = torch.cat((x1[..., k:-1], deriv_idx_1), dim=1)
c = torch.cat((x2[..., :k], deriv_idx_2), dim=1)
d = torch.cat((x2[..., k:-1], deriv_idx_2), dim=1)

else:
d = x1.shape[-1]

assert d == x2.shape[-1], "tensors not the same dimension"
assert d % 2 == 0, "dimension must be even"

k = int(d / 2)

a = x1[..., :k]
b = x1[..., k:]
c = x2[..., :k]
d = x2[..., k:]

if not diag:
return (
to_linear_operator(self.latent_kernel(a, c, diag=diag, **params))
+ to_linear_operator(self.latent_kernel(b, d, diag=diag, **params))
- to_linear_operator(self.latent_kernel(b, c, diag=diag, **params))
- to_linear_operator(self.latent_kernel(a, d, diag=diag, **params))
)
else:
return (
self.latent_kernel(a, c, diag=diag, **params)
+ self.latent_kernel(b, d, diag=diag, **params)
- self.latent_kernel(b, c, diag=diag, **params)
- self.latent_kernel(a, d, diag=diag, **params)
)
2 changes: 0 additions & 2 deletions aepsych/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from .monotonic_rejection_gp import MonotonicRejectionGP
from .multitask_regression import IndependentMultitaskGPRModel, MultitaskGPRModel
from .ordinal_gp import OrdinalGPModel
from .pairwise_probit import PairwiseProbitModel
from .semi_p import (
HadamardSemiPModel,
semi_p_posterior_transform,
Expand All @@ -26,7 +25,6 @@
"GPClassificationModel",
"MonotonicRejectionGP",
"GPRegressionModel",
"PairwiseProbitModel",
"OrdinalGPModel",
"MonotonicProjectionGP",
"MultitaskGPRModel",
Expand Down
Loading
Loading