From 3150db767ae6e190216903fd46d2a9b1c2672621 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 17 Jul 2020 15:39:57 -0700 Subject: [PATCH] [Relay][Dyn] Add dynamic reshape grad (#6080) * add dynamic rehape grad * fix lint * fix unit tests, warning --- python/tvm/relay/op/_tensor_grad.py | 12 +++++++ python/tvm/relay/testing/__init__.py | 36 ++++++++++++++++--- .../relay/dyn/test_dynamic_op_level3.py | 9 ++++- 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 849d0a3f26d4..3e87f6078664 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -514,6 +514,18 @@ def reshape_grad(orig, grad): return [reshape_like(grad, orig.args[0])] +@register_gradient("dyn.reshape") +def dyn_reshape_grad(orig, grad): + """Gradient of dyn_reshape""" + return [reshape_like(grad, orig.args[0]), zeros_like(orig.args[1])] + + +@register_gradient("shape_of") +def shape_of_grad(orig, grad): + """Gradient of shape_of""" + return [zeros_like(orig.args[0])] + + @register_gradient("cast") def cast_grad(orig, grad): x = orig.args[0] diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index a53e9d7b31ef..0204e5bb5146 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -26,7 +26,6 @@ import tvm.relay.op as op from tvm.relay import Prelude - from . import mlp from . import resnet from . import resnet_3d @@ -47,6 +46,7 @@ from .py_converter import to_python, run_as_python from ..transform import gradient + def run_opt_pass(expr, opt_pass, import_prelude=False): assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) @@ -65,7 +65,14 @@ def _np_randn_from_type(t, scale=1, mean=0): return (mean + (scale * np.random.randn(*(int(d) for d in t.shape)))).astype(t.dtype) -def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, mean=0): +def check_grad(func, + inputs=None, + test_inputs=None, + eps=1e-6, + atol=1e-5, + rtol=1e-3, + scale=None, + mean=0): """Perform numerical gradient checking given a relay function. Compare analytical gradients to numerical gradients derived from two-sided approximation. Note @@ -80,6 +87,11 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me Optional user-provided input parameters to use. If not given, will generate random normal inputs scaled to be close to the chosen epsilon value to avoid numerical precision loss. + test_inputs: List[np.array] + The inputs to test for gradient matching. Useful in cases where some inputs are not + differentiable, such as symbolic inputs to dynamic ops. If not given, all inputs are + tested. + eps: float The epsilon value to use for computing numerical gradient approximation. @@ -109,6 +121,9 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me # Generate random inputs on the same scale as epsilon to avoid numerical precision loss. inputs = [_np_randn_from_type(x.checked_type, scale=scale, mean=mean) for x in params] + if test_inputs is None: + test_inputs = inputs + for target, ctx in ctx_list(): intrp = relay.create_executor(ctx=ctx, target=target) @@ -116,9 +131,20 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me _, grads = intrp.evaluate(bwd_func)(*inputs) grads = [grad.asnumpy().astype("float64") for grad in grads] + # Throw out gradients we aren't testing + if inputs != test_inputs: + tmp = [] + # find the gradient that corresponds to every test input + for test_input in test_inputs: + for i, grad in enumerate(grads): + if inputs[i] is test_input: + tmp.append(grad) + break + grads = tmp + # Get numeric gradients for each dimension of each param, using two-sided approximation. approx_grads = [] - for x in inputs: + for x in test_inputs: approx_grad = np.zeros(x.shape) for i in np.ndindex(*x.shape): x_i = x[i] @@ -129,7 +155,6 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me x[i] = x_i approx_grad[i] = np.sum((fwd_plus - fwd_minus) / (2 * eps)) approx_grads.append(approx_grad) - # Compare gradients by checking that relative difference is below tolerance. for grad, approx_grad in zip(grads, approx_grads): np.testing.assert_allclose(grad, approx_grad, atol=atol, rtol=rtol) @@ -142,13 +167,16 @@ def rand(dtype, *shape): def count_ops(expr): """count number of times a given op is called in the graph""" class OpCounter(tvm.relay.ExprVisitor): + """OpCounter""" def visit_call(self, call): if hasattr(call, 'op'): self.node_counter[call.op.name] += 1 return super().visit_call(call) + def count(self, expr): self.node_set = {} self.node_counter = collections.Counter() self.visit(expr) return self.node_counter + return OpCounter().count(expr) diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index e63f9b8cd722..ff98c480f8f1 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -44,7 +44,11 @@ def verify_reshape(shape, newshape, oshape): func = relay.Function([x, y], z) x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + x_data = np.ones(shape).astype("float32") ref_res = np.reshape(x_data, oshape) + check_grad(run_infer_type(func), + inputs=[x_data, np.array(newshape).astype("int64")], + test_inputs=[x_data], eps=1e-3) verify_func(func, [x_data, np.array(newshape).astype("int64")], ref_res) verify_reshape((2, 3, 4), (8, 3), (8, 3)) verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) @@ -66,6 +70,8 @@ def verify_reshape(shape, newshape, oshape): x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32") ref_res = np.reshape(x_data, oshape) + check_grad(run_infer_type(func), + inputs=[x_data, y_data], eps=1e-3) verify_func(func, [x_data, y_data], ref_res) verify_reshape((2, 3, 4), (8, 3), (8, 3)) verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) @@ -79,6 +85,7 @@ def verify_tile(dshape, reps): func = relay.Function([x, r], z) x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32") ref_res = np.tile(x_data, reps=reps) + reps_data = np.array(reps).astype("float32") verify_func(func, [x_data, np.array(reps).astype("float32")], ref_res) verify_tile((2, 3, 4), (3, 2, 1)) verify_tile((2, 3, 4), (1, 2)) @@ -111,4 +118,4 @@ def verify_zeros_ones(shape, dtype): test_dyn_reshape() test_dyn_shape_reshape() test_dyn_tile() - test_dyn_zeros_ones() \ No newline at end of file + test_dyn_zeros_ones()