Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzilin committed May 30, 2021
1 parent d8f111b commit 4b717d6
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 17 deletions.
2 changes: 1 addition & 1 deletion include/tvm/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the batch_to_space_nd operation
* \return The negative log likelihood loss of the predictions and targets.
*/
inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights,
std::string reduction = "mean", int ignore_index = -100,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2532,6 +2532,7 @@ def create_convert_map(self):
"aten::sort": self.sort,
"aten::_unique2": self.unique,
"aten::nll_loss": self.nll_loss,
"aten::nll_loss2d": self.nll_loss,
}

def update_convert_map(self, custom_map):
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@
from .matrix_set_diag import matrix_set_diag
from .space_to_batch_nd import space_to_batch_nd_python
from .batch_to_space_nd import batch_to_space_nd_python
from .nll_loss import nll_loss
73 changes: 73 additions & 0 deletions python/tvm/topi/testing/nll_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""NLLLoss in python"""
import numpy as np


def nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-100):
"""nll_loss operator implemented in numpy.
output{n, i_1, i_2, ..., i_k} = -p * w
where t = target{n, i_1, i_2, ..., i_k}
p = predictions{n, t, i_1, i_2, i_k}
w = weights{n, i_1, i_2, ..., i_k} if t != ignore_index else 0
result = reduction(output)
Parameters
----------
predictions : numpy.ndarray
(k+2)-D with shape (N, C, d_1, d_2, ..., d_k),
where C is the number of target classes
targets : numpy.ndarray
(k+1)-D with shape (N, d_1, d_2, ..., d_k)
The target value of the input.
weights : numpy.ndarray
1-D with shape (C,)
The weight of each target value.
reduction : string
The reduction method to apply to output.
Can be "mean", "sum" or "none".
ignore_index : int
The target value to ignore.
Returns
-------
output : numpy.ndarray
a scalar if the reduction type is "mean" or "sum",
otherwise the same shape as `target`.
"""
res = np.zeros(targets.shape)
weight_sum = 0.0
for index in np.ndindex(targets.shape):
class_id = targets[index]
if class_id != ignore_index:
index_list = list(index)
pred_index = tuple(index_list[:1] + [class_id] + index_list[1:])
res[index] = -predictions[pred_index] * weights[class_id]
weight_sum += weights[class_id]
if reduction == "mean":
return np.sum(res) / weight_sum
if reduction == "sum":
return np.sum(res)
else:
return res
52 changes: 36 additions & 16 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1104,22 +1104,42 @@ bool NLLLossRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const auto* weights = types[2].as<TensorTypeNode>();
const NLLLossAttrs* param = attrs.as<NLLLossAttrs>();
if (predictions == nullptr || targets == nullptr || weights == nullptr) return false;
ICHECK(predictions->shape.size() - targets->shape.size() == 1)
<< "NLLLossRel: predictions should be one dimension larger than targets, "
<< "predictions shape = " << predictions->shape << ", "
<< "targets shape = " << targets->shape;
ICHECK(weights->shape.size() == 1)
<< "NLLLossRel: weights should be a one dimension Tensor with its length "
<< "the number of classes, but Tensor of dimension " << weights->shape.size()
<< " were provided.";
ICHECK(reporter->AssertEQ(predictions->shape[1], weights->shape[0]))
<< "NLLLossRel: the second dimension of predictions should be the number of classes, "
<< "which is the length of weights, "
<< "predictions shape = " << predictions->shape << ", "
<< "weights shape = " << weights->shape;
ICHECK(predictions->dtype == weights->dtype && predictions->dtype.is_float())
<< "NLLLossRel: predictions and weights should be of the same floating type.";
ICHECK(targets->dtype.is_int()) << "NLLLossRel: targets should be of int type.";
if (!(predictions->shape.size() - targets->shape.size() == 1)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "NLLLossRel: predictions should be one"
<< " dimension larger than targets,"
<< "predictions shape = " << predictions->shape
<< ", targets shape = " << targets->shape);
return false;
}
if (!(weights->shape.size() == 1)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "NLLLossRel: weights should be a one dimension"
<< " Tensor with its length the number of classes,"
<< " but Tensor of dimension " << weights->shape.size()
<< " were provided.");
return false;
}
if (!reporter->AssertEQ(predictions->shape[1], weights->shape[0])) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "NLLLossRel: the second dimension of predictions"
<< " should be the number of classes, "
<< "which is the length of weights, "
<< "predictions shape = " << predictions->shape
<< ", weights shape = " << weights->shape);
return false;
}
if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "NLLLossRel: predictions and weights should"
<< " be of the same floating type.");
return false;
}
if (!targets->dtype.is_int()) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "NLLLossRel: targets should be of int type.");
return false;
}
// assign output type
if (param->reduction == "none") {
reporter->Assign(types[3], TensorType(targets->shape, predictions->dtype));
Expand Down
24 changes: 24 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3866,6 +3866,29 @@ def test_fn(is_sorted, return_inverse, return_counts):
verify_trace_model(test_fn(True, False, True), [in_data], targets)


def test_forward_nll_loss():
torch.set_grad_enabled(False)
N, C = 10, 3
predictions = torch.rand((N, C)).float()
targets = torch.randint(0, 3, (N,))
weights = torch.tensor([1, 2, 3]).float()
verify_model(torch.nn.NLLLoss().eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(weight=weights).eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(ignore_index=1).eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(reduction="sum").eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(reduction="none").eval(), input_data=[predictions, targets])

# multidimension nll loss (aten::nll_loss2d)
d1, d2 = 2, 3
predictions = torch.rand((N, C, d1, d2)).float()
targets = torch.randint(0, 3, (N, d1, d2))
verify_model(torch.nn.NLLLoss().eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(weight=weights).eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(ignore_index=1).eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(reduction="sum").eval(), input_data=[predictions, targets])
verify_model(torch.nn.NLLLoss(reduction="none").eval(), input_data=[predictions, targets])


if __name__ == "__main__":
# some structural tests
test_forward_traced_function()
Expand Down Expand Up @@ -4007,6 +4030,7 @@ def test_fn(is_sorted, return_inverse, return_counts):
test_unique()
test_hard_swish()
test_hard_sigmoid()
test_forward_nll_loss()

# Model tests
test_resnet18()
Expand Down
44 changes: 44 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,49 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
_verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "RIGHT_RIGHT")


@tvm.testing.uses_gpu
def test_nll_loss():
def _get_oshape(target_shape, reduction):
if reduction == "none":
return target_shape
else:
return []

def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float32"):
C = prediction_shape[1]
target_shape = prediction_shape[:1] + prediction_shape[2:]

predictions = relay.var("predictions", relay.TensorType(prediction_shape, dtype))
targets = relay.var("targets", relay.TensorType(target_shape, "int32"))
weights = relay.var("weights", relay.TensorType((C,), dtype))
ignore_index_const = relay.const(ignore_index)
out = relay.nn.nll_loss(predictions, targets, weights, reduction, ignore_index)
checked = run_infer_type(out)
assert checked.checked_type == relay.ty.TensorType(
_get_oshape(target_shape, reduction), dtype
)
func = relay.Function([predictions, targets, weights], out)
predictions_np = np.random.uniform(size=prediction_shape).astype(dtype)
targets_np = np.random.randint(0, C, target_shape).astype("int32")
weights_np = np.random.uniform(size=(C,)).astype(dtype)
out_np = tvm.topi.testing.nll_loss(
predictions_np, targets_np, weights_np, reduction, ignore_index
)

for target, dev in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, device=dev, target=target)
out_relay = intrp.evaluate(func)(predictions_np, targets_np, weights_np)
tvm.testing.assert_allclose(out_relay.asnumpy(), out_np)

_verify((10, 5))
_verify((10, 5, 2, 2))
_verify((10, 5), reduction="sum")
_verify((10, 5), reduction="none")
_verify((10, 5), ignore_index=3)
_verify((10, 5), dtype="float64")


if __name__ == "__main__":
test_adaptive_pool()
test_collapse_sum_like()
Expand All @@ -590,3 +633,4 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
test_one_hot()
test_ndarray_size()
test_matrix_set_diag()
test_nll_loss()
70 changes: 70 additions & 0 deletions tests/python/topi/python/test_topi_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for loss operators."""
import numpy as np
import pytest
import tvm
from tvm import te
from tvm import topi
import tvm.topi.testing

