Skip to content

Commit

Permalink
Identifying mixture sub-graphs defined with an IfElse Op
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama committed Aug 28, 2022
1 parent 8ce3c63 commit d0d4c7b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
10 changes: 5 additions & 5 deletions aeppl/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
node_rewriter,
pre_greedy_node_rewriter,
)
from aesara.ifelse import ifelse
from aesara.ifelse import IfElse, ifelse
from aesara.scalar.basic import Switch
from aesara.tensor.basic import Join, MakeVector
from aesara.tensor.elemwise import Elemwise
Expand Down Expand Up @@ -305,14 +305,14 @@ def mixture_replace(fgraph, node):
return [new_mixture_rv]


@node_rewriter((Elemwise,))
def switch_mixture_replace(fgraph, node):
@node_rewriter((Elemwise, IfElse))
def switch_ifelse_mixture_replace(fgraph, node):
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

if not isinstance(node.op.scalar_op, Switch):
if not isinstance(node.op, IfElse) and not isinstance(node.op.scalar_op, Switch):
return None # pragma: no cover

old_mixture_rv = node.default_output()
Expand Down Expand Up @@ -420,7 +420,7 @@ def logprob_MixtureRV(
logprob_rewrites_db.register(
"mixture_replace",
EquilibriumGraphRewriter(
[mixture_replace, switch_mixture_replace],
[mixture_replace, switch_ifelse_mixture_replace],
max_use_ratio=aesara.config.optdb__max_use_ratio,
),
0,
Expand Down
6 changes: 4 additions & 2 deletions tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest
import scipy.stats.distributions as sp
from aesara.ifelse import ifelse
from aesara.graph.basic import Variable, equal_computations
from aesara.tensor.random.basic import CategoricalRV
from aesara.tensor.shape import shape_tuple
Expand Down Expand Up @@ -715,7 +716,8 @@ def test_mixture_with_DiracDelta():
assert m_vv in logp_res


def test_switch_mixture():
@pytest.mark.parametrize("op", [at.switch, ifelse])
def test_switch_ifelse_mixture(op):
srng = at.random.RandomStream(29833)

X_rv = srng.normal(-10.0, 0.1, name="X")
Expand All @@ -725,7 +727,7 @@ def test_switch_mixture():
i_vv = I_rv.clone()
i_vv.name = "i"

Z1_rv = at.switch(I_rv, X_rv, Y_rv)
Z1_rv = op(I_rv, X_rv, Y_rv)
z_vv = Z1_rv.clone()
z_vv.name = "z1"

Expand Down

0 comments on commit d0d4c7b

Please sign in to comment.