From e0ed97a4d5c6bd029139799944378430beecdd92 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Fri, 25 Jun 2021 23:46:48 +0800 Subject: [PATCH] [Relay, TOPI] Add negative log likelihood loss (nll_loss) op (#8056) * add nll_loss * enrich the doc and rename parameters * update upon review * add tests * update based on reviews * update upon reviews * update upon reviews --- include/tvm/relay/attrs/nn.h | 13 +++ include/tvm/topi/nn.h | 48 +++++++++++ python/tvm/relay/frontend/pytorch.py | 16 ++++ python/tvm/relay/op/nn/_nn.py | 11 +++ python/tvm/relay/op/nn/nn.py | 36 +++++++++ python/tvm/relay/op/op_attrs.py | 5 ++ python/tvm/topi/nn/__init__.py | 1 + python/tvm/topi/nn/loss.py | 60 ++++++++++++++ python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/nll_loss.py | 72 +++++++++++++++++ src/relay/op/nn/nn.cc | 80 +++++++++++++++++++ src/topi/nn.cc | 4 + tests/python/frontend/pytorch/test_forward.py | 24 ++++++ tests/python/relay/test_op_level10.py | 41 ++++++++++ tests/python/topi/python/test_topi_loss.py | 70 ++++++++++++++++ 15 files changed, 482 insertions(+) create mode 100644 python/tvm/topi/nn/loss.py create mode 100644 python/tvm/topi/testing/nll_loss.py create mode 100644 tests/python/topi/python/test_topi_loss.py diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index a58bb8750c14..dc202674eb08 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1424,6 +1424,19 @@ struct BatchToSpaceNDAttrs : public tvm::AttrsNode { } }; // struct BatchToSpaceNDAttrs +/*! \brief Attributes used in NLLLoss operator */ +struct NLLLossAttrs : public tvm::AttrsNode { + std::string reduction; + int ignore_index; + + TVM_DECLARE_ATTRS(NLLLossAttrs, "relay.attrs.NLLLossAttrs") { + TVM_ATTR_FIELD(reduction).set_default("mean").describe( + "The reduction method to apply to the output. Can be" + "'none', 'mean' or 'sum'."); + TVM_ATTR_FIELD(ignore_index).describe("The target value to ignore."); + } +}; // struct NLLLossAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_NN_H_ diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index d3328c59afb4..90c1c09a070b 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -642,6 +643,53 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, out = strided_slice(out, begin_idx, end_idx, strides); return out; } + +/*! + * \brief Negative log likelihood loss. + * + * \param predictions The prediction tensor. + * \param targets The target tensor. + * \param weights A manual rescaling weight given to each class. + * \param reduction The reduction method to apply to the output. + * \param ignore_index The target value to ignore. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return The negative log likelihood loss of the predictions and targets. + */ +inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights, + std::string reduction = "mean", int ignore_index = -100, + const std::string name = "nll_loss", const std::string tag = kBroadcast) { + auto T = tvm::te::compute( + targets->shape, + [&](const tvm::Array& target_indices) { + auto c = targets(target_indices); + tvm::Array pred_indices; + pred_indices.push_back(target_indices[0]); // batch index + pred_indices.push_back(c); // class index + for (size_t i = 1; i < target_indices.size(); i++) { + pred_indices.push_back(target_indices[i]); // indices for multidimensional loss + } + return tvm::tir::Select(c != ignore_index, -predictions(pred_indices) * weights(c), + tvm::tir::make_const(predictions->dtype, 0)); + }, + name, tag); + if (reduction == "mean") { + auto W = tvm::te::compute( + targets->shape, + [&](const tvm::Array& target_indices) { + auto c = targets(target_indices); + return tvm::tir::Select(c != ignore_index, weights(c), + tvm::tir::make_const(predictions->dtype, 0)); + }, + name, tag); + return topi::divide(topi::sum(T, {}), topi::sum(W, {})); + } else if (reduction == "sum") { + return topi::sum(T, {}); + } else { // reduction == "none" + return T; + } +} } // namespace topi } // namespace tvm #endif // TVM_TOPI_NN_H_ diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 44d4b0c66216..b95913a7add2 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2320,6 +2320,20 @@ def unique(self, inputs, input_types): unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") return (unique_sliced, inverse_indices) + def nll_loss(self, inputs, input_types): + assert len(inputs) == 5 + [predictions, targets, weights, reduction, ignore_index] = inputs + num_class = self.infer_shape(predictions)[1] + if reduction == 0: + reduction = "none" + elif reduction == 1: + reduction = "mean" + else: + reduction = "sum" + if weights is None: + weights = _op.full(_expr.const(1), (num_class,), dtype=input_types[0]) + return _op.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2532,6 +2546,8 @@ def create_convert_map(self): "aten::argsort": self.argsort, "aten::sort": self.sort, "aten::_unique2": self.unique, + "aten::nll_loss": self.nll_loss, + "aten::nll_loss2d": self.nll_loss, } def update_convert_map(self, custom_map): diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index c6c4f4bfb959..04d38ce39422 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -886,6 +886,17 @@ def compute_cross_entropy_with_logits(attrs, inputs, out_dtype): reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE) +# nll_loss +@reg.register_compute("nn.nll_loss") +def compute_nll_loss(attrs, inputs, out_dtype): + predictions, targets, weights = inputs + return [topi.nn.nll_loss(predictions, targets, weights, attrs.reduction, attrs.ignore_index)] + + +reg.register_reduce_schedule("nn.nll_loss") +reg.register_pattern("nn.nll_loss", OpPattern.OUT_ELEMWISE_FUSABLE) + + # depth_to_space @reg.register_compute("nn.depth_to_space") def compute_depth_to_space(attrs, inputs, out_dtype): diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index caf1f187fad3..bef899eeaaab 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2974,6 +2974,42 @@ def cross_entropy_with_logits(predictions, targets): return _make.cross_entropy_with_logits(predictions, targets) +def nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-100): + """Negative log likelihood loss. + + output{n, i_1, i_2, ..., i_k} = -p * w + where t = target{n, i_1, i_2, ..., i_k} + p = predictions{n, t, i_1, i_2, i_k} + w = weights{n, i_1, i_2, ..., i_k} if t != ignore_index else 0 + + result = reduction(output) + + Parameters + ---------- + predictions : tvm.relay.Expr + The predictions. + + targets : tvm.relay.Expr + The target value of each prediction. + + weights : tvm.relay.Expr + The weight of each target value. + + reduction : string + The reduction method to apply to the output. + Possible values are "mean", "sum" and "none". + + ignore_index : int + The target value to ignore. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.nll_loss(predictions, targets, weights, reduction, ignore_index) + + def depth_to_space(data, block_size, layout="NCHW", mode="DCR"): """Convert channels into spatial blocks. diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 6844d133a77e..2e13d1f042a2 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -572,3 +572,8 @@ class ThreefryGenerateAttrs(Attrs): @tvm._ffi.register_object("relay.attrs.UniformAttrs") class UniformAttrs(Attrs): """Attributes used in UniformAttrs operators""" + + +@tvm._ffi.register_object("relay.attrs.NLLLossAttrs") +class NLLLossAttrs(Attrs): + """Attributes for nn.nll_loss""" diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py index 94a5b30c9b76..b5e766adbc12 100644 --- a/python/tvm/topi/nn/__init__.py +++ b/python/tvm/topi/nn/__init__.py @@ -49,3 +49,4 @@ from .space_to_depth import * from .space_to_batch_nd import * from .batch_to_space_nd import * +from .loss import * diff --git a/python/tvm/topi/nn/loss.py b/python/tvm/topi/nn/loss.py new file mode 100644 index 000000000000..1d6f588c7d53 --- /dev/null +++ b/python/tvm/topi/nn/loss.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Loss functions definitions.""" +from __future__ import absolute_import +from . import cpp + + +def nll_loss(predictions, targets, weights, reduction, ignore_index): + """Negative log likelihood loss on the input data. + + output{n, i_1, i_2, ..., i_k} = -p * w + where t = target{n, i_1, i_2, ..., i_k} + p = predictions{n, t, i_1, i_2, i_k} + w = weights{n, i_1, i_2, ..., i_k} if t != ignore_index else 0 + + result = reduction(output) + + Parameters + ---------- + predictions : tvm.te.Tensor + (k+2)-D with shape (N, C, d_1, d_2, ..., d_k), + where C is the number of target classes + + targets : tvm.te.Tensor + (k+1)-D with shape (N, d_1, d_2, ..., d_k) + The target value of the input. + + weights : tvm.te.Tensor + 1-D with shape (C,) + The weight of each target value. + + reduction : string + The reduction method to apply to output. + Can be "mean", "sum" or "none". + + ignore_index : int + The target value to ignore. + + Returns + ------- + output : tvm.te.Tensor + a scalar if the reduction type is "mean" or "sum", + otherwise the same shape as `target`. + """ + return cpp.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index ef7d86322be7..afb251417315 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -69,3 +69,4 @@ from .matrix_set_diag import matrix_set_diag from .space_to_batch_nd import space_to_batch_nd_python from .batch_to_space_nd import batch_to_space_nd_python +from .nll_loss import nll_loss diff --git a/python/tvm/topi/testing/nll_loss.py b/python/tvm/topi/testing/nll_loss.py new file mode 100644 index 000000000000..fd78f6f56d00 --- /dev/null +++ b/python/tvm/topi/testing/nll_loss.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""NLLLoss in python""" +import numpy as np + + +def nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-100): + """nll_loss operator implemented in numpy. + + output{n, i_1, i_2, ..., i_k} = -p * w + where t = target{n, i_1, i_2, ..., i_k} + p = predictions{n, t, i_1, i_2, i_k} + w = weights{n, i_1, i_2, ..., i_k} if t != ignore_index else 0 + + result = reduction(output) + + Parameters + ---------- + predictions : numpy.ndarray + (k+2)-D with shape (N, C, d_1, d_2, ..., d_k), + where C is the number of target classes + + targets : numpy.ndarray + (k+1)-D with shape (N, d_1, d_2, ..., d_k) + The target value of the input. + + weights : numpy.ndarray + 1-D with shape (C,) + The weight of each target value. + + reduction : string + The reduction method to apply to output. + Can be "mean", "sum" or "none". + + ignore_index : int + The target value to ignore. + + Returns + ------- + output : numpy.ndarray + a scalar if the reduction type is "mean" or "sum", + otherwise the same shape as `target`. + """ + res = np.zeros(targets.shape) + weight_sum = 0.0 + for index in np.ndindex(targets.shape): + class_id = targets[index] + if class_id != ignore_index: + index_list = list(index) + pred_index = tuple(index_list[:1] + [class_id] + index_list[1:]) + res[index] = -predictions[pred_index] * weights[class_id] + weight_sum += weights[class_id] + if reduction == "mean": + return np.sum(res) / weight_sum + if reduction == "sum": + return np.sum(res) + return res diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 32c0a21d46c7..281fc5093325 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1068,6 +1068,7 @@ Dilate data with given dilation value (0 by default). .set_support_level(10) .add_type_rel("Dilate", DilateRel); +// relay.nn.cross_entropy_with_logits // Positional relay function to create cross_entropy_with_logits operator used by frontend FFI. Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { static const Op& op = Op::Get("nn.cross_entropy_with_logits"); @@ -1091,6 +1092,85 @@ Accept logits. // Depth to space and space to depth TVM_REGISTER_NODE_TYPE(SubPixelAttrs); +// relay.nn.nll_loss +TVM_REGISTER_NODE_TYPE(NLLLossAttrs); + +bool NLLLossRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 4) << "NLLLossRel expects 4 types, but " << types.size() + << " were provided."; + const auto* predictions = types[0].as(); + const auto* targets = types[1].as(); + const auto* weights = types[2].as(); + const NLLLossAttrs* param = attrs.as(); + if (predictions == nullptr || targets == nullptr || weights == nullptr) return false; + if (!(predictions->shape.size() - targets->shape.size() == 1)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "NLLLossRel: predictions should be one" + << " dimension larger than targets," + << "predictions shape = " << predictions->shape + << ", targets shape = " << targets->shape); + return false; + } + if (!(weights->shape.size() == 1)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "NLLLossRel: weights should be a one dimension" + << " Tensor with its length the number of classes," + << " but Tensor of dimension " << weights->shape.size() + << " were provided."); + return false; + } + if (!reporter->AssertEQ(predictions->shape[1], weights->shape[0])) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "NLLLossRel: the second dimension of predictions" + << " should be the number of classes, " + << "which is the length of weights, " + << "predictions shape = " << predictions->shape + << ", weights shape = " << weights->shape); + return false; + } + if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "NLLLossRel: predictions and weights should" + << " be of the same floating type."); + return false; + } + if (!targets->dtype.is_int()) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "NLLLossRel: targets should be of int type."); + return false; + } + // assign output type + if (param->reduction == "none") { + reporter->Assign(types[3], TensorType(targets->shape, predictions->dtype)); + } else { + reporter->Assign(types[3], TensorType({}, predictions->dtype)); + } + return true; +} + +// Handler to create a call to the padding op used by front-end FFI +Expr MakeNLLLoss(Expr predictions, Expr targets, Expr weights, String reduction, int ignore_index) { + auto attrs = make_object(); + attrs->reduction = reduction; + attrs->ignore_index = ignore_index; + static const Op& op = Op::Get("nn.nll_loss"); + return Call(op, {predictions, targets, weights}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.nll_loss").set_body_typed(MakeNLLLoss); + +RELAY_REGISTER_OP("nn.nll_loss") + .describe(R"code( +Negative log likelihood loss for given prediction and target. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(3) + .add_argument("predictions", "Tensor", "The prediction tensor.") + .add_argument("targets", "Tensor", "The target tensor.") + .add_argument("weights", "Tensor", "The weight of each target values.") + .add_type_rel("NLLLoss", NLLLossRel); + bool DepthToSpaceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 356f3d2ea18f..2950aee4e90d 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -65,6 +65,10 @@ TVM_REGISTER_GLOBAL("topi.nn.batch_to_space_nd").set_body([](TVMArgs args, TVMRe *rv = batch_to_space_nd(args[0], args[1], args[2], args[3]); }); +TVM_REGISTER_GLOBAL("topi.nn.nll_loss").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nll_loss(args[0], args[1], args[2], args[3], args[4]); +}); + /* Ops from nn/dense.h */ TVM_REGISTER_GLOBAL("topi.nn.dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::dense(args[0], args[1], args[2], args[3]); diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index ae2bedac0b29..56c0cf71761e 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3875,6 +3875,29 @@ def test_fn(is_sorted, return_inverse, return_counts): verify_trace_model(test_fn(True, False, True), [in_data], targets) +def test_forward_nll_loss(): + torch.set_grad_enabled(False) + N, C = 10, 3 + predictions = torch.rand((N, C)).float() + targets = torch.randint(0, 3, (N,)) + weights = torch.tensor([1, 2, 3]).float() + verify_model(torch.nn.NLLLoss().eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(weight=weights).eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(ignore_index=1).eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(reduction="sum").eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(reduction="none").eval(), input_data=[predictions, targets]) + + # multidimension nll loss (aten::nll_loss2d) + d1, d2 = 2, 3 + predictions = torch.rand((N, C, d1, d2)).float() + targets = torch.randint(0, 3, (N, d1, d2)) + verify_model(torch.nn.NLLLoss().eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(weight=weights).eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(ignore_index=1).eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(reduction="sum").eval(), input_data=[predictions, targets]) + verify_model(torch.nn.NLLLoss(reduction="none").eval(), input_data=[predictions, targets]) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -4017,6 +4040,7 @@ def test_fn(is_sorted, return_inverse, return_counts): test_unique() test_hard_swish() test_hard_sigmoid() + test_forward_nll_loss() # Model tests test_resnet18() diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 96d90b2a4f76..040fa3fb4315 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -577,6 +577,47 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): _verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "RIGHT_RIGHT") +@tvm.testing.parametrize_targets +def test_nll_loss(dev, target): + def _get_oshape(target_shape, reduction): + if reduction == "none": + return target_shape + else: + return [] + + def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float32"): + C = prediction_shape[1] + target_shape = prediction_shape[:1] + prediction_shape[2:] + + predictions = relay.var("predictions", relay.TensorType(prediction_shape, dtype)) + targets = relay.var("targets", relay.TensorType(target_shape, "int32")) + weights = relay.var("weights", relay.TensorType((C,), dtype)) + out = relay.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) + checked = run_infer_type(out) + assert checked.checked_type == relay.ty.TensorType( + _get_oshape(target_shape, reduction), dtype + ) + func = relay.Function([predictions, targets, weights], out) + predictions_np = np.random.uniform(size=prediction_shape).astype(dtype) + targets_np = np.random.randint(0, C, target_shape).astype("int32") + weights_np = np.random.uniform(size=(C,)).astype(dtype) + out_np = tvm.topi.testing.nll_loss( + predictions_np, targets_np, weights_np, reduction, ignore_index + ) + + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, device=dev, target=target) + out_relay = intrp.evaluate(func)(predictions_np, targets_np, weights_np) + tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) + + _verify((10, 5)) + _verify((10, 5, 2, 2)) + _verify((10, 5), reduction="sum") + _verify((10, 5), reduction="none") + _verify((10, 5), ignore_index=3) + _verify((10, 5), dtype="float64") + + if __name__ == "__main__": test_adaptive_pool() test_collapse_sum_like() diff --git a/tests/python/topi/python/test_topi_loss.py b/tests/python/topi/python/test_topi_loss.py new file mode 100644 index 000000000000..7fd8238bf0ae --- /dev/null +++ b/tests/python/topi/python/test_topi_loss.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for loss operators.""" +import numpy as np +import pytest +import tvm +from tvm import te +from tvm import topi +import tvm.topi.testing + +import tvm.testing + + +def verify_nll_loss( + dev, target, prediction_shape, reduction="mean", ignore_index=-100, dtype="float32" +): + C = prediction_shape[1] + target_shape = prediction_shape[:1] + prediction_shape[2:] + predictions = te.placeholder(shape=prediction_shape, name="predictions", dtype=dtype) + targets = te.placeholder(shape=target_shape, name="targets", dtype="int32") + weights = te.placeholder(shape=(C,), name="weights", dtype=dtype) + nll_loss_result = topi.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) + + with tvm.target.Target(target): + fschedule = tvm.topi.testing.get_reduce_schedule(target) + s = fschedule([nll_loss_result]) + fn = tvm.build(s, [predictions, targets, weights, nll_loss_result], target, name="nll_loss") + + predictions_npy = np.random.uniform(size=prediction_shape).astype(dtype) + targets_npy = np.random.randint(0, C, target_shape).astype("int32") + weights_npy = np.random.uniform(size=(C,)).astype(dtype) + out_npy = tvm.topi.testing.nll_loss( + predictions_npy, targets_npy, weights_npy, reduction, ignore_index + ) + + predictions_nd = tvm.nd.array(predictions_npy, dev) + targets_nd = tvm.nd.array(targets_npy, dev) + weights_nd = tvm.nd.array(weights_npy, dev) + out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(nll_loss_result.dtype), dev) + fn(predictions_nd, targets_nd, weights_nd, out_nd) + out_topi = out_nd.asnumpy() + tvm.testing.assert_allclose(out_topi, out_npy) + + +@tvm.testing.parametrize_targets +def test_nll_loss(dev, target): + verify_nll_loss(dev, target, (10, 5)) + verify_nll_loss(dev, target, (10, 5, 2, 2)) + verify_nll_loss(dev, target, (10, 5), reduction="sum") + verify_nll_loss(dev, target, (10, 5), reduction="none") + verify_nll_loss(dev, target, (10, 5), ignore_index=3) + verify_nll_loss(dev, target, (10, 5), dtype="float64") + + +if __name__ == "__main__": + test_nll_loss(tvm.device("cpu"), tvm.target.Target("llvm"))