Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce graph rewrite for mixture sub-graphs defined via IfElse Op #169

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

larryshamalama
Copy link
Contributor

@larryshamalama larryshamalama commented Aug 28, 2022

Closes #76.

Akin to #154, this PR introduces a node_rewriter for IfElse. Effectively, this builds on the recently added switch_mixture_replace to accommodate mixture sub-graphs as the same essence but defined with a different Op: IfElse. Below is an example of the new functionality.

import aesara.tensor as at
from aesara.ifelse import ifelse

from aeppl.joint_logprob import joint_logprob

srng = at.random.RandomStream(seed=2320)

I_rv = srng.bernoulli(0.5, name="I")
X_rv = srng.normal(-10, 0.1, name="X")
Y_rv = srng.normal(10, 0.1, name="Y")

Z_rv = ifelse(I_rv, X_rv, Y_rv)
Z_rv.name = "Z"

z_vv = Z_rv.clone()
i_vv = I_rv.clone()

logp = joint_logprob({Z_rv: z_vv, I_rv: i_vv})
print(logp.eval({z_vv: -10, i_vv: 0})) # 0.6904993773247732
print(logp.eval({z_vv: -10, i_vv: 1})) # -19999.309500622676

@codecov
Copy link

codecov bot commented Aug 28, 2022

Codecov Report

Base: 95.15% // Head: 94.94% // Decreases project coverage by -0.21% ⚠️

Coverage data is based on head (3458414) compared to base (0959489).
Patch coverage: 100.00% of modified lines in pull request are covered.

❗ Current head 3458414 differs from pull request most recent head 0dd44af. Consider uploading reports for the commit 0dd44af to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #169      +/-   ##
==========================================
- Coverage   95.15%   94.94%   -0.22%     
==========================================
  Files          12       12              
  Lines        2023     1878     -145     
  Branches      253      280      +27     
==========================================
- Hits         1925     1783     -142     
+ Misses         56       53       -3     
  Partials       42       42              
Impacted Files Coverage Δ
aeppl/mixture.py 97.72% <100.00%> (+1.13%) ⬆️
aeppl/printing.py 89.68% <0.00%> (-2.09%) ⬇️
aeppl/rewriting.py 94.00% <0.00%> (-0.18%) ⬇️
aeppl/tensor.py 85.71% <0.00%> (-0.14%) ⬇️
aeppl/logprob.py 98.02% <0.00%> (-0.05%) ⬇️
aeppl/transforms.py 96.43% <0.00%> (-0.02%) ⬇️
aeppl/scan.py 94.73% <0.00%> (ø)
aeppl/dists.py 94.50% <0.00%> (ø)
aeppl/cumsum.py 100.00% <0.00%> (ø)
aeppl/abstract.py 100.00% <0.00%> (ø)
... and 4 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@larryshamalama
Copy link
Contributor Author

I added many tests; I used test_hetero_mixture_categorical for inspiration on what to cover. I have yet to do a thorough filtering on what exactly is needed. I seem to be running into issues with variable lifting, since computations are not exactly identical, making equal_computations to yield False. The reasons for failing tests can be more fundamental; I will look into this slowly.

@larryshamalama larryshamalama marked this pull request as draft September 2, 2022 20:42
@larryshamalama
Copy link
Contributor Author

Here are some notes regarding why some tests are failing. The discrepancy between graphs occur here:

  • The first inputs to MixtureRV can be NoneConst or TensorConstant{0}.
  • Discrepancy in component shapes (e.g. TensorConstant{(1,) of -10.0} vs. TensorConstant{-10.0}).

This comment would serve as a reminder on what I am stuck on before the upcoming meeting.

@larryshamalama
Copy link
Contributor Author

larryshamalama commented Nov 24, 2022

I'm revisiting this PR slowly and after quite some time. I'm investigating one of my failing test cases and I probably have forgotten many details just due to time passing... Consider the following code.

