Skip to content

Commit

Permalink
[PyTorch] [Relay] Add l1 and mse loss function for pytorch frontend (a…
Browse files Browse the repository at this point in the history
…pache#11978)

* add l1 and mse loss function for pytorch frontend

* fix CI
  • Loading branch information
Yuanjing Shi authored Jul 1, 2022
1 parent beea0d2 commit ec39199
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 3 deletions.
33 changes: 32 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,35 @@ def cross_entropy_loss_with_logits(self, inputs, input_types):
assert weights is None, "weight not supported in cross_entropy_loss"
return _op.nn.cross_entropy_with_logits(_op.nn.log_softmax(input), target)

def l1_loss(self, inputs, input_types):
assert len(inputs) == 3
[predictions, targets, reduction] = inputs
delta = _op.abs(_op.subtract(predictions, targets))
if reduction == 0:
# reduction = "none"
return delta
elif reduction == 1:
# reduction = "mean"
return _op.mean(delta)
else:
# reduction = "sum"
return _op.sum(delta)

def mse_loss(self, inputs, input_types):
assert len(inputs) == 3
[predictions, targets, reduction] = inputs
delta = _op.subtract(predictions, targets)
delta = _op.power(delta, _expr.const(2, input_types[0]))
if reduction == 0:
# reduction = "none"
return delta
elif reduction == 1:
# reduction = "mean"
return _op.mean(delta)
else:
# reduction = "sum"
return _op.sum(delta)

def hard_sigmoid(self, inputs, input_types):
def _relu6(x):
return _op.tensor.clip(x, 0.0, 6.0)
Expand Down Expand Up @@ -3200,7 +3229,6 @@ def create_convert_map(self):
"aten::silu": self.silu,
"aten::glu": self.glu,
"aten::log_sigmoid": self.log_sigmoid,
"aten::cross_entropy_loss": self.cross_entropy_loss_with_logits,
"aten::adaptive_avg_pool1d": functools.partial(
self.adaptive_avg_pool, _op.nn.adaptive_avg_pool1d
),
Expand Down Expand Up @@ -3374,6 +3402,9 @@ def create_convert_map(self):
"aten::nll_loss": self.nll_loss,
"aten::nll_loss2d": self.nll_loss,
"aten::nll_loss_nd": self.nll_loss,
"aten::cross_entropy_loss": self.cross_entropy_loss_with_logits,
"aten::l1_loss": self.l1_loss,
"aten::mse_loss": self.mse_loss,
"aten::flip": self.flip,
"aten::gru": self.gru,
"aten::lstm": self.lstm,
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/nn/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ def log_softmax(x, axis=-1):
Parameters
----------
data : tvm.te.Tensor
2-D input data
N-D input data
Returns
-------
output : tvm.te.Tensor
2-D output with same shape
N-D output with same shape
"""
shape = x.shape
if axis < 0:
Expand Down
36 changes: 36 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4177,6 +4177,42 @@ def test_cross_entropy_loss():
verify_model(torch.nn.CrossEntropyLoss().eval(), input_data=[predictions, targets])


def test_forward_l1_loss():
torch.set_grad_enabled(False)
N, C = 10, 3
predictions = torch.rand((N, C)).float()
targets = torch.rand((N, C)).float()
verify_model(torch.nn.L1Loss().eval(), input_data=[predictions, targets])
verify_model(torch.nn.L1Loss(reduction="sum").eval(), input_data=[predictions, targets])
verify_model(torch.nn.L1Loss(reduction="none").eval(), input_data=[predictions, targets])

# multidimension l1 loss
d1, d2 = 2, 3
predictions = torch.rand((N, C, d1, d2)).float()
targets = torch.rand((N, C, d1, d2)).float()
verify_model(torch.nn.L1Loss().eval(), input_data=[predictions, targets])
verify_model(torch.nn.L1Loss(reduction="sum").eval(), input_data=[predictions, targets])
verify_model(torch.nn.L1Loss(reduction="none").eval(), input_data=[predictions, targets])


def test_forward_mse_loss():
torch.set_grad_enabled(False)
N, C = 10, 3
predictions = torch.rand((N, C)).float()
targets = torch.rand((N, C)).float()
verify_model(torch.nn.MSELoss().eval(), input_data=[predictions, targets])
verify_model(torch.nn.MSELoss(reduction="sum").eval(), input_data=[predictions, targets])
verify_model(torch.nn.MSELoss(reduction="none").eval(), input_data=[predictions, targets])

# multidimension mse loss
d1, d2 = 2, 3
predictions = torch.rand((N, C, d1, d2)).float()
targets = torch.rand((N, C, d1, d2)).float()
verify_model(torch.nn.MSELoss().eval(), input_data=[predictions, targets])
verify_model(torch.nn.MSELoss(reduction="sum").eval(), input_data=[predictions, targets])
verify_model(torch.nn.MSELoss(reduction="none").eval(), input_data=[predictions, targets])


@tvm.testing.uses_gpu
def test_forward_flip():
torch.set_grad_enabled(False)
Expand Down

0 comments on commit ec39199

Please sign in to comment.