Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor get_scalar_constant_value #643

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Oct 31, 2021

This PR adds missing canonicalizations and misc. updates that relate to get_scalar_constant_value.

get_scalar_constant_value uses a confusing and poorly defined abstraction: it attempts to return the non-symbolic value of a symbolic scalar graph and/or the "underlying" scalar value of a graph. The latter concept of an "underlying" scalar value is ambiguous and very limited, since it only applies in a few cases—like the outputs of Alloc Ops, because they use a scalar input (i.e. the "underlying" scalar value) to construct an array/tensor.

  • Replace uses of get_scalar_constant_value for Alloc and Second

Its implementation is a hard to follow (and maintain) set of conditions, iterations, and recursions—all of which are a mix of existing/missing constants-based canonicalizations and re-implemented constant folding steps.

The goal of this PR is to further the replacement/removal of get_scalar_constant_value with a more complete set of optimizations, and, when/if necessary, explicit use of constant folding.

This PR also implements some of the missing Shape-related canonicalizations mentioned in #642.

@brandonwillard brandonwillard added enhancement New feature or request graph rewriting refactor This issue involves refactoring labels Oct 31, 2021
@brandonwillard brandonwillard self-assigned this Oct 31, 2021
@brandonwillard brandonwillard force-pushed the refactor-get_scalar_constant_value branch 4 times, most recently from b208cec to 090661c Compare November 1, 2021 00:56
@codecov
Copy link

codecov bot commented Nov 1, 2021

Codecov Report

Merging #643 (3a3b122) into main (eedc5e8) will increase coverage by 0.04%.
The diff coverage is 90.84%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #643      +/-   ##
==========================================
+ Coverage   77.10%   77.15%   +0.04%     
==========================================
  Files         156      156              
  Lines       46909    47007      +98     
  Branches    10259    10281      +22     
==========================================
+ Hits        36170    36269      +99     
+ Misses       8158     8151       -7     
- Partials     2581     2587       +6     
Impacted Files Coverage Δ
aesara/configdefaults.py 71.39% <ø> (-0.08%) ⬇️
aesara/ifelse.py 49.71% <ø> (+0.14%) ⬆️
aesara/tensor/exceptions.py 100.00% <ø> (ø)
aesara/tensor/basic.py 85.57% <81.25%> (+0.22%) ⬆️
aesara/tensor/subtensor_opt.py 84.17% <86.66%> (+0.26%) ⬆️
aesara/tensor/basic_opt.py 84.71% <87.67%> (-0.16%) ⬇️
aesara/tensor/math_opt.py 85.83% <91.11%> (-0.33%) ⬇️
aesara/graph/features.py 65.95% <100.00%> (+3.32%) ⬆️
aesara/graph/fg.py 88.70% <100.00%> (ø)
aesara/graph/opt.py 63.21% <100.00%> (+0.04%) ⬆️
... and 10 more

aesara/tensor/basic.py Outdated Show resolved Hide resolved
aesara/tensor/basic.py Outdated Show resolved Hide resolved
aesara/tensor/basic.py Outdated Show resolved Hide resolved
@brandonwillard brandonwillard force-pushed the refactor-get_scalar_constant_value branch 2 times, most recently from a341c88 to ddd9ebe Compare November 10, 2021 19:39
* Convert tests to pytest `parametrized` tests
* Use `pytest.raises`
* Remove unused timing code and methods
These changes "flatten" the nested conditions that lead to the replacement
logic.

They also clarify the tests that are labeled as being for
`local_add_specialize`, which actually aren't.  Also, the unrelated tests
implicitly relied on the canonicalizations built into
`get_scalar_constant_value`; now, the actual canonicalizations are required in
anticipation of `get_scalar_constant_value`'s replacement.
@brandonwillard brandonwillard force-pushed the refactor-get_scalar_constant_value branch from ddd9ebe to 4c4eb86 Compare November 14, 2021 19:58
@brandonwillard brandonwillard force-pushed the refactor-get_scalar_constant_value branch from 4c4eb86 to 3a3b122 Compare November 14, 2021 23:21
@brandonwillard brandonwillard merged commit 620edab into aesara-devs:main Nov 15, 2021
@brandonwillard brandonwillard deleted the refactor-get_scalar_constant_value branch November 15, 2021 17:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting refactor This issue involves refactoring
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants