-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relax][Training] Loss functions (#112)
This PR introduces loss functions for relax training and provides a tool `append_loss` which enables user to append a loss after a forward function. About the `append_loss`, some previous discussions can be found in #111. Currently support: - L1Loss - MSELoss - CrossEntropyLoss
- Loading branch information
Showing
12 changed files
with
1,037 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,3 +17,5 @@ | |
"""The Relax training APIs.""" | ||
|
||
from . import optimizer | ||
from . import utils | ||
from . import loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""FFI APIs for tvm.relax.training""" | ||
import tvm._ffi | ||
|
||
tvm._ffi._init_api("relax.training", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=redefined-builtin, invalid-name | ||
"""Loss functions library for relax.""" | ||
|
||
from typing import Optional, Union | ||
|
||
# isort: off | ||
from typing_extensions import Literal | ||
|
||
# isort: on | ||
|
||
from ..block_builder import BlockBuilder | ||
from ..expr import Expr, Var, Function, StructInfo | ||
|
||
from ..op import abs, sum, mean, subtract, multiply | ||
from ..op.nn import log_softmax, nll_loss | ||
|
||
|
||
def _create_param_var(param: Union[Var, StructInfo], param_name: str) -> Var: | ||
if isinstance(param, StructInfo): | ||
param = Var(param_name, param) | ||
if not isinstance(param, Var): | ||
raise TypeError("The type of param should be Var or StructInfo, but got " + type(param)) | ||
return Var(param.name_hint, param.struct_info) | ||
|
||
|
||
class Loss: | ||
r"""Base class of all loss. | ||
Parameters | ||
---------- | ||
loss_name : str | ||
The name of the loss function. | ||
reduction : Literal["mean", "sum", "none"] | ||
The reduction method to apply to output. Can be "mean", "sum" or "none". | ||
none : no reduction will be applied, | ||
mean : the sum of the output will be divided by the batch_size, | ||
sum : the output will be summed. | ||
""" | ||
|
||
_valid_reductions = ["mean", "sum", "none"] | ||
|
||
def __init__(self, loss_name: str, reduction: Literal["mean", "sum", "none"] = "mean") -> None: | ||
self._loss_name = loss_name | ||
self._reduction = reduction | ||
|
||
if self._reduction not in self._valid_reductions: | ||
raise ValueError("Reduction can only be one of these values: ", self._valid_reductions) | ||
|
||
def _with_reduction(self, expr: Expr) -> Expr: | ||
"""Add a reduction to the final loss. | ||
Parameters | ||
---------- | ||
expr : Expr | ||
The loss expr. | ||
""" | ||
if self._reduction == "sum": | ||
expr = sum(expr) | ||
elif self._reduction == "mean": | ||
expr = mean(expr) | ||
elif self._reduction != "none": | ||
raise ValueError("Reduction can only be one of these values: ", self._valid_reductions) | ||
return expr | ||
|
||
|
||
class L1Loss(Loss): | ||
r"""Mean element-wise absolute value difference. | ||
Parameters | ||
---------- | ||
reduction : Literal["mean", "sum", "none"] | ||
The reduction method to apply to output. Can be "mean", "sum" or "none". | ||
none : no reduction will be applied, | ||
mean : the sum of the output will be divided by the batch_size, | ||
sum : the output will be summed. | ||
""" | ||
|
||
def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> None: | ||
super().__init__("l1_loss", reduction) | ||
|
||
def __call__( | ||
self, | ||
predictions: Union[Var, StructInfo], | ||
targets: Union[Var, StructInfo], | ||
) -> Function: | ||
"""Get the relax function of L1Loss. If the parameters are | ||
struct info, it will create corresponding variables. | ||
Parameters | ||
---------- | ||
predictions : Union[Var, StructInfo] | ||
The predictions of the model in the calculation of loss. | ||
targets : Union[Var, StructInfo] | ||
The ground truth in the calculation of loss. | ||
Returns | ||
---------- | ||
The relax function of L1Loss with the loss name as its global symbol. | ||
""" | ||
bb = BlockBuilder() | ||
|
||
predictions = _create_param_var(predictions, "predictions") | ||
targets = _create_param_var(targets, "targets") | ||
|
||
with bb.function(self._loss_name, [predictions, targets]): | ||
with bb.dataflow(): | ||
lv = abs(subtract(predictions, targets)) | ||
loss = bb.emit_output(self._with_reduction(lv)) | ||
bb.emit_func_output(loss) | ||
|
||
return bb.get()[self._loss_name] | ||
|
||
|
||
class MSELoss(Loss): | ||
r"""Measures the element-wise mean squared error. | ||
Parameters | ||
---------- | ||
reduction : Literal["mean", "sum", "none"] | ||
The reduction method to apply to output. Can be "mean", "sum" or "none". | ||
none : no reduction will be applied, | ||
mean : the sum of the output will be divided by the batch_size, | ||
sum : the output will be summed. | ||
""" | ||
|
||
def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> None: | ||
super().__init__("mse_loss", reduction) | ||
|
||
def __call__( | ||
self, | ||
predictions: Union[Var, StructInfo], | ||
targets: Union[Var, StructInfo], | ||
) -> Function: | ||
"""Get the relax function of MSELoss. If the parameters are | ||
struct info, it will create corresponding variables. | ||
Parameters | ||
---------- | ||
predictions : Union[Var, StructInfo] | ||
The predictions of the model in the calculation of loss. | ||
targets : Union[Var, StructInfo] | ||
The ground truth in the calculation of loss. | ||
Returns | ||
---------- | ||
The relax function of MSELoss with the loss name as its global symbol. | ||
""" | ||
bb = BlockBuilder() | ||
|
||
predictions = _create_param_var(predictions, "predictions") | ||
targets = _create_param_var(targets, "targets") | ||
|
||
with bb.function(self._loss_name, [predictions, targets]): | ||
with bb.dataflow(): | ||
lv = subtract(predictions, targets) | ||
lv = multiply(lv, lv) | ||
loss = bb.emit_output(self._with_reduction(lv)) | ||
bb.emit_func_output(loss) | ||
|
||
return bb.get()[self._loss_name] | ||
|
||
|
||
class CrossEntropyLoss(Loss): | ||
r"""CrossEntropyLoss. It is a combination of a log_softmax computation and a nll_loss. | ||
Parameters | ||
---------- | ||
reduction : Literal["mean", "sum", "none"] | ||
The reduction method to apply to output. Can be "mean", "sum" or "none". | ||
none : no reduction will be applied, | ||
mean : the sum of the output will be divided by the batch_size, | ||
sum : the output will be summed. | ||
ignore_index : int | ||
Specifies a target value that is ignored and does not contribute to the input gradient. | ||
""" | ||
|
||
ignore_index: int | ||
|
||
def __init__( | ||
self, | ||
reduction: Literal["mean", "sum", "none"] = "mean", | ||
ignore_index: int = -100, | ||
) -> None: | ||
super().__init__("cross_entropy_loss", reduction) | ||
self.ignore_index = ignore_index | ||
|
||
def __call__( | ||
self, | ||
predictions: Union[Var, StructInfo], | ||
targets: Union[Var, StructInfo], | ||
weights: Optional[Union[Var, StructInfo]] = None, | ||
) -> Function: | ||
"""Get the relax function of CrossEntropyLoss. If the parameters are | ||
struct info, it will create corresponding variables. | ||
Parameters | ||
---------- | ||
predictions : Union[Var, StructInfo] | ||
The predictions of the model in the calculation of loss. | ||
targets : Union[Var, StructInfo] | ||
The ground truth in the calculation of loss. | ||
weights : Optional[Union[Var, StructInfo]] | ||
a manual rescaling weight given to each class. It has to be a Tensor of size C. | ||
Returns | ||
---------- | ||
The relax function of CrossEntropyLoss with the loss name as its global symbol. | ||
""" | ||
bb = BlockBuilder() | ||
|
||
predictions = _create_param_var(predictions, "predictions") | ||
targets = _create_param_var(targets, "targets") | ||
|
||
arg_list = [predictions, targets] | ||
if weights: | ||
weights = _create_param_var(weights, "weights") | ||
arg_list.append(weights) | ||
|
||
with bb.function(self._loss_name, arg_list): | ||
with bb.dataflow(): | ||
logits = bb.emit(log_softmax(predictions)) | ||
loss = bb.emit_output( | ||
nll_loss(logits, targets, weights, self._reduction, self.ignore_index) | ||
) | ||
bb.emit_func_output(loss) | ||
|
||
return bb.get()[self._loss_name] |
Oops, something went wrong.