import aesara
import aesara.tensor as at
from aeppl.rewriting import construct_ir_fgraph

srng = at.random.RandomStream(29833)

X_rv = srng.normal(loc=[10, 20], scale=0.1, size=(2,), name="X")
Y_rv = srng.normal(loc=[-10, -20], scale=0.1, size=(2,), name="Y")

I_rv = srng.bernoulli([0.9, 0.1], size=(2,), name="I")
i_vv = I_rv.clone()
i_vv.name = "i"

Z1_rv = at.switch(I_rv, X_rv, Y_rv)
z_vv = Z1_rv.clone()
z_vv.name = "z1"

fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
aesara.dprint(fgraph.outputs[0])

yields

SpecifyShape [id A]
 |MixtureRV{indices_end_idx=2, out_dtype='float64', out_broadcastable=(False,)} [id B]
 | |TensorConstant{0} [id C]
 | |bernoulli_rv{0, (0,), int64, False}.1 [id D] 'I'
 | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x16488AB20>) [id E]
 | | |TensorConstant{(1,) of 2} [id F]
 | | |TensorConstant{4} [id G]
 | | |TensorConstant{[0.9 0.1]} [id H]
 | |normal_rv{0, (0, 0), floatX, False}.1 [id I] 'X'
 | | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x1648899A0>) [id J]
 | | |TensorConstant{(1,) of 2} [id F]
 | | |TensorConstant{11} [id K]
 | | |TensorConstant{[10 20]} [id L]
 | | |TensorConstant{0.1} [id M]
 | |normal_rv{0, (0, 0), floatX, False}.1 [id N] 'Y'
 |   |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x16488A340>) [id O]
 |   |TensorConstant{(1,) of 2} [id F]
 |   |TensorConstant{11} [id K]
 |   |TensorConstant{[-10 -20]} [id P]
 |   |TensorConstant{0.1} [id M]
 |TensorConstant{2} [id Q]
bernoulli_rv{0, (0,), int64, False}.1 [id D] 'I'

Where does the SpecifyShape Op come from? My understanding is that the switch_mixture_replace rewrite gets called and returns MixtureRV{indices_end_idx=2, out_dtype='float64', out_broadcastable=(False,)}.0 in this case, so the SpecifyShape Op must come from elsewhere...

@rlouf
Copy link
Member

rlouf commented Nov 24, 2022

Hey, thanks for revisiting the PR! Do you mean running the code on this PR branch? If I run your code snippet on main I get the following result:

Elemwise{switch,no_inplace} [id A]
 |bernoulli_rv{0, (0,), int64, False}.1 [id B] 'I'
 | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F0239D0E880>) [id C]
 | |TensorConstant{(1,) of 2} [id D]
 | |TensorConstant{4} [id E]
 | |TensorConstant{[0.9 0.1]} [id F]
 |normal_rv{0, (0, 0), floatX, False}.1 [id G] 'X'
 | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F023B681B60>) [id H]
 | |TensorConstant{(1,) of 2} [id D]
 | |TensorConstant{11} [id I]
 | |TensorConstant{[10 20]} [id J]
 | |TensorConstant{0.1} [id K]
 |normal_rv{0, (0, 0), floatX, False}.1 [id L] 'Y'
   |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F0239D0DEE0>) [id M]
   |TensorConstant{(1,) of 2} [id D]
   |TensorConstant{11} [id I]
   |TensorConstant{[-10 -20]} [id N]
   |TensorConstant{0.1} [id K]

Have you considered dispatching Switch and Ifelse in separate functions for now? It may make it easier to progress without breaking Switch so you always have a working point of comparison. (you can keep the tests together)

You will also need to resolve the (small) merge conflict due to the new joint_logprob interface, but this should be easy.

@larryshamalama
Copy link
Contributor Author

Hey, thanks for revisiting the PR! Do you mean running the code on this PR branch? If I run your code snippet on main I get the following result:

Yes, running the code on this branch! The graph rewrite for Switch mixtures on main does not work for non-scalar components...

