diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 33eb67a1eb9d7..1a95d30e3ff19 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -659,7 +659,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, */ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights, std::string reduction = "mean", int ignore_index = -100, - const std::string name = "nll_loss", const std::string tag = kBroadcast) { + const std::string name = "nll_loss", const std::string tag = "") { auto T = tvm::te::compute( targets->shape, [&](const tvm::Array& target_indices) { diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index ee83f31516352..04d38ce39422a 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -894,7 +894,7 @@ def compute_nll_loss(attrs, inputs, out_dtype): reg.register_reduce_schedule("nn.nll_loss") -reg.register_pattern("nn.nll_loss", OpPattern.OPAQUE) +reg.register_pattern("nn.nll_loss", OpPattern.OUT_ELEMWISE_FUSABLE) # depth_to_space diff --git a/python/tvm/topi/testing/nll_loss.py b/python/tvm/topi/testing/nll_loss.py index b6eeb187d3b78..fd78f6f56d009 100644 --- a/python/tvm/topi/testing/nll_loss.py +++ b/python/tvm/topi/testing/nll_loss.py @@ -69,5 +69,4 @@ def nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-100) return np.sum(res) / weight_sum if reduction == "sum": return np.sum(res) - else: - return res + return res diff --git a/tests/python/topi/python/test_topi_loss.py b/tests/python/topi/python/test_topi_loss.py index 0fb3f392da352..3cb7172adae4e 100644 --- a/tests/python/topi/python/test_topi_loss.py +++ b/tests/python/topi/python/test_topi_loss.py @@ -25,46 +25,45 @@ import tvm.testing -def verify_nll_loss(prediction_shape, reduction="mean", ignore_index=-100, dtype="float32"): +def verify_nll_loss( + dev, target, prediction_shape, reduction="mean", ignore_index=-100, dtype="float32" +): C = prediction_shape[1] target_shape = prediction_shape[:1] + prediction_shape[2:] predictions = te.placeholder(shape=prediction_shape, name="predictions", dtype=dtype) targets = te.placeholder(shape=target_shape, name="targets", dtype="int32") weights = te.placeholder(shape=(C,), name="weights", dtype=dtype) - nll_loss_result = topi.nn.nll_loss( - predictions, targets, weights, reduction, ignore_index - ) + nll_loss_result = topi.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) + + with tvm.target.Target(target): + s = tvm.te.create_schedule(nll_loss_result.op) + fn = tvm.build(s, [predictions, targets, weights, nll_loss_result], target, name="nll_loss") - def check_device(target, dev): - print("Running on target: %s" % target) - with tvm.target.Target(target): - s = tvm.topi.testing.get_injective_schedule(target)(nll_loss_result) - fn = tvm.build(s, [predictions, targets, weights, nll_loss_result], target, name="nll_loss") - predictions_npy = np.random.uniform(size=prediction_shape).astype(dtype) - targets_npy = np.random.randint(0, C, target_shape).astype("int32") - weights_npy = np.random.uniform(size=(C,)).astype(dtype) - out_npy = tvm.topi.testing.nll_loss(predictions_npy, targets_npy, weights_npy, reduction, ignore_index) - predictions_nd = tvm.nd.array(predictions_npy, dev) - targets_nd = tvm.nd.array(targets_npy, dev) - weights_nd = tvm.nd.array(weights_npy, dev) - out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(nll_loss_result.dtype), dev) - fn(predictions_nd, targets_nd, weights_nd, out_nd) - out_topi = out_nd.asnumpy() - tvm.testing.assert_allclose(out_topi, out_npy) + predictions_npy = np.random.uniform(size=prediction_shape).astype(dtype) + targets_npy = np.random.randint(0, C, target_shape).astype("int32") + weights_npy = np.random.uniform(size=(C,)).astype(dtype) + out_npy = tvm.topi.testing.nll_loss( + predictions_npy, targets_npy, weights_npy, reduction, ignore_index + ) - for target, dev in tvm.testing.enabled_targets(): - check_device(target, dev) + predictions_nd = tvm.nd.array(predictions_npy, dev) + targets_nd = tvm.nd.array(targets_npy, dev) + weights_nd = tvm.nd.array(weights_npy, dev) + out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(nll_loss_result.dtype), dev) + fn(predictions_nd, targets_nd, weights_nd, out_nd) + out_topi = out_nd.asnumpy() + tvm.testing.assert_allclose(out_topi, out_npy) -@tvm.testing.uses_gpu -def test_nll_loss(): - verify_nll_loss((10, 5,)) - verify_nll_loss((10, 5, 2, 2)) - verify_nll_loss((10, 5,), reduction="sum") - verify_nll_loss((10, 5,), reduction="none") - verify_nll_loss((10, 5,), ignore_index=3) - verify_nll_loss((10, 5,), dtype="float64") +@tvm.testing.parametrize_targets +def test_nll_loss(dev, target): + verify_nll_loss(dev, target, (10, 5)) + verify_nll_loss(dev, target, (10, 5, 2, 2)) + verify_nll_loss(dev, target, (10, 5), reduction="sum") + verify_nll_loss(dev, target, (10, 5), reduction="none") + verify_nll_loss(dev, target, (10, 5), ignore_index=3) + verify_nll_loss(dev, target, (10, 5), dtype="float64") if __name__ == "__main__": - test_nll_loss() + test_nll_loss(tvm.device("cpu"), tvm.target.Target("llvm"))