Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Scan optimizations and fix Scan + RandomVariable shape inference issues #635

Merged

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Oct 28, 2021

This PR refactors some of the Scan optimizations in order to fix Scan + RandomVariable shape inference issues. These changes are a part of #584 and are a solution to #608.

  • Replace use of FunctionGraph.replace* in add_nitsot_outputs
    This isn't directly related to Fix PushOutNonSeqScan + RandomVariable shape inference error in Scans #608, so it can be taken care of later.

  • Use SpecifyShape with the RandomVariable's size parameter
    This makes it possible for Scan's shape inference to track the shape of the cloned non-sequence variables it makes out of RandomVariable size parameters.

  • Add canonicalizations for SpecifyShape
    These are necessary for broadcastable inference in RandomVariable when SpecifyShape is used.

  • Convert global Scan optimizations to local optimizations.

    It looks like the current Scan optimizations mutate the FunctionGraphs in-place in different stages, and this might be leading to inconsistent intermediate graph states that fail when the shape inference optimization is performed.

    This seems to be the case for the MWE in Fix PushOutNonSeqScan + RandomVariable shape inference error in Scans #608, because shape inference errors occur when the cloned size parameter in the Scan's body (i.e. the size variable in the graph constructed by the scan_body function) is encountered. It should be possible to assign the shape of size_at to that variable in the ShapeFeature; however, the outer size_at is nowhere to be found in the FunctionGraph that's being optimized—implying that a valid Scan isn't being used to replace the original Scan, but that the original Scan is possibly being changed in multiple FunctionGraph.replace* steps.

    From inspection of aesara.graph.opt, it's clear that this is a distinct possibility, since most of the optimizations are global optimizations with multiple FunctionGraph.replace_all* calls in a single optimization. This approach to graph rewriting is rather undesirable for the reasons mentioned above, so this PR will attempt to remedy the situation.

@brandonwillard brandonwillard self-assigned this Oct 28, 2021
@brandonwillard brandonwillard added important refactor This issue involves refactoring Scan Involves the `Scan` `Op` labels Oct 28, 2021
@brandonwillard brandonwillard marked this pull request as draft October 28, 2021 22:04
@codecov
Copy link

codecov bot commented Oct 28, 2021

Codecov Report

Merging #635 (afefb01) into main (cbf9112) will decrease coverage by 0.01%.
The diff coverage is 86.44%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #635      +/-   ##
==========================================
- Coverage   77.16%   77.14%   -0.02%     
==========================================
  Files         156      156              
  Lines       47007    47022      +15     
  Branches    10281    10282       +1     
==========================================
+ Hits        36271    36277       +6     
- Misses       8150     8157       +7     
- Partials     2586     2588       +2     
Impacted Files Coverage Δ
aesara/compile/ops.py 83.57% <ø> (-0.12%) ⬇️
aesara/graph/opt.py 63.21% <0.00%> (ø)
aesara/scan/op.py 81.98% <ø> (ø)
aesara/scan/opt.py 81.30% <ø> (-0.75%) ⬇️
aesara/tensor/type.py 91.21% <ø> (-0.03%) ⬇️
aesara/scan/utils.py 87.25% <50.00%> (ø)
aesara/tensor/shape.py 88.40% <60.00%> (-0.42%) ⬇️
aesara/tensor/basic_opt.py 84.67% <77.77%> (-0.04%) ⬇️
aesara/tensor/subtensor_opt.py 84.33% <91.30%> (+0.16%) ⬆️
aesara/link/numba/dispatch/random.py 100.00% <100.00%> (ø)
... and 6 more

@brandonwillard brandonwillard force-pushed the fix-scalar-shape-cloning branch 3 times, most recently from 6c61870 to 3bc437e Compare October 29, 2021 17:52
@brandonwillard brandonwillard force-pushed the fix-scalar-shape-cloning branch 2 times, most recently from dc58c1c to f540230 Compare November 15, 2021 19:39
@brandonwillard brandonwillard marked this pull request as ready for review November 15, 2021 19:39
@brandonwillard brandonwillard added the bug Something isn't working label Nov 15, 2021
@brandonwillard brandonwillard force-pushed the fix-scalar-shape-cloning branch 2 times, most recently from 8c17336 to 306528a Compare November 15, 2021 22:31
@brandonwillard brandonwillard force-pushed the fix-scalar-shape-cloning branch from 306528a to e7f48be Compare November 15, 2021 22:40
@brandonwillard brandonwillard force-pushed the fix-scalar-shape-cloning branch from e7f48be to eed637d Compare November 16, 2021 03:07
@brandonwillard brandonwillard changed the title Update Scan optimizations Update Scan optimizations and fix Scan + RandomVariable shape inference issues Nov 16, 2021
@brandonwillard brandonwillard force-pushed the fix-scalar-shape-cloning branch from eed637d to afefb01 Compare November 16, 2021 04:46
@brandonwillard brandonwillard merged commit 117b40c into aesara-devs:main Nov 18, 2021
@brandonwillard brandonwillard deleted the fix-scalar-shape-cloning branch November 18, 2021 16:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working graph rewriting important refactor This issue involves refactoring Scan Involves the `Scan` `Op`
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix PushOutNonSeqScan + RandomVariable shape inference error in Scans
1 participant