Have you considered dispatching Switch and Ifelse in separate functions for now? It may make it easier to progress without breaking Switch so you always have a working point of comparison. (you can keep the tests together)

Yes, but I felt like that graph rewrite for both would be very similar. I can separate them as I work through them, for now...

You will also need to resolve the (small) merge conflict due to the new joint_logprob interface, but this should be easy.

Okay sounds good!

@rlouf
Copy link
Member

rlouf commented Nov 24, 2022

Yes, but I felt like that graph rewrite for both would be very similar. I can separate them as I work through them, for now...

They likely will, and we may want to merge them later. But I think it would be easier for you to move from one stable state to another, changing one thing at a time, and always keeping a reference implementation (Switch) working.

@larryshamalama
Copy link
Contributor Author

I just rebased my code. I am working on having Switch/IfElse-induced mixture subgraphs yield the same canonical (or IR? the graph obtained after running switch_mixture_replace) graph as if it were defined as a Subtensor. The SpecifyShape Op still comes up in the example above and I believe that I should modify my rewrite such that it does not show up. Is there a way to check what sequence of graph rewrites have been applied?

@rlouf
Copy link
Member

rlouf commented Nov 24, 2022

Of course: https://aesara.readthedocs.io/en/latest/extending/graph_rewriting.html#detailed-profiling-of-aesara-rewrites.

Alternatively, you can add a breakpoint here, aesara.dprint(fgraph), note the index of the SpecifyShape node, then inspect the node using node = fgraph.toposort()[index] and then node.tag may contain information about the rewrite that created it.

@brandonwillard
Copy link
Member

brandonwillard commented Nov 24, 2022

Of course: https://aesara.readthedocs.io/en/latest/extending/graph_rewriting.html#detailed-profiling-of-aesara-rewrites.

Alternatively, you can add a breakpoint here, aesara.dprint(fgraph), note the index of the SpecifyShape node, then inspect the node using node = fgraph.toposort()[index] and then node.tag may contain information about the rewrite that created it.

There's also with aesara.config.change_flags(optimizer_verbose=True): .... I use that a lot.

In general, as @rlouf said, don't be afraid to put breakpoint()s wherever you want. Our pre-commit settings will prevent those from being added to commits, so you don't even need to worry about forgetting them.

@rlouf
Copy link
Member

rlouf commented Dec 6, 2022

Adding to this, you can set any reasonable IDE up so when you run tests it will open a debugger console whenever it hits a breakpoint or fails. If you don't have that in place already, spend some time setting it up; it was a huge boost in my productivity.

@larryshamalama
Copy link
Contributor Author

Adding to this, you can set any reasonable IDE up so when you run tests it will open a debugger console whenever it hits a breakpoint or fails. If you don't have that in place already, spend some time setting it up; it was a huge boost in my productivity.

Thanks for the tip. I also saw your recent related tweet 😅

As for this PR, I am thinking that it's best to close it to 1) split the tasks into smaller sub-PRs (I felt like too much was going on at once) and 2) address some other issues that came up. As for the latter, I divided them into subsections below. Any guidance would be helpful...

Reworking switch_mixture_replace

Firstly, switch_mixture_replace isn't all correct... I only considered scalar indices and components. For instance, to make equal_computations yield True, I used a as_nontensor_scalar wrapper for indices which would not work in general.

aeppl/aeppl/mixture.py

Lines 340 to 342 in 473c1e6

new_node = mix_op.make_node(
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
)

SpecifyShape Op

The appearance of the SpecifyShape Op seems to be new... perhaps due to this recent addition to Aesara? Maybe a good first step would be to replace out_broadcastable in MixtureRV with the corresponding static shapes, if available. Would this be a good first step?

aeppl/aeppl/mixture.py

Lines 180 to 197 in 473c1e6

