From 12ca224662a630e55a8f6f1f8c5e187d2456051d Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Tue, 4 Jul 2023 21:52:36 +0000 Subject: [PATCH] Add hacked_twin overloads for _unsafe indexing functions (#104127) Fixes #104037 This hacky workaround already exists for the normal overloads. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104127 Approved by: https://github.com/ezyang --- pt_ops.bzl | 2 ++ test/jit/test_misc.py | 33 ++++++++++++++++++++ torch/csrc/jit/runtime/register_prim_ops.cpp | 33 ++++++++++++++++++++ 3 files changed, 68 insertions(+) diff --git a/pt_ops.bzl b/pt_ops.bzl index 5deca64ebac32..c680140366cd5 100644 --- a/pt_ops.bzl +++ b/pt_ops.bzl @@ -332,9 +332,11 @@ PT_OPS_PRIM = [ "aten::copy_.float", "aten::backward", "aten::index.Tensor_hacked_twin", + "aten::_unsafe_index.Tensor_hacked_twin", "aten::_index_put_impl_.hacked_twin", "aten::index_put_.hacked_twin", "aten::index_put.hacked_twin", + "aten::_unsafe_index_put.hacked_twin", "aten::to.prim_Device", "aten::to.prim_dtype", "prim::is_cuda", diff --git a/test/jit/test_misc.py b/test/jit/test_misc.py index 16e4d5661382a..b9a881ee7bc80 100644 --- a/test/jit/test_misc.py +++ b/test/jit/test_misc.py @@ -171,6 +171,39 @@ def gen_data(): torch.index_put_(input1, [index1], value1, accumulate=False) self.assertEqual(input, input1) + def test_unsafe_hacked_twin(self): + + def gen_data(): + with freeze_rng_state(): + return torch.randn(10), torch.randint(10, (20,)), torch.randn(20) + + input, index, value, = gen_data() + input1, index1, value1, = gen_data() + out1 = torch.ops.aten._unsafe_index_put.hacked_twin(input, [index], value, accumulate=False) + out2 = torch.index_put(input1, [index1], value1, accumulate=False) + self.assertEqual(out1, out2) + + torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index]) + torch.index_put(input1, [index1], value1, accumulate=False) + self.assertEqual(input, input1) + + def index_put_fn(input, index, value): + return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False) + + input2, index2, value2 = gen_data() + script_index_put_fn = torch.jit.script(index_put_fn) + expect = index_put_fn(input2.clone(), index2, value2) + actual = script_index_put_fn(input2.clone(), index2, value2) + self.assertEqual(expect, actual) + + def index_fn(input, index, value): + return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False) + + script_index_fn = torch.jit.script(index_fn) + expect = index_fn(input2.clone(), index2, value2) + actual = script_index_fn(input2.clone(), index2, value2) + self.assertEqual(expect, actual) + def test_export_opnames_interface(self): @torch.jit.interface diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index a0b0ffed656e3..0a9dc9f498832 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -1058,6 +1058,21 @@ static const std::vector opGenArgs{ push(stack, std::move(result)); }, aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::_unsafe_index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"), + [](Stack& stack) { + auto indices = pop(stack).to>(); + c10::List> opt_list_indices; + opt_list_indices.reserve(indices.size()); + for (const auto& ten : indices) { + opt_list_indices.push_back(ten); + } + auto self = pop(stack).toTensor(); + auto result = at::_unsafe_index(self, opt_list_indices); + push(stack, std::move(result)); + }, + aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::_index_put_impl_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)"), @@ -1113,6 +1128,24 @@ static const std::vector opGenArgs{ push(stack, std::move(result)); }, aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::_unsafe_index_put.hacked_twin(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor"), + [](Stack& stack) { + auto accumulate = pop(stack).toBool(); + auto values = pop(stack).toTensor(); + auto indices = pop(stack).to>(); + c10::List> opt_list_indices; + opt_list_indices.reserve(indices.size()); + for (const auto& ten : indices) { + opt_list_indices.push_back(ten); + } + auto self = pop(stack).toTensor(); + auto result = + at::_unsafe_index_put(self, opt_list_indices, values, accumulate); + push(stack, std::move(result)); + }, + aliasAnalysisFromSchema()), // reference function parse_to_conversion in python_arg_parsing.h OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA(