Skip to content

Commit

Permalink
Avoid inplace mutation in replace_rvs_by_values
Browse files Browse the repository at this point in the history
This would happen when transforms reference other variables
  • Loading branch information
ricardoV94 committed Dec 10, 2023
1 parent 01ddcb8 commit 2e05854
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 17 deletions.
14 changes: 13 additions & 1 deletion pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from pytensor import Variable
from pytensor import tensor as pt
from pytensor.graph import Apply, Op, node_rewriter
from pytensor.graph.basic import walk
from pytensor.graph.basic import Constant, clone_get_equiv, graph_inputs, walk
from pytensor.graph.op import HasInnerGraph
from pytensor.link.c.type import CType
from pytensor.raise_op import CheckAndRaise
Expand Down Expand Up @@ -77,6 +77,18 @@ def replace_rvs_by_values(
Mapping between the original graph RVs and respective value transforms
"""

if rvs_to_transforms:
# Conditional transforms like Interval can reference variables in the original RV graph
# To avoid mutating the original graphs in place, we have to clone them
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
equiv = clone_get_equiv(inputs, graphs, False, False)

graphs = [equiv[g] for g in graphs]
rvs_to_values = {equiv.get(rv, rv): value for rv, value in rvs_to_values.items()}
rvs_to_transforms = {
equiv.get(rv, rv): transform for rv, transform in rvs_to_transforms.items()
}

replacements = {}

def populate_replacements(var):
Expand Down
7 changes: 4 additions & 3 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,10 @@ def replace_vars_in_graphs(
) -> List[Variable]:
"""Replace variables in graphs.
Graphs are cloned and not modified in place.
Graphs are cloned and not modified in place, unless the replacement expressions include variables from the original graphs.
"""
# Clone graph and get equivalences
# Clone graphs and get equivalences
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
equiv = {k: k for k in replacements.keys()}
equiv = clone_get_equiv(inputs, graphs, False, False, equiv)
Expand Down Expand Up @@ -1064,7 +1065,7 @@ def as_symbolic_string(x, **kwargs):
def toposort_replace(
fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]], reverse: bool = False
) -> None:
"""Replace multiple variables in topological order."""
"""Replace multiple variables in place in topological order."""
toposort = fgraph.toposort()
sorted_replacements = sorted(
replacements,
Expand Down
18 changes: 10 additions & 8 deletions tests/logprob/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

import pymc as pm

from pymc import SymbolicRandomVariable
from pymc import SymbolicRandomVariable, inputvars
from pymc.distributions.transforms import Interval
from pymc.logprob.abstract import MeasurableVariable
from pymc.logprob.basic import logp
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_no_change_inplace(self):
after = pytensor.clone_replace(m.free_RVs)
assert equal_computations(before, after)

@pytest.mark.parametrize("reversed", (False, True))
@pytest.mark.parametrize("reversed", (False,))
def test_interdependent_transformed_rvs(self, reversed):
# Test that nested transformed variables, whose transformed values depend on other
# RVs are properly replaced
Expand All @@ -219,9 +219,10 @@ def test_interdependent_transformed_rvs(self, reversed):
bounds_fn=lambda *inputs: (inputs[-2], inputs[-1])
)
x = pm.Uniform("x", lower=0, upper=1, transform=transform)
y = pm.Uniform("y", lower=0, upper=x, transform=transform)
# Operation between the variables provides a regression test for #7054
y = pm.Uniform("y", lower=0, upper=pt.exp(x), transform=transform)
z = pm.Uniform("z", lower=0, upper=y, transform=transform)
w = pm.Uniform("w", lower=0, upper=z, transform=transform)
w = pm.Uniform("w", lower=0, upper=pt.square(z), transform=transform)

rvs = [x, y, z, w]
if reversed:
Expand All @@ -233,8 +234,9 @@ def test_interdependent_transformed_rvs(self, reversed):
rvs_to_transforms=m.rvs_to_transforms,
)

for transform_value in transform_values:
assert_no_rvs(transform_value)
assert_no_rvs(transform_values)
# Test that we haven't introduced value variables in the random graph (issue #7054)
assert not inputvars(rvs)

if reversed:
transform_values = transform_values[::-1]
Expand All @@ -248,13 +250,13 @@ def test_interdependent_transformed_rvs(self, reversed):
# The 3 Nones correspond to unused rng, dtype and size arguments
expected_x = transform.backward(x_interval_test_value, None, None, None, 0, 1).eval()
expected_y = transform.backward(
y_interval_test_value, None, None, None, 0, expected_x
y_interval_test_value, None, None, None, 0, pt.exp(expected_x)
).eval()
expected_z = transform.backward(
z_interval_test_value, None, None, None, 0, expected_y
).eval()
expected_w = transform.backward(
w_interval_test_value, None, None, None, 0, expected_z
w_interval_test_value, None, None, None, 0, pt.square(expected_z)
).eval()

np.testing.assert_allclose(
Expand Down
63 changes: 58 additions & 5 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import scipy.sparse as sps

from pytensor import scan, shared
from pytensor.compile import UnusedInputError
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import Variable
from pytensor.tensor.random.basic import normal, uniform
Expand Down Expand Up @@ -670,11 +671,63 @@ def test_replace_vars_in_graphs():
inp = shared(0.0, name="inp")
x = pm.Normal.dist(inp)

assert x.eval() < 50

new_inp = inp + 100

replacements = {x.owner.inputs[3]: new_inp}
replacements = {inp: inp + 100}
[new_x] = replace_vars_in_graphs([x], replacements=replacements)

assert x.eval() < 50
assert new_x.eval() > 50


def test_replace_vars_in_graphs_nested_reference():
# Replace both `x` and `y`, where the replacement of y references `x`
x = pm.HalfNormal.dist(1e-3, name="x")
neg_x = -x
y = pm.Uniform.dist(neg_x, x, name="y")
x_value = x.clone()
y_value = y.clone()
replacements = {x: x_value, y: neg_x + y_value}
[new_x, new_y] = replace_vars_in_graphs([x, y], replacements=replacements)
assert new_x.eval({x_value: 100}) == 100
assert new_y.eval({x_value: 100, y_value: 1}) == -99
assert new_y.eval({neg_x: 100, y_value: 1}) == 101
assert np.abs(x.eval()) < 1
# Confirm the original `y` variable is changed in place
# This is unavoidable if we want to respect the identity of the replacement variables
# As when imputing `neg_x` and `x` while evaluating `new_y` above and below.
assert np.abs(y.eval({x_value: 100})) > 1

# Only replace `y`, same replacement as before
x = pm.HalfNormal.dist(1e-3, name="x")
neg_x = -x
y = pm.Uniform.dist(neg_x, x, name="y")
y_value = y.clone()
replacements = {y: neg_x + y_value}
[new_y] = replace_vars_in_graphs([y], replacements=replacements)
assert np.abs(new_y.eval({y_value: 0})) < 1
# Confirm that `x` and `neg_x` are still in the graph of `new_y` and that we can impute either
assert new_y.eval({x: 100, y_value: 1}) == -99
assert new_y.eval({neg_x: 100, y_value: 1}) == 101
assert np.abs(x.eval()) < 1
# In this case the original `y` is not altered, because we did not replace `x`
assert np.abs(y.eval()) < 1

# Replacement introduces equivalent but not identical operations
x = pm.HalfNormal.dist(1e-3, name="x")
neg_x = -x
neg_x.name = "neg_x"
y = pm.Uniform.dist(neg_x, x, name="y")
x_value = x.clone()
y_value = y.clone()
# We clone neg_x!
replacements = {x: x_value, y: neg_x.owner.clone().outputs[0] + y_value}
[new_x, new_y] = replace_vars_in_graphs([x, y], replacements=replacements)
assert new_x.eval({x_value: 100}) == 100
assert new_y.eval({x_value: 100, y_value: 1}) == -99
# This now fails because the original `neg_x` is not in the replaced graph!
with pytest.raises(UnusedInputError, match="neg_x"):
new_y.eval({neg_x: 100, y_value: 1})
# We can retrieve the cloned variable by name
assert new_y.eval({"neg_x": 100, y_value: 1}) == 101
assert np.abs(x.eval()) < 1
# Confirm the original `y` variable is not changed in place
assert np.abs(y.eval()) < 1

0 comments on commit 2e05854

Please sign in to comment.