Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Detailed Documentation for ZeroSumNormal Distribution #7433

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 69 additions & 20 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def backward(self, value, *inputs):
def forward(self, value, *inputs):
y = pt.zeros(value.shape)
y = pt.set_subtensor(y[..., 0], value[..., 0])
y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1]))
log_value = pt.log(value[..., 1:] - value[..., :-1])
y = pt.set_subtensor(y[..., 1:], log_value)
return y

def log_jac_det(self, value, *inputs):
Expand All @@ -116,8 +117,9 @@ def log_jac_det(self, value, *inputs):

class SumTo1(Transform):
"""
Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of values in [0,1]
This Transformation operates on the last dimension of the input tensor.
Transforms K - 1 dimensional simplex space (k values in [0,1] and that
sum to 1) to a K - 1 vector of values in [0,1]. This Transformation
operates on the last dimension of the input tensor.
"""

name = "sumto1"
Expand All @@ -140,7 +142,8 @@ def log_jac_det(self, value, *inputs):

class CholeskyCovPacked(Transform):
"""
Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the
Transforms the diagonal elements of
the LKJCholeskyCov distribution to be on the
log scale
"""

Expand All @@ -157,10 +160,14 @@ def __init__(self, n):
self.diag_idxs = pt.arange(1, n + 1).cumsum() - 1

def backward(self, value, *inputs):
return pt.set_subtensor(value[..., self.diag_idxs], pt.exp(value[..., self.diag_idxs]))
diag_values = value[..., self.diag_idxs]
exp_values = pt.exp(diag_values)
return pt.set_subtensor(value[..., self.diag_idxs], exp_values)

def forward(self, value, *inputs):
return pt.set_subtensor(value[..., self.diag_idxs], pt.log(value[..., self.diag_idxs]))
diag_values = value[..., self.diag_idxs]
log_values = pt.log(diag_values)
return pt.set_subtensor(value[..., self.diag_idxs], log_values)

def log_jac_det(self, value, *inputs):
return pt.sum(value[..., self.diag_idxs], axis=-1)
Expand All @@ -180,8 +187,9 @@ def log_jac_det(self, value, *inputs):


class Interval(IntervalTransform):
"""Wrapper around :class:`pymc.logprob.transforms.IntervalTransform` for use in the
``transform`` argument of a random variable.
"""
Wrapper around :class:`pymc.logprob.transforms.IntervalTransform` for use
in the ``transform`` argument of a random variable.

Parameters
----------
Expand All @@ -192,15 +200,15 @@ class Interval(IntervalTransform):
Upper bound of the interval transform. Must be a constant finite value.
By default (``upper=None``), the interval is not bounded above.
bounds_fn : callable, optional
Alternative to lower and upper. Must return a tuple of lower and upper bounds
as a symbolic function of the respective distribution inputs. If one of lower or
upper is ``None``, the interval is unbounded on that edge.

.. warning:: Expressions returned by `bounds_fn` should depend only on the
distribution inputs or other constants. Expressions that depend on nonlocal
variables, such as other distributions defined in the model context will
likely break sampling.
Alternative to lower and upper. Must return a tuple of lower and upper
bounds as a symbolic function of the respective distribution inputs. If
one of lower or upper is ``None``,the interval is unbounded on
that edge.

.. warning:: Expressions returned by `bounds_fn` should depend only on
the distribution inputs or other constants. Expressions that depend
on nonlocal variables, such as other distributions defined in the
model context will likely break sampling.

Examples
--------
Expand All @@ -220,10 +228,14 @@ def get_bounds(rng, size, mu, sigma):
return 0, None

with pm.Model():
interval = pm.distributions.transforms.Interval(bounds_fn=get_bounds)
interval = pm.distributions.transforms.Interval(
bounds_fn=get_bounds
)

x = pm.Normal("x", transform=interval)

Create a lower-bounded interval transform that depends on a distribution parameter
Create a lower-bounded interval transform that depends on a
distribution parameter

.. code-block:: python

Expand Down Expand Up @@ -267,10 +279,47 @@ class ZeroSumTransform(Transform):
"""
Constrains any random samples to sum to zero along the user-provided ``zerosum_axes``.

This transform is useful when modeling distributions where the sum of certain dimensions
must be zero, such as in some types of constrained latent variable models or in certain
types of signal processing applications.

Parameters
----------
zerosum_axes : list of ints
Must be a list of integers (positive or negative).
zerosum_axes : list of int
List of integers specifying the axes along which the random samples should sum to zero.
Positive integers indicate dimensions in the standard order, while negative integers
can be used to reference dimensions from the end of the shape.

Examples
--------
Suppose you want to ensure that the last dimension of a tensor sums to zero. You can use
`ZeroSumTransform` as follows:

.. code-block:: python

import pymc as pm

with pm.Model() as model:
# Create a 2D variable with the last axis constrained to sum to zero
x = pm.Normal("x", shape=(10, 5), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[-1]))

Methods
-------
forward(value, *rv_inputs)
Transforms the input tensor to ensure that the specified axes sum to zero.

backward(value, *rv_inputs)
Computes the inverse transform to convert back to the original space where the sum was zero.

log_jac_det(value, *rv_inputs)
Returns the log Jacobian determinant of the transform. For this transform, it is zero.

Notes
-----
The `extend_axis` and `extend_axis_rev` methods are used internally to handle the transformation:
- `extend_axis`: Extends the axis by adding an additional element to ensure zero-sum constraint.
- `extend_axis_rev`: Reverses the extension operation applied by `extend_axis`.

"""

name = "zerosum"
Expand Down
Loading