Skip to content

Commit

Permalink
[einsum] Fix opt_einsum defaults to be more reasonable (#86985)
Browse files Browse the repository at this point in the history
Fixes the confusing situation mentioned here #85224 (comment) by

- setting better OG defaults
- changing warnings to errors now that we have better defaults

Test plan:
- Ran einsum tests locally + CI
- Uninstalled opt-einsum and ran through setting
     - `enabled` to False (doesn't throw error)
     - `strategy` to anything that's not None (errors)
     - `strategy` to None (noops)
- Installed opt-einsum and ran through setting
     - `enabled` to False (doesn't throw error)
     - `enabled` to True (doesn't throw error, no ops + defaults to 'auto')
     - `strategy` to random string (errors)
     - `strategy` to None (noops, still is 'auto')
     - `strategy` to 'greedy' (is set to 'greedy')
Pull Request resolved: #86985
Approved by: https://github.com/soulitzer
  • Loading branch information
janeyx99 committed Oct 17, 2022
1 parent 492d572 commit 7612969
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions torch/backends/opt_einsum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def get_opt_einsum() -> Any:

def _set_enabled(_enabled: bool) -> None:
if not is_available() and _enabled:
warnings.warn('opt_einsum is not available, so setting `enabled` to True will not reap '
'the benefits of calculating an optimal path for einsum. torch.einsum will '
'fall back to contracting from left to right. To enable this optimal path '
'calculation, please install opt-einsum.')
raise ValueError(f'opt_einsum is not available, so setting `enabled` to {_enabled} will not reap '
'the benefits of calculating an optimal path for einsum. torch.einsum will '
'fall back to contracting from left to right. To enable this optimal path '
'calculation, please install opt-einsum.')
global enabled
enabled = _enabled

Expand All @@ -38,12 +38,13 @@ def _get_enabled() -> bool:

def _set_strategy(_strategy: str) -> None:
if not is_available():
raise ValueError('opt_einsum is not available, so `strategy` cannot be set. Please install opt-einsum or '
'unset `strategy`.')
raise ValueError(f'opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. '
'torch.einsum will bypass path calculation and simply contract from left to right. '
'Please install opt_einsum or unset `strategy`.')
if not enabled:
warnings.warn('opt_einsum is not enabled, so setting a `strategy` will not make a meaningful change. '
'torch.einsum will bypass path calculation and simply contract from left to right. '
'Please set `enabled` to `True` as well or unset `strategy`.')
raise ValueError(f'opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. '
'torch.einsum will bypass path calculation and simply contract from left to right. '
'Please set `enabled` to `True` as well or unset `strategy`.')
if _strategy not in ['auto', 'greedy', 'optimal']:
raise ValueError(f'`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}')
global strategy
Expand All @@ -64,7 +65,7 @@ def set_flags(_enabled=None, _strategy=None):


@contextmanager
def flags(enabled=False, strategy='auto'):
def flags(enabled=None, strategy=None):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(enabled, strategy)
try:
Expand Down Expand Up @@ -94,5 +95,5 @@ def __init__(self, m, name):
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__)

enabled = True
strategy = 'auto'
enabled = True if is_available() else False
strategy = 'auto' if is_available() else None

0 comments on commit 7612969

Please sign in to comment.