Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][Training] Loss functions #112

Merged
merged 17 commits into from
Jan 31, 2023
3 changes: 2 additions & 1 deletion python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from . import expr_functor
from . import struct_info
from . import utils
from . import training

# Expr

Expand Down Expand Up @@ -88,4 +89,4 @@
)

# Training utils
from .training import optimizer
from .training import loss, optimizer
2 changes: 1 addition & 1 deletion python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def nll_loss(
The weight of each target value.
If not specified, it is treated as if having all ones.

reduction : string
reduction : str
The reduction method to apply to the output.
Possible values are "mean", "sum" and "none".

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@
"""The Relax training APIs."""

from . import optimizer
from . import utils
from . import loss
20 changes: 20 additions & 0 deletions python/tvm/relax/training/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.relax.training"""
SiriusNEO marked this conversation as resolved.
Show resolved Hide resolved
import tvm._ffi

tvm._ffi._init_api("relax.training", __name__)
251 changes: 251 additions & 0 deletions python/tvm/relax/training/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin, invalid-name
"""Loss functions library for relax."""

from typing import Optional, Union

# isort: off
from typing_extensions import Literal

# isort: on

from ..block_builder import BlockBuilder
from ..expr import Expr, Var, Function, StructInfo

from ..op import abs, sum, mean, subtract, multiply
from ..op.nn import log_softmax, nll_loss


def _create_param_var(param: Union[Var, StructInfo], param_name: str) -> Var:
if isinstance(param, StructInfo):
param = Var(param_name, param)
if not isinstance(param, Var):
raise TypeError("The type of param should be Var or StructInfo, but got " + type(param))
return Var(param.name_hint, param.struct_info)


class Loss:
r"""Base class of all loss.

Parameters
----------
loss_name : str
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
The name of the loss function.

reduction : Literal["mean", "sum", "none"]
The reduction method to apply to output. Can be "mean", "sum" or "none".

none : no reduction will be applied,
mean : the sum of the output will be divided by the batch_size,
sum : the output will be summed.
"""

_valid_reductions = ["mean", "sum", "none"]

def __init__(self, loss_name: str, reduction: Literal["mean", "sum", "none"] = "mean") -> None:
self._loss_name = loss_name
self._reduction = reduction

if self._reduction not in self._valid_reductions:
raise ValueError("Reduction can only be one of these values: ", self._valid_reductions)

def _with_reduction(self, expr: Expr) -> Expr:
"""Add a reduction to the final loss.

Parameters
----------
expr : Expr
The loss expr.
"""
if self._reduction == "sum":
expr = sum(expr)
elif self._reduction == "mean":
expr = mean(expr)
elif self._reduction != "none":
raise ValueError("Reduction can only be one of these values: ", self._valid_reductions)
return expr


class L1Loss(Loss):
r"""Mean element-wise absolute value difference.

Parameters
----------
reduction : Literal["mean", "sum", "none"]
The reduction method to apply to output. Can be "mean", "sum" or "none".

none : no reduction will be applied,
mean : the sum of the output will be divided by the batch_size,
sum : the output will be summed.
"""

def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> None:
super().__init__("l1_loss", reduction)

def __call__(
self,
predictions: Union[Var, StructInfo],
targets: Union[Var, StructInfo],
) -> Function:
"""Get the relax function of L1Loss. If the parameters are
struct info, it will create corresponding variables.

Parameters
----------
predictions : Union[Var, StructInfo]
The predictions of the model in the calculation of loss.
targets : Union[Var, StructInfo]
The ground truth in the calculation of loss.

Returns
----------
The relax function of L1Loss with the loss name as its global symbol.
"""
bb = BlockBuilder()

predictions = _create_param_var(predictions, "predictions")
targets = _create_param_var(targets, "targets")

with bb.function(self._loss_name, [predictions, targets]):
with bb.dataflow():
lv = abs(subtract(predictions, targets))
loss = bb.emit_output(self._with_reduction(lv))
bb.emit_func_output(loss)

return bb.get()[self._loss_name]


class MSELoss(Loss):
r"""Measures the element-wise mean squared error.

Parameters
----------
reduction : Literal["mean", "sum", "none"]
The reduction method to apply to output. Can be "mean", "sum" or "none".

none : no reduction will be applied,
mean : the sum of the output will be divided by the batch_size,
sum : the output will be summed.
"""

def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> None:
super().__init__("mse_loss", reduction)

def __call__(
self,
predictions: Union[Var, StructInfo],
targets: Union[Var, StructInfo],
) -> Function:
"""Get the relax function of MSELoss. If the parameters are
struct info, it will create corresponding variables.

Parameters
----------
predictions : Union[Var, StructInfo]
The predictions of the model in the calculation of loss.
targets : Union[Var, StructInfo]
The ground truth in the calculation of loss.

Returns
----------
The relax function of MSELoss with the loss name as its global symbol.
"""
bb = BlockBuilder()

predictions = _create_param_var(predictions, "predictions")
targets = _create_param_var(targets, "targets")

with bb.function(self._loss_name, [predictions, targets]):
with bb.dataflow():
lv = subtract(predictions, targets)
lv = multiply(lv, lv)
loss = bb.emit_output(self._with_reduction(lv))
bb.emit_func_output(loss)

return bb.get()[self._loss_name]


class CrossEntropyLoss(Loss):
r"""CrossEntropyLoss. It is a combination of a log_softmax computation and a nll_loss.

Parameters
----------
reduction : Literal["mean", "sum", "none"]
The reduction method to apply to output. Can be "mean", "sum" or "none".

none : no reduction will be applied,
mean : the sum of the output will be divided by the batch_size,
sum : the output will be summed.

ignore_index : int
Specifies a target value that is ignored and does not contribute to the input gradient.
"""

ignore_index: int

def __init__(
self,
reduction: Literal["mean", "sum", "none"] = "mean",
ignore_index: int = -100,
) -> None:
super().__init__("cross_entropy_loss", reduction)
self.ignore_index = ignore_index

def __call__(
self,
predictions: Union[Var, StructInfo],
targets: Union[Var, StructInfo],
weights: Optional[Union[Var, StructInfo]] = None,
) -> Function:
"""Get the relax function of CrossEntropyLoss. If the parameters are
struct info, it will create corresponding variables.

Parameters
----------
predictions : Union[Var, StructInfo]
The predictions of the model in the calculation of loss.

targets : Union[Var, StructInfo]
The ground truth in the calculation of loss.

weights : Optional[Union[Var, StructInfo]]
a manual rescaling weight given to each class. It has to be a Tensor of size C.

Returns
----------
The relax function of CrossEntropyLoss with the loss name as its global symbol.
"""
bb = BlockBuilder()

predictions = _create_param_var(predictions, "predictions")
targets = _create_param_var(targets, "targets")

arg_list = [predictions, targets]
if weights:
weights = _create_param_var(weights, "weights")
arg_list.append(weights)

with bb.function(self._loss_name, arg_list):
with bb.dataflow():
logits = bb.emit(log_softmax(predictions))
loss = bb.emit_output(
nll_loss(logits, targets, weights, self._reduction, self.ignore_index)
)
bb.emit_func_output(loss)

return bb.get()[self._loss_name]
Loading