class MixtureRV(Op):
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""
__props__ = ("indices_end_idx", "out_dtype", "out_broadcastable")
def __init__(self, indices_end_idx, out_dtype, out_broadcastable):
super().__init__()
self.indices_end_idx = indices_end_idx
self.out_dtype = out_dtype
self.out_broadcastable = out_broadcastable
def make_node(self, *inputs):
return Apply(
self, list(inputs), [TensorType(self.out_dtype, self.out_broadcastable)()]
)
def perform(self, node, inputs, outputs):
raise NotImplementedError("This is a stand-in Op.") # pragma: no cover

Mismatch in MixtureRV shapes generated by Switch vs. at.stack

With the hot fix replacing broadcastable with shape, the MixtureRV shapes seem to be different if they are generated by a Switch vs. Join. Is this because subtensors don't have static shape inference yet? That would be my guess (Aesara issue #922?), but I'm not sure. Below is an example that I created using this branch's additions.

import aesara.tensor as at
from aeppl.rewriting import construct_ir_fgraph
from aeppl.mixture import MixtureRV

srng = at.random.RandomStream(29833)

X_rv = srng.normal([10, 20], 0.1, size=(2,), name="X")
Y_rv = srng.normal([-10, -20], 0.1, size=(2,), name="Y")

I_rv = srng.bernoulli([0.99, 0.01], size=(2,), name="I")
i_vv = I_rv.clone()
i_vv.name = "i"

Z1_rv = at.switch(I_rv, X_rv, Y_rv)
z_vv = Z1_rv.clone()
z_vv.name = "z1"

fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})

assert isinstance(fgraph.outputs[0].owner.op, MixtureRV)
assert not hasattr(
    fgraph.outputs[0].tag, "test_value"
)  # aesara.config.compute_test_value == "off"
assert fgraph.outputs[0].name is None

Z1_rv.name = "Z1"

fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})

assert fgraph.outputs[0].name == "Z1-mixture"

# building the identical graph but with a stack to check that mixture computations are identical

Z2_rv = at.stack((X_rv, Y_rv))[I_rv]

fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv})

fgraph.outputs[0].type.shape # (2,)
fgraph2.outputs[0].type.shape # (None, None)

IfElse mixture subgraphs

Given that IfElse requires scalar conditions, maybe it would be good to start with them instead of refining switch-mixtures... Happy to hear any thoughts about these points above. I feel like there's a lot going on, and it can be challenging to address all at once (especially given that this is continuation from this summer's work...)

@rlouf
Copy link
Member

rlouf commented Dec 8, 2022

PRs that touch on core mechanisms in Aesara, or simply that implement big changes, can easily get frustrating. Breaking the problem down like you did is a great reaction to this situation. Do you mind if I keep it open and I come back to you later next week at least with some questions, maybe some insight?

@rlouf rlouf reopened this Dec 8, 2022
@larryshamalama
Copy link
Contributor Author

PRs that touch on core mechanisms in Aesara, or simply that implement big changes, can easily get frustrating. Breaking the problem down like you did is a great reaction to this situation. Do you mind if I keep it open and I come back to you later next week at least with some questions, maybe some insight?

Of course, not a problem at all!

@larryshamalama
Copy link
Contributor Author

PRs that touch on core mechanisms in Aesara, or simply that implement big changes, can easily get frustrating. Breaking the problem down like you did is a great reaction to this situation. Do you mind if I keep it open and I come back to you later next week at least with some questions, maybe some insight?

@rlouf Just a quick update that @brandonwillard and I conversed recently, hence the recent force-push. The current focus is to ensure that the current mixture indexing operations via at.switch and at.stack match what NumPy would yield with np.where. For now, we leverage mixture_replace in constructing tests that should be passing and then further optimize the switch_mixture_replace rewrite. Afterwards, we can add the IfElse rewrite which should be similar to switch_mixture_replace but the different length of branch shapes (with same dimension) would need a bit of thought.

@rlouf
Copy link
Member

rlouf commented Dec 17, 2022

Glad to hear this is back on track!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support mixtures defined with IfElse
3 participants