Skip to content

Commit

Permalink
Allow vector-valued indices for switch/ifelse mixture sub-graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama authored and brandonwillard committed Dec 16, 2022
1 parent 21787e9 commit 8c9c0f3
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 65 deletions.
61 changes: 20 additions & 41 deletions aeppl/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,48 +313,21 @@ def switch_mixture_replace(fgraph, node):
return None # pragma: no cover

old_mixture_rv = node.default_output()
# idx, component_1, component_2 = node.inputs

mixture_rvs = []

for component_rv in node.inputs[1:]:
if not (
component_rv.owner
and isinstance(component_rv.owner.op, MeasurableVariable)
and component_rv not in rv_map_feature.rv_values
):
return None
new_node = assign_custom_measurable_outputs(component_rv.owner)
out_idx = component_rv.owner.outputs.index(component_rv)
new_comp_rv = new_node.outputs[out_idx]
mixture_rvs.append(new_comp_rv)

"""
Unlike mixtures generated via at.stack, there is only one condition, i.e. index
for switch/ifelse-defined mixture sub-graphs. However, this condition can be
non-scalar for Switch Ops.
"""
mix_op = MixtureRV(
2,
old_mixture_rv.dtype,
old_mixture_rv.broadcastable,
)
new_node = mix_op.make_node(
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
# Add an extra dimension to the indices so that the `MixtureRV` we
# construct represents a valid
# `at.stack(node.inputs[1:])[f(node.inputs[0])]`, for some function `f`,
# that's equivalent to `at.switch(*node.inputs)`.
out_shape = at.broadcast_shape(
*(tuple(v.shape) for v in node.inputs[1:]), arrays_are_shapes=True
)
switch_indices = (node.inputs[0],) + tuple(at.arange(s) for s in out_shape)

new_mixture_rv = new_node.default_output()

if aesara.config.compute_test_value != "off":
if not hasattr(old_mixture_rv.tag, "test_value"):
compute_test_value(node)

new_mixture_rv.tag.test_value = old_mixture_rv.tag.test_value

if old_mixture_rv.name:
new_mixture_rv.name = f"{old_mixture_rv.name}-mixture"
# Construct the proxy/intermediate mixture representation
switch_stack = at.stack(node.inputs[::-1])[switch_indices]
switch_stack.name = old_mixture_rv.name

return [new_mixture_rv]
return mixture_replace.transform(fgraph, switch_stack.owner)


@node_rewriter((IfElse,))
Expand Down Expand Up @@ -394,9 +367,15 @@ def ifelse_mixture_replace(fgraph, node):
old_mixture_rv.dtype,
old_mixture_rv.broadcastable,
)
new_node = mix_op.make_node(
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
)

if node.inputs[0].ndim == 0:
# as_nontensor_scalar to allow graphs to be identical to mixture sub-graphs
# created using at.stack and Subtensor indexing
new_node = mix_op.make_node(
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
)
else:
new_node = mix_op.make_node(*([at.constant(0), node.inputs[0]] + mixture_rvs))

new_mixture_rv = new_node.default_output()

Expand Down
133 changes: 109 additions & 24 deletions tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,25 +233,6 @@ def test_hetero_mixture_binomial(p_val, size):
(),
0,
),
(
(
np.array(0, dtype=aesara.config.floatX),
np.array(1, dtype=aesara.config.floatX),
),
(
np.array(0.5, dtype=aesara.config.floatX),
np.array(0.5, dtype=aesara.config.floatX),
),
(
np.array(100, dtype=aesara.config.floatX),
np.array(1, dtype=aesara.config.floatX),
),
np.array([0.1, 0.5, 0.4], dtype=aesara.config.floatX),
(),
(),
(),
0,
),
(
(
np.array(0, dtype=aesara.config.floatX),
Expand Down Expand Up @@ -683,14 +664,118 @@ def test_mixture_with_DiracDelta():
assert M_rv in logp_res


@pytest.mark.parametrize("op", [at.switch, ifelse])
def test_switch_ifelse_mixture(op):
@pytest.mark.parametrize(
"op, X_args, Y_args, p_val, comp_size, idx_size",
[
[op] + list(test_args)
for op in [at.switch, ifelse]
for test_args in [
(
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(),
(),
),
(
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(),
(6,),
),
(
(
np.array([10, 20], dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array([-10, -20], dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array([0.9, 0.1], dtype=aesara.config.floatX),
(2,),
(2,),
),
(
(
np.array([10, 20], dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array([-10, -20], dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array([0.9, 0.1], dtype=aesara.config.floatX),
None,
None,
),
(
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(2, 3),
(2, 3),
),
(
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(2, 3),
(),
),
(
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(3,),
(3,),
),
]
if not ((test_args[-1] is None or len(test_args[-1]) > 0) and op == ifelse)
],
)
def test_switch_ifelse_mixture(op, X_args, Y_args, p_val, comp_size, idx_size):
"""
The argument size is both the input to srng.normal and the expected
size of the mixture RV Z1_rv
"""
srng = at.random.RandomStream(29833)

X_rv = srng.normal(-10.0, 0.1, name="X")
Y_rv = srng.normal(10.0, 0.1, name="Y")
X_rv = srng.normal(*X_args, size=comp_size, name="X")
Y_rv = srng.normal(*Y_args, size=comp_size, name="Y")

I_rv = srng.bernoulli(0.5, name="I")
I_rv = srng.bernoulli(p_val, size=idx_size, name="I")
i_vv = I_rv.clone()
i_vv.name = "i"

Expand Down

0 comments on commit 8c9c0f3

Please sign in to comment.