Skip to content

Commit

Permalink
missed a reference to relay.dyn.reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Jun 29, 2020
1 parent 7d11ffc commit 4737479
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tests/python/relay/test_pass_dynamic_to_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_dynamic_to_static_reshape():
def verify_reshape(shape, newshape, oshape):
x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("y", relay.TensorType(newshape, "float32"))
z = relay.dyn.reshape(x, relay.shape_of(y))
z = relay.reshape(x, relay.shape_of(y))
func = run_infer_type(relay.Function([x, y], z))
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())

Expand All @@ -66,8 +66,8 @@ def test_dynamic_to_static_double_reshape():
def verify_reshape(shape, newshape):
x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("y", relay.TensorType(newshape, "float32"))
z = relay.dyn.reshape(x, relay.shape_of(y))
z = relay.dyn.reshape(z, relay.shape_of(x))
z = relay.reshape(x, relay.shape_of(y))
z = relay.reshape(z, relay.shape_of(x))
func = run_infer_type(relay.Function([x, y], z))
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())

Expand All @@ -88,10 +88,10 @@ def test_dynamic_to_static_quad_reshape():
def verify_reshape(shape, newshape):
x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("y", relay.TensorType(newshape, "float32"))
z1 = relay.dyn.reshape(x, relay.shape_of(y))
z2 = relay.dyn.reshape(z1, relay.shape_of(x))
z3 = relay.dyn.reshape(z2, relay.shape_of(z1))
z4 = relay.dyn.reshape(z3, relay.shape_of(z2))
z1 = relay.reshape(x, relay.shape_of(y))
z2 = relay.reshape(z1, relay.shape_of(x))
z3 = relay.reshape(z2, relay.shape_of(z1))
z4 = relay.reshape(z3, relay.shape_of(z2))
func = run_infer_type(relay.Function([x, y], z4))
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())

Expand Down

0 comments on commit 4737479

Please sign in to comment.