Skip to content

Commit

Permalink
[FRONTEND][PYTORCH] Support fo nn.SiLU added (#8753)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alperen Bag authored Aug 15, 2021
1 parent 3ebd353 commit e12ddca
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
5 changes: 5 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,10 @@ def selu(self, inputs, input_types):
alpha * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data)
)

def silu(self, inputs, input_types):
data = inputs[0]
return data * _op.tensor.sigmoid(data)

def log_sigmoid(self, inputs, input_types):
data = inputs[0]
return _op.log(_op.tensor.sigmoid(data))
Expand Down Expand Up @@ -2623,6 +2627,7 @@ def create_convert_map(self):
"aten::celu": self.celu,
"aten::gelu": self.gelu,
"aten::selu": self.selu,
"aten::silu": self.silu,
"aten::log_sigmoid": self.log_sigmoid,
"aten::adaptive_avg_pool2d": self.adaptive_avg_pool_2d,
"aten::adaptive_max_pool2d": self.adaptive_max_pool_2d,
Expand Down
8 changes: 8 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,14 @@ def test_forward_selu():
verify_model(torch.nn.SELU().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_silu():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.SiLU().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_softplus():
torch.set_grad_enabled(False)
Expand Down

0 comments on commit e12ddca

Please sign in to comment.