Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Dyn] Add dynamic reshape grad #6080

Merged
merged 3 commits into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,18 @@ def reshape_grad(orig, grad):
return [reshape_like(grad, orig.args[0])]


@register_gradient("dyn.reshape")
def dyn_reshape_grad(orig, grad):
"""Gradient of dyn_reshape"""
return [reshape_like(grad, orig.args[0]), zeros_like(orig.args[1])]


@register_gradient("shape_of")
def shape_of_grad(orig, grad):
"""Gradient of shape_of"""
return [zeros_like(orig.args[0])]


@register_gradient("cast")
def cast_grad(orig, grad):
x = orig.args[0]
Expand Down
36 changes: 32 additions & 4 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import tvm.relay.op as op
from tvm.relay import Prelude


from . import mlp
from . import resnet
from . import resnet_3d
Expand All @@ -47,6 +46,7 @@
from .py_converter import to_python, run_as_python
from ..transform import gradient


def run_opt_pass(expr, opt_pass, import_prelude=False):
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
Expand All @@ -65,7 +65,14 @@ def _np_randn_from_type(t, scale=1, mean=0):
return (mean + (scale * np.random.randn(*(int(d) for d in t.shape)))).astype(t.dtype)


def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, mean=0):
def check_grad(func,
inputs=None,
test_inputs=None,
eps=1e-6,
atol=1e-5,
rtol=1e-3,
scale=None,
mean=0):
"""Perform numerical gradient checking given a relay function.
Compare analytical gradients to numerical gradients derived from two-sided approximation. Note
Expand All @@ -80,6 +87,11 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me
Optional user-provided input parameters to use. If not given, will generate random normal
inputs scaled to be close to the chosen epsilon value to avoid numerical precision loss.
test_inputs: List[np.array]
The inputs to test for gradient matching. Useful in cases where some inputs are not
differentiable, such as symbolic inputs to dynamic ops. If not given, all inputs are
tested.
eps: float
The epsilon value to use for computing numerical gradient approximation.
Expand Down Expand Up @@ -109,16 +121,30 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me
# Generate random inputs on the same scale as epsilon to avoid numerical precision loss.
inputs = [_np_randn_from_type(x.checked_type, scale=scale, mean=mean) for x in params]

if test_inputs is None:
test_inputs = inputs

for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)

# Get analytic gradients.
_, grads = intrp.evaluate(bwd_func)(*inputs)
grads = [grad.asnumpy().astype("float64") for grad in grads]

# Throw out gradients we aren't testing
if inputs != test_inputs:
tmp = []
# find the gradient that corresponds to every test input
for test_input in test_inputs:
for i, grad in enumerate(grads):
if inputs[i] is test_input:
tmp.append(grad)
break
grads = tmp

# Get numeric gradients for each dimension of each param, using two-sided approximation.
approx_grads = []
for x in inputs:
for x in test_inputs:
approx_grad = np.zeros(x.shape)
for i in np.ndindex(*x.shape):
x_i = x[i]
Expand All @@ -129,7 +155,6 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me
x[i] = x_i
approx_grad[i] = np.sum((fwd_plus - fwd_minus) / (2 * eps))
approx_grads.append(approx_grad)

# Compare gradients by checking that relative difference is below tolerance.
for grad, approx_grad in zip(grads, approx_grads):
np.testing.assert_allclose(grad, approx_grad, atol=atol, rtol=rtol)
Expand All @@ -142,13 +167,16 @@ def rand(dtype, *shape):
def count_ops(expr):
"""count number of times a given op is called in the graph"""
class OpCounter(tvm.relay.ExprVisitor):
"""OpCounter"""
def visit_call(self, call):
if hasattr(call, 'op'):
self.node_counter[call.op.name] += 1
return super().visit_call(call)

def count(self, expr):
self.node_set = {}
self.node_counter = collections.Counter()
self.visit(expr)
return self.node_counter

return OpCounter().count(expr)
9 changes: 8 additions & 1 deletion tests/python/relay/dyn/test_dynamic_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def verify_reshape(shape, newshape, oshape):

func = relay.Function([x, y], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
x_data = np.ones(shape).astype("float32")
ref_res = np.reshape(x_data, oshape)
check_grad(run_infer_type(func),
inputs=[x_data, np.array(newshape).astype("int64")],
test_inputs=[x_data], eps=1e-3)
verify_func(func, [x_data, np.array(newshape).astype("int64")], ref_res)
verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
Expand All @@ -66,6 +70,8 @@ def verify_reshape(shape, newshape, oshape):
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32")
ref_res = np.reshape(x_data, oshape)
check_grad(run_infer_type(func),
inputs=[x_data, y_data], eps=1e-3)
verify_func(func, [x_data, y_data], ref_res)
verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
Expand All @@ -79,6 +85,7 @@ def verify_tile(dshape, reps):
func = relay.Function([x, r], z)
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
ref_res = np.tile(x_data, reps=reps)
reps_data = np.array(reps).astype("float32")
verify_func(func, [x_data, np.array(reps).astype("float32")], ref_res)
verify_tile((2, 3, 4), (3, 2, 1))
verify_tile((2, 3, 4), (1, 2))
Expand Down Expand Up @@ -111,4 +118,4 @@ def verify_zeros_ones(shape, dtype):
test_dyn_reshape()
test_dyn_shape_reshape()
test_dyn_tile()
test_dyn_zeros_ones()
test_dyn_zeros_ones()