Skip to content

Commit

Permalink
enrich the doc and rename parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzilin committed May 18, 2021
1 parent ab4c3fe commit f636993
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 60 deletions.
30 changes: 15 additions & 15 deletions include/tvm/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -647,41 +647,41 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
/*!
* \brief Negative log likelihood loss.
*
* \param input The input tensor.
* \param target The target tensor.
* \param weight A manual rescaling weight given to each class.
* \param predictions The prediction tensor.
* \param targets The target tensor.
* \param weights A manual rescaling weight given to each class.
* \param reduction The reduction method to apply to the output.
* \param ignore_index The target value to ignore.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the batch_to_space_nd operation
*/
inline Tensor nll_loss(const Tensor& input, const Tensor& target, const Tensor& weight,
inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights,
std::string reduction = "mean", int ignore_index = -100,
const std::string name = "nll_loss", const std::string tag = kBroadcast) {
auto T = tvm::te::compute(
target->shape,
targets->shape,
[&](const tvm::Array<tvm::tir::Var>& target_indices) {
auto c = target(target_indices);
tvm::Array<tvm::PrimExpr> input_indices;
auto c = targets(target_indices);
tvm::Array<tvm::PrimExpr> pred_indices;
for (size_t i = 0; i < target_indices.size(); i++) {
input_indices.push_back(target_indices[i]);
pred_indices.push_back(target_indices[i]);
if (i == 0) {
input_indices.push_back(c);
pred_indices.push_back(c);
}
}
return tvm::tir::Select(c != ignore_index, -input(input_indices) * weight(c),
tvm::tir::make_const(input->dtype, 0));
return tvm::tir::Select(c != ignore_index, -predictions(pred_indices) * weights(c),
tvm::tir::make_const(predictions->dtype, 0));
},
name, tag);
if (reduction == "mean") {
auto W = tvm::te::compute(
target->shape,
targets->shape,
[&](const tvm::Array<tvm::tir::Var>& target_indices) {
auto c = target(target_indices);
return tvm::tir::Select(c != ignore_index, weight(c),
tvm::tir::make_const(input->dtype, 0));
auto c = targets(target_indices);
return tvm::tir::Select(c != ignore_index, weights(c),
tvm::tir::make_const(predictions->dtype, 0));
},
name, tag);
return topi::divide(topi::sum(T, {}), topi::sum(W, {}));
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2307,17 +2307,17 @@ def unique(self, inputs, input_types):

def nll_loss(self, inputs, input_types):
assert len(inputs) == 5
[input, target, weight, reduction, ignore_index] = inputs
num_class = self.infer_shape(input)[1]
[predictions, targets, weights, reduction, ignore_index] = inputs
num_class = self.infer_shape(predictions)[1]
if reduction == 0:
reduction = "none"
elif reduction == 1:
reduction = "mean"
else:
reduction = "sum"
if weight is None:
weight = _op.full(_expr.const(1), (num_class,), dtype=input_types[0])
return _op.nn.nll_loss(input, target, weight, reduction, ignore_index)
if weights is None:
weights = _op.full(_expr.const(1), (num_class,), dtype=input_types[0])
return _op.nn.nll_loss(predictions, targets, weights, reduction, ignore_index)

# Operator mappings
def create_convert_map(self):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,8 +878,8 @@ def compute_cross_entropy_with_logits(attrs, inputs, out_dtype):
# nll_loss
@reg.register_compute("nn.nll_loss")
def compute_nll_loss(attrs, inputs, out_dtype):
input, target, weights = inputs
return [topi.nn.nll_loss(input, target, weights, attrs.reduction, attrs.ignore_index)]
predictions, targets, weights = inputs
return [topi.nn.nll_loss(predictions, targets, weights, attrs.reduction, attrs.ignore_index)]


reg.register_reduce_schedule("nn.nll_loss")
Expand Down
20 changes: 13 additions & 7 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2973,22 +2973,28 @@ def cross_entropy_with_logits(predictions, targets):
return _make.cross_entropy_with_logits(predictions, targets)


def nll_loss(input, target, weight, reduction="mean", ignore_index=-100):
def nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-100):
"""Negative log likelihood loss.
output{n, i_1, i_2, ..., i_k} = predictions{n, t, i_1, i_2, i_k}
where t = target{n, i_1, i_2, ..., i_k}
result = reduction(output)
Parameters
----------
input : tvm.relay.Expr
The input.
predictions : tvm.relay.Expr
The predictions.
target : tvm.relay.Expr
The target value of the input.
targets : tvm.relay.Expr
The target value of each prediction.
weight : tvm.relay.Expr
weights : tvm.relay.Expr
The weight of each target value.
reduction : string
The reduction method to apply to the output.
Possible values are "mean", "sum" and "none".
ignore_index : int
The target value to ignore.
Expand All @@ -2998,7 +3004,7 @@ def nll_loss(input, target, weight, reduction="mean", ignore_index=-100):
result : tvm.relay.Expr
The computed result.
"""
return _make.nll_loss(input, target, weight, reduction, ignore_index)
return _make.nll_loss(predictions, targets, weights, reduction, ignore_index)


def depth_to_space(data, block_size, layout="NCHW", mode="DCR"):
Expand Down
17 changes: 11 additions & 6 deletions python/tvm/topi/nn/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,30 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-argument
"""TVM operator negative log likelihood loss compute."""
"""Loss functions definitions."""
from __future__ import absolute_import
from . import cpp


def nll_loss(input, target, weight, reduction, ignore_index):
def nll_loss(predictions, targets, weights, reduction, ignore_index):
"""Negative log likelihood loss on the input data.
output{n, i_1, i_2, ..., i_k} = predictions{n, t, i_1, i_2, i_k}
where t = target{n, i_1, i_2, ..., i_k}
result = reduction(output)
Parameters
----------
input : tvm.te.Tensor
predictions : tvm.te.Tensor
(k+2)-D with shape (N, C, d_1, d_2, ..., d_k),
where C is the number of target classes
target : tvm.te.Tensor
targets : tvm.te.Tensor
(k+1)-D with shape (N, d_1, d_2, ..., d_k)
The target value of the input.
weight : tvm.te.Tensor
weights : tvm.te.Tensor
1-D with shape (C,)
The weight of each target value.
Expand All @@ -50,4 +55,4 @@ def nll_loss(input, target, weight, reduction, ignore_index):
a scalar if the reduction type is "mean" or "sum",
otherwise the same shape as `target`.
"""
return cpp.nn.nll_loss(input, target, weight, reduction, ignore_index)
return cpp.nn.nll_loss(predictions, targets, weights, reduction, ignore_index)
55 changes: 30 additions & 25 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1097,53 +1097,58 @@ TVM_REGISTER_NODE_TYPE(NLLLossAttrs);

bool NLLLossRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 4);
const auto* input = types[0].as<TensorTypeNode>();
const auto* target = types[1].as<TensorTypeNode>();
const auto* weight = types[2].as<TensorTypeNode>();
ICHECK_EQ(types.size(), 4) << "NLLLossRel expects 4 types, but " << types.size()
<< " were provided.";
const auto* predictions = types[0].as<TensorTypeNode>();
const auto* targets = types[1].as<TensorTypeNode>();
const auto* weights = types[2].as<TensorTypeNode>();
const NLLLossAttrs* param = attrs.as<NLLLossAttrs>();
if (input == nullptr || target == nullptr || weight == nullptr) return false;
ICHECK(input->shape.size() - target->shape.size() == 1)
<< "NLLLossRel: input should be one dimension larger than target, "
<< "input shape = " << input->shape << ", "
<< "target shape = " << target->shape;
ICHECK(weight->shape.size() == 1);
ICHECK(reporter->AssertEQ(input->shape[1], weight->shape[0]))
<< "NLLLossRel: the second dimension of input should be the number of classes, "
<< "which is the length of weight, "
<< "input shape = " << input->shape << ", "
<< "weight shape = " << weight->shape;
ICHECK(input->dtype == weight->dtype && input->dtype.is_float());
ICHECK(target->dtype.is_int());
if (predictions == nullptr || targets == nullptr || weights == nullptr) return false;
ICHECK(predictions->shape.size() - targets->shape.size() == 1)
<< "NLLLossRel: predictions should be one dimension larger than targets, "
<< "predictions shape = " << predictions->shape << ", "
<< "targets shape = " << targets->shape;
ICHECK(weights->shape.size() == 1)
<< "NLLLossRel: weights should be a one dimension Tensor with its length "
<< "the number of classes, but Tensor of dimension " << weights->shape.size()
<< " were provided.";
ICHECK(reporter->AssertEQ(predictions->shape[1], weights->shape[0]))
<< "NLLLossRel: the second dimension of predictions should be the number of classes, "
<< "which is the length of weights, "
<< "predictions shape = " << predictions->shape << ", "
<< "weights shape = " << weights->shape;
ICHECK(predictions->dtype == weights->dtype && predictions->dtype.is_float())
<< "NLLLossRel: predictions and weights should be of the same floating type.";
ICHECK(targets->dtype.is_int()) << "NLLLossRel: targets should be of int type.";
// assign output type
if (param->reduction == "none") {
reporter->Assign(types[3], TensorType(target->shape, input->dtype));
reporter->Assign(types[3], TensorType(targets->shape, predictions->dtype));
} else {
reporter->Assign(types[3], TensorType({}, input->dtype));
reporter->Assign(types[3], TensorType({}, predictions->dtype));
}
return true;
}

