Skip to content

Commit

Permalink
Add hacked_twin overloads for _unsafe indexing functions (pytorch#104127
Browse files Browse the repository at this point in the history
)

Fixes pytorch#104037

This hacky workaround already exists for the normal overloads.
Pull Request resolved: pytorch#104127
Approved by: https://github.com/ezyang
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Jul 5, 2023
1 parent 2385dad commit 12ca224
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pt_ops.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,11 @@ PT_OPS_PRIM = [
"aten::copy_.float",
"aten::backward",
"aten::index.Tensor_hacked_twin",
"aten::_unsafe_index.Tensor_hacked_twin",
"aten::_index_put_impl_.hacked_twin",
"aten::index_put_.hacked_twin",
"aten::index_put.hacked_twin",
"aten::_unsafe_index_put.hacked_twin",
"aten::to.prim_Device",
"aten::to.prim_dtype",
"prim::is_cuda",
Expand Down
33 changes: 33 additions & 0 deletions test/jit/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,39 @@ def gen_data():
torch.index_put_(input1, [index1], value1, accumulate=False)
self.assertEqual(input, input1)

def test_unsafe_hacked_twin(self):

def gen_data():
with freeze_rng_state():
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)

input, index, value, = gen_data()
input1, index1, value1, = gen_data()
out1 = torch.ops.aten._unsafe_index_put.hacked_twin(input, [index], value, accumulate=False)
out2 = torch.index_put(input1, [index1], value1, accumulate=False)
self.assertEqual(out1, out2)

torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index])
torch.index_put(input1, [index1], value1, accumulate=False)
self.assertEqual(input, input1)

def index_put_fn(input, index, value):
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False)

input2, index2, value2 = gen_data()
script_index_put_fn = torch.jit.script(index_put_fn)
expect = index_put_fn(input2.clone(), index2, value2)
actual = script_index_put_fn(input2.clone(), index2, value2)
self.assertEqual(expect, actual)

def index_fn(input, index, value):
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False)

script_index_fn = torch.jit.script(index_fn)
expect = index_fn(input2.clone(), index2, value2)
actual = script_index_fn(input2.clone(), index2, value2)
self.assertEqual(expect, actual)

def test_export_opnames_interface(self):

@torch.jit.interface
Expand Down
33 changes: 33 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,21 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::_unsafe_index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"),
[](Stack& stack) {
auto indices = pop(stack).to<c10::List<at::Tensor>>();
c10::List<c10::optional<at::Tensor>> opt_list_indices;
opt_list_indices.reserve(indices.size());
for (const auto& ten : indices) {
opt_list_indices.push_back(ten);
}
auto self = pop(stack).toTensor();
auto result = at::_unsafe_index(self, opt_list_indices);
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::_index_put_impl_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)"),
Expand Down Expand Up @@ -1113,6 +1128,24 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::_unsafe_index_put.hacked_twin(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor"),
[](Stack& stack) {
auto accumulate = pop(stack).toBool();
auto values = pop(stack).toTensor();
auto indices = pop(stack).to<c10::List<at::Tensor>>();
c10::List<c10::optional<at::Tensor>> opt_list_indices;
opt_list_indices.reserve(indices.size());
for (const auto& ten : indices) {
opt_list_indices.push_back(ten);
}
auto self = pop(stack).toTensor();
auto result =
at::_unsafe_index_put(self, opt_list_indices, values, accumulate);
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),
// reference function parse_to_conversion in python_arg_parsing.h
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
Expand Down

0 comments on commit 12ca224

Please sign in to comment.