Skip to content

Commit

Permalink
[Relay][Training] Additional gradients (apache#8307)
Browse files Browse the repository at this point in the history
  • Loading branch information
altanh authored and ylc committed Jan 13, 2022
1 parent ad24093 commit 6a2cd6f
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 5 deletions.
62 changes: 57 additions & 5 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
"""Gradient definitions for Relay operators"""
from tvm.topi.nn.utils import get_pad_tuple
from tvm.topi.utils import get_const_tuple
from tvm.error import OpError
Expand Down Expand Up @@ -527,10 +527,7 @@ def softmax_grad(orig, grad):
@register_gradient("nn.log_softmax")
def log_softmax_grad(orig, grad):
"""Gradient of log_softmax"""
x = orig.args[0]
sm = _nn.softmax(x, axis=orig.attrs.axis)
grad = grad / sm
return softmax_grad(sm, grad)
return [grad - _sum(grad, axis=orig.attrs.axis, keepdims=True) * exp(orig)]


@register_gradient("nn.bias_add")
Expand Down Expand Up @@ -596,6 +593,12 @@ def cast_grad(orig, grad):
return [cast_like(grad, x)]


@register_gradient("cast_like")
def cast_like_grad(orig, grad):
x, like = orig.args
return [cast_like(grad, x), zeros_like(like)]


@register_gradient("nn.batch_flatten")
def batch_flatten_grad(orig, grad):
"""Returns grad reshaped to data dims"""
Expand Down Expand Up @@ -873,3 +876,52 @@ def less_equal_grad(orig, grad):
Returns the gradient of less_equal.
"""
return [zeros_like(orig.args[0]), zeros_like(orig.args[1])]


@register_gradient("not_equal")
def not_equal_grad(orig, grad):
"""
Returns the gradient of not_equal (just zeros).
"""
return [zeros_like(orig.args[0]), zeros_like(orig.args[1])]


@register_gradient("strided_slice")
def strided_slice_grad(orig, grad):
"""
Returns the gradient of strided_slice, which is equal to grad where the
input was sliced and zero elsewhere.
"""
assert orig.attrs.axes is None, "grad for strided_slice with axes is not yet supported"
x = orig.args[0]
begin = get_const_tuple(orig.attrs.begin)
end = get_const_tuple(orig.attrs.end)
strides = get_const_tuple(orig.attrs.strides)
if orig.attrs.slice_mode == "size":
# convert sizes to ending indices and ignore strides
end = list(end)
for i, (start, size) in enumerate(zip(begin, end)):
if size == -1:
end[i] = int(x.checked_type.shape[i])
else:
end[i] = start + size
strides = None
else:
assert orig.attrs.slice_mode == "end"
return [strided_set(zeros_like(x), grad, begin, end, strides)]


@register_gradient("one_hot")
def one_hot_grad(orig, grad):
"""
Returns the gradient of one_hot, which is the sum of grad at on and off
indices for on_value and off_value respectively.
"""
indices, on_value, off_value = orig.args

g_zeros = zeros_like(grad)
on_mask = equal(orig, on_value)
grad_on = _sum(where(on_mask, grad, g_zeros))
grad_off = _sum(where(on_mask, g_zeros, grad))

return [zeros_like(indices), cast_like(grad_on, on_value), cast_like(grad_off, off_value)]
24 changes: 24 additions & 0 deletions tests/python/relay/test_op_grad_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import pytest
import numpy as np

from tvm import relay
from tvm.relay.testing import check_grad
Expand Down Expand Up @@ -72,5 +73,28 @@ def test_reverse_reshape_grad():
check_grad(relay.Function([x], relay.op.reverse_reshape(x, (-1, 0))))


def test_one_hot_grad():
indices_shape = (3, 4)
depth = 5
axis = -1

for indices_dtype in ["int32", "int64"]:
for val_dtype in ["float32", "float64"]:
inputs = [
np.random.randint(depth, size=indices_shape, dtype=indices_dtype),
np.array(np.random.randn() * 1e-5).astype(val_dtype),
np.array(np.random.randn() * 1e-5).astype(val_dtype),
]
test_inputs = inputs[1:]

indices = relay.var("indices", shape=indices_shape, dtype=indices_dtype)
on_val = relay.var("on_val", shape=tuple(), dtype=val_dtype)
off_val = relay.var("off_val", shape=tuple(), dtype=val_dtype)
y = relay.one_hot(indices, on_val, off_val, depth, axis, val_dtype)
f = relay.Function([indices, on_val, off_val], y)

check_grad(f, inputs=inputs, test_inputs=test_inputs)


if __name__ == "__main__":
pytest.main([__file__])
7 changes: 7 additions & 0 deletions tests/python/relay/test_op_grad_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def test_cast_grad():
check_grad(fwd_func)


def test_cast_like_grad():
data = relay.var("data", shape=(10, 4), dtype="float32")
like = relay.var("like", shape=(1,), dtype="float64")
fwd_func = relay.Function([data, like], relay.cast_like(data, like))
check_grad(fwd_func)


def test_copy_grad():
data = relay.var("data", relay.TensorType((10, 4), "float64"))
fwd_func = relay.Function([data], relay.copy(data))
Expand Down
33 changes: 33 additions & 0 deletions tests/python/relay/test_op_grad_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,38 @@ def test_less_equal_grad():
check_grad(fwd_func, inputs=inputs, test_inputs=inputs, eps=1e-6)


def test_not_equal_grad():
x_type = relay.TensorType((2, 3, 4), "float32")
y_type = relay.TensorType((3, 1), "float32")
# We need to generate inputs far apart to get correct numerical gradients
# (otherwise adding epsilon may change comparison result). The gradient
# should always be zero for both inputs.
inputs = [
np.random.choice([-1, 1], size=x_type.concrete_shape).astype(x_type.dtype),
np.random.choice([-2, 2], size=y_type.concrete_shape).astype(y_type.dtype),
]

x = relay.var("x", type_annotation=x_type)
y = relay.var("y", type_annotation=y_type)
fwd_func = relay.Function([x, y], relay.not_equal(x, y))
check_grad(fwd_func, inputs=inputs, test_inputs=inputs, eps=1e-6)


def test_strided_slice_grad():
def check(sh, dtype, begin, end, strides, slice_mode):
x = relay.var("x", shape=sh, dtype=dtype)
f = relay.Function(
[x],
relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode),
)
check_grad(f)

check((2, 3, 4), "float32", (0, 1, 0), (-1, -1, 1), (1, 1, 1), "size")
check((2, 3, 4), "float32", (0, 1, 0), (2, 3, 1), (1, 1, 1), "end")
# check that strides are properly ignored when using "size" mode
check((2, 3, 4), "float32", (0, 0, 0), (-1, -1, -1), (1, 1, 2), "size")
check((2, 3, 4), "float32", (0, 0, 0), (2, 3, 4), (1, 1, 2), "end")


if __name__ == "__main__":
pytest.main()

0 comments on commit 6a2cd6f

Please sign in to comment.