// Handler to create a call to the padding op used by front-end FFI
Expr MakeNLLLoss(Expr input, Expr target, Expr weight, String reduction, int ignore_index) {
Expr MakeNLLLoss(Expr predictions, Expr targets, Expr weights, String reduction, int ignore_index) {
auto attrs = make_object<NLLLossAttrs>();
attrs->reduction = reduction;
attrs->ignore_index = ignore_index;
static const Op& op = Op::Get("nn.nll_loss");
return Call(op, {input, target, weight}, Attrs(attrs), {});
return Call(op, {predictions, targets, weights}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.nll_loss").set_body_typed(MakeNLLLoss);

RELAY_REGISTER_OP("nn.nll_loss")
.describe(R"code(
Negative log likelihood loss for given input and target.
Negative log likelihood loss for given prediction and target.
)code" TVM_ADD_FILELINE)
.set_attrs_type<NLLLossAttrs>()
.set_num_inputs(3)
.add_argument("input", "Tensor", "The input tensor.")
.add_argument("target", "Tensor", "The target tensor.")
.add_argument("weight", "Tensor", "The weight of each target values.")
.add_argument("predictions", "Tensor", "The prediction tensor.")
.add_argument("targets", "Tensor", "The target tensor.")
.add_argument("weights", "Tensor", "The weight of each target values.")
.add_type_rel("NLLLoss", NLLLossRel);

bool DepthToSpaceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down

0 comments on commit f636993

Please sign in to comment.