Skip to content

Commit

Permalink
Add ZarrTrace
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Oct 16, 2024
1 parent 5352798 commit b25ae02
Showing 1 changed file with 302 additions and 0 deletions.
302 changes: 302 additions & 0 deletions pymc/backends/zarr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any

import numcodecs
import numpy as np
import zarr

from pytensor.tensor.variable import TensorVariable
from zarr.storage import BaseStore
from zarr.sync import Synchronizer

from pymc.backends.arviz import (
coords_and_dims_for_inferencedata,
find_constants,
find_observations,
)
from pymc.backends.base import BaseTrace
from pymc.model.core import Model, modelcontext
from pymc.step_methods.compound import (
BlockedStep,
CompoundStep,
StatsBijection,
get_stats_dtypes_shapes_from_steps,
)
from pymc.util import get_default_varnames


class ZarrChain(BaseTrace):
def __init__(
self,
store: BaseStore | MutableMapping,
stats_bijection: StatsBijection,
synchronizer: Synchronizer | None = None,
model: Model | None = None,
vars: Sequence[TensorVariable] | None = None,
test_point: Sequence[dict[str, np.ndarray]] | None = None,
):
super().__init__(name="zarr", model=model, vars=vars, test_point=test_point)
self.draw_idx = 0
self._posterior = zarr.open_group(
store, synchronizer=synchronizer, path="posterior", mode="a"
)
self._sample_stats = zarr.open_group(
store, synchronizer=synchronizer, path="sample_stats", mode="a"
)
self._sampling_state = zarr.open_group(
store, synchronizer=synchronizer, path="_sampling_state", mode="a"
)
self.stats_bijection = stats_bijection

def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None):
self.chain = chain

def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
chain = self.chain
draw_idx = self.draw_idx
for var_name, var_value in zip(self.varnames, self.fn(draw)):
self._posterior[var_name].set_orthogonal_selection(
(chain, draw_idx),
var_value,
)
for var_name, var_value in self.stats_bijection.map(stats).items():
self._sample_stats[var_name].set_orthogonal_selection(
(chain, draw_idx),
var_value,
)
self.draw_idx += 1

def record_sampling_state(self, step):
self._sampling_state.sampling_state.set_coordinate_selection(
self.chain, np.array([step.sampling_state], dtype="object")
)
self._sampling_state.draw_idx.set_coordinate_selection(self.chain, self.draw_idx)


FILL_VALUE_TYPE = float | int | bool | str | np.datetime64 | np.timedelta64 | None


def get_fill_value_and_codec(
dtype: Any,
) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]:
_dtype = np.dtype(dtype)
if np.issubdtype(_dtype, np.floating):
return (np.nan, _dtype, None)
elif np.issubdtype(_dtype, np.integer):
return (-1_000_000, _dtype, None)
elif np.issubdtype(_dtype, "bool"):
return (False, _dtype, None)
elif np.issubdtype(_dtype, "str"):
return ("", _dtype, None)
elif np.issubdtype(_dtype, "datetime64"):
return (np.datetime64(0), _dtype, None)
elif np.issubdtype(_dtype, "timedelta64"):
return (np.timedelta(0), _dtype, None)
else:
return (None, _dtype, numcodecs.Pickle())


class ZarrTrace:
def __init__(
self,
store: BaseStore | MutableMapping | None = None,
synchronizer: Synchronizer | None = None,
model: Model | None = None,
vars: Sequence[TensorVariable] | None = None,
include_transformed: bool = False,
):
model = modelcontext(model)
self.model = model

self.synchronizer = synchronizer
self.root = zarr.group(
store=store,
overwrite=True,
synchronizer=synchronizer,
)
self.coords, self.vars_to_dims = coords_and_dims_for_inferencedata(model)

if vars is None:
vars = model.unobserved_value_vars

unnamed_vars = {var for var in vars if var.name is None}
if unnamed_vars:
raise Exception(f"Can't trace unnamed variables: {unnamed_vars}")
self.varnames = get_default_varnames(
[var.name for var in vars], include_transformed=include_transformed
)
self.vars = [var for var in vars if var.name in self.varnames]

self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore")

# Get variable shapes. Most backends will need this
# information.
test_point = model.initial_point()
var_values = list(zip(self.varnames, self.fn(test_point)))
self.var_dtype_shapes = {var: (value.dtype, value.shape) for var, value in var_values}
self._is_base_setup = False

