Skip to content

Commit

Permalink
[Relay] Add grads (apache#3857)
Browse files Browse the repository at this point in the history
* Add gradient implementations

* Add docstrings to fix lint errors
  • Loading branch information
SWu authored and wweic committed Sep 16, 2019
1 parent cbc98a8 commit 88cc58c
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 12 deletions.
79 changes: 72 additions & 7 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,25 @@
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from topi.util import get_const_tuple

from topi.nn.util import get_pad_tuple
from ..expr import const, Tuple, TupleGetItem
from topi.util import get_const_tuple

from ..expr import Tuple, TupleGetItem, const
from . import nn as _nn
from .op import register_gradient
from .reduce import sum as _sum
from .transform import collapse_sum_like, broadcast_to_like, where, transpose, reshape, tile, \
strided_slice
from .tensor import exp, negative, power, less, cos, sin
from .tensor import zeros_like, ones_like
from . import nn as _nn
from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like
from .transform import (
broadcast_to_like,
collapse_sum_like,
reshape,
reshape_like,
strided_slice,
tile,
transpose,
where,
)


@register_gradient("log")
Expand Down Expand Up @@ -250,3 +259,59 @@ def conv2d_grad(orig, grad):
end=[None, None, filter_h, filter_w])

return [backward_data, backward_weight]


@register_gradient("nn.softmax")
def softmax_grad(orig, grad):
"""Gradient of softmax"""
return [(grad - _sum(grad * orig, orig.attrs.axis, True)) * orig]


@register_gradient("nn.bias_add")
def bias_grad(orig, grad):
"""Returns grad"""
data, bias = orig.args
return [collapse_sum_like(grad, data),
collapse_sum_like(grad, bias)]


@register_gradient("nn.dense")
def dense_grad(orig, grad):
"""Returns [grad' @ weight, data @ grad']"""
data, weight = orig.args
return [collapse_sum_like(transpose(grad) * weight, data),
collapse_sum_like(data * transpose(grad), weight)]


@register_gradient("nn.batch_flatten")
def batch_flatten_grad(orig, grad):
"""Returns grad reshaped to data dims"""
data = orig.args[0]
return [reshape_like(grad, data)]


@register_gradient("transpose")
def transpose_grad(orig, grad):
"""Returns grad transposed over the complement of original transpose axes"""
orig_axes = orig.attrs.axes
if orig_axes:
dims = len(orig_axes)
new_axes = [0] * dims
for i in range(dims):
new_axes[int(orig_axes[i])] = i
else:
new_axes = None
return [transpose(grad, axes=new_axes)]


@register_gradient("negative")
def negative_grad(orig, grad):
"""Returns -grad"""
return [-grad]


@register_gradient("sum")
def sum_grad(orig, grad):
"""Returns grad broadcasted to data dims"""
data = orig.args[0]
return [broadcast_to_like(grad, data)]
19 changes: 18 additions & 1 deletion tests/python/relay/test_op_grad_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np

import tvm
from tvm import relay
from tvm.relay.testing import check_grad, ctx_list, run_infer_type
from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list, run_infer_type


def sigmoid(x):
one = np.ones_like(x)
Expand All @@ -30,6 +32,7 @@ def relu(x):
np.maximum(x_copy, 0, x_copy)
return x_copy


def test_unary_op():
def check_single_op(opfunc, ref):
shape = (10, 4)
Expand Down Expand Up @@ -93,6 +96,20 @@ def check_binary_op(opfunc, ref):
check_binary_op(opfunc, ref)


def test_softmax_grad():
data = relay.var("data", relay.TensorType((1, 16), "float64"))
fwd_func = relay.Function([data], relay.nn.softmax(data))
check_grad(fwd_func)


def test_bias_add_grad():
data = relay.var("data", relay.TensorType((1, 16), "float32"))
bias = relay.var("bias", relay.TensorType((16,), "float32"))
fwd_func = relay.Function([data, bias], relay.nn.bias_add(data, bias))
check_grad(fwd_func)


if __name__ == "__main__":
test_unary_op()
test_binary_op()
test_bias_add_grad()
31 changes: 28 additions & 3 deletions tests/python/relay/test_op_grad_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm

import topi
import topi.testing
import tvm
from tvm import relay
from tvm.relay.testing import check_grad, ctx_list, run_infer_type
from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list, check_grad
from tvm.relay.testing import run_infer_type


def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):
Expand Down Expand Up @@ -129,7 +129,32 @@ def test_conv2d_grad():
verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1], mode='first_order')


def verify_dense_grad(d_shape, w_shape):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
weight = relay.var("weight", relay.TensorType(w_shape, "float32"))
fwd_func = relay.Function([data, weight], relay.nn.dense(data, weight))
check_grad(fwd_func)


def test_dense_grad():
verify_dense_grad((1, 8), (16, 8))
verify_dense_grad((1, 4), (3, 4))


def verify_batch_flatten_grad(d_shape):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
fwd_func = relay.Function([data], relay.nn.batch_flatten(data))
check_grad(fwd_func)


def test_batch_flatten_grad():
verify_batch_flatten_grad((1, 2, 3, 4))
verify_batch_flatten_grad((1, 8))


if __name__ == "__main__":
test_max_pool2d_grad()
test_avg_pool2d_grad()
test_conv2d_grad()
test_dense_grad()
test_batch_flatten_grad()
23 changes: 22 additions & 1 deletion tests/python/relay/test_op_grad_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np

import tvm
from tvm import relay
from tvm.relay.testing import check_grad, ctx_list, run_infer_type
from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list, run_infer_type


def test_clip():
ref = (lambda x: np.where(x > 10.0, np.zeros_like(x),
Expand All @@ -38,5 +40,24 @@ def test_clip():
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)


def verify_transpose_grad(d_shape, axes=None):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
fwd_func = relay.Function([data], relay.transpose(data, axes=axes))
check_grad(fwd_func)


def test_transpose_grad():
verify_transpose_grad((1, 2, 3, 4))
verify_transpose_grad((1, 2, 3, 4), axes=(0, 2, 3, 1))


def test_negative_grad():
data = relay.var("data", relay.TensorType((10, 4), "float32"))
fwd_func = relay.Function([data], relay.negative(data))
check_grad(fwd_func)


if __name__ == "__main__":
test_clip()
test_transpose_grad()
test_negative_grad()
35 changes: 35 additions & 0 deletions tests/python/relay/test_op_grad_level4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm import relay
from tvm.relay.testing import check_grad


def verify_sum_grad(d_shape, axis=None, keepdims=False, exclude=False):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
fwd_func = relay.Function([data], relay.sum(data, axis=axis, keepdims=keepdims, exclude=exclude))
check_grad(fwd_func)


def test_sum_grad():
verify_sum_grad((4, 2))
verify_sum_grad((4, 2), axis=-1, keepdims=True)
verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)



if __name__ == "__main__":
test_sum_grad()

0 comments on commit 88cc58c

Please sign in to comment.