diff --git a/aten/src/ATen/native/Col2Im.cpp b/aten/src/ATen/native/Col2Im.cpp index 51e005c2901b93..0a98ee5c46ca77 100644 --- a/aten/src/ATen/native/Col2Im.cpp +++ b/aten/src/ATen/native/Col2Im.cpp @@ -144,7 +144,7 @@ static void col2im_out_cpu_template( output.resize_({batch_size, n_output_plane, output_height, output_width}); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "col2im_out_cpu", [&] { Tensor input_n = Tensor(); Tensor output_n = Tensor(); diff --git a/aten/src/ATen/native/Im2Col.cpp b/aten/src/ATen/native/Im2Col.cpp index 25eb4d6787240a..dac2ee6e3f103a 100644 --- a/aten/src/ATen/native/Im2Col.cpp +++ b/aten/src/ATen/native/Im2Col.cpp @@ -94,7 +94,7 @@ static void im2col_out_cpu_template( output.resize_({batch_size, n_output_plane, output_length}); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "im2col_out_cpu", [&] { Tensor input_n; Tensor output_n; diff --git a/aten/src/ATen/native/cuda/Col2Im.cu b/aten/src/ATen/native/cuda/Col2Im.cu index e390666b307aa1..bb6d4748deb1f5 100644 --- a/aten/src/ATen/native/cuda/Col2Im.cu +++ b/aten/src/ATen/native/cuda/Col2Im.cu @@ -102,7 +102,7 @@ void col2im_out_cuda_template( output.resize_({batch_size, n_output_plane, output_height, output_width}); int64_t output_batch_stride = output.stride(0); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "col2im_out_cuda", [&] { int64_t height_col = (output_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1)) / diff --git a/aten/src/ATen/native/cuda/Im2Col.cu b/aten/src/ATen/native/cuda/Im2Col.cu index 796eda97b37331..312ad893c0d81e 100644 --- a/aten/src/ATen/native/cuda/Im2Col.cu +++ b/aten/src/ATen/native/cuda/Im2Col.cu @@ -103,7 +103,7 @@ static void im2col_out_cuda_template( output.resize_({batch_size, n_output_plane, output_length}); // Launch kernel - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "im2col_out_cuda", [&] { Tensor input_n; Tensor output_n; diff --git a/docs/source/masked.rst b/docs/source/masked.rst index 086e1f6c8b25c2..00dadd2abf6885 100644 --- a/docs/source/masked.rst +++ b/docs/source/masked.rst @@ -283,11 +283,9 @@ The following ops are currently supported: kron meshgrid narrow - nn.functional.unfold ravel select split - stack t transpose vsplit @@ -296,7 +294,6 @@ The following ops are currently supported: Tensor.expand_as Tensor.reshape Tensor.reshape_as - Tensor.unfold Tensor.view Other functions diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py index 50d03b306bb0cc..0817822be457b2 100644 --- a/test/test_maskedtensor.py +++ b/test/test_maskedtensor.py @@ -68,18 +68,6 @@ def _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08): if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol): raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match") -def _compare_forward_backward(data, mask, fn): - mt = masked_tensor(data, mask, requires_grad=True) - masked_res = fn(mt) - masked_res.sum().backward() - - t = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_() - tensor_res = fn(t) - tensor_res.sum().backward() - - _compare_mt_t(masked_res, tensor_res) - _compare_mt_t(mt.grad, t.grad, atol=1e-06) - def _create_random_mask(shape, device): return make_tensor(shape, device=device, dtype=torch.bool) @@ -178,8 +166,15 @@ def test_softmax(self, device): ], device=device ) + mt = masked_tensor(data, mask, requires_grad=True) + masked_res = torch.softmax(mt, -1) + masked_res.sum().backward() + xinf = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_() + tensor_res = torch.softmax(xinf, -1) + tensor_res.sum().backward() - _compare_forward_backward(data, mask, lambda t: torch.softmax(t, -1)) + _compare_mt_t(masked_res, tensor_res) + _compare_mt_t(mt.grad, xinf.grad, atol=1e-06) def test_where(self, device): data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device) @@ -199,35 +194,6 @@ def test_where(self, device): _compare_mt_t(mx.grad, x.grad) _compare_mt_t(my.grad, y.grad) - def test_unfold(self, device): - data = torch.rand(5, 5, device=device) - mask = torch.rand(5, 5, device=device) > 0.5 - _compare_forward_backward(data, mask, lambda t: t.unfold(1, 2, 2)) - - def test_nn_unfold(self, device): - data = torch.rand(2, 5, 3, 4, device=device) - mask = torch.rand(2, 5, 3, 4, device=device) > 0.5 - _compare_forward_backward(data, mask, lambda t: torch.nn.functional.unfold(t, kernel_size=(2, 3))) - - def test_stack(self, device): - masked_tensors = [ - masked_tensor( - torch.rand(2, 5, 3, 4, device=device), - torch.rand(2, 5, 3, 4, device=device) > 0.5, - requires_grad=True, - ) for _ in range(3) - ] - - data_tensors = [mt.get_data().detach().clone().requires_grad_() for mt in masked_tensors] - masked_res = torch.stack(masked_tensors) - tensor_res = torch.stack(data_tensors) - - masked_res.sum().backward() - tensor_res.sum().backward() - _compare_mt_t(masked_res, tensor_res) - for mt, t in zip(masked_tensors, data_tensors): - _compare_mt_t(mt.grad, t.grad, atol=1e-06) - def test_to_sparse(self, device): for sample in _generate_sample_data(device=device): data = sample.input diff --git a/torch/masked/maskedtensor/passthrough.py b/torch/masked/maskedtensor/passthrough.py index ba13f50c1fee9c..4a2e79456c86db 100644 --- a/torch/masked/maskedtensor/passthrough.py +++ b/torch/masked/maskedtensor/passthrough.py @@ -30,11 +30,6 @@ torch.ops.aten._reshape_alias, torch.ops.aten.cat, torch.ops.aten.unsqueeze, - torch.ops.aten.unfold, - torch.ops.aten.unfold_backward, - torch.ops.aten.im2col, - torch.ops.aten.col2im, - torch.ops.aten.stack, ] diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index fecab41a741f33..3049cd8d8f5753 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15266,8 +15266,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): autodiff_nonfusible_nodes=["aten::hardswish"]), OpInfo('nn.functional.unfold', aten_name='im2col', - dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_nn_unfold, # Runs very slowly on slow gradcheck - alternatively reduce input sizes gradcheck_fast_mode=True,