Skip to content

Commit

Permalink
Reland #2 of "Added {logical_not, trace} refs, moved logical ops to u…
Browse files Browse the repository at this point in the history
…se method overloads"

Pull Request resolved: pytorch#79819

Approved by: https://github.com/mruberry
  • Loading branch information
Chillee authored and pytorchmergebot committed Jun 20, 2022
1 parent 26b5129 commit f3665dd
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 27 deletions.
6 changes: 1 addition & 5 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,6 @@ def run_meta_crossref(
torch.mode: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::mode
torch.multinomial: {bf16, f32, f64}, # aten::multinomial, aten::multinomial.out
torch.mvlgamma: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense, aten::mvlgamma.out
torch.nanmean: {bf16, f16, f32, f64},
torch.nanquantile: {f32, f64},
torch.nn.functional.conv1d: {bf16, f32, f64, i64},
torch.nn.functional.conv2d: {bf16, f32, f64, i64},
torch.nn.functional.conv_transpose1d: {f32, f64, i64},
Expand Down Expand Up @@ -465,9 +463,9 @@ def run_meta_crossref(
torch.functional.cdist: {f32, f64},
torch.functional.tensordot: {bf16, f32, f64, i16, i32, i64, i8, u8},
torch.inner: {bf16, f32, f64, i16, i32, i64, i8, u8},
torch.logical_not: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8},
torch.nn.functional.cross_entropy: {bf16, f32, f64},
torch.nn.functional.interpolate: {bf16, f32, f64, u8},
torch.nanmean: {bf16, f16, f32, f64}, # TODO(chilli): Doesn't seem to work for some reason?
torch.nn.functional.nll_loss: {bf16, f32, f64}, # TODO
torch.linalg.pinv: {f32, f64},
torch.empty: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
Expand Down Expand Up @@ -627,8 +625,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
aten.log_sigmoid_forward.output: {bf16, f64, f32},
aten.logcumsumexp.default: {bf16, f64, f32},
aten.logcumsumexp.out: {bf16, f64, f32},
aten.logical_not.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.logical_not_.default: {bf16, f16, f64, f32},
aten.masked_select.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.masked_select.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.max_pool3d_with_indices.default: {f64, f32},
Expand Down
10 changes: 0 additions & 10 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,11 +1005,6 @@ def _fused_dropout_decomposition(input, p, generator=None):
return (res, mask)


@register_decomposition(aten.logical_not)
def logical_not(self: Tensor) -> Tensor:
return ~self.to(dtype=torch.bool)


@register_decomposition(aten.xlogy.Tensor)
@pw_cast_for_int_to_real
def xlogy(self: Tensor, other: Tensor) -> Tensor:
Expand Down Expand Up @@ -1166,11 +1161,6 @@ def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor:
return result.log().add(maxes_squeezed)


@register_decomposition(aten.trace.default)
def trace(self: Tensor) -> Tensor:
return torch.sum(torch.diag(self))


# nb: Should use acc_t, not op_math
@register_decomposition(aten.log_sigmoid_forward)
@out_wrapper_multi('output', 'buffer')
Expand Down
8 changes: 7 additions & 1 deletion torch/_prims/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ def torch_to_refs_map():
(torch.nn.functional, torch._refs.nn.functional),
(torch.special, torch._refs.special),
]
r = {}
r: Dict[Any, Any] = {
torch.Tensor.__invert__: torch._refs.bitwise_not,
torch.Tensor.__xor__: torch._refs.bitwise_xor,
torch.Tensor.__and__: torch._refs.bitwise_and,
torch.Tensor.__or__: torch._refs.bitwise_or,
torch.Tensor.__eq__: torch._refs.eq,
}
for mod_torch, mod_refs in modules:
for s in mod_refs.__all__: # type: ignore[attr-defined]
r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)
Expand Down
38 changes: 28 additions & 10 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
"square",
"tan",
"tanh",
"trace",
#
# Elementwise Binary References
#
Expand Down Expand Up @@ -119,6 +120,7 @@
# 'ldexp',
"le",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"lt",
Expand Down Expand Up @@ -996,10 +998,10 @@ def _lcm(a: TensorLikeType, b: TensorLikeType):