@property
def posterior(self):
return self.root.posterior

@property
def sample_stats(self):
return self.root.sample_stats

@property
def constant_data(self):
return self.root.constant_data

@property
def observed_data(self):
return self.root.observed_data

@property
def sampling_state(self):
return self.root.sampling_state

def init_trace(self, chains: int, draws: int, step: BlockedStep | CompoundStep):
self.create_group(
name="constant_data",
data_dict=find_constants(self.model),
)

self.create_group(
name="observed_data",
data_dict=find_observations(self.model),
)

self.init_group_with_empty(
group=self.root.create_group(name="posterior", overwrite=True),
var_dtype_and_shape=self.var_dtype_shapes,
chains=chains,
draws=draws,
)
stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps(
[step] if isinstance(step, BlockedStep) else step.methods
)
self.init_group_with_empty(
group=self.root.create_group(name="sample_stats", overwrite=True),
var_dtype_and_shape=stats_dtypes_shapes,
chains=chains,
draws=draws,
)

self.init_sampling_state_group(chains=chains)

self.straces = [
ZarrChain(
store=self.root.store,
synchronizer=self.synchronizer,
model=self.model,
vars=self.vars,
test_point=None,
stats_bijection=StatsBijection(step.stats_dtypes),
)
]
for chain, strace in enumerate(self.straces):
strace.setup(draws=draws, chain=chain, sampler_vars=None)

def close(self):
for strace in self.straces:
strace._posterior.close()
strace._sample_stats.close()
strace._sampling_state.close()
zarr.consolidate_metadata(self.root.store)
self.root.store.close()

def init_sampling_state_group(self, chains):
state = self.root.create_group(name="_sampling_state", overwrite=True)
sampling_state = state.empty(
name="sampling_state",
overwrite=True,
shape=(chains,),
chunks=(1,),
dtype="object",
object_codec=numcodecs.Pickle(),
)
sampling_state.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]})
draw_idx = state.array(
name="draw_idx",
overwrite=True,
data=np.zeros(chains, dtype="int"),
chunks=(1,),
dtype="int",
fill_value=-1,
)
draw_idx.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]})
chain = state.array(name="chain", data=range(chains))
chain.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]})

def init_group_with_empty(self, group, var_dtype_and_shape, chains, draws):
group_coords = {"chain": range(chains), "draw": range(draws)}
for name, (dtype, shape) in var_dtype_and_shape.items():
fill_value, dtype, object_codec = get_fill_value_and_codec(dtype)
shape = shape or ()
array = group.full(
name=name,
dtype=dtype,
fill_value=fill_value,
object_codec=object_codec,
shape=(chains, draws, *shape),
chunks=(1, 1, *shape),
)
try:
dims = self.vars_to_dims[name]
for dim in dims:
group_coords[dim] = self.coords[dim]
except KeyError:
dims = []
for i, shape_i in enumerate(shape):
dim = f"{name}_dim_{i}"
dims.append(dim)
group_coords[dim] = list(range(shape_i))
dims = ("chain", "draw", *dims)
array.attrs.update({"_ARRAY_DIMENSIONS": dims})
for dim, coord in group_coords.items():
array = group.array(name=dim, data=coord, fill_value=None)
array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
return group

def create_group(self, name, data_dict):
if data_dict:
group_coords = {}
group = self.root.create_group(name=name, overwrite=True)
for var_name, var_value in data_dict.items():
fill_value, dtype, object_codec = get_fill_value_and_codec(var_value.dtype)
array = group.array(
name=var_name,
data=var_value,
fill_value=fill_value,
dtype=dtype,
object_codec=object_codec,
)
try:
dims = self.vars_to_dims[var_name]
for dim in dims:
group_coords[dim] = self.coords[dim]
except KeyError:
dims = []
for i in range(var_value.ndim):
dim = f"{var_name}_dim_{i}"
dims.append(dim)
group_coords[dim] = list(range(var_value.shape[i]))
array.attrs.update({"_ARRAY_DIMENSIONS": dims})
for dim, coord in group_coords.items():
array = group.array(name=dim, data=coord, fill_value=None)
array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
return group

0 comments on commit b25ae02

Please sign in to comment.