Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, TOPI] Add negative log likelihood loss (nll_loss) op #8056

Merged
merged 7 commits into from
Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the batch_to_space_nd operation
* \return The negative log likelihood loss of the predictions and targets.
*/
inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights,
std::string reduction = "mean", int ignore_index = -100,
zhuzilin marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2532,6 +2532,7 @@ def create_convert_map(self):
"aten::sort": self.sort,
"aten::_unique2": self.unique,
"aten::nll_loss": self.nll_loss,
"aten::nll_loss2d": self.nll_loss,
}

def update_convert_map(self, custom_map):
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@
from .matrix_set_diag import matrix_set_diag
from .space_to_batch_nd import space_to_batch_nd_python
from .batch_to_space_nd import batch_to_space_nd_python
from .nll_loss import nll_loss
73 changes: 73 additions & 0 deletions python/tvm/topi/testing/nll_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""NLLLoss in python"""
import numpy as np


def nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-100):
"""nll_loss operator implemented in numpy.

output{n, i_1, i_2, ..., i_k} = -p * w
where t = target{n, i_1, i_2, ..., i_k}
p = predictions{n, t, i_1, i_2, i_k}
w = weights{n, i_1, i_2, ..., i_k} if t != ignore_index else 0

result = reduction(output)

Parameters
----------
predictions : numpy.ndarray
(k+2)-D with shape (N, C, d_1, d_2, ..., d_k),
where C is the number of target classes

targets : numpy.ndarray
(k+1)-D with shape (N, d_1, d_2, ..., d_k)
The target value of the input.

weights : numpy.ndarray
1-D with shape (C,)
The weight of each target value.

reduction : string
The reduction method to apply to output.
Can be "mean", "sum" or "none".

ignore_index : int
The target value to ignore.

Returns
-------
output : numpy.ndarray
a scalar if the reduction type is "mean" or "sum",
otherwise the same shape as `target`.
"""
res = np.zeros(targets.shape)
weight_sum = 0.0
for index in np.ndindex(targets.shape):
class_id = targets[index]
if class_id != ignore_index:
index_list = list(index)
pred_index = tuple(index_list[:1] + [class_id] + index_list[1:])
res[index] = -predictions[pred_index] * weights[class_id]
weight_sum += weights[class_id]
if reduction == "mean":
return np.sum(res) / weight_sum
if reduction == "sum":
return np.sum(res)
else:
return res
52 changes: 36 additions & 16 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1104,22 +1104,42 @@ bool NLLLossRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const auto* weights = types[2].as<TensorTypeNode>();
const NLLLossAttrs* param = attrs.as<NLLLossAttrs>();
if (predictions == nullptr || targets == nullptr || weights == nullptr) return false;
ICHECK(predictions->shape.size() - targets->shape.size() == 1)
<< "NLLLossRel: predictions should be one dimension larger than targets, "
<< "predictions shape = " << predictions->shape << ", "
<< "targets shape = " << targets->shape;
ICHECK(weights->shape.size() == 1)
<< "NLLLossRel: weights should be a one dimension Tensor with its length "
<< "the number of classes, but Tensor of dimension " << weights->shape.size()
<< " were provided.";
ICHECK(reporter->AssertEQ(predictions->shape[1], weights->shape[0]))
<< "NLLLossRel: the second dimension of predictions should be the number of classes, "
<< "which is the length of weights, "
<< "predictions shape = " << predictions->shape << ", "
<< "weights shape = " << weights->shape;
ICHECK(predictions->dtype == weights->dtype && predictions->dtype.is_float())
<< "NLLLossRel: predictions and weights should be of the same floating type.";
ICHECK(targets->dtype.is_int()) << "NLLLossRel: targets should be of int type.";
if (!(predictions->shape.size() - targets->shape.size() == 1)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "NLLLossRel: predictions should be one"
<< " dimension larger than targets,"
<< "predictions shape = " << predictions->shape
<< ", targets shape = " << targets->shape);
return false;
}
if (!(weights->shape.size() == 1)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "NLLLossRel: weights should be a one dimension"
<< " Tensor with its length the number of classes,"
<< " but Tensor of dimension " << weights->shape.size()
<< " were provided.");
return false;
}
if (!reporter->AssertEQ(predictions->shape[1], weights->shape[0])) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "NLLLossRel: the second dimension of predictions"
<< " should be the number of classes, "
<< "which is the length of weights, "
<< "predictions shape = " << predictions->shape
<< ", weights shape = " << weights->shape);
return false;
}
if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "NLLLossRel: predictions and weights should"
<< " be of the same floating type.");
return false;
}
if (!targets->dtype.is_int()) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "NLLLossRel: targets should be of int type.");
return false;
}
// assign output type
if (param->reduction == "none") {
reporter->Assign(types[3], TensorType(targets->shape, predictions->dtype));
Expand Down
24 changes: 24 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3866,6 +3866,29 @@ def test_fn(is_sorted, return_inverse, return_counts):
verify_trace_model(test_fn(True, False, True), [in_data], targets)


def test_forward_nll_loss():
torch.set_grad_enabled(False)
N, C = 10, 3
predictions = torch.rand((N, C)).float()
targets = torch.randint(0, 3, (N,))
weights = torch.tensor([1, 2, 3]).float()
verify_model(torch.nn.NLLLoss().eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(weight=weights).eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(ignore_index=1).eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(reduction="sum").eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(reduction="none").eval(), input_data=[predictions, targets])

# multidimension nll loss (aten::nll_loss2d)
d1, d2 = 2, 3
predictions = torch.rand((N, C, d1, d2)).float()
targets = torch.randint(0, 3, (N, d1, d2))
verify_model(torch.nn.NLLLoss().eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(weight=weights).eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(ignore_index=1).eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(reduction="sum").eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(reduction="none").eval(), input_data=[predictions, targets])


if __name__ == "__main__":
# some structural tests
test_forward_traced_function()
Expand Down Expand Up @@ -4007,6 +4030,7 @@ def test_fn(is_sorted, return_inverse, return_counts):
test_unique()
test_hard_swish()
test_hard_sigmoid()
test_forward_nll_loss()

# Model tests
test_resnet18()
Expand Down
44 changes: 44 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,49 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
_verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "RIGHT_RIGHT")


@tvm.testing.uses_gpu
def test_nll_loss():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch this to use parameterize_targets.

def _get_oshape(target_shape, reduction):
if reduction == "none":
return target_shape
else:
return []

def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float32"):
C = prediction_shape[1]
target_shape = prediction_shape[:1] + prediction_shape[2:]

predictions = relay.var("predictions", relay.TensorType(prediction_shape, dtype))
targets = relay.var("targets", relay.TensorType(target_shape, "int32"))
weights = relay.var("weights", relay.TensorType((C,), dtype))
ignore_index_const = relay.const(ignore_index)
out = relay.nn.nll_loss(predictions, targets, weights, reduction, ignore_index)
checked = run_infer_type(out)
assert checked.checked_type == relay.ty.TensorType(
_get_oshape(target_shape, reduction), dtype
)
func = relay.Function([predictions, targets, weights], out)
predictions_np = np.random.uniform(size=prediction_shape).astype(dtype)
targets_np = np.random.randint(0, C, target_shape).astype("int32")
weights_np = np.random.uniform(size=(C,)).astype(dtype)
out_np = tvm.topi.testing.nll_loss(
predictions_np, targets_np, weights_np, reduction, ignore_index
)

for target, dev in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, device=dev, target=target)
out_relay = intrp.evaluate(func)(predictions_np, targets_np, weights_np)
tvm.testing.assert_allclose(out_relay.asnumpy(), out_np)

_verify((10, 5))
_verify((10, 5, 2, 2))
_verify((10, 5), reduction="sum")
_verify((10, 5), reduction="none")
_verify((10, 5), ignore_index=3)
_verify((10, 5), dtype="float64")


if __name__ == "__main__":
test_adaptive_pool()
test_collapse_sum_like()
Expand All @@ -590,3 +633,4 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
test_one_hot()
test_ndarray_size()
test_matrix_set_diag()
test_nll_loss()
70 changes: 70 additions & 0 deletions tests/python/topi/python/test_topi_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for loss operators."""
import numpy as np
import pytest
import tvm
from tvm import te
from tvm import topi
import tvm.topi.testing

