Closed
Description
The dtype promotion check in _fix_promotion
does not correctly identify scalar inputs, and unconditionally accesses .dtype
.
This breaks binary operators with float
scalar inputs.
The can be fixed by accessing dtype via getattr
with a None
default or validating that the input is not a scalar.
Happy to provide a PR.
Minimal repo, in version 1.4, via:
import torch
import numpy
import array_api_compat as aac
aac.__version__ ()
t = torch.arange(10)
n = numpy.arange(10)
numpy.add(n, 1.0)
torch.add(t, 1.0)
aac.get_namespace(n).add(n, 1.0)
aac.get_namespace(t).add(t, 1.0)
Raises:
9 torch.add(t, 1.0)
11 aac.get_namespace(n).add(n, 1.0)
---> 12 aac.get_namespace(t).add(t, 1.0)
File ~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:91, in _two_arg.<locals>._f(x1, x2, **kwargs)
89 @wraps(f)
90 def _f(x1, x2, /, **kwargs):
---> 91 x1, x2 = _fix_promotion(x1, x2)
92 return f(x1, x2, **kwargs)
File ~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:104, in _fix_promotion(x1, x2, only_scalar)
103 def _fix_promotion(x1, x2, only_scalar=True):
--> 104 if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
105 return x1, x2
106 # If an argument is 0-D pytorch downcasts the other argument
AttributeError: 'float' object has no attribute 'dtype'
Would expect equivalent behavior to torch.add
.
See:
https://gist.github.com/asford/ee688d59f0747a6507b9670a83fa7c47
Metadata
Metadata
Assignees
Labels
No labels