diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 13d13ff24c28..b0bdc598f4e4 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -417,42 +417,6 @@ def _rsub(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.subtract(rhs, lhs)) - ########## Linear Algebra ########## - - def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: - - args = self.retrieve_args(node) - - data = args[0] - # Default ord=2 if not supplied - ord_val = args[1] if len(args) > 1 else 2.0 - dim = args[2] if len(args) > 2 else None - keepdim = args[3] if len(args) > 3 else False - - # If ord_val is a Python float/int, wrap it in a Relax const - # so that it matches data's dtype. - dtype = data.struct_info.dtype - ord_expr = ( - ord_val if isinstance(ord_val, relax.Expr) else relax.const(float(ord_val), dtype) - ) - # Reciprocal - reci_expr = ( - relax.op.divide(relax.const(1.0, dtype), ord_expr) - if isinstance(ord_val, relax.Expr) - else relax.const(1.0 / float(ord_val), dtype) - ) - - # abs(data) - abs_data = self.block_builder.emit(relax.op.abs(data)) - # abs_data^ord - abs_data_pow = self.block_builder.emit(relax.op.power(abs_data, ord_expr)) - # sum over dim - reduced = self.block_builder.emit(relax.op.sum(abs_data_pow, dim, keepdims=keepdim)) - # (sum(...))^(1/ord) - norm_val = self.block_builder.emit(relax.op.power(reduced, reci_expr)) - - return norm_val - ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: @@ -980,16 +944,22 @@ def _norm(self, node: fx.Node) -> relax.Var: elif order == "fro": return self.block_builder.emit( relax.op.sqrt( - relax.op.sum(relax.op.multiply(data, data), axis=axis, keepdims=keepdims), + relax.op.sum(relax.op.multiply(data, data), axis=axis, keepdims=keepdims) ) ) else: - reci_order = relax.const(1 / order, dtype=dtype) - order = relax.const(order, dtype=dtype) + ord_expr = ( + order if isinstance(order, relax.Expr) else relax.const(float(order), dtype=dtype) + ) + reci_order = ( + relax.op.divide(relax.const(1.0, dtype), ord_expr) + if isinstance(order, relax.Expr) + else relax.const(1.0 / order, dtype=dtype) + ) return self.block_builder.emit( relax.op.power( relax.op.sum( - relax.op.power(relax.op.abs(data), order), axis=axis, keepdims=keepdims + relax.op.power(relax.op.abs(data), ord_expr), axis=axis, keepdims=keepdims ), reci_order, ) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index ed6740a25ef2..a178982acd06 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -369,7 +369,7 @@ def create_convert_map( "__xor__.Tensor": self._binary_op(relax.op.bitwise_xor, operator.xor), "__xor__.Scalar": self._binary_op(relax.op.bitwise_xor, operator.xor), # linear algebra - "linalg_vector_norm.default": self._linalg_vector_norm, + "linalg_vector_norm.default": self._norm, # neural network "_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional, "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index a386a989f00e..c6ead5aaccfb 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4379,6 +4379,118 @@ def main( verify_model(Narrow(), example_args, {}, Expected) +def test_norm(): + class Norm(Module): + def __init__(self, p, dim=None, keepdim=False): + super().__init__() + self.p = p + self.dim = dim + self.keepdim = keepdim + + def forward(self, x): + return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.max(R.abs(inp_0), axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.min(R.abs(inp_0), axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) + lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(2, "float32")) + lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False) + lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(0.5, "float32")) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected4: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) + lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(1.0, "float32")) + lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False) + lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(1.0, "float32")) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected5: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) + lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(-4.0, "float32")) + lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1, axis=None, keepdims=True) + lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power( + lv2, R.const(-0.25, "float32") + ) + gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected6: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) + lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(0.5, "float32")) + lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1, axis=None, keepdims=True) + lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power(lv2, R.const(2.0, "float32")) + gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + norms = [ + ((float("inf"), None, False), Expected1), + ((float("-inf"), None, False), Expected2), + ((float(2), None, False), Expected3), + ((float(1.0), None, False), Expected4), + ((float(-4), None, True), Expected5), + ((float(0.5), None, True), Expected6), + ] + + example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),) + + for (p, dim, keepdim), expected in norms: + verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected) + + def test_eye(): class Eye1(Module): def forward(self, input): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index e8db6af34709..ce889678055e 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4938,19 +4938,16 @@ def main( return gv norms = [ - (float("inf"), None, False), - (float("-inf"), None, False), - (float(2), None, False), - (float(1.0), None, False), - (float(-4), None, True), - (float(0.5), None, True), - ("fro", None, False), + ((float("inf"), None, False), Expected1), + ((float("-inf"), None, False), Expected2), + ((float(2), None, False), Expected3), + ((float(1.0), None, False), Expected4), + ((float(-4), None, True), Expected5), + ((float(0.5), None, True), Expected6), + (("fro", None, False), Expected7), ] - for norm, expected in zip( - norms, [Expected1, Expected2, Expected3, Expected4, Expected5, Expected6, Expected7] - ): - p, dim, keepdim = norm + for (p, dim, keepdim), expected in norms: verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {}, expected)