diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 8adb811f5a1d..462bce573892 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1483,6 +1483,50 @@ def _impl(inputs, input_types): return _impl +def _rsub(): + def _impl(inputs, input_types): + # TODO: Figure out a better way to get typing to work for tensor + scalar + type0 = input_types[0] + if isinstance(inputs[1], _expr.Expr): + type0 = input_types[1] + + type1 = input_types[1] + if isinstance(inputs[0], _expr.Expr): + type1 = input_types[0] + + data1 = _convert_elemwise_input(inputs[0], type0) + data0 = _convert_elemwise_input(inputs[1], type1) + alpha = _expr.const(float(inputs[2])) + + return get_relay_op("subtract")(data0, alpha * data1) + return _impl + + +def _embedding(): + def _impl(inputs, input_types): + weight = inputs[0] + indices = inputs[1] + + return _op.take(weight, indices.astype('int32'), axis=0) + return _impl + + +def _one_hot(): + def _impl(inputs, input_types): + indices = inputs[0].astype('int32') + num_classes = inputs[1] + if num_classes == -1: + msg = "Inferring the number of classes is not yet supported." + raise NotImplementedError(msg) + + dtype = 'int32' + on_value = tvm.relay.const(1.0, dtype) + off_value = tvm.relay.const(0.0, dtype) + + return _op.one_hot(indices, on_value, off_value, num_classes, -1, dtype) + return _impl + + # Helper functions for operator implementation def _convert_dtype_value(val): convert_torch_dtype_map = {7:"torch.float64", @@ -1696,6 +1740,9 @@ def _get_convert_map(prelude): "aten::Float" : _Float(), "aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(), "aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(), + "aten::rsub" : _rsub(), + "aten::embedding" : _embedding(), + "aten::one_hot" : _one_hot(), "aten::mm" : _matmul(), "relay::tensor_array_stack" : _tensor_array_stack(prelude), "aten::add" : _add(prelude), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2712f7267b6e..dd7eb33aeac1 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1474,6 +1474,56 @@ def forward(self, *args): verify_model(Variance5().float().eval(), input_data=input_data) +def test_forward_rsub(): + torch.set_grad_enabled(False) + + class Rsub1(Module): + def forward(self, *args): + return torch.rsub(args[0], args[1]) + + class Rsub2(Module): + def forward(self, *args): + return torch.rsub(args[0], args[1], alpha=0.5) + + d1 = torch.rand([1, 3]).float() + d2 = torch.rand([1, 3]).float() + d3 = torch.rand([1, 3]).int() + verify_model(Rsub1().float().eval(), input_data=[d1, d2]) + verify_model(Rsub1().float().eval(), input_data=[d1, d3]) + verify_model(Rsub2().float().eval(), input_data=[d1, d2]) + verify_model(Rsub2().float().eval(), input_data=[d1, d3]) + + +def test_forward_embedding(): + torch.set_grad_enabled(False) + + input_data = torch.randint(0, 10, [2, 4]).long() + verify_model(torch.nn.Embedding(10, 3).float().eval(), input_data=input_data) + + input_data = torch.randint(0, 4, [2, 3, 4]).long() + verify_model(torch.nn.Embedding(4, 5, sparse=False).float().eval(), input_data=input_data) + + input_data = torch.randint(0, 4, [2, 3, 4]).long() + verify_model(torch.nn.Embedding(4, 5, sparse=True).float().eval(), input_data=input_data) + + +def test_forward_onehot(): + torch.set_grad_enabled(False) + + class OneHot1(Module): + def forward(self, *args): + return torch.nn.functional.one_hot(args[0], num_classes=3) + + class OneHot2(Module): + def forward(self, *args): + return torch.nn.functional.one_hot(args[0], num_classes=5) + + input_data = torch.arange(0, 5) % 3 + verify_model(OneHot1().float().eval(), input_data=input_data) + + input_data = torch.arange(0, 5) % 4 + verify_model(OneHot2().float().eval(), input_data=input_data) + def test_forward_isfinite(): torch.set_grad_enabled(False) @@ -1995,6 +2045,9 @@ def forward(self, *args): test_forward_add() test_forward_subtract() test_forward_multiply() + test_forward_rsub() + test_forward_onehot() + test_forward_embedding() test_forward_reshape() test_forward_reciprocal() test_forward_repeat()