From f6369934c17f47f50ee8b8ef287162a95a093ddc Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Tue, 18 May 2021 13:48:34 +0800 Subject: [PATCH] enrich the doc and rename parameters --- include/tvm/topi/nn.h | 30 +++++++-------- python/tvm/relay/frontend/pytorch.py | 10 ++--- python/tvm/relay/op/nn/_nn.py | 4 +- python/tvm/relay/op/nn/nn.py | 20 ++++++---- python/tvm/topi/nn/loss.py | 17 ++++++--- src/relay/op/nn/nn.cc | 55 +++++++++++++++------------- 6 files changed, 76 insertions(+), 60 deletions(-) diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index f649e1ee49a6d..0328de8a99756 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -647,9 +647,9 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, /*! * \brief Negative log likelihood loss. * - * \param input The input tensor. - * \param target The target tensor. - * \param weight A manual rescaling weight given to each class. + * \param predictions The prediction tensor. + * \param targets The target tensor. + * \param weights A manual rescaling weight given to each class. * \param reduction The reduction method to apply to the output. * \param ignore_index The target value to ignore. * \param name The name of the operation. @@ -657,31 +657,31 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, * * \return A Tensor whose op member is the batch_to_space_nd operation */ -inline Tensor nll_loss(const Tensor& input, const Tensor& target, const Tensor& weight, +inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights, std::string reduction = "mean", int ignore_index = -100, const std::string name = "nll_loss", const std::string tag = kBroadcast) { auto T = tvm::te::compute( - target->shape, + targets->shape, [&](const tvm::Array& target_indices) { - auto c = target(target_indices); - tvm::Array input_indices; + auto c = targets(target_indices); + tvm::Array pred_indices; for (size_t i = 0; i < target_indices.size(); i++) { - input_indices.push_back(target_indices[i]); + pred_indices.push_back(target_indices[i]); if (i == 0) { - input_indices.push_back(c); + pred_indices.push_back(c); } } - return tvm::tir::Select(c != ignore_index, -input(input_indices) * weight(c), - tvm::tir::make_const(input->dtype, 0)); + return tvm::tir::Select(c != ignore_index, -predictions(pred_indices) * weights(c), + tvm::tir::make_const(predictions->dtype, 0)); }, name, tag); if (reduction == "mean") { auto W = tvm::te::compute( - target->shape, + targets->shape, [&](const tvm::Array& target_indices) { - auto c = target(target_indices); - return tvm::tir::Select(c != ignore_index, weight(c), - tvm::tir::make_const(input->dtype, 0)); + auto c = targets(target_indices); + return tvm::tir::Select(c != ignore_index, weights(c), + tvm::tir::make_const(predictions->dtype, 0)); }, name, tag); return topi::divide(topi::sum(T, {}), topi::sum(W, {})); diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a916dcdb19590..6e6c925884ab4 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2307,17 +2307,17 @@ def unique(self, inputs, input_types): def nll_loss(self, inputs, input_types): assert len(inputs) == 5 - [input, target, weight, reduction, ignore_index] = inputs - num_class = self.infer_shape(input)[1] + [predictions, targets, weights, reduction, ignore_index] = inputs + num_class = self.infer_shape(predictions)[1] if reduction == 0: reduction = "none" elif reduction == 1: reduction = "mean" else: reduction = "sum" - if weight is None: - weight = _op.full(_expr.const(1), (num_class,), dtype=input_types[0]) - return _op.nn.nll_loss(input, target, weight, reduction, ignore_index) + if weights is None: + weights = _op.full(_expr.const(1), (num_class,), dtype=input_types[0]) + return _op.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) # Operator mappings def create_convert_map(self): diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 11f681e872bcf..8b9002e6a650a 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -878,8 +878,8 @@ def compute_cross_entropy_with_logits(attrs, inputs, out_dtype): # nll_loss @reg.register_compute("nn.nll_loss") def compute_nll_loss(attrs, inputs, out_dtype): - input, target, weights = inputs - return [topi.nn.nll_loss(input, target, weights, attrs.reduction, attrs.ignore_index)] + predictions, targets, weights = inputs + return [topi.nn.nll_loss(predictions, targets, weights, attrs.reduction, attrs.ignore_index)] reg.register_reduce_schedule("nn.nll_loss") diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index b310a223b800b..8ce0d30168a01 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2973,22 +2973,28 @@ def cross_entropy_with_logits(predictions, targets): return _make.cross_entropy_with_logits(predictions, targets) -def nll_loss(input, target, weight, reduction="mean", ignore_index=-100): +def nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-100): """Negative log likelihood loss. + output{n, i_1, i_2, ..., i_k} = predictions{n, t, i_1, i_2, i_k} + where t = target{n, i_1, i_2, ..., i_k} + + result = reduction(output) + Parameters ---------- - input : tvm.relay.Expr - The input. + predictions : tvm.relay.Expr + The predictions. - target : tvm.relay.Expr - The target value of the input. + targets : tvm.relay.Expr + The target value of each prediction. - weight : tvm.relay.Expr + weights : tvm.relay.Expr The weight of each target value. reduction : string The reduction method to apply to the output. + Possible values are "mean", "sum" and "none". ignore_index : int The target value to ignore. @@ -2998,7 +3004,7 @@ def nll_loss(input, target, weight, reduction="mean", ignore_index=-100): result : tvm.relay.Expr The computed result. """ - return _make.nll_loss(input, target, weight, reduction, ignore_index) + return _make.nll_loss(predictions, targets, weights, reduction, ignore_index) def depth_to_space(data, block_size, layout="NCHW", mode="DCR"): diff --git a/python/tvm/topi/nn/loss.py b/python/tvm/topi/nn/loss.py index f1c3cf47f4f6a..da1ad6e0f66fd 100644 --- a/python/tvm/topi/nn/loss.py +++ b/python/tvm/topi/nn/loss.py @@ -15,25 +15,30 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name,unused-argument -"""TVM operator negative log likelihood loss compute.""" +"""Loss functions definitions.""" from __future__ import absolute_import from . import cpp -def nll_loss(input, target, weight, reduction, ignore_index): +def nll_loss(predictions, targets, weights, reduction, ignore_index): """Negative log likelihood loss on the input data. + output{n, i_1, i_2, ..., i_k} = predictions{n, t, i_1, i_2, i_k} + where t = target{n, i_1, i_2, ..., i_k} + + result = reduction(output) + Parameters ---------- - input : tvm.te.Tensor + predictions : tvm.te.Tensor (k+2)-D with shape (N, C, d_1, d_2, ..., d_k), where C is the number of target classes - target : tvm.te.Tensor + targets : tvm.te.Tensor (k+1)-D with shape (N, d_1, d_2, ..., d_k) The target value of the input. - weight : tvm.te.Tensor + weights : tvm.te.Tensor 1-D with shape (C,) The weight of each target value. @@ -50,4 +55,4 @@ def nll_loss(input, target, weight, reduction, ignore_index): a scalar if the reduction type is "mean" or "sum", otherwise the same shape as `target`. """ - return cpp.nn.nll_loss(input, target, weight, reduction, ignore_index) + return cpp.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 37981850b0b7b..67e996f555a46 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1097,53 +1097,58 @@ TVM_REGISTER_NODE_TYPE(NLLLossAttrs); bool NLLLossRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 4); - const auto* input = types[0].as(); - const auto* target = types[1].as(); - const auto* weight = types[2].as(); + ICHECK_EQ(types.size(), 4) << "NLLLossRel expects 4 types, but " << types.size() + << " were provided."; + const auto* predictions = types[0].as(); + const auto* targets = types[1].as(); + const auto* weights = types[2].as(); const NLLLossAttrs* param = attrs.as(); - if (input == nullptr || target == nullptr || weight == nullptr) return false; - ICHECK(input->shape.size() - target->shape.size() == 1) - << "NLLLossRel: input should be one dimension larger than target, " - << "input shape = " << input->shape << ", " - << "target shape = " << target->shape; - ICHECK(weight->shape.size() == 1); - ICHECK(reporter->AssertEQ(input->shape[1], weight->shape[0])) - << "NLLLossRel: the second dimension of input should be the number of classes, " - << "which is the length of weight, " - << "input shape = " << input->shape << ", " - << "weight shape = " << weight->shape; - ICHECK(input->dtype == weight->dtype && input->dtype.is_float()); - ICHECK(target->dtype.is_int()); + if (predictions == nullptr || targets == nullptr || weights == nullptr) return false; + ICHECK(predictions->shape.size() - targets->shape.size() == 1) + << "NLLLossRel: predictions should be one dimension larger than targets, " + << "predictions shape = " << predictions->shape << ", " + << "targets shape = " << targets->shape; + ICHECK(weights->shape.size() == 1) + << "NLLLossRel: weights should be a one dimension Tensor with its length " + << "the number of classes, but Tensor of dimension " << weights->shape.size() + << " were provided."; + ICHECK(reporter->AssertEQ(predictions->shape[1], weights->shape[0])) + << "NLLLossRel: the second dimension of predictions should be the number of classes, " + << "which is the length of weights, " + << "predictions shape = " << predictions->shape << ", " + << "weights shape = " << weights->shape; + ICHECK(predictions->dtype == weights->dtype && predictions->dtype.is_float()) + << "NLLLossRel: predictions and weights should be of the same floating type."; + ICHECK(targets->dtype.is_int()) << "NLLLossRel: targets should be of int type."; // assign output type if (param->reduction == "none") { - reporter->Assign(types[3], TensorType(target->shape, input->dtype)); + reporter->Assign(types[3], TensorType(targets->shape, predictions->dtype)); } else { - reporter->Assign(types[3], TensorType({}, input->dtype)); + reporter->Assign(types[3], TensorType({}, predictions->dtype)); } return true; } // Handler to create a call to the padding op used by front-end FFI -Expr MakeNLLLoss(Expr input, Expr target, Expr weight, String reduction, int ignore_index) { +Expr MakeNLLLoss(Expr predictions, Expr targets, Expr weights, String reduction, int ignore_index) { auto attrs = make_object(); attrs->reduction = reduction; attrs->ignore_index = ignore_index; static const Op& op = Op::Get("nn.nll_loss"); - return Call(op, {input, target, weight}, Attrs(attrs), {}); + return Call(op, {predictions, targets, weights}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.nll_loss").set_body_typed(MakeNLLLoss); RELAY_REGISTER_OP("nn.nll_loss") .describe(R"code( -Negative log likelihood loss for given input and target. +Negative log likelihood loss for given prediction and target. )code" TVM_ADD_FILELINE) .set_attrs_type() .set_num_inputs(3) - .add_argument("input", "Tensor", "The input tensor.") - .add_argument("target", "Tensor", "The target tensor.") - .add_argument("weight", "Tensor", "The weight of each target values.") + .add_argument("predictions", "Tensor", "The prediction tensor.") + .add_argument("targets", "Tensor", "The target tensor.") + .add_argument("weights", "Tensor", "The weight of each target values.") .add_type_rel("NLLLoss", NLLLossRel); bool DepthToSpaceRel(const Array& types, int num_inputs, const Attrs& attrs,