Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,12 @@ def convert(node: fx.Node):

return convert

def _where(self, node: fx.Node) -> relax.Var:
condition = self.env[node.args[0]]
x = self.env[node.args[1]]
y = self.env[node.args[2]]
return self.block_builder.emit(relax.op.where(condition, x, y))

########## Manipulation ##########

def _cat(self, node: fx.Node) -> relax.Var:
Expand All @@ -967,6 +973,17 @@ def _chunk(self, node: fx.Node) -> relax.Var:
relax.op.split(x=x, indices_or_sections=n_sections, axis=dim)
)

def _cumprod(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]

dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
if "dtype" in node.kwargs:
dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
else:
dtype = None

return self.block_builder.emit(relax.op.cumprod(x, dim, dtype))

def _cumsum(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]

Expand Down
7 changes: 7 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def _fetch_attr(self, model, target: str):

########## Unary Ops ##########

def _reciprocal(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x))

def _leakyrelu_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
Expand Down Expand Up @@ -708,6 +712,7 @@ def create_convert_map(
"logical_not": self._unary_op(relax.op.logical_not),
"log_softmax": self._log_softmax,
"neg": self._unary_op(relax.op.negative),
"reciprocal": self._reciprocal,
"relu": self._unary_op(relax.op.nn.relu),
"round": self._round,
"rsqrt": self._unary_op(relax.op.rsqrt),
Expand Down Expand Up @@ -784,11 +789,13 @@ def create_convert_map(
# search
"argmax": self._argmax_argmin(relax.op.argmax),
"argmin": self._argmax_argmin(relax.op.argmin),
"where": self._where,
# tensor manipulation
"cat": self._cat,
"chunk": self._chunk,
"concat": self._cat,
"contiguous": lambda node: self.env[node.args[0]],
"cumprod": self._cumprod,
"cumsum": self._cumsum,
"expand": self._expand,
"expand_as.default": self._expand_as,
Expand Down
65 changes: 65 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2339,6 +2339,27 @@ def main(
verify_model(LogSoftmax(), input_info, {}, expected_log_softmax)
verify_model(LogSoftmax2(), input_info, {}, expected_log_softmax)

# reciprocal
class Reciprocal(Module):
def forward(self, input):
return torch.reciprocal(input)

@tvm.script.ir_module
class expected_reciprocal:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
R.const(1.0, "float32"), input_1
)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Reciprocal(), input_info, {}, expected_reciprocal)

# relu
class ReLU0(Module):
def __init__(self):
Expand Down Expand Up @@ -4315,5 +4336,49 @@ def main(
verify_model(Prod(), [([5, 3], "float32")], {}, Expected)


def test_cumprod():
class Cumprod(Module):
def forward(self, x):
return torch.cumprod(x, 0)

@tvm.script.ir_module
class Expected:
@R.function
def main(
inp_0: R.Tensor((5, 3), dtype="float32"),
) -> R.Tensor((5, 3), dtype="float32"):
with R.dataflow():
lv: R.Tensor((5, 3), dtype="float32") = R.cumprod(inp_0, axis=0, exclusive=False)
gv: R.Tensor((5, 3), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Cumprod(), [([5, 3], "float32")], {}, Expected)


def test_where():
class Where(Module):
def forward(self, condition, x, y):
return torch.where(condition, x, y)

@tvm.script.ir_module
class Expected:
@R.function
def main(
inp_0: R.Tensor((5, 3), dtype="bool"),
inp_1: R.Tensor((5, 3), dtype="float32"),
inp_2: R.Tensor((5, 3), dtype="float32"),
) -> R.Tensor((5, 3), dtype="float32"):
with R.dataflow():
lv: R.Tensor((5, 3), dtype="float32") = R.where(inp_0, inp_1, inp_2)
gv: R.Tensor((5, 3), dtype="float32") = lv
R.output(gv)
return gv

verify_model(
Where(), [([5, 3], "bool"), ([5, 3], "float32"), ([5, 3], "float32")], {}, Expected
)


if __name__ == "__main__":
tvm.testing.main()
Loading