Skip to content

torch binary operations broken for scalar inputs in _fix_promotion #85

Closed
@asford

Description

@asford

The dtype promotion check in _fix_promotion does not correctly identify scalar inputs, and unconditionally accesses .dtype.
This breaks binary operators with float scalar inputs.

The can be fixed by accessing dtype via getattr with a None default or validating that the input is not a scalar.
Happy to provide a PR.

Minimal repo, in version 1.4, via:

import torch
import numpy
import array_api_compat as aac

aac.__version__ ()

t = torch.arange(10)
n = numpy.arange(10)

numpy.add(n, 1.0)
torch.add(t, 1.0)

aac.get_namespace(n).add(n, 1.0)
aac.get_namespace(t).add(t, 1.0)

Raises:

      9 torch.add(t, 1.0)
     11 aac.get_namespace(n).add(n, 1.0)
---> 12 aac.get_namespace(t).add(t, 1.0)

File ~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:91, in _two_arg.<locals>._f(x1, x2, **kwargs)
     89 @wraps(f)
     90 def _f(x1, x2, /, **kwargs):
---> 91     x1, x2 = _fix_promotion(x1, x2)
     92     return f(x1, x2, **kwargs)

File ~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:104, in _fix_promotion(x1, x2, only_scalar)
    103 def _fix_promotion(x1, x2, only_scalar=True):
--> 104     if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
    105         return x1, x2
    106     # If an argument is 0-D pytorch downcasts the other argument

AttributeError: 'float' object has no attribute 'dtype'

Would expect equivalent behavior to torch.add.

See:
https://gist.github.com/asford/ee688d59f0747a6507b9670a83fa7c47

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions