From 7384adf5d2c2145d5ed4b796aae9a33065f3b90e Mon Sep 17 00:00:00 2001 From: Samuel Date: Sat, 11 Apr 2020 10:32:58 +0530 Subject: [PATCH] [PYTORCH]Abs, Arange, Softplus ops (#5295) * [PYTHON]Abs, Arange, Softplus ops * Review comments updated --- python/tvm/relay/frontend/pytorch.py | 52 +++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 66 +++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b8b32e7d8925..a542ccc48af0 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -57,6 +57,33 @@ def _impl(inputs, input_types): return get_relay_op(name)(data0, data1) return _impl +def _abs(): + def _impl(inputs, input_types): + data = inputs[0] + return _op.abs(data) + return _impl + +def _arange(): + def _impl(inputs, input_types): + if len(inputs) == 5: + dtype = "float" if "float" in input_types[0:1] else _convert_dtype_value(inputs[1]) + start = _create_typed_const(0, dtype) + stop = _create_typed_const(inputs[0], dtype) + step = _create_typed_const(1, dtype) + elif len(inputs) == 7: + dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3]) + start = _create_typed_const(inputs[0], dtype) + stop = _create_typed_const(inputs[1], dtype) + step = _create_typed_const(inputs[2], dtype) + else: + msg = "Unknown number of arguments (%d) to parse." % (len(inputs)) + raise AssertionError(msg) + return _op.transform.arange(start=start, + stop=stop, + step=step, + dtype=_convert_data_type(dtype)) + return _impl + def _squeeze(): def _impl(inputs, input_types): data = inputs[0] @@ -732,6 +759,13 @@ def _impl(inputs, input_types): return _op.tensor.sigmoid(data) return _impl +def _softplus(): + def _impl(inputs, input_types): + data = inputs[0] + beta = _expr.const(float(inputs[1])) + return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta + return _impl + def _avg_pool2d(): def _impl(inputs, input_types): data = inputs[0] @@ -1044,6 +1078,21 @@ def _impl(inputs, input_types): return _impl # Helper functions for operator implementation +def _convert_dtype_value(val): + convert_torch_dtype_map = {7:"torch.float64", + 6:"torch.float32", + 5:"torch.float16", + 4:"torch.int64", + 3:"torch.int32", + 2:"torch.int16", + 1:"torch.int8", + 0:"torch.unit8", + None:"torch.int64"} # Default is torch.int64 + if val in convert_torch_dtype_map: + return convert_torch_dtype_map[val] + else: + msg = "Torch data type value %d is not handled yet." % (val) + raise NotImplementedError(msg) def _convert_data_type(input_type): if input_type in ["double", "torch.float64"]: @@ -1118,6 +1167,8 @@ def _wrap_const(c): "aten::pow" : _elemwise("power"), "aten::div" : _elemwise("divide"), "aten::div_" : _elemwise("divide"), + "aten::abs" : _abs(), + "aten::arange" : _arange(), "aten::ones" : _ones(), "aten::zeros" : _zeros(), "aten::reciprocal" : _reciprocal(), @@ -1167,6 +1218,7 @@ def _wrap_const(c): "aten::clone" : _clone(), "aten::log_softmax" : _log_softmax(), "aten::sigmoid" : _sigmoid(), + "aten::softplus" : _softplus(), "aten::avg_pool2d" : _avg_pool2d(), "aten::avg_pool3d" : _avg_pool3d(), "aten::dropout" : _dropout(), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 4226463e9527..d60ab9eeec5f 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -375,6 +375,54 @@ def forward(self, *args): verify_model(Squeeze1().float().eval(), input_data=input_data) verify_model(Squeeze2().float().eval(), input_data=input_data) +def test_forward_arange(): + torch.set_grad_enabled(False) + + class Arange1(Module): + def forward(self, *args): + return torch.arange(5) + class Arange2(Module): + def forward(self, *args): + return torch.arange(2.5) + class Arange3(Module): + def forward(self, *args): + return torch.arange(1, 4) + class Arange4(Module): + def forward(self, *args): + return torch.arange(1, 2.5, 0.5) + class Arange5(Module): + def forward(self, *args): + return torch.arange(1, 2, 1, dtype=torch.int32) + class Arange6(Module): + def forward(self, *args): + return torch.arange(start=1, end=6, step=2) + class Arange7(Module): + def forward(self, *args): + return torch.arange(1, 4, dtype=torch.float32) + class Arange8(Module): + def forward(self, *args): + return torch.arange(1, 2, 1, dtype=torch.int16) + + verify_model(Arange1().float().eval()) + verify_model(Arange2().float().eval()) + verify_model(Arange3().float().eval()) + verify_model(Arange4().float().eval()) + verify_model(Arange5().float().eval()) + verify_model(Arange6().float().eval()) + verify_model(Arange7().float().eval()) + verify_model(Arange8().float().eval()) + +def test_forward_abs(): + torch.set_grad_enabled(False) + input_shape = [2, 1, 10, 1, 10] + + class Abs1(Module): + def forward(self, *args): + return args[0].abs() + + input_data = torch.rand(input_shape).float() + verify_model(Abs1().float().eval(), input_data=input_data) + def test_forward_concatenate(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -445,6 +493,20 @@ def test_forward_selu(): input_data = torch.rand(input_shape).float() verify_model(torch.nn.SELU().eval(), input_data=input_data) +def test_forward_softplus(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.Softplus().eval(), input_data=input_data) + verify_model(torch.nn.Softplus(beta=1.5, threshold=20).eval(), input_data=input_data) + verify_model(torch.nn.Softplus(beta=5, threshold=10).eval(), input_data=input_data) + +def test_forward_softsign(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.Softsign().eval(), input_data=input_data) + def test_forward_log_sigmoid(): torch.set_grad_enabled(False) input_shape = [10, 10] @@ -1254,6 +1316,8 @@ def forward(self, xs): test_forward_view() test_forward_select() test_forward_clone() + test_forward_softplus() + test_forward_softsign() test_forward_logsoftmax() test_forward_sigmoid() test_forward_dense() @@ -1264,6 +1328,8 @@ def forward(self, xs): test_forward_mean() test_forward_expand() test_forward_pow() + test_forward_abs() + test_forward_arange() test_forward_chunk() test_forward_split() test_upsample()