Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Make check stricter: disallow inserting function with free vars into module. #6313

Merged
merged 2 commits into from
Aug 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
from . import op, transform

from .analysis import free_vars

def get_tensor_array_shape(expr, dtype, prelude):
"""Get the static shape of a tensor array if it has fixed rank shape.
Expand All @@ -51,7 +51,7 @@ def get_tensor_array_shape(expr, dtype, prelude):
has dynamic shape.
"""
mod = prelude.mod
mod["main"] = Function([], expr)
mod["main"] = Function(free_vars(expr), expr)
mod = transform.InferType()(mod)
checked_type = mod["main"].body.checked_type
assert isinstance(checked_type, TypeCall), "Input must be a tensor array."
Expand Down
14 changes: 4 additions & 10 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,10 @@ relay::Function RunTypeCheck(const IRModule& mod, const GlobalVar& var, relay::F
// Type check the item before we add it to the module.
auto fv = relay::FreeVars(func);
auto ftv = relay::FreeTypeVars(func, mod);
if (fv.size() != 0) {
LOG(WARNING) << "There are free variables: " << fv << " in function: " << AsText(func, false)
<< std::endl;
}
if (ftv.size() != 0) {
LOG(WARNING) << "There are free type variables: " << ftv
<< " in function: " << AsText(func, false) << std::endl;
}
func = relay::Function(concat(func->params, fv), func->body, func->ret_type,
concat(func->type_params, ftv), func->attrs);
CHECK_EQ(fv.size(), 0) << "There are free variables: " << fv
<< " in function: " << AsText(func, false);
CHECK_EQ(ftv.size(), 0) << "There are free type variables: " << fv
<< " in function: " << AsText(func, false);
// Type check the item before we add it to the module.
relay::Function checked_func = InferType(func, mod, var);
return checked_func;
Expand Down
140 changes: 1 addition & 139 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3852,143 +3852,5 @@ def lstm_cell():
tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)


#######################################################################
# Main
# ----
if __name__ == '__main__':
# Transforms
test_forward_slice()
test_forward_transpose()
test_forward_reshape()
test_forward_depthtospace()
test_forward_spacetodepth()
test_forward_squeeze()
test_forward_pack()
test_forward_size()
test_forward_broadcast_to()
test_forward_fill()
test_forward_crop()
test_forward_resize()
test_forward_crop_and_resize()
test_forward_pad()
test_forward_unpack()
test_forward_gather()
test_forward_gather_nd()
test_forward_stridedslice()
test_forward_split()
test_forward_unstack()
test_forward_tile()
test_forward_top_k_v2()
test_forward_clip_by_value()
test_forward_maximum()
test_forward_minimum()
test_forward_range()
test_forward_right_shift()
test_forward_left_shift()
test_forward_truncatemod()
test_forward_one_hot()
test_forward_atan2()
test_forward_nms()

# Activations
test_forward_sigmoid()
test_forward_relu()
test_forward_leaky_relu()
test_forward_elu()
test_forward_selu()
test_forward_tanh()

# Tensor
test_forward_round()
test_forward_reverse_v2()
test_forward_pow_exp()
test_forward_sign()
test_forward_negative()
test_forward_divide()
test_forward_abs()
test_forward_softplus()
test_forward_sqrt()
test_forward_rsqrt()
test_forward_expand_dims()
test_forward_square()
test_forward_softmax()
test_forward_log_softmax()
test_forward_bias_add()
test_forward_zeros_like()
test_forward_squared_difference()
test_forward_add_n()
test_forward_floormod()
test_forward_isfinite()
test_forward_isinf()
test_forward_unravel_index()
test_forward_unary()

# Reductions
test_forward_argminmax()
test_forward_reduce()
test_forward_mean()

# TensorArray
test_tensor_array_write_read()
test_tensor_array_concat()
test_tensor_array_scatter()
test_tensor_array_gather()
test_tensor_array_size()
test_tensor_array_split()
test_tensor_array_stack()
test_tensor_array_unstack()

