Skip to content

Commit

Permalink
update based on reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzilin committed Jun 11, 2021
1 parent 4b717d6 commit c31fdaf
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 35 deletions.
2 changes: 1 addition & 1 deletion include/tvm/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
*/
inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights,
std::string reduction = "mean", int ignore_index = -100,
const std::string name = "nll_loss", const std::string tag = kBroadcast) {
const std::string name = "nll_loss", const std::string tag = "") {
auto T = tvm::te::compute(
targets->shape,
[&](const tvm::Array<tvm::tir::Var>& target_indices) {
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ def compute_nll_loss(attrs, inputs, out_dtype):


reg.register_reduce_schedule("nn.nll_loss")
reg.register_pattern("nn.nll_loss", OpPattern.OPAQUE)
reg.register_pattern("nn.nll_loss", OpPattern.OUT_ELEMWISE_FUSABLE)


# depth_to_space
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/topi/testing/nll_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,4 @@ def nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-100)
return np.sum(res) / weight_sum
if reduction == "sum":
return np.sum(res)
else:
return res
return res
61 changes: 30 additions & 31 deletions tests/python/topi/python/test_topi_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,46 +25,45 @@
import tvm.testing


def verify_nll_loss(prediction_shape, reduction="mean", ignore_index=-100, dtype="float32"):
def verify_nll_loss(
dev, target, prediction_shape, reduction="mean", ignore_index=-100, dtype="float32"
):
C = prediction_shape[1]
target_shape = prediction_shape[:1] + prediction_shape[2:]
predictions = te.placeholder(shape=prediction_shape, name="predictions", dtype=dtype)
targets = te.placeholder(shape=target_shape, name="targets", dtype="int32")
weights = te.placeholder(shape=(C,), name="weights", dtype=dtype)
nll_loss_result = topi.nn.nll_loss(
predictions, targets, weights, reduction, ignore_index
)
nll_loss_result = topi.nn.nll_loss(predictions, targets, weights, reduction, ignore_index)

with tvm.target.Target(target):
s = tvm.te.create_schedule(nll_loss_result.op)
fn = tvm.build(s, [predictions, targets, weights, nll_loss_result], target, name="nll_loss")

def check_device(target, dev):
print("Running on target: %s" % target)
with tvm.target.Target(target):
s = tvm.topi.testing.get_injective_schedule(target)(nll_loss_result)
fn = tvm.build(s, [predictions, targets, weights, nll_loss_result], target, name="nll_loss")
predictions_npy = np.random.uniform(size=prediction_shape).astype(dtype)
targets_npy = np.random.randint(0, C, target_shape).astype("int32")
weights_npy = np.random.uniform(size=(C,)).astype(dtype)
out_npy = tvm.topi.testing.nll_loss(predictions_npy, targets_npy, weights_npy, reduction, ignore_index)
predictions_nd = tvm.nd.array(predictions_npy, dev)
targets_nd = tvm.nd.array(targets_npy, dev)
weights_nd = tvm.nd.array(weights_npy, dev)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(nll_loss_result.dtype), dev)
fn(predictions_nd, targets_nd, weights_nd, out_nd)
out_topi = out_nd.asnumpy()
tvm.testing.assert_allclose(out_topi, out_npy)
predictions_npy = np.random.uniform(size=prediction_shape).astype(dtype)
targets_npy = np.random.randint(0, C, target_shape).astype("int32")
weights_npy = np.random.uniform(size=(C,)).astype(dtype)
out_npy = tvm.topi.testing.nll_loss(
predictions_npy, targets_npy, weights_npy, reduction, ignore_index
)

for target, dev in tvm.testing.enabled_targets():
check_device(target, dev)
predictions_nd = tvm.nd.array(predictions_npy, dev)
targets_nd = tvm.nd.array(targets_npy, dev)
weights_nd = tvm.nd.array(weights_npy, dev)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(nll_loss_result.dtype), dev)
fn(predictions_nd, targets_nd, weights_nd, out_nd)
out_topi = out_nd.asnumpy()
tvm.testing.assert_allclose(out_topi, out_npy)


@tvm.testing.uses_gpu
def test_nll_loss():
verify_nll_loss((10, 5,))
verify_nll_loss((10, 5, 2, 2))
verify_nll_loss((10, 5,), reduction="sum")
verify_nll_loss((10, 5,), reduction="none")
verify_nll_loss((10, 5,), ignore_index=3)
verify_nll_loss((10, 5,), dtype="float64")
@tvm.testing.parametrize_targets
def test_nll_loss(dev, target):
verify_nll_loss(dev, target, (10, 5))
verify_nll_loss(dev, target, (10, 5, 2, 2))
verify_nll_loss(dev, target, (10, 5), reduction="sum")
verify_nll_loss(dev, target, (10, 5), reduction="none")
verify_nll_loss(dev, target, (10, 5), ignore_index=3)
verify_nll_loss(dev, target, (10, 5), dtype="float64")


if __name__ == "__main__":
test_nll_loss()
test_nll_loss(tvm.device("cpu"), tvm.target.Target("llvm"))

0 comments on commit c31fdaf

Please sign in to comment.