import tvm.testing


def verify_nll_loss(prediction_shape, reduction="mean", ignore_index=-100, dtype="float32"):
C = prediction_shape[1]
target_shape = prediction_shape[:1] + prediction_shape[2:]
predictions = te.placeholder(shape=prediction_shape, name="predictions", dtype=dtype)
targets = te.placeholder(shape=target_shape, name="targets", dtype="int32")
weights = te.placeholder(shape=(C,), name="weights", dtype=dtype)
nll_loss_result = topi.nn.nll_loss(
predictions, targets, weights, reduction, ignore_index
)

def check_device(target, dev):
print("Running on target: %s" % target)
with tvm.target.Target(target):
s = tvm.topi.testing.get_injective_schedule(target)(nll_loss_result)
fn = tvm.build(s, [predictions, targets, weights, nll_loss_result], target, name="nll_loss")
predictions_npy = np.random.uniform(size=prediction_shape).astype(dtype)
targets_npy = np.random.randint(0, C, target_shape).astype("int32")
weights_npy = np.random.uniform(size=(C,)).astype(dtype)
out_npy = tvm.topi.testing.nll_loss(predictions_npy, targets_npy, weights_npy, reduction, ignore_index)
predictions_nd = tvm.nd.array(predictions_npy, dev)
targets_nd = tvm.nd.array(targets_npy, dev)
weights_nd = tvm.nd.array(weights_npy, dev)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(nll_loss_result.dtype), dev)
fn(predictions_nd, targets_nd, weights_nd, out_nd)
out_topi = out_nd.asnumpy()
tvm.testing.assert_allclose(out_topi, out_npy)

for target, dev in tvm.testing.enabled_targets():
check_device(target, dev)


@tvm.testing.uses_gpu
def test_nll_loss():
verify_nll_loss((10, 5,))
verify_nll_loss((10, 5, 2, 2))
verify_nll_loss((10, 5,), reduction="sum")
verify_nll_loss((10, 5,), reduction="none")
verify_nll_loss((10, 5,), ignore_index=3)
verify_nll_loss((10, 5,), dtype="float64")


if __name__ == "__main__":
test_nll_loss()

0 comments on commit 4b717d6

Please sign in to comment.