Skip to content

Commit

Permalink
[Relay][Dyn] Add dynamic reshape grad (#6080)
Browse files Browse the repository at this point in the history
* add dynamic rehape grad

* fix lint

* fix unit tests, warning
  • Loading branch information
Matthew Brookhart authored Jul 17, 2020
1 parent ccacb1e commit 3150db7
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
12 changes: 12 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,18 @@ def reshape_grad(orig, grad):
return [reshape_like(grad, orig.args[0])]


@register_gradient("dyn.reshape")
def dyn_reshape_grad(orig, grad):
"""Gradient of dyn_reshape"""
return [reshape_like(grad, orig.args[0]), zeros_like(orig.args[1])]


@register_gradient("shape_of")
def shape_of_grad(orig, grad):
"""Gradient of shape_of"""
return [zeros_like(orig.args[0])]


@register_gradient("cast")
def cast_grad(orig, grad):
x = orig.args[0]
Expand Down
36 changes: 32 additions & 4 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import tvm.relay.op as op
from tvm.relay import Prelude


from . import mlp
from . import resnet
from . import resnet_3d
Expand All @@ -47,6 +46,7 @@
from .py_converter import to_python, run_as_python
from ..transform import gradient


def run_opt_pass(expr, opt_pass, import_prelude=False):
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
Expand All @@ -65,7 +65,14 @@ def _np_randn_from_type(t, scale=1, mean=0):
return (mean + (scale * np.random.randn(*(int(d) for d in t.shape)))).astype(t.dtype)


def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, mean=0):
def check_grad(func,
inputs=None,
test_inputs=None,
eps=1e-6,
atol=1e-5,
rtol=1e-3,
scale=None,
mean=0):
"""Perform numerical gradient checking given a relay function.
Compare analytical gradients to numerical gradients derived from two-sided approximation. Note
Expand All @@ -80,6 +87,11 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me
Optional user-provided input parameters to use. If not given, will generate random normal
inputs scaled to be close to the chosen epsilon value to avoid numerical precision loss.
test_inputs: List[np.array]
The inputs to test for gradient matching. Useful in cases where some inputs are not
differentiable, such as symbolic inputs to dynamic ops. If not given, all inputs are
tested.
eps: float
The epsilon value to use for computing numerical gradient approximation.
Expand Down Expand Up @@ -109,16 +121,30 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me
# Generate random inputs on the same scale as epsilon to avoid numerical precision loss.
inputs = [_np_randn_from_type(x.checked_type, scale=scale, mean=mean) for x in params]

if test_inputs is None:
test_inputs = inputs

for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)

# Get analytic gradients.
_, grads = intrp.evaluate(bwd_func)(*inputs)
grads = [grad.asnumpy().astype("float64") for grad in grads]

# Throw out gradients we aren't testing
if inputs != test_inputs:
tmp = []
# find the gradient that corresponds to every test input
for test_input in test_inputs:
for i, grad in enumerate(grads):
if inputs[i] is test_input:
tmp.append(grad)
break
grads = tmp

# Get numeric gradients for each dimension of each param, using two-sided approximation.
approx_grads = []
for x in inputs:
for x in test_inputs:
approx_grad = np.zeros(x.shape)
for i in np.ndindex(*x.shape):
x_i = x[i]
Expand All @@ -129,7 +155,6 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me
x[i] = x_i
approx_grad[i] = np.sum((fwd_plus - fwd_minus) / (2 * eps))
approx_grads.append(approx_grad)

# Compare gradients by checking that relative difference is below tolerance.
for grad, approx_grad in zip(grads, approx_grads):
np.testing.assert_allclose(grad, approx_grad, atol=atol, rtol=rtol)
Expand All @@ -142,13 +167,16 @@ def rand(dtype, *shape):
def count_ops(expr):
"""count number of times a given op is called in the graph"""
class OpCounter(tvm.relay.ExprVisitor):
"""OpCounter"""
def visit_call(self, call):
if hasattr(call, 'op'):
self.node_counter[call.op.name] += 1
return super().visit_call(call)

def count(self, expr):
self.node_set = {}
self.node_counter = collections.Counter()
self.visit(expr)
return self.node_counter

return OpCounter().count(expr)
9 changes: 8 additions & 1 deletion tests/python/relay/dyn/test_dynamic_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def verify_reshape(shape, newshape, oshape):

func = relay.Function([x, y], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
x_data = np.ones(shape).astype("float32")
ref_res = np.reshape(x_data, oshape)
check_grad(run_infer_type(func),
inputs=[x_data, np.array(newshape).astype("int64")],
test_inputs=[x_data], eps=1e-3)
verify_func(func, [x_data, np.array(newshape).astype("int64")], ref_res)
verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
Expand All @@ -66,6 +70,8 @@ def verify_reshape(shape, newshape, oshape):
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32")
ref_res = np.reshape(x_data, oshape)
check_grad(run_infer_type(func),
inputs=[x_data, y_data], eps=1e-3)
verify_func(func, [x_data, y_data], ref_res)
verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
Expand All @@ -79,6 +85,7 @@ def verify_tile(dshape, reps):
func = relay.Function([x, r], z)
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
ref_res = np.tile(x_data, reps=reps)
reps_data = np.array(reps).astype("float32")
verify_func(func, [x_data, np.array(reps).astype("float32")], ref_res)
verify_tile((2, 3, 4), (3, 2, 1))
verify_tile((2, 3, 4), (1, 2))
Expand Down Expand Up @@ -111,4 +118,4 @@ def verify_zeros_ones(shape, dtype):
test_dyn_reshape()
test_dyn_shape_reshape()
test_dyn_tile()
test_dyn_zeros_ones()
test_dyn_zeros_ones()

0 comments on commit 3150db7

Please sign in to comment.