Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Op][NN] cross_entropy, log_softmax, nll_loss #94

Merged
merged 10 commits into from
Jan 15, 2023
13 changes: 13 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,19 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
}
}; // struct DropoutAttrs

/*! \brief Attributes used in nll_loss operator */
struct NLLLossAttrs : public tvm::AttrsNode<NLLLossAttrs> {
String reduction;
int ignore_index;

TVM_DECLARE_ATTRS(NLLLossAttrs, "relax.attrs.NLLLossAttrs") {
TVM_ATTR_FIELD(reduction).set_default("mean").describe(
"The reduction method to apply to the output. Can be"
"'none', 'mean' or 'sum'.");
TVM_ATTR_FIELD(ignore_index).describe("The target value to ignore.");
}
}; // struct NLLLossAttrs

} // namespace relax
} // namespace tvm

Expand Down
127 changes: 126 additions & 1 deletion python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ def silu(data: Expr) -> Expr:
def softmax(data: Expr, axis: int = -1) -> Expr:
r"""Computes softmax.

.. math:: text{softmax}(x)_i = frac{exp(x_i)}{\sum_j exp(x_j)}
.. math::

\text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}

Parameters
----------
Expand All @@ -351,6 +353,34 @@ def softmax(data: Expr, axis: int = -1) -> Expr:
return _ffi_api.softmax(data, axis) # type: ignore


def log_softmax(data: Expr, axis: int = -1) -> Expr:
r"""Computes log softmax.

.. math::

\text{log\_softmax}(x_i) = \log\left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}\right)

.. note::
This operator can be optimized away for inference.

Parameters
----------
data: relax.Expr
The input data to the operator.

axis: int
The axis to sum over when computing log softmax.
If not specified, it is by default the last axis of the input tensor.
Supports negative indexing.

Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.log_softmax(data, axis) # type: ignore


def batch_norm(
data: Expr,
gamma: Expr,
Expand Down Expand Up @@ -525,3 +555,98 @@ def dropout(data: Expr, rate: float = 0.5) -> Expr:
mask tensor (1.0 where element not dropped, 0.0 where dropped)
"""
return _ffi_api.dropout(data, rate) # type: ignore


def cross_entropy_without_logits(predictions: Expr, labels: Expr) -> Expr:
r"""CrossEntropy without logits between the predictions and labels.

The shape of predictions and labels must be the same. And when ndim >= 2,
the first dimension is regarded as the batch_size N. In this case the
computed result will divide by N to perform a mean reduction.

.. math::

\text{cross\_entropy\_without\_logits}(x_i, y_i) = \frac{\sum_i -y_i \log x_i}{N}

Parameters
----------
predictions : relax.Expr
The predictions.

labels : relax.Expr
The labels (the ground truth values).

Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.cross_entropy_without_logits(predictions, labels) # type: ignore


def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr:
r"""CrossEntropy with logits between the predictions and labels.

The shape issue is the same with cross_entropy_without_logits.

.. math::

\text{cross\_entropy\_with\_logits}(x_i, y_i) = \frac{\sum_i -x_i \cdot y_i}{N}

Parameters
----------
predictions : relax.Expr
The predictions.

labels : relax.Expr
The labels (the ground truth values).

Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.cross_entropy_with_logits(predictions, labels) # type: ignore


def nll_loss(
predictions: Expr,
targets: Expr,
weights: Optional[Expr] = None,
reduction: str = "mean",
ignore_index: int = -100,
) -> Expr:
"""Negative log likelihood loss.

`output[n, i_1, i_2, ..., i_k] = -p * w`, where
- `p = predictions[n, t, i_1, i_2, i_k]`,
- `t = targets[n, i_1, i_2, ..., i_k]`,
- `w = weights[n, i_1, i_2, ..., i_k] if t != ignore_index else 0`

result = reduction(output)

Parameters
----------
predictions : relax.Expr
The predictions.

targets : relax.Expr
The target value of each prediction.

weights : Optional[relax.Expr]
The weight of each target value.
If not specified, it is treated as if having all ones.

reduction : string
The reduction method to apply to the output.
Possible values are "mean", "sum" and "none".

ignore_index : int
The target value to ignore.

Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.nll_loss(predictions, targets, weights, reduction, ignore_index) # type: ignore
Loading