diff --git a/test/test_meta.py b/test/test_meta.py index cbf0fc05cd0cb..b6a8990fcd7e7 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -400,8 +400,6 @@ def run_meta_crossref( torch.mode: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::mode torch.multinomial: {bf16, f32, f64}, # aten::multinomial, aten::multinomial.out torch.mvlgamma: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense, aten::mvlgamma.out - torch.nanmean: {bf16, f16, f32, f64}, - torch.nanquantile: {f32, f64}, torch.nn.functional.conv1d: {bf16, f32, f64, i64}, torch.nn.functional.conv2d: {bf16, f32, f64, i64}, torch.nn.functional.conv_transpose1d: {f32, f64, i64}, @@ -465,9 +463,9 @@ def run_meta_crossref( torch.functional.cdist: {f32, f64}, torch.functional.tensordot: {bf16, f32, f64, i16, i32, i64, i8, u8}, torch.inner: {bf16, f32, f64, i16, i32, i64, i8, u8}, - torch.logical_not: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, torch.nn.functional.cross_entropy: {bf16, f32, f64}, torch.nn.functional.interpolate: {bf16, f32, f64, u8}, + torch.nanmean: {bf16, f16, f32, f64}, # TODO(chilli): Doesn't seem to work for some reason? torch.nn.functional.nll_loss: {bf16, f32, f64}, # TODO torch.linalg.pinv: {f32, f64}, torch.empty: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8}, @@ -627,8 +625,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None): aten.log_sigmoid_forward.output: {bf16, f64, f32}, aten.logcumsumexp.default: {bf16, f64, f32}, aten.logcumsumexp.out: {bf16, f64, f32}, - aten.logical_not.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, - aten.logical_not_.default: {bf16, f16, f64, f32}, aten.masked_select.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, aten.masked_select.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, aten.max_pool3d_with_indices.default: {f64, f32}, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 9c42256c420bb..153b94510d85f 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1005,11 +1005,6 @@ def _fused_dropout_decomposition(input, p, generator=None): return (res, mask) -@register_decomposition(aten.logical_not) -def logical_not(self: Tensor) -> Tensor: - return ~self.to(dtype=torch.bool) - - @register_decomposition(aten.xlogy.Tensor) @pw_cast_for_int_to_real def xlogy(self: Tensor, other: Tensor) -> Tensor: @@ -1166,11 +1161,6 @@ def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor: return result.log().add(maxes_squeezed) -@register_decomposition(aten.trace.default) -def trace(self: Tensor) -> Tensor: - return torch.sum(torch.diag(self)) - - # nb: Should use acc_t, not op_math @register_decomposition(aten.log_sigmoid_forward) @out_wrapper_multi('output', 'buffer') diff --git a/torch/_prims/context.py b/torch/_prims/context.py index a37d888b95048..c17b44efce890 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -27,7 +27,13 @@ def torch_to_refs_map(): (torch.nn.functional, torch._refs.nn.functional), (torch.special, torch._refs.special), ] - r = {} + r: Dict[Any, Any] = { + torch.Tensor.__invert__: torch._refs.bitwise_not, + torch.Tensor.__xor__: torch._refs.bitwise_xor, + torch.Tensor.__and__: torch._refs.bitwise_and, + torch.Tensor.__or__: torch._refs.bitwise_or, + torch.Tensor.__eq__: torch._refs.eq, + } for mod_torch, mod_refs in modules: for s in mod_refs.__all__: # type: ignore[attr-defined] r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index e731db5d5952c..e8b86e87af050 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -88,6 +88,7 @@ "square", "tan", "tanh", + "trace", # # Elementwise Binary References # @@ -119,6 +120,7 @@ # 'ldexp', "le", "logical_and", + "logical_not", "logical_or", "logical_xor", "lt", @@ -996,10 +998,10 @@ def _lcm(a: TensorLikeType, b: TensorLikeType): def _logical_and(a: TensorLikeType, b: TensorLikeType): if not utils.is_boolean_dtype(a.dtype): - a = ne(a, 0) + a = a != 0 if not utils.is_boolean_dtype(b.dtype): - b = ne(b, 0) - return bitwise_and(a, b) + b = b != 0 + return a & b logical_and = _make_elementwise_binary_reference( @@ -1009,12 +1011,21 @@ def _logical_and(a: TensorLikeType, b: TensorLikeType): ) +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, aten_op=torch.ops.aten.logical_not +) +def logical_not(a: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + return a == 0 + return ~a + + def _logical_or(a: TensorLikeType, b: TensorLikeType): if not utils.is_boolean_dtype(a.dtype): - a = ne(a, 0) + a = a != 0 if not utils.is_boolean_dtype(b.dtype): - b = ne(b, 0) - return bitwise_or(a, b) + b = b != 0 + return a | b logical_or = _make_elementwise_binary_reference( @@ -1026,10 +1037,10 @@ def _logical_or(a: TensorLikeType, b: TensorLikeType): def _logical_xor(a: TensorLikeType, b: TensorLikeType): if not utils.is_boolean_dtype(a.dtype): - a = ne(a, 0) + a = a != 0 if not utils.is_boolean_dtype(b.dtype): - b = ne(b, 0) - return bitwise_xor(a, b) + b = b != 0 + return a ^ b # TODO: skip unnecessary conversion of long to float @@ -2614,6 +2625,13 @@ def equal(a: TensorLikeType, b: TensorLikeType) -> bool: return item(all(eq(a, b))) # type: ignore[return-value] -# populate the decomp table +@register_decomposition(torch.ops.aten.trace) +def trace(self: TensorLikeType) -> TensorLikeType: + utils.check( + self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}" + ) + return torch.sum(torch.diag(self, 0)) + + import torch._refs.nn.functional import torch._refs.special diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b9f0476b31886..2b4b59eeaf15a 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3662,6 +3662,10 @@ def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs): requires_grad=requires_grad))),) +def error_inputs_trace(op, device): + yield ErrorInput(SampleInput(make_tensor((3, 4, 5), dtype=torch.float32, device=device)), error_regex="expected a matrix") + + def sample_inputs_renorm(self, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) cases = (((S, S, S), (2, 1, 0.5)), @@ -4330,7 +4334,6 @@ def error_inputs_embedding(op_info, device, **kwargs): def error_inputs_t(op_info, device, **kwargs): yield ErrorInput( SampleInput(torch.randn(2, 3, 4, 5, device=device)), - error_type=RuntimeError, error_regex="expects a tensor with <= 2", ) @@ -17634,6 +17637,7 @@ def error_inputs_mean(op_info, device, **kwargs): dtypes=all_types_and_complex(), dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + error_inputs_func=error_inputs_trace, supports_inplace_autograd=False, supports_out=False, supports_forward_ad=True, @@ -20620,6 +20624,16 @@ def __init__( ), ) ), + ElementwiseUnaryPythonRefInfo( + "_refs.logical_not", + torch_opinfo_name="logical_not", + skips=( + DecorateInfo( + # NotImplementedError: argument of type: + unittest.skip("Fails aten complex and nvfuser doesn't support eq(a, 0)"), 'TestCommon', 'test_python_ref_executor' + ), + ) + ), ElementwiseBinaryPythonRefInfo( "_refs.logical_or", torch_opinfo_name="logical_or", @@ -21193,6 +21207,16 @@ def __init__( DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), ), ), + PythonRefInfo( + "_refs.trace", + torch_opinfo_name="trace", + decorators=( + # TODO: torch.diag is currently not supported by either refs, meta funcs, or NVFuser + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + DecorateInfo(unittest.skip("diag is not supported by meta"), 'TestCommon', 'test_python_ref_meta'), + DecorateInfo(unittest.skip("diag is not supported by nvfuser"), 'TestCommon', 'test_python_ref_executor'), + ), + ), # # Tensor Creation Reference OpInfos #