def _logical_and(a: TensorLikeType, b: TensorLikeType):
if not utils.is_boolean_dtype(a.dtype):
a = ne(a, 0)
a = a != 0
if not utils.is_boolean_dtype(b.dtype):
b = ne(b, 0)
return bitwise_and(a, b)
b = b != 0
return a & b


logical_and = _make_elementwise_binary_reference(
Expand All @@ -1009,12 +1011,21 @@ def _logical_and(a: TensorLikeType, b: TensorLikeType):
)


@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, aten_op=torch.ops.aten.logical_not
)
def logical_not(a: TensorLikeType):
if not utils.is_boolean_dtype(a.dtype):
return a == 0
return ~a


def _logical_or(a: TensorLikeType, b: TensorLikeType):
if not utils.is_boolean_dtype(a.dtype):
a = ne(a, 0)
a = a != 0
if not utils.is_boolean_dtype(b.dtype):
b = ne(b, 0)
return bitwise_or(a, b)
b = b != 0
return a | b


logical_or = _make_elementwise_binary_reference(
Expand All @@ -1026,10 +1037,10 @@ def _logical_or(a: TensorLikeType, b: TensorLikeType):

def _logical_xor(a: TensorLikeType, b: TensorLikeType):
if not utils.is_boolean_dtype(a.dtype):
a = ne(a, 0)
a = a != 0
if not utils.is_boolean_dtype(b.dtype):
b = ne(b, 0)
return bitwise_xor(a, b)
b = b != 0
return a ^ b


# TODO: skip unnecessary conversion of long to float
Expand Down Expand Up @@ -2614,6 +2625,13 @@ def equal(a: TensorLikeType, b: TensorLikeType) -> bool:
return item(all(eq(a, b))) # type: ignore[return-value]


# populate the decomp table
@register_decomposition(torch.ops.aten.trace)
def trace(self: TensorLikeType) -> TensorLikeType:
utils.check(
self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
)
return torch.sum(torch.diag(self, 0))


import torch._refs.nn.functional
import torch._refs.special
26 changes: 25 additions & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3662,6 +3662,10 @@ def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs):
requires_grad=requires_grad))),)


def error_inputs_trace(op, device):
yield ErrorInput(SampleInput(make_tensor((3, 4, 5), dtype=torch.float32, device=device)), error_regex="expected a matrix")


def sample_inputs_renorm(self, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
cases = (((S, S, S), (2, 1, 0.5)),
Expand Down Expand Up @@ -4330,7 +4334,6 @@ def error_inputs_embedding(op_info, device, **kwargs):
def error_inputs_t(op_info, device, **kwargs):
yield ErrorInput(
SampleInput(torch.randn(2, 3, 4, 5, device=device)),
error_type=RuntimeError,
error_regex="expects a tensor with <= 2",
)

Expand Down Expand Up @@ -17634,6 +17637,7 @@ def error_inputs_mean(op_info, device, **kwargs):
dtypes=all_types_and_complex(),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
error_inputs_func=error_inputs_trace,
supports_inplace_autograd=False,
supports_out=False,
supports_forward_ad=True,
Expand Down Expand Up @@ -20620,6 +20624,16 @@ def __init__(
),
)
),
ElementwiseUnaryPythonRefInfo(
"_refs.logical_not",
torch_opinfo_name="logical_not",
skips=(
DecorateInfo(
# NotImplementedError: argument of type: <class 'complex'>
unittest.skip("Fails aten complex and nvfuser doesn't support eq(a, 0)"), 'TestCommon', 'test_python_ref_executor'
),
)
),
ElementwiseBinaryPythonRefInfo(
"_refs.logical_or",
torch_opinfo_name="logical_or",
Expand Down Expand Up @@ -21193,6 +21207,16 @@ def __init__(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
),
),
PythonRefInfo(
"_refs.trace",
torch_opinfo_name="trace",
decorators=(
# TODO: torch.diag is currently not supported by either refs, meta funcs, or NVFuser
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
DecorateInfo(unittest.skip("diag is not supported by meta"), 'TestCommon', 'test_python_ref_meta'),
DecorateInfo(unittest.skip("diag is not supported by nvfuser"), 'TestCommon', 'test_python_ref_executor'),
),
),
#
# Tensor Creation Reference OpInfos
#
Expand Down

0 comments on commit f3665dd

Please sign in to comment.