Skip to content

Commit

Permalink
[Fix][Frontend] Fix FX importer for nn operators (tlc-pack#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 authored Dec 17, 2022
1 parent 987e208 commit f03f9d2
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relax/frontend/pytorch_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _flatten(self, node: fx.node.Node) -> relax.Var:
start_dim = node.args[1] if len(node.args) >= 2 else 0
end_dim = node.args[2] if len(node.args) == 3 else -1
assert start_dim == 1 and end_dim == -1
return self.bb.emit(relax.op.flatten(x))
return self.bb.emit(relax.op.nn.flatten(x))

def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
Expand Down Expand Up @@ -395,7 +395,7 @@ def _matmul_impl(self, a: relax.Expr, b: relax.Expr):
return self.bb.emit(relax.op.nn.matmul(a, b, out_dtype="float32"))

def _gelu(self, node: fx.node.Node) -> relax.Var:
return self.bb.emit(relax.op.gelu(self.env[node.args[0]]))
return self.bb.emit(relax.op.nn.gelu(self.env[node.args[0]]))

def _interpolate(self, node: fx.node.Node) -> relax.Var:
# torch.nn.functional.interpolate(
Expand Down Expand Up @@ -523,7 +523,7 @@ def _view(self, node: fx.node.Node) -> relax.Var:

def _silu(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
return self.bb.emit(relax.op.silu(x))
return self.bb.emit(relax.op.nn.silu(x))

def _group_norm(self, node: fx.node.Node) -> relax.Var:
# torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)
Expand Down

0 comments on commit f03f9d2

Please sign in to comment.