diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index d5b891088933..09b1435aac0f 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Backend compiler related feature registration""" +"""Gradient definitions for Relay operators""" from tvm.topi.nn.utils import get_pad_tuple from tvm.topi.utils import get_const_tuple from tvm.error import OpError @@ -527,10 +527,7 @@ def softmax_grad(orig, grad): @register_gradient("nn.log_softmax") def log_softmax_grad(orig, grad): """Gradient of log_softmax""" - x = orig.args[0] - sm = _nn.softmax(x, axis=orig.attrs.axis) - grad = grad / sm - return softmax_grad(sm, grad) + return [grad - _sum(grad, axis=orig.attrs.axis, keepdims=True) * exp(orig)] @register_gradient("nn.bias_add") @@ -596,6 +593,12 @@ def cast_grad(orig, grad): return [cast_like(grad, x)] +@register_gradient("cast_like") +def cast_like_grad(orig, grad): + x, like = orig.args + return [cast_like(grad, x), zeros_like(like)] + + @register_gradient("nn.batch_flatten") def batch_flatten_grad(orig, grad): """Returns grad reshaped to data dims""" @@ -873,3 +876,52 @@ def less_equal_grad(orig, grad): Returns the gradient of less_equal. """ return [zeros_like(orig.args[0]), zeros_like(orig.args[1])] + + +@register_gradient("not_equal") +def not_equal_grad(orig, grad): + """ + Returns the gradient of not_equal (just zeros). + """ + return [zeros_like(orig.args[0]), zeros_like(orig.args[1])] + + +@register_gradient("strided_slice") +def strided_slice_grad(orig, grad): + """ + Returns the gradient of strided_slice, which is equal to grad where the + input was sliced and zero elsewhere. + """ + assert orig.attrs.axes is None, "grad for strided_slice with axes is not yet supported" + x = orig.args[0] + begin = get_const_tuple(orig.attrs.begin) + end = get_const_tuple(orig.attrs.end) + strides = get_const_tuple(orig.attrs.strides) + if orig.attrs.slice_mode == "size": + # convert sizes to ending indices and ignore strides + end = list(end) + for i, (start, size) in enumerate(zip(begin, end)): + if size == -1: + end[i] = int(x.checked_type.shape[i]) + else: + end[i] = start + size + strides = None + else: + assert orig.attrs.slice_mode == "end" + return [strided_set(zeros_like(x), grad, begin, end, strides)] + + +@register_gradient("one_hot") +def one_hot_grad(orig, grad): + """ + Returns the gradient of one_hot, which is the sum of grad at on and off + indices for on_value and off_value respectively. + """ + indices, on_value, off_value = orig.args + + g_zeros = zeros_like(grad) + on_mask = equal(orig, on_value) + grad_on = _sum(where(on_mask, grad, g_zeros)) + grad_off = _sum(where(on_mask, g_zeros, grad)) + + return [zeros_like(indices), cast_like(grad_on, on_value), cast_like(grad_off, off_value)] diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index 4a6ffb933881..e2145f77b366 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import pytest +import numpy as np from tvm import relay from tvm.relay.testing import check_grad @@ -72,5 +73,28 @@ def test_reverse_reshape_grad(): check_grad(relay.Function([x], relay.op.reverse_reshape(x, (-1, 0)))) +def test_one_hot_grad(): + indices_shape = (3, 4) + depth = 5 + axis = -1 + + for indices_dtype in ["int32", "int64"]: + for val_dtype in ["float32", "float64"]: + inputs = [ + np.random.randint(depth, size=indices_shape, dtype=indices_dtype), + np.array(np.random.randn() * 1e-5).astype(val_dtype), + np.array(np.random.randn() * 1e-5).astype(val_dtype), + ] + test_inputs = inputs[1:] + + indices = relay.var("indices", shape=indices_shape, dtype=indices_dtype) + on_val = relay.var("on_val", shape=tuple(), dtype=val_dtype) + off_val = relay.var("off_val", shape=tuple(), dtype=val_dtype) + y = relay.one_hot(indices, on_val, off_val, depth, axis, val_dtype) + f = relay.Function([indices, on_val, off_val], y) + + check_grad(f, inputs=inputs, test_inputs=test_inputs) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 821e10f97e21..ae3fc2641a25 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -69,6 +69,13 @@ def test_cast_grad(): check_grad(fwd_func) +def test_cast_like_grad(): + data = relay.var("data", shape=(10, 4), dtype="float32") + like = relay.var("like", shape=(1,), dtype="float64") + fwd_func = relay.Function([data, like], relay.cast_like(data, like)) + check_grad(fwd_func) + + def test_copy_grad(): data = relay.var("data", relay.TensorType((10, 4), "float64")) fwd_func = relay.Function([data], relay.copy(data)) diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py index 0f73e89c94ad..17d30cacac41 100644 --- a/tests/python/relay/test_op_grad_level4.py +++ b/tests/python/relay/test_op_grad_level4.py @@ -86,5 +86,38 @@ def test_less_equal_grad(): check_grad(fwd_func, inputs=inputs, test_inputs=inputs, eps=1e-6) +def test_not_equal_grad(): + x_type = relay.TensorType((2, 3, 4), "float32") + y_type = relay.TensorType((3, 1), "float32") + # We need to generate inputs far apart to get correct numerical gradients + # (otherwise adding epsilon may change comparison result). The gradient + # should always be zero for both inputs. + inputs = [ + np.random.choice([-1, 1], size=x_type.concrete_shape).astype(x_type.dtype), + np.random.choice([-2, 2], size=y_type.concrete_shape).astype(y_type.dtype), + ] + + x = relay.var("x", type_annotation=x_type) + y = relay.var("y", type_annotation=y_type) + fwd_func = relay.Function([x, y], relay.not_equal(x, y)) + check_grad(fwd_func, inputs=inputs, test_inputs=inputs, eps=1e-6) + + +def test_strided_slice_grad(): + def check(sh, dtype, begin, end, strides, slice_mode): + x = relay.var("x", shape=sh, dtype=dtype) + f = relay.Function( + [x], + relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode), + ) + check_grad(f) + + check((2, 3, 4), "float32", (0, 1, 0), (-1, -1, 1), (1, 1, 1), "size") + check((2, 3, 4), "float32", (0, 1, 0), (2, 3, 1), (1, 1, 1), "end") + # check that strides are properly ignored when using "size" mode + check((2, 3, 4), "float32", (0, 0, 0), (-1, -1, -1), (1, 1, 2), "size") + check((2, 3, 4), "float32", (0, 0, 0), (2, 3, 4), (1, 1, 2), "end") + + if __name__ == "__main__": pytest.main()