import tvm.testing


def verify_nll_loss(prediction_shape, reduction="mean", ignore_index=-100, dtype="float32"):
C = prediction_shape[1]
target_shape = prediction_shape[:1] + prediction_shape[2:]
predictions = te.placeholder(shape=prediction_shape, name="predictions", dtype=dtype)
targets = te.placeholder(shape=target_shape, name="targets", dtype="int32")
weights = te.placeholder(shape=(C,), name="weights", dtype=dtype)
nll_loss_result = topi.nn.nll_loss(
predictions, targets, weights, reduction, ignore_index
)

def check_device(target, dev):
print("Running on target: %s" % target)
with tvm.target.Target(target):
s = tvm.topi.testing.get_injective_schedule(target)(nll_loss_result)
zhuzilin marked this conversation as resolved.
Show resolved Hide resolved
fn = tvm.build(s, [predictions, targets, weights, nll_loss_result], target, name="nll_loss")
predictions_npy = np.random.uniform(size=prediction_shape).astype(dtype)
targets_npy = np.random.randint(0, C, target_shape).astype("int32")
weights_npy = np.random.uniform(size=(C,)).astype(dtype)
out_npy = tvm.topi.testing.nll_loss(predictions_npy, targets_npy, weights_npy, reduction, ignore_index)
predictions_nd = tvm.nd.array(predictions_npy, dev)
targets_nd = tvm.nd.array(targets_npy, dev)
weights_nd = tvm.nd.array(weights_npy, dev)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(nll_loss_result.dtype), dev)
fn(predictions_nd, targets_nd, weights_nd, out_nd)
out_topi = out_nd.asnumpy()
tvm.testing.assert_allclose(out_topi, out_npy)

for target, dev in tvm.testing.enabled_targets():
check_device(target, dev)


@tvm.testing.uses_gpu
def test_nll_loss():
zhuzilin marked this conversation as resolved.
Show resolved Hide resolved
verify_nll_loss((10, 5,))
verify_nll_loss((10, 5, 2, 2))
verify_nll_loss((10, 5,), reduction="sum")
verify_nll_loss((10, 5,), reduction="none")
verify_nll_loss((10, 5,), ignore_index=3)
verify_nll_loss((10, 5,), dtype="float64")


if __name__ == "__main__":
test_nll_loss()