From d784f8cac4112a16154db289baa92f931eac7107 Mon Sep 17 00:00:00 2001 From: sunway Date: Sat, 25 Sep 2021 04:21:10 +0800 Subject: [PATCH] [Frontend][TFLite] fix #9078 (#9099) Co-authored-by: sunway --- python/tvm/relay/frontend/tflite.py | 2 +- tests/python/frontend/tflite/test_forward.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 250e9c4eb117..93a1dba233f2 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -775,7 +775,7 @@ def convert_softmax(self, op): assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] - params = {"axis": 1} # 1 is channel + params = {"axis": -1} # -1 is channel in_expr = self.get_expr(input_tensor_idx) # TODO - Naive softmax int8 implementation leads to bad accuracy. Currently, we can diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index d4c0b28e4e14..c073681dcbf5 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -3286,6 +3286,7 @@ def _test_softmax(data): def test_forward_softmax(): """Softmax""" _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 2, 3))) ######################################################################