From 409c70cf2c856a83cf5aca843fac0f6568c8753f Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 10 Dec 2020 23:20:37 +0000 Subject: [PATCH] Add softplus operator conversion to Onnx. --- python/tvm/relay/frontend/onnx.py | 12 +++++++++ tests/python/frontend/onnx/test_forward.py | 29 ++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 0b6ebdb5d5c2..f0d7e2d21d40 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2095,6 +2095,17 @@ def _impl_v11(cls, inputs, attr, params): return result +class Softplus(OnnxOpConverter): + """Operator converter for Softplus.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + data = inputs[0] + data_dtype = infer_type(data).checked_type.dtype + data = _op.exp(data) + _expr.const(1, dtype=data_dtype) + return _op.log(data) + + class Loop(OnnxOpConverter): """Operator converter for Loop""" @@ -2371,6 +2382,7 @@ def _get_convert_map(opset): "Sum": Sum.get_converter(opset), "Mean": Mean.get_converter(opset), "Clip": Clip.get_converter(opset), + "Softplus": Softplus.get_converter(opset), # softmax default axis is different in onnx "Softmax": Softmax.get_converter(opset), "LogSoftmax": AttrCvt("log_softmax", {"axis": ("axis", 1)}), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1e0b729cbef0..d7a07f7271a9 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3983,6 +3983,34 @@ def verify_maxunpool(data, indices, kernel_shape, strides, output_shape=None, pa verify_maxunpool(xT, xI, [2, 2], strides=[2, 2], pads=pads) +@tvm.testing.uses_gpu +def test_softplus(): + def verify_softplus(indata): + node = helper.make_node( + "Softplus", + inputs=["X"], + outputs=["Y"], + ) + + graph = helper.make_graph( + [node], + "softplus_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(indata.shape))], + ) + + model = helper.make_model(graph, producer_name="softplus_test") + + verify_with_ort_with_inputs(model, [indata], dtype="float32", use_vm=True, opset=11) + + # Simple case with all signs. + input_data = np.array([[-1, 0, 1]], dtype=np.float32) + verify_softplus(input_data) + # More fancy case. + input_data = np.random.randn(1, 32, 32, 3).astype("float32") + verify_softplus(input_data) + + if __name__ == "__main__": test_flatten() test_reshape() @@ -4061,3 +4089,4 @@ def verify_maxunpool(data, indices, kernel_shape, strides, output_shape=None, pa test_loop() test_size() test_maxunpool() + test_softplus()