# General
test_forward_multi_input()
test_forward_multi_output()
test_forward_variable()
test_placeholder()

# NN
test_forward_convolution()
test_forward_convolution3d()
test_forward_convolution3d_transpose()
test_forward_pooling()
test_forward_concat_v2()
test_forward_lrn()
test_forward_l2_normalize()
test_forward_space_to_batch_nd()
test_forward_batch_to_space_nd()
test_forward_dilation()

# End to End
test_forward_inception_v3()
test_forward_inception_v1()
test_forward_mobilenet()
test_forward_resnetv2()
test_forward_ssd()
test_forward_placeholder()
test_forward_ptb()

# RNN
test_forward_lstm()

# Elementwise
test_forward_ceil()
test_forward_floor()

# Relational ops
test_forward_rel_ops()
test_forward_logical()
test_forward_where()
test_forward_matmul()
test_forward_batch_matmul()

# Internal misc. ops
test_read_variable_op()

# Sharing params case using Mean ops
test_sharing_node()

# StatefulPartitionedCall
test_forward_spop()

# Test dynamic input shape
test_forward_dynamic_input_shape()

test_forward_dynmaic_rnn_lstmblockcell()
pytest.main([__file__])
23 changes: 3 additions & 20 deletions tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
"""Test that type checker correcly computes types
for expressions.
"""
import pytest
import tvm
from tvm import te
from tvm import relay
from tvm.relay import op, transform, analysis
from tvm.relay import Any


def run_infer_type(expr, mod=None):
if not mod:
mod = tvm.IRModule.from_expr(expr)
Expand Down Expand Up @@ -368,26 +368,9 @@ def test_if():
f = relay.Var('f', choice_t)
true_branch = relay.Var('True', relay.TensorType([Any(), 1], dtype='float32'))
false_branch = relay.Var('False', relay.TensorType([Any(), Any()], dtype='float32'))
top = relay.Function([true_branch, false_branch], relay.If(f(), true_branch, false_branch))
top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch))
ft = run_infer_type(top)
tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype='float32'))

if __name__ == "__main__":
test_free_expr()
test_dual_op()
test_single_op()
test_recursion()
test_monomorphic_let()
test_decl()
test_recursion()
test_tuple()
test_incomplete_call()
test_type_args()
test_global_var_recursion()
test_equal()
test_ref()
test_constructor_type()
test_constructor_call()
test_adt_match()
test_let_polymorphism()
test_if()
pytest.main([__file__])
16 changes: 3 additions & 13 deletions tests/python/relay/test_vm_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name, missing-docstring, no-else-return
"""Unit tests for the Relay VM serialization and deserialization."""
import pytest
import numpy as np

import tvm
Expand Down Expand Up @@ -291,22 +292,11 @@ def test_vm_shape_of():

newshape_var = relay.var('newshape', shape=(2,), dtype='int64')
args.append(np.array((1, -1), dtype='int64'))
main = relay.reshape(relu_x, newshape=newshape_var)
main = relay.Function([x, newshape_var], relay.reshape(relu_x, newshape=newshape_var))

res = get_serialized_output(main, *args).asnumpy()
tvm.testing.assert_allclose(res.flatten(), data.flatten())


if __name__ == "__main__":
test_serializer()
test_save_load()
test_const()
test_if()
test_loop()
test_tuple()
test_adt_list()
test_adt_compose()
test_closure()
test_synthetic()
test_mobilenet()
test_vm_shape_of()
pytest.main([__file__])
2 changes: 1 addition & 1 deletion tutorials/dev/use_pass_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def example():
z = relay.add(y, c)
z1 = relay.add(y, c)
z2 = relay.add(z, z1)
return relay.Function([x], z2)
return relay.Function([x, weight], z2)

###############################################################################
# Let us register layout alteration for a conv2d op so that we can apply the
Expand Down