Skip to content

Commit

Permalink
[Op][NN] cross_entropy, log_softmax, nll_loss (#94)
Browse files Browse the repository at this point in the history
After discussing about the loss, a good way is `log_softmax` + `nll_loss`. This PR introduces these two operators and tests them.
As for `nll_loss`, here are some basic shape descriptions which may help review. And an important reference: https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss
```
def nll_loss(
    predictions: Expr,
    targets: Expr,
    weights: Optional[Expr] = None,
    reduction: str = "mean",
    ignore_index: int = -100,
) -> Expr:

Notations:
N: minibatch size
C: number of classes
K: number of input dimensions

Shape:
    weights: (C,) (always)

  without minibatch:
    predictions: (C,)
    targets: ()
    output: ()
    
  with minibatch N:
    predictions: (N, C)
    targets: (N,)
    output: (N,) (reduction=none)
    output: () (reduction=mean/sum)
  
  with minibatch N and high dimension input d1, d2, ..., dk:
    predictions: (N, C, d1, d2, ..., dk)
    targets: (N, d1, d2, ..., dk)
    output: (N, d1, d2, ..., dk) (reduction=none)
    output: () (reduction=mean/sum)
```
Our inference rule is trusting `predictions`, do equal assertion if other arguments have enough information and do best effort inference. Please check the code for details.

This PR also introduces cross entropy operator since it is dropped when rebasing onto tlc. Given that torch has different definitions with our cross entropy, here we use the names `cross_entropy_without_logits` and `cross_entropy_with_logits` to make it less confused and align with relay.
  • Loading branch information
SiriusNEO authored and MasterJH5574 committed Jan 31, 2023
1 parent 185fa09 commit 65003de
Show file tree
Hide file tree
Showing 7 changed files with 1,228 additions and 35 deletions.
13 changes: 13 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,19 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
}
}; // struct DropoutAttrs

/*! \brief Attributes used in nll_loss operator */
struct NLLLossAttrs : public tvm::AttrsNode<NLLLossAttrs> {
String reduction;
int ignore_index;

TVM_DECLARE_ATTRS(NLLLossAttrs, "relax.attrs.NLLLossAttrs") {
TVM_ATTR_FIELD(reduction).set_default("mean").describe(
"The reduction method to apply to the output. Can be"
"'none', 'mean' or 'sum'.");
TVM_ATTR_FIELD(ignore_index).describe("The target value to ignore.");
}
}; // struct NLLLossAttrs

} // namespace relax
} // namespace tvm

Expand Down
127 changes: 126 additions & 1 deletion python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ def silu(data: Expr) -> Expr:
def softmax(data: Expr, axis: int = -1) -> Expr:
r"""Computes softmax.
.. math:: text{softmax}(x)_i = frac{exp(x_i)}{\sum_j exp(x_j)}
.. math::
\text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
Parameters
----------
Expand All @@ -348,6 +350,34 @@ def softmax(data: Expr, axis: int = -1) -> Expr:
return _ffi_api.softmax(data, axis) # type: ignore


def log_softmax(data: Expr, axis: int = -1) -> Expr:
r"""Computes log softmax.
.. math::
\text{log\_softmax}(x_i) = \log\left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}\right)
.. note::
This operator can be optimized away for inference.
Parameters
----------
data: relax.Expr
The input data to the operator.
axis: int
The axis to sum over when computing log softmax.
If not specified, it is by default the last axis of the input tensor.
Supports negative indexing.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.log_softmax(data, axis) # type: ignore


def batch_norm(
data: Expr,
gamma: Expr,
Expand Down Expand Up @@ -522,3 +552,98 @@ def dropout(data: Expr, rate: float = 0.5) -> Expr:
mask tensor (1.0 where element not dropped, 0.0 where dropped)
"""
return _ffi_api.dropout(data, rate) # type: ignore


def cross_entropy_without_logits(predictions: Expr, labels: Expr) -> Expr:
r"""CrossEntropy without logits between the predictions and labels.
The shape of predictions and labels must be the same. And when ndim >= 2,
the first dimension is regarded as the batch_size N. In this case the
computed result will divide by N to perform a mean reduction.
.. math::
\text{cross\_entropy\_without\_logits}(x_i, y_i) = \frac{\sum_i -y_i \log x_i}{N}
Parameters
----------
predictions : relax.Expr
The predictions.
labels : relax.Expr
The labels (the ground truth values).
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.cross_entropy_without_logits(predictions, labels) # type: ignore


def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr:
r"""CrossEntropy with logits between the predictions and labels.
The shape issue is the same with cross_entropy_without_logits.
.. math::
\text{cross\_entropy\_with\_logits}(x_i, y_i) = \frac{\sum_i -x_i \cdot y_i}{N}
Parameters
----------
predictions : relax.Expr
The predictions.
labels : relax.Expr
The labels (the ground truth values).
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.cross_entropy_with_logits(predictions, labels) # type: ignore


def nll_loss(
predictions: Expr,
targets: Expr,
weights: Optional[Expr] = None,
reduction: str = "mean",
ignore_index: int = -100,
) -> Expr:
"""Negative log likelihood loss.
`output[n, i_1, i_2, ..., i_k] = -p * w`, where
- `p = predictions[n, t, i_1, i_2, i_k]`,
- `t = targets[n, i_1, i_2, ..., i_k]`,
- `w = weights[n, i_1, i_2, ..., i_k] if t != ignore_index else 0`
result = reduction(output)
Parameters
----------
predictions : relax.Expr
The predictions.
targets : relax.Expr
The target value of each prediction.
weights : Optional[relax.Expr]
The weight of each target value.
If not specified, it is treated as if having all ones.
reduction : string
The reduction method to apply to the output.
Possible values are "mean", "sum" and "none".
ignore_index : int
The target value to ignore.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.nll_loss(predictions, targets, weights, reduction, ignore_index) # type: ignore
Loading

0 comments on commit 65003de

Please sign in to comment.