From 331bea8346f294d03fb55c23843dfccda12bba54 Mon Sep 17 00:00:00 2001 From: handar423 Date: Fri, 29 May 2020 18:17:24 +0800 Subject: [PATCH] fix small bug about dense_grad --- python/tvm/relay/op/_tensor_grad.py | 4 ++-- tests/python/relay/test_op_grad_level2.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 8ba10207020e9..55908c76c4789 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -472,8 +472,8 @@ def bias_add_grad(orig, grad): def dense_grad(orig, grad): """Returns [grad' @ weight, data @ grad']""" data, weight = orig.args - return [collapse_sum_like(transpose(grad) * weight, data), - collapse_sum_like(data * transpose(grad), weight)] + return [collapse_sum_like(_nn.dense(grad, transpose(weight)), data), + collapse_sum_like(_nn.dense(transpose(grad), transpose(data)), weight)] @register_gradient("reshape") diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 2b5a1c29e0ded..d898451ff6ac1 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -162,6 +162,7 @@ def verify_dense_grad(d_shape, w_shape): def test_dense_grad(): verify_dense_grad((1, 8), (16, 8)) verify_dense_grad((1, 4), (3, 4)) + verify_dense_grad((5, 4), (3, 4)) def verify_batch_flatten_grad(d_shape):