Skip to content

Commit

Permalink
[Relax][Training] Loss functions (#112)
Browse files Browse the repository at this point in the history
This PR introduces loss functions for relax training and provides a tool
`append_loss` which enables user to append a loss after a forward
function.

About the `append_loss`, some previous discussions can be found in
#111.

Currently support:
-  L1Loss
-  MSELoss
- CrossEntropyLoss
  • Loading branch information
SiriusNEO authored Jan 31, 2023
1 parent 77445d4 commit 38ac587
Show file tree
Hide file tree
Showing 12 changed files with 1,037 additions and 87 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from . import expr_functor
from . import struct_info
from . import utils
from . import training

# Expr

Expand Down Expand Up @@ -88,4 +89,4 @@
)

# Training utils
from .training import optimizer
from .training import loss, optimizer
2 changes: 1 addition & 1 deletion python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def nll_loss(
The weight of each target value.
If not specified, it is treated as if having all ones.
reduction : string
reduction : str
The reduction method to apply to the output.
Possible values are "mean", "sum" and "none".
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@
"""The Relax training APIs."""

from . import optimizer
from . import utils
from . import loss
20 changes: 20 additions & 0 deletions python/tvm/relax/training/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.relax.training"""
import tvm._ffi

tvm._ffi._init_api("relax.training", __name__)
251 changes: 251 additions & 0 deletions python/tvm/relax/training/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin, invalid-name
"""Loss functions library for relax."""

from typing import Optional, Union

# isort: off
from typing_extensions import Literal

# isort: on

from ..block_builder import BlockBuilder
from ..expr import Expr, Var, Function, StructInfo

from ..op import abs, sum, mean, subtract, multiply
from ..op.nn import log_softmax, nll_loss


def _create_param_var(param: Union[Var, StructInfo], param_name: str) -> Var:
if isinstance(param, StructInfo):
param = Var(param_name, param)
if not isinstance(param, Var):
raise TypeError("The type of param should be Var or StructInfo, but got " + type(param))
return Var(param.name_hint, param.struct_info)


class Loss:
r"""Base class of all loss.
Parameters
----------
loss_name : str
The name of the loss function.
reduction : Literal["mean", "sum", "none"]
The reduction method to apply to output. Can be "mean", "sum" or "none".
none : no reduction will be applied,
mean : the sum of the output will be divided by the batch_size,
sum : the output will be summed.
"""

_valid_reductions = ["mean", "sum", "none"]

def __init__(self, loss_name: str, reduction: Literal["mean", "sum", "none"] = "mean") -> None:
self._loss_name = loss_name
self._reduction = reduction

if self._reduction not in self._valid_reductions:
raise ValueError("Reduction can only be one of these values: ", self._valid_reductions)

def _with_reduction(self, expr: Expr) -> Expr:
"""Add a reduction to the final loss.
Parameters
----------
expr : Expr
The loss expr.
"""
if self._reduction == "sum":
expr = sum(expr)
elif self._reduction == "mean":
expr = mean(expr)
elif self._reduction != "none":
raise ValueError("Reduction can only be one of these values: ", self._valid_reductions)
return expr


class L1Loss(Loss):
r"""Mean element-wise absolute value difference.
Parameters
----------
reduction : Literal["mean", "sum", "none"]
The reduction method to apply to output. Can be "mean", "sum" or "none".
none : no reduction will be applied,
mean : the sum of the output will be divided by the batch_size,
sum : the output will be summed.
"""

def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> None:
super().__init__("l1_loss", reduction)

def __call__(
self,
predictions: Union[Var, StructInfo],
targets: Union[Var, StructInfo],
) -> Function:
"""Get the relax function of L1Loss. If the parameters are
struct info, it will create corresponding variables.
Parameters
----------
predictions : Union[Var, StructInfo]
The predictions of the model in the calculation of loss.
targets : Union[Var, StructInfo]
The ground truth in the calculation of loss.
Returns
----------
The relax function of L1Loss with the loss name as its global symbol.
"""
bb = BlockBuilder()

predictions = _create_param_var(predictions, "predictions")
targets = _create_param_var(targets, "targets")

with bb.function(self._loss_name, [predictions, targets]):
with bb.dataflow():
lv = abs(subtract(predictions, targets))
loss = bb.emit_output(self._with_reduction(lv))
bb.emit_func_output(loss)

return bb.get()[self._loss_name]


class MSELoss(Loss):
r"""Measures the element-wise mean squared error.
Parameters
----------
reduction : Literal["mean", "sum", "none"]
The reduction method to apply to output. Can be "mean", "sum" or "none".
none : no reduction will be applied,
mean : the sum of the output will be divided by the batch_size,
sum : the output will be summed.
"""

def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> None:
super().__init__("mse_loss", reduction)

def __call__(
self,
predictions: Union[Var, StructInfo],
targets: Union[Var, StructInfo],
) -> Function:
"""Get the relax function of MSELoss. If the parameters are
struct info, it will create corresponding variables.
Parameters
----------
predictions : Union[Var, StructInfo]
The predictions of the model in the calculation of loss.
targets : Union[Var, StructInfo]
The ground truth in the calculation of loss.
Returns
----------
The relax function of MSELoss with the loss name as its global symbol.
"""
bb = BlockBuilder()

predictions = _create_param_var(predictions, "predictions")
targets = _create_param_var(targets, "targets")

with bb.function(self._loss_name, [predictions, targets]):
with bb.dataflow():
lv = subtract(predictions, targets)
lv = multiply(lv, lv)
loss = bb.emit_output(self._with_reduction(lv))
bb.emit_func_output(loss)

return bb.get()[self._loss_name]


class CrossEntropyLoss(Loss):
r"""CrossEntropyLoss. It is a combination of a log_softmax computation and a nll_loss.
Parameters
----------
reduction : Literal["mean", "sum", "none"]
The reduction method to apply to output. Can be "mean", "sum" or "none".
none : no reduction will be applied,
mean : the sum of the output will be divided by the batch_size,
sum : the output will be summed.
ignore_index : int
Specifies a target value that is ignored and does not contribute to the input gradient.
"""

ignore_index: int

def __init__(
self,
reduction: Literal["mean", "sum", "none"] = "mean",
ignore_index: int = -100,
) -> None:
super().__init__("cross_entropy_loss", reduction)
self.ignore_index = ignore_index

def __call__(
self,
predictions: Union[Var, StructInfo],
targets: Union[Var, StructInfo],
weights: Optional[Union[Var, StructInfo]] = None,
) -> Function:
"""Get the relax function of CrossEntropyLoss. If the parameters are
struct info, it will create corresponding variables.
Parameters
----------
predictions : Union[Var, StructInfo]
The predictions of the model in the calculation of loss.
targets : Union[Var, StructInfo]
The ground truth in the calculation of loss.
weights : Optional[Union[Var, StructInfo]]
a manual rescaling weight given to each class. It has to be a Tensor of size C.
Returns
----------
The relax function of CrossEntropyLoss with the loss name as its global symbol.
"""
bb = BlockBuilder()

predictions = _create_param_var(predictions, "predictions")
targets = _create_param_var(targets, "targets")

arg_list = [predictions, targets]
if weights:
weights = _create_param_var(weights, "weights")
arg_list.append(weights)

with bb.function(self._loss_name, arg_list):
with bb.dataflow():
logits = bb.emit(log_softmax(predictions))
loss = bb.emit_output(
nll_loss(logits, targets, weights, self._reduction, self.ignore_index)
)
bb.emit_func_output(loss)

return bb.get()[self._loss_name]
Loading

0 comments on commit 38ac587

Please sign in to comment.