diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 71554a8a5bab..fe0ae412a228 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -949,6 +949,12 @@ def convert(node: fx.Node): return convert + def _where(self, node: fx.Node) -> relax.Var: + condition = self.env[node.args[0]] + x = self.env[node.args[1]] + y = self.env[node.args[2]] + return self.block_builder.emit(relax.op.where(condition, x, y)) + ########## Manipulation ########## def _cat(self, node: fx.Node) -> relax.Var: @@ -967,6 +973,17 @@ def _chunk(self, node: fx.Node) -> relax.Var: relax.op.split(x=x, indices_or_sections=n_sections, axis=dim) ) + def _cumprod(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + if "dtype" in node.kwargs: + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + else: + dtype = None + + return self.block_builder.emit(relax.op.cumprod(x, dim, dtype)) + def _cumsum(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 022a7bffea80..c4008a939688 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -62,6 +62,10 @@ def _fetch_attr(self, model, target: str): ########## Unary Ops ########## + def _reciprocal(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x)) + def _leakyrelu_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -708,6 +712,7 @@ def create_convert_map( "logical_not": self._unary_op(relax.op.logical_not), "log_softmax": self._log_softmax, "neg": self._unary_op(relax.op.negative), + "reciprocal": self._reciprocal, "relu": self._unary_op(relax.op.nn.relu), "round": self._round, "rsqrt": self._unary_op(relax.op.rsqrt), @@ -784,11 +789,13 @@ def create_convert_map( # search "argmax": self._argmax_argmin(relax.op.argmax), "argmin": self._argmax_argmin(relax.op.argmin), + "where": self._where, # tensor manipulation "cat": self._cat, "chunk": self._chunk, "concat": self._cat, "contiguous": lambda node: self.env[node.args[0]], + "cumprod": self._cumprod, "cumsum": self._cumsum, "expand": self._expand, "expand_as.default": self._expand_as, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 726ff6f8e81d..b8d7f0b14e5b 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -2339,6 +2339,27 @@ def main( verify_model(LogSoftmax(), input_info, {}, expected_log_softmax) verify_model(LogSoftmax2(), input_info, {}, expected_log_softmax) + # reciprocal + class Reciprocal(Module): + def forward(self, input): + return torch.reciprocal(input) + + @tvm.script.ir_module + class expected_reciprocal: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + R.const(1.0, "float32"), input_1 + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Reciprocal(), input_info, {}, expected_reciprocal) + # relu class ReLU0(Module): def __init__(self): @@ -4315,5 +4336,49 @@ def main( verify_model(Prod(), [([5, 3], "float32")], {}, Expected) +def test_cumprod(): + class Cumprod(Module): + def forward(self, x): + return torch.cumprod(x, 0) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="float32") = R.cumprod(inp_0, axis=0, exclusive=False) + gv: R.Tensor((5, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Cumprod(), [([5, 3], "float32")], {}, Expected) + + +def test_where(): + class Where(Module): + def forward(self, condition, x, y): + return torch.where(condition, x, y) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="bool"), + inp_1: R.Tensor((5, 3), dtype="float32"), + inp_2: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="float32") = R.where(inp_0, inp_1, inp_2) + gv: R.Tensor((5, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model( + Where(), [([5, 3], "bool"), ([5, 3], "float32"), ([5, 3], "float32")], {}, Expected + ) + + if __name__ == "__main__": tvm.testing.main()