Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
def _hardtanh(self, node: fx.Node) -> relax.Expr:
args = self.retrieve_args(node)
x = args[0]
min_val = node.args[1] if len(args) > 1 else node.kwargs("min_val", -1.0)
max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val", 1.0)
min_val = node.args[1] if len(args) > 1 else node.kwargs.get("min_val", -1.0)
max_val = node.args[2] if len(args) > 2 else node.kwargs.get("max_val", 1.0)
return self.block_builder.emit(relax.op.clip(x, min_val, max_val))

def _log2(self, node: fx.Node) -> relax.Var:
Expand Down Expand Up @@ -216,6 +216,19 @@ def _slice(self, node: fx.Node) -> relax.Var:
stride = [node.args[4] if len(node.args) > 4 else 1]
return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride))

def _unflatten(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dim = node.args[1]
sizes = node.args[2]

x_shape = list(self.shape_of(x))
if dim < 0:
dim += len(x_shape)

new_shape = x_shape[:dim] + sizes + x_shape[dim + 1 :]
return self.block_builder.emit(relax.op.reshape(x, new_shape))

########## Creation ##########

def _one_hot(self, node: fx.Node) -> relax.Var:
Expand Down Expand Up @@ -258,14 +271,17 @@ def create_convert_map(
"cos.default": self._unary_op(relax.op.cos),
"cosh.default": self._unary_op(relax.op.cosh),
"dropout.default": lambda node: self.env[node.args[0]],
"dropout_.default": lambda node: self.env[node.args[0]],
"elu.default": self._elu,
"erf.default": self._unary_op(relax.op.erf),
"exp.default": self._unary_op(relax.op.exp),
"floor.default": self._unary_op(relax.op.floor),
"gelu.default": self._gelu,
"hardsigmoid.default": self._hardsigmoid,
"hardswish.default": self._hardswish,
"hardswish_.default": self._hardswish,
"hardtanh.default": self._hardtanh,
"hardtanh_.default": self._hardtanh,
"isfinite.default": self._unary_op(relax.op.isfinite),
"isinf.default": self._unary_op(relax.op.isinf),
"isnan.default": self._unary_op(relax.op.isnan),
Expand All @@ -278,12 +294,14 @@ def create_convert_map(
"neg.default": self._unary_op(relax.op.negative),
"reciprocal.default": self._reciprocal,
"relu.default": self._unary_op(relax.op.nn.relu),
"relu_.default": self._unary_op(relax.op.nn.relu),
"round.default": self._round,
"rsqrt.default": self._unary_op(relax.op.rsqrt),
"selu.default": self._unary_op(relax.op.nn.selu),
"sigmoid.default": self._unary_op(relax.op.sigmoid),
"sign.default": self._unary_op(relax.op.sign),
"silu.default": self._unary_op(relax.op.nn.silu),
"silu_.default": self._unary_op(relax.op.nn.silu),
"sin.default": self._unary_op(relax.op.sin),
"sinh.default": self._unary_op(relax.op.sinh),
"softmax.int": self._softmax,
Expand All @@ -296,6 +314,7 @@ def create_convert_map(
"triu.default": self._tril_triu(relax.op.triu),
# binary
"add.Tensor": self._binary_op(relax.op.add, operator.add),
"add_.Tensor": self._binary_op(relax.op.add, operator.add),
"div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
"eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
Expand Down Expand Up @@ -393,6 +412,7 @@ def create_convert_map(
"tile.default": self._tile,
"topk.default": self._topk,
"transpose.int": self._transpose,
"unflatten.int": self._unflatten,
"unsqueeze.default": lambda node: self.block_builder.emit(
relax.op.expand_dims(self.env[node.args[0]], node.args[1])
),
Expand Down
53 changes: 53 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ class Dropout2(Module):
def forward(self, input):
return torch.dropout(input, 0.5, train=True)

class Dropout3(Module):
def forward(self, input):
return torch.ops.aten.dropout_(input, 0.5, train=True)

@tvm.script.ir_module
class expected_dropout:
@R.function
Expand All @@ -268,6 +272,7 @@ def main(

verify_model(Dropout1(), example_args, {}, expected_dropout)
verify_model(Dropout2(), example_args, {}, expected_dropout)
verify_model(Dropout3(), example_args, {}, expected_dropout)

# elu
class Elu(Module):
Expand Down Expand Up @@ -383,6 +388,10 @@ class Hardswish2(torch.nn.Module):
def forward(self, input):
return torch.nn.functional.hardswish(input)

class Hardswish3(torch.nn.Module):
def forward(self, input):
return torch.ops.aten.hardswish_(input)

@tvm.script.ir_module
class expected1:
@R.function
Expand All @@ -402,6 +411,7 @@ def main(

verify_model(Hardswish(), example_args, {}, expected1)
verify_model(Hardswish2(), example_args, {}, expected1)
verify_model(Hardswish3(), example_args, {}, expected1)

# hardtanh
test_hardtanh()
Expand Down Expand Up @@ -511,6 +521,10 @@ class ReLU1(Module):
def forward(self, input):
return torch.nn.functional.relu(input)

class ReLU2(Module):
def forward(self, input):
return torch.ops.aten.relu_(input)

@tvm.script.ir_module
class expected_relu:
@R.function
Expand All @@ -526,6 +540,7 @@ def main(

verify_model(ReLU0(), example_args, {}, expected_relu)
verify_model(ReLU1(), example_args, {}, expected_relu)
verify_model(ReLU2(), example_args, {}, expected_relu)

# selu
class Selu1(Module):
Expand Down Expand Up @@ -597,6 +612,10 @@ class SiLU2(Module):
def forward(self, input):
return torch.nn.functional.silu(input)

class SiLU3(Module):
def forward(self, input):
return torch.ops.aten.silu_(input)

@tvm.script.ir_module
class expected_silu:
@R.function
Expand All @@ -612,6 +631,7 @@ def main(

verify_model(SiLU(), example_args, {}, expected_silu)
verify_model(SiLU2(), example_args, {}, expected_silu)
verify_model(SiLU3(), example_args, {}, expected_silu)

# softmax
test_softmax()
Expand All @@ -636,6 +656,10 @@ class Hardtanh2(torch.nn.Module):
def forward(self, input):
return torch.nn.functional.hardtanh(input)

class Hardtanh3(torch.nn.Module):
def forward(self, input):
return torch.ops.aten.hardtanh_(input)

@tvm.script.ir_module
class expected1:
@R.function
Expand All @@ -653,6 +677,7 @@ def main(
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
verify_model(Hardtanh(), example_args, {}, expected1)
verify_model(Hardtanh2(), example_args, {}, expected1)
verify_model(Hardtanh3(), example_args, {}, expected1)


def test_leakyrelu():
Expand Down Expand Up @@ -845,6 +870,7 @@ def main(

operator_binary_1 = [
(operator.add, R.add),
(torch.ops.aten.add_, R.add),
(operator.sub, R.subtract),
(operator.mul, R.multiply),
(operator.truediv, R.divide),
Expand Down Expand Up @@ -3603,6 +3629,33 @@ def main(
verify_model(Select(), example_args, {}, Expected)


def test_unflatten():
class Unflatten(Module):
def forward(self, input):
return torch.ops.aten.unflatten(input, 1, (3, 5))

class Unflatten1(Module):
def forward(self, input):
return torch.ops.aten.unflatten(input, -2, (3, 5))

@tvm.script.ir_module
class Expected:
@R.function
def main(
inp_0: R.Tensor((2, 15, 7), dtype="float32"),
) -> R.Tuple(R.Tensor((2, 3, 5, 7), dtype="float32")):
with R.dataflow():
lv: R.Tensor((2, 3, 5, 7), dtype="float32") = R.reshape(inp_0, [2, 3, 5, 7])
gv: R.Tuple(R.Tensor((2, 3, 5, 7), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(2, 15, 7, dtype=torch.float32),)

verify_model(Unflatten(), example_args, {}, Expected)
verify_model(Unflatten1(), example_args, {}, Expected)


def test_gather():
class Gather0(Module):
def forward(self, data, indices):
Expand Down
Loading