Skip to content

Commit

Permalink
fix tanh gradient and update tests to use downstream gradient (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
altanh authored and Lokiiiiii committed Mar 1, 2021
1 parent aa06a4a commit 816c99c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 26 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def sigmoid_grad(orig, grad):
@register_gradient("tanh")
def tanh_grad(orig, grad):
"""Returns grad * (1 - tanh(x) * tanh(x))."""
return [grad * ones_like(orig) - orig * orig]
return [grad * (ones_like(orig) - orig * orig)]


@register_gradient("nn.relu")
Expand Down
52 changes: 27 additions & 25 deletions tests/python/relay/test_op_grad_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,42 +42,44 @@ def check_single_op(opfunc, ref, dtype):
shape = (10, 4)
tp = relay.TensorType(shape, dtype)
x = relay.var("x", tp)
y = opfunc(x)
g = relay.var("g", tp)
y = opfunc(x) * g

if ref is not None:
data = np.random.rand(*shape).astype(dtype)
ref_grad = ref(data)
fwd_func = relay.Function([x], y)
grad_in = np.random.rand(*shape).astype(dtype)
ref_grad = ref(data, grad_in)
fwd_func = relay.Function([x, g], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))

for target, ctx in tvm.testing.enabled_targets():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad,) = intrp.evaluate(bwd_func)(data)
op_res, (op_grad, _) = intrp.evaluate(bwd_func)(data, grad_in)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)

for opfunc, ref in [
(tvm.relay.log, lambda x: 1 / x),
(tvm.relay.exp, np.exp),
(tvm.relay.sigmoid, lambda x: sigmoid(x) * (1 - sigmoid(x))),
(tvm.relay.tanh, lambda x: 1 - np.tanh(x) * np.tanh(x)),
(tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)),
(tvm.relay.abs, lambda x: np.where(x < 0, -np.ones_like(x), np.ones_like(x))),
(relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))),
(tvm.relay.erf, lambda x: 2.0 / (np.pi ** (0.5)) * np.exp(-x * x)),
(tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
(tvm.relay.sin, lambda x: np.cos(x)),
(tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)),
(tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0))),
(tvm.relay.log2, lambda x: 1 / (np.log(2) * x)),
(tvm.relay.log10, lambda x: 1 / (np.log(10) * x)),
(tvm.relay.cosh, lambda x: np.sinh(x)),
(tvm.relay.sinh, lambda x: np.cosh(x)),
(tvm.relay.asin, lambda x: 1.0 / (1.0 - x ** 2) ** (1.0 / 2.0)),
(tvm.relay.acos, lambda x: -1.0 / (1.0 - x ** 2.0) ** (1.0 / 2.0)),
(tvm.relay.acosh, lambda x: 1.0 / (x ** 2 - 1.0) ** (1.0 / 2.0)),
(tvm.relay.asinh, lambda x: 1.0 / (x ** 2 + 1.0) ** (1.0 / 2.0)),
(tvm.relay.atanh, lambda x: -1.0 / (x ** 2 - 1.0)),
(tvm.relay.log, lambda x, g: g * (1 / x)),
(tvm.relay.exp, lambda x, g: g * np.exp(x)),
(tvm.relay.sigmoid, lambda x, g: g * sigmoid(x) * (1 - sigmoid(x))),
(tvm.relay.tanh, lambda x, g: g * (1 - np.tanh(x) * np.tanh(x))),
(tvm.relay.sqrt, lambda x, g: g * 0.5 * np.power(x, -0.5)),
(tvm.relay.abs, lambda x, g: np.where(x < 0, -g, g)),
(relay.nn.relu, lambda x, g: np.where(x < 0, np.zeros_like(x), g)),
(tvm.relay.erf, lambda x, g: g * (2.0 / (np.pi ** (0.5)) * np.exp(-x * x))),
(tvm.relay.cos, lambda x, g: g * -1.0 * np.sin(x)),
(tvm.relay.sin, lambda x, g: g * np.cos(x)),
(tvm.relay.tan, lambda x, g: g * (1.0 / (np.cos(x) ** 2))),
(tvm.relay.atan, lambda x, g: g * (1 / (1 + np.power(x, 2.0)))),
(tvm.relay.log2, lambda x, g: g * (1 / (np.log(2) * x))),
(tvm.relay.log10, lambda x, g: g * (1 / (np.log(10) * x))),
(tvm.relay.cosh, lambda x, g: g * (np.sinh(x))),
(tvm.relay.sinh, lambda x, g: g * (np.cosh(x))),
(tvm.relay.asin, lambda x, g: g * (1.0 / (1.0 - x ** 2) ** (1.0 / 2.0))),
(tvm.relay.acos, lambda x, g: g * (-1.0 / (1.0 - x ** 2.0) ** (1.0 / 2.0))),
(tvm.relay.acosh, lambda x, g: g * (1.0 / (x ** 2 - 1.0) ** (1.0 / 2.0))),
(tvm.relay.asinh, lambda x, g: g * (1.0 / (x ** 2 + 1.0) ** (1.0 / 2.0))),
(tvm.relay.atanh, lambda x, g: g * (-1.0 / (x ** 2 - 1.0))),
]:
for dtype in ("float32", "float64"):
check_single_op(opfunc, ref, dtype)
Expand Down

0 comments on commit 816c99c

Please sign in to comment.