Skip to content

Commit

Permalink
Support exporting aten::copy_ and aten::index_put to ONNX opset 11 (p…
Browse files Browse the repository at this point in the history
…ytorch#26941)

Summary:
- [x] Add more comments and refactor the logic of `ReshapeToAdvancedIndexingFormat`
- [x] Add more description here. Cases that are/aren't supported, and how they are supported.
- [x] Need to merge this PR pytorch#27186 to enable testing inplace operators.

We are now supporting exporting aten::copy_ and aten::index_put to ONNX.
Here's a breakdown of the different cases in PyTorch code.

```
# Case 1: Scalar Indices
x[0, 1, 2] = data

# Case 2: Slice Indices
x[1:3, :, ::2] = data

# Case 3: Ellipsis Indices
x[..., 0] = data

# Case 4: Tensor Indices
ind1 = torch.tensor([0, 2])
ind2 = torch.tensor([1, 1])
x[ind1, ind2] = data

# Case 5: Mixing all the above cases
ind1 = torch.tensor([0, 2])
ind2 = torch.tensor([1, 1])
x[1:3, ind1, ind2, ..., 3] = data
```

Limitations:

Tensor indices must be consecutive, and 1-d tensors.

```
# Supported
ind1 = torch.tensor([0, 2])
ind2 = torch.tensor([1, 1])
x[ind1, ind2] = data

# Not supported
ind1 = torch.tensor([0, 2])
ind2 = torch.tensor([1, 1])
ind3 = torch.tensor([[0], [1]])
x[ind1, :, ind2] = data
x[ind3] = data
```

Negative indices are not supported.
```
# Not supported
x[-1] = data
```
Pull Request resolved: pytorch#26941

Differential Revision: D17951030

Pulled By: houseroad

fbshipit-source-id: 4357777072f53aa0bc4b297aa1ee53457a7f8dec
  • Loading branch information
BowenBao authored and facebook-github-bot committed Dec 7, 2019
1 parent 8b2bac7 commit ae5af68
Show file tree
Hide file tree
Showing 9 changed files with 569 additions and 2 deletions.
168 changes: 168 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,174 @@ def test_tensor_index_advanced_indexing(self):
def test_tensor_index_advanced_indexing_consecutive(self):
self._test_index_generic(lambda input: input[:, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None])

@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put(self):
class IndexPutModel(torch.nn.Module):
def forward(self, x, ind, update):
x[ind] = update
return x

x = torch.randn(3, 4)
ind = torch.tensor([1], dtype=torch.long)
update = torch.ones(4)
self.run_test(IndexPutModel(), (x, ind, update))

@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_accumulate(self):
class IndexPutModel(torch.nn.Module):
def forward(self, x, ind, update):
return x.index_put((ind, ), update, accumulate=True)

x = torch.randn(3, 4)
ind = torch.tensor([2], dtype=torch.long)
update = torch.ones(4)
self.run_test(IndexPutModel(), (x, ind, update))

@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_slice_index(self):
class IndexPutModel(torch.nn.Module):
def forward(self, x, update):
x[1:2, 1:3, torch.tensor([1])] += update
return x

x = torch.randn(3, 4, 5)
update = torch.tensor([10, 15]).view(1, 2, 1)
self.run_test(IndexPutModel(), (x, update))

class IndexPutModel2(torch.nn.Module):
def forward(self, x, update):
x[torch.tensor([0, 2]), torch.tensor([1, 2])] += update
return x

x = torch.randn(3, 4, 5)
update = torch.randn(2, 5)
self.run_test(IndexPutModel2(), (x, update))

class IndexPutModel3(torch.nn.Module):
def forward(self, x, update):
x[torch.tensor([0, 2]), 1:2] += update
return x

x = torch.randn(3, 4, 5)
update = torch.tensor([10, 15]).view(2, 1, 1)
self.run_test(IndexPutModel3(), (x, update))

class IndexPutModel4(torch.nn.Module):
def forward(self, x, update):
x[torch.tensor([0, 2]), 2] += update
return x

x = torch.randn(3, 4, 5)
update = torch.tensor([10, 15]).view(2, 1)
self.run_test(IndexPutModel4(), (x, update))

class IndexPutModel5(torch.nn.Module):
def forward(self, x, update):
x[1:3, torch.tensor([0, 2]), 2] += update
return x

x = torch.randn(3, 4, 5)
update = torch.tensor([10, 15]).view(2, 1)
self.run_test(IndexPutModel5(), (x, update))

class IndexPutModel6(torch.nn.Module):
def forward(self, x, update):
x[1:3, 0] = update
return x

x = torch.randn(3, 4, 5)
update = torch.arange(2 * 5).to(torch.float).view(2, 5)
self.run_test(IndexPutModel6(), (x, update))

@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_ellipsis(self):
class IndexPutModel(torch.nn.Module):
def forward(self, x, update):
x[..., torch.tensor([2, 1, 3]), 2:4] += update
return x

x = torch.randn(3, 4, 5, 6, 7)
update = torch.randn(3, 1, 1, 3, 2)
self.run_test(IndexPutModel(), (x, update))

class IndexPutModel2(torch.nn.Module):
def forward(self, x, update):
x[2, ..., torch.tensor([2, 1, 3]), 2:4] += update
return x

x = torch.randn(3, 4, 5, 6, 7)
update = torch.randn(4, 1, 3, 2)
self.run_test(IndexPutModel2(), (x, update))

@skipIfUnsupportedMinOpsetVersion(11)
def test_copy_(self):
class CopyModel(torch.nn.Module):
def forward(self, x, data):
x[1:3] = data
return x

x = torch.randn(3, 4)
update = torch.randn(2, 4)
self.run_test(CopyModel(), (x, update))

# mixed slice and select
class CopyModel2(torch.nn.Module):
def forward(self, x, data):
x[1:3, 0] = data
return x

x = torch.randn(3, 4)
update = torch.tensor([0], dtype=torch.float32)
self.run_test(CopyModel2(), (x, update))

update = torch.tensor([2, 3], dtype=torch.float32)
self.run_test(CopyModel2(), (x, update))

update = torch.randn(2)
self.run_test(CopyModel2(), (x, update))

class CopyModel3(torch.nn.Module):
def forward(self, x, data):
x[1, 1:3] = data
return x

x = torch.randn(3, 4)
update = torch.tensor([0], dtype=torch.float32)
self.run_test(CopyModel3(), (x, update))

update = torch.tensor([2, 3], dtype=torch.float32)
self.run_test(CopyModel3(), (x, update))

update = torch.randn(2)
self.run_test(CopyModel3(), (x, update))

update = torch.randn(1, 2)
self.run_test(CopyModel3(), (x, update))

@skipIfUnsupportedMinOpsetVersion(11)
def test_copy_ellipsis(self):
class CopyModel(torch.nn.Module):
def forward(self, x, update):
x[..., 1] = update
return x

x = torch.randn(2, 3, 4)
update = torch.ones(1)
self.run_test(CopyModel(), (x, update))

x = torch.randn(2, 3, 4, 5, 6)
update = torch.ones(1)
self.run_test(CopyModel(), (x, update))

class CopyModel2(torch.nn.Module):
def forward(self, x, update):
x[2, ..., 1:3] = update
return x

x = torch.randn(3, 4, 5, 6)
update = torch.ones(1)
self.run_test(CopyModel2(), (x, update))

@skipIfUnsupportedMinOpsetVersion(10)
def test_flip(self):
class MyModule(torch.nn.Module):
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def add_torch_libs():
"torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp",
"torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp",
"torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp",
"torch/csrc/jit/passes/onnx/prepare_inplace_ops_for_onnx.cpp",
"torch/csrc/jit/passes/remove_inplace_ops.cpp",
"torch/csrc/jit/passes/utils/check_alias_annotation.cpp",
"torch/csrc/jit/python_arg_flatten.cpp",
Expand Down
1 change: 1 addition & 0 deletions torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ set(TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/constant_fold.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/scalar_type_analysis.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/unpack_quantized_weights.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/prepare_inplace_ops_for_onnx.cpp
${TORCH_SRC_DIR}/csrc/jit/python_arg_flatten.cpp
${TORCH_SRC_DIR}/csrc/jit/python_interpreter.cpp
${TORCH_SRC_DIR}/csrc/jit/python_ir.cpp
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/autograd/VariableTypeManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,21 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
if(torch::jit::tracer::isTracing()) {
const jit::tracer::TracingState& state = *jit::tracer::getTracingState();
auto& graph = state.graph;
if (state.force_outplace) {
if (state.force_outplace && self.storage().use_count() <= 1) {
// if you have no views of self, then an in place copy is equivalent to
// making sure we expand src to the same size as self
jit::Node* node = graph->create(jit::aten::expand_as, /*num_outputs=*/1);
jit::tracer::addInputs(node, "src", src);
jit::tracer::addInputs(node, "self", self);
graph->insertNode(node);
jit::tracer::ensureUniqueIfOutOfPlaced("copy_ (possibly due to an assignment)", self);
output = node->output();
} else {
output = graph->insert(
jit::aten::copy_,
{jit::tracer::getValueTrace(self), jit::tracer::getValueTrace(src)});
jit::tracer::recordSourceLocation(output->node());
}
jit::tracer::ensureUniqueIfOutOfPlaced("copy_ (possibly due to an assignment)", self);
}
// TODO: once copy is exposed in Declarations.yaml we may be able to bind
// it automatically
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
#include <torch/csrc/jit/passes/onnx/unpack_quantized_weights.h>
#include <torch/csrc/jit/passes/onnx/prepare_inplace_ops_for_onnx.h>
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/passes/quantization.h>
#include <torch/csrc/jit/passes/remove_expands.h>
Expand Down Expand Up @@ -132,6 +133,7 @@ void initJITBindings(PyObject* module) {
},
pybind11::return_value_policy::move)
.def("_jit_pass_onnx_scalar_type_analysis", ScalarTypeAnalysisForONNX)
.def("_jit_pass_onnx_prepare_inplace_ops_for_onnx", PrepareInplaceOpsForONNX)
.def("_jit_pass_fuse", FuseGraph)
.def(
"_jit_pass_dce",
Expand Down
Loading

0 comments on commit ae5af68

Please sign in to comment.