Skip to content

Commit

Permalink
Do not mutate GlobalVar's checked_type field. (apache#2026)
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch authored and tqchen committed Oct 29, 2018
1 parent d79633a commit d915318
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ class TypeInferencer::Resolver : public ExprMutator {
}

Expr VisitExpr_(const GlobalVarNode* op) final {
return AttachCheckedType(op);
return GetRef<GlobalVar>(op);
}

Expr VisitExpr_(const OpNode* op) final {
Expand Down
11 changes: 11 additions & 0 deletions tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,16 @@ def f(x) {
assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a)
assert relay.ir_pass.infer_type(fx).checked_type == a

def test_global_var_cow_issue():
env = relay.env.Environment({})
gv = relay.GlobalVar("foo")
x = relay.var('x', shape=[])
func = relay.Function([x], relay.Call(gv, [x]), relay.TensorType([], 'float32'))
env[gv] = func
# They should both point to the same global variable if global variables are
# stable across type checking.
assert gv == func.body.op

if __name__ == "__main__":
test_free_expr()
test_dual_op()
Expand All @@ -134,3 +144,4 @@ def f(x) {
test_free_expr()
test_type_args()
test_self_reference()
test_global_var_cow_issue()

0 comments on commit d915318

Please sign in to comment.