Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame committed Jul 2, 2019
1 parent 1e8ca97 commit d69ca11
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 31 deletions.
14 changes: 14 additions & 0 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,17 @@
from .config import ctx_list
from .init import create_workload
from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
import tvm.relay as relay
from tvm.relay import transform


def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr)
mod = opt_pass(mod)
entry = mod[mod.entry_func]
return entry if isinstance(expr, relay.Function) else entry.body


def run_infer_type(expr):
return run_opt_pass(expr, transform.InferType())
8 changes: 0 additions & 8 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,6 @@
from tvm.relay import transform


def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr)
mod = opt_pass(mod)
entry = mod[mod.entry_func]
return entry if isinstance(expr, relay.Function) else entry.body


def test_fuse_simple():
"""Simple testcase."""
def before():
Expand Down
10 changes: 1 addition & 9 deletions tests/python/relay/test_pass_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,7 @@
from tvm.relay import create_executor, transform
from tvm.relay.transform import gradient
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr


def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod[mod.entry_func]
return entry if isinstance(expr, relay.Function) else entry.body
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type


def rand(dtype='float32', *shape):
Expand Down
27 changes: 13 additions & 14 deletions tests/python/relay/test_pass_to_cps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import alpha_equal, infer_type, detect_feature
from tvm.relay.ir_pass import to_cps, un_cps
from tvm.relay.analysis import alpha_equal, detect_feature
from tvm.relay.transform import to_cps, un_cps
from tvm.relay.feature import Feature
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, run_opt_pass
from tvm.relay import create_executor
from tvm.relay import Function, transform

Expand All @@ -42,13 +42,12 @@ def test_recursion():
double = relay.Function([x], x + x)
i = relay.var("i", t)
func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
func = infer_type(func, mod=mod)
cps_func = infer_type(un_cps(infer_type(to_cps(func, mod=mod), mod=mod)), mod=mod)
print(mod)
print(cps_func)
mod[mod.entry_func] = func
mod[mod.entry_func] = to_cps(mod[mod.entry_func], mod=mod)
mod[mod.entry_func] = un_cps(mod[mod.entry_func])
ex = create_executor(mod=mod)
i_nd = rand(dtype, *shape)
forward = ex.evaluate(cps_func)(i_nd)
forward = ex.evaluate(mod.entry_func)(i_nd)
tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())


Expand All @@ -57,12 +56,12 @@ def test_recursion():
# cps and pe can completely eliminate the allocation of reference.
def test_cps_pe():
def destroy_ref(x):
x = infer_type(x)
x = run_infer_type(x)
x = to_cps(x)
x = infer_type(x)
x = run_infer_type(x)
y = un_cps(x)
y = infer_type(y)
x = transform.OptimizeOnExpr(x, [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)])
y = run_infer_type(y)
x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
assert Feature.fRefCreate not in detect_feature(x)
unit = relay.Function([], relay.const(0., dtype='float32'))
f_ref = relay.Var("f_ref")
Expand All @@ -82,7 +81,7 @@ def destroy_ref(x):
destroy_ref(F)

G = relay.Function([cond], relay.If(cond, one, two))
G = relay.ir_pass.gradient(G)
G = relay.transform.gradient(G)
destroy_ref(G)

x = relay.var("x", shape=(1, 16))
Expand All @@ -92,7 +91,7 @@ def destroy_ref(x):
H = relay.If(cond, x, y)
H = relay.add(H, z)
H = relay.Function([cond,x,y,z], H)
H = relay.ir_pass.gradient(H)
H = relay.transform.gradient(H)
destroy_ref(H)


Expand Down

0 comments on commit d69ca11

Please sign in to comment.