From 4b2ccded4d6a4788c176253fa9037091c6019b11 Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Mon, 19 Jul 2021 04:55:04 -0700 Subject: [PATCH 1/2] src/runtime/module.cc (#8496) --- src/runtime/module.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/runtime/module.cc b/src/runtime/module.cc index acc7fc7286d1..cff65452671e 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -113,7 +113,10 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { if (pf == nullptr) { const PackedFunc* f = Registry::Get(name); ICHECK(f != nullptr) << "Cannot find function " << name - << " in the imported modules or global registry"; + << " in the imported modules or global registry." + << " If this involves ops from a contrib library like" + << " cuDNN, ensure TVM was built with the relevant" + << " library."; return f; } else { import_cache_.insert(std::make_pair(name, std::make_shared(pf))); From 6d88bdd3ea36d9e3167eb92e167bd05f997c7fb3 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 19 Jul 2021 09:18:49 -0600 Subject: [PATCH 2/2] Update Docker CI (#8193) * add failing onnx tets * point jenkins at new docker * support convtranspose opset 11 autopadding * Don't force output shape for conv transpose tests, add 1D and 3D cases * disable test until CI update complete * try updating docker images again * skip a test until update complete * next try at docker images * manage TF memory use in TF1 tests * support explicit padding for NCHW TF padding test * Update to tagged tlcpack images Thanks, Andrew! Co-authored-by: Andrew Reusch Co-authored-by: Andrew Reusch --- Jenkinsfile | 8 +- python/tvm/relay/frontend/onnx.py | 64 ++++ tests/python/frontend/onnx/test_forward.py | 311 +++++++++++++++--- .../frontend/tensorflow/test_forward.py | 9 + 4 files changed, 340 insertions(+), 52 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index f26b148085fb..65ccbf27326f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -45,12 +45,12 @@ // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> ci_lint = "tlcpack/ci-lint:v0.66" -ci_gpu = "tlcpack/ci-gpu:v0.75" -ci_cpu = "tlcpack/ci-cpu:v0.74" +ci_gpu = "tlcpack/ci-gpu:v0.76" +ci_cpu = "tlcpack/ci-cpu:v0.75" ci_wasm = "tlcpack/ci-wasm:v0.71" ci_i386 = "tlcpack/ci-i386:v0.73" -ci_qemu = "tlcpack/ci-qemu:v0.05" -ci_arm = "tlcpack/ci-arm:v0.05" +ci_qemu = "tlcpack/ci-qemu:v0.06" +ci_arm = "tlcpack/ci-arm:v0.06" // <--- End of regex-scanned config. // tvm libraries diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index cea538f3ea85..42bde838859a 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -585,6 +585,70 @@ def _impl_v1(cls, inputs, attr, params): out = _op.nn.bias_add(out, inputs[2]) return out + @classmethod + def _impl_v11(cls, inputs, attr, params): + # get number of channels + out_type = infer_type(inputs[1]) + out_shapes = [get_const_tuple(out_type.checked_type.shape)] + channels = out_shapes[0][1] + attr["channels"] = channels + groups = attr.get("group", 1) + + if "kernel_shape" not in attr: + attr["kernel_shape"] = out_shapes[0][2:] + + attr["groups"] = groups + # infer pads for auto_pad + data = inputs[0] + input_shape = infer_shape(data) + ndim = len(input_shape) + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): + # Warning: Convolution does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import + kernel_shape = attr["kernel_shape"] + kndim = len(kernel_shape) + dilations = attr.get("dilations", [1] * kndim) + output_padding = attr.get("output_padding", [0] * kndim) + strides = attr["strides"] + total_pad = [0] * kndim + for i in range(kndim): + total_pad[i] = ( + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - strides[i] + ) + left = [p // 2 for p in total_pad] + right = [total_pad[i] - left[i] for i in range(kndim)] + if "LOWER" in attr["auto_pad"]: + pad = left + right + else: + pad = right + left + attr["pads"] = pad + elif attr["auto_pad"] == "VALID": + attr["pads"] = tuple([0 for i in range(ndim - 2)]) + elif attr["auto_pad"] == "NOTSET": + pass + else: + msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) + attr.pop("auto_pad") + + out = AttrCvt( + op_name=dimension_picker("conv", "_transpose"), + transforms={ + "kernel_shape": "kernel_size", + "dilations": ("dilation", 1), + "pads": ("padding", 0), + "group": ("groups", 1), + }, + disables=["output_shape"], + custom_check=dimension_constraint(), + )([data, inputs[1]], attr, params) + use_bias = len(inputs) == 3 + if use_bias: + out = _op.nn.bias_add(out, inputs[2]) + return out + class GlobalAveragePool(OnnxOpConverter): """Operator converter for GlobalAveragePool""" diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index c891088807f5..9328a82271d7 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1117,7 +1117,14 @@ def verify_gemm(a_shape, b_shape, c_shape=None, freeze_params=False, dtype="floa ) model = helper.make_model(graph, producer_name="gemm_test") - verify_with_ort_with_inputs(model, input_values, freeze_params=freeze_params, dtype=dtype) + atol = 1e-5 + rtol = 1e-5 + if dtype == "float16": + atol = 1e-3 + rtol = 1e-3 + verify_with_ort_with_inputs( + model, input_values, freeze_params=freeze_params, dtype=dtype, atol=atol, rtol=rtol + ) @tvm.testing.uses_gpu @@ -2583,7 +2590,6 @@ def repeat(N, D): def verify_convtranspose_with_padding( x_shape, w_shape, - y_shape, padding, kernel_shape, strides, @@ -2619,12 +2625,12 @@ def verify_convtranspose_with_padding( helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)), ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, ["?"] * len(x_shape))], ) model = helper.make_model(graph, producer_name="convtranspose_pad_test") - verify_with_ort(model, [x_shape, w_shape], [y_shape], use_vm=True, convert_to_static=True) + verify_with_ort(model, [x_shape, w_shape], use_vm=True, convert_to_static=True) def verify_convtranspose(x_shape, w_shape, y_shape, p, group=1): @@ -2652,7 +2658,7 @@ def verify_convtranspose(x_shape, w_shape, y_shape, p, group=1): ) model = helper.make_model(graph, producer_name="convtranspose_test") - verify_with_ort(model, [x_shape, w_shape], y_shape) + verify_with_ort(model, [x_shape, w_shape], y_shape, opset=11) @tvm.testing.uses_gpu @@ -2669,14 +2675,13 @@ def test_convtranspose(): def repeat(N, D): return tuple([N for _ in range(D)]) - # TODO(mbrookhart): onnxruntime in CI only supports 2D, - # find something else to test 1D and 3D against + # TODO(mbrookhart): onnxruntime in CI only supports 2D, and 1D and 3D + # Once onnxruntime update is complete for D in [2]: # Convolution with padding verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(5, D), 2 * repeat(1, D), repeat(3, D), repeat(1, D), @@ -2686,62 +2691,58 @@ def repeat(N, D): verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(7, D), 2 * repeat(0, D), repeat(3, D), repeat(1, D), repeat(1, D), ) - # Convolution with autopadding - verify_convtranspose_with_padding( - (1, 1) + repeat(5, D), - (1, 1) + repeat(3, D), - (1, 1) + repeat(5, D), - None, - repeat(3, D), - repeat(1, D), - repeat(1, D), - auto_pad="SAME_UPPER", - ) - # Convolution with valid autopadding - verify_convtranspose_with_padding( - (1, 1) + repeat(5, D), - (1, 1) + repeat(3, D), - (1, 1) + repeat(7, D), - None, - repeat(3, D), - repeat(1, D), - repeat(1, D), - auto_pad="VALID", - ) # Convolution with unset padding verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(7, D), 2 * repeat(0, D), repeat(3, D), repeat(1, D), repeat(1, D), True, ) - # Convolution with non uniform stride - verify_convtranspose_with_padding( - (1, 1) + repeat(5, D), - (1, 1) + repeat(3, D), - (1, 1) + repeat(9, D), - None, - repeat(3, D), - repeat(2, D), - repeat(1, D), - auto_pad="SAME_UPPER", - ) + ## TODO(mbrookhart): renable autopad tests when CI ONNX + ## and ONNX runtime match versions + # # Convolution with autopadding + # verify_convtranspose_with_padding( + # (1, 1) + repeat(5, D), + # (1, 1) + repeat(3, D), + # None, + # repeat(3, D), + # repeat(1, D), + # repeat(1, D), + # auto_pad="SAME_UPPER", + # ) + # # Convolution with valid autopadding + # verify_convtranspose_with_padding( + # (1, 1) + repeat(5, D), + # (1, 1) + repeat(3, D), + # None, + # repeat(3, D), + # repeat(1, D), + # repeat(1, D), + # auto_pad="VALID", + # ) + # # Convolution with non uniform stride + # verify_convtranspose_with_padding( + # (1, 1) + repeat(5, D), + # (1, 1) + repeat(3, D), + # None, + # repeat(3, D), + # repeat(2, D), + # repeat(1, D), + # auto_pad="SAME_UPPER", + # ) # Convolution with dilation # TODO(mbrookhart): Relay doesn't currently support convtranspose with dilation # verify_convtranspose_with_padding( # (1, 1) + repeat(5, D), # (1, 1) + repeat(3, D), - # (1, 1) + repeat(5, D), # 2 * repeat(2, D), # repeat(3, D), # repeat(1, D), @@ -3985,7 +3986,7 @@ def verify_cond_loop(): trip_count = np.array(40).astype(np.int64) cond = np.array(1).astype(bool) input_vals = [trip_count, cond, y] - verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True) + verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True, opset=11) def verify_count_loop(): @@ -4040,7 +4041,7 @@ def verify_count_loop(): trip_count = np.array(5).astype(np.int64) cond = np.array(1).astype(bool) input_vals = [trip_count, cond, y] - verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True) + verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True, opset=11) def verify_tensor_loop(shapeless_output=False): @@ -4102,7 +4103,7 @@ def verify_tensor_loop(shapeless_output=False): cond = np.array(1).astype(bool) input_vals = [trip_count, cond, y] verify_with_ort_with_inputs( - loop_model, input_vals, use_vm=True, freeze_params=True, convert_to_static=True + loop_model, input_vals, use_vm=True, freeze_params=True, convert_to_static=True, opset=11 ) @@ -4429,9 +4430,32 @@ def verify_eyelike(indata): onnx_test_folders = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) unsupported_onnx_tests = [ + "test_adagrad/", + "test_adagrad_multiple/", + "test_adam/", + "test_adam_multiple/", + "test_argmax_default_axis_example_select_last_index/", + "test_argmax_default_axis_random_select_last_index/", + "test_argmax_keepdims_example_select_last_index/", + "test_argmax_keepdims_random_select_last_index/", + "test_argmax_negative_axis_keepdims_example_select_last_index/", + "test_argmax_negative_axis_keepdims_random_select_last_index/", + "test_argmax_no_keepdims_example_select_last_index/", + "test_argmax_no_keepdims_random_select_last_index/", + "test_argmin_default_axis_example_select_last_index/", + "test_argmin_default_axis_random_select_last_index/", + "test_argmin_keepdims_example_select_last_index/", + "test_argmin_keepdims_random_select_last_index/", + "test_argmin_negative_axis_keepdims_example_select_last_index/", + "test_argmin_negative_axis_keepdims_random_select_last_index/", + "test_argmin_no_keepdims_example_select_last_index/", + "test_argmin_no_keepdims_random_select_last_index/", + "test_cast_BFLOAT16_to_FLOAT/", "test_cast_DOUBLE_to_FLOAT16/", + "test_cast_FLOAT_to_BFLOAT16/", "test_cast_FLOAT_to_STRING/", "test_cast_STRING_to_FLOAT/", + "test_celu/", "test_compress_0/", "test_compress_1/", "test_compress_default_axis/", @@ -4447,17 +4471,109 @@ def verify_eyelike(indata): "test_cumsum_2d_negative_axis/", "test_det_2d/", "test_det_nd/", + "test_dropout_default/", + "test_dropout_default_mask/", + "test_dropout_default_mask_ratio/", + "test_dropout_default_ratio/", + "test_einsum_batch_diagonal/", + "test_einsum_batch_matmul/", + "test_einsum_inner_prod/", + "test_einsum_sum/", + "test_einsum_transpose/", + "test_greater_equal/", + "test_greater_equal_bcast/", + "test_hardmax_axis_0/", + "test_hardmax_axis_1/", + "test_hardmax_default_axis/", + "test_if_seq/", + "test_less_equal/", + "test_less_equal_bcast/", + "test_logsoftmax_axis_0/", + "test_logsoftmax_axis_0_expanded/", + "test_logsoftmax_axis_1/", + "test_logsoftmax_axis_1_expanded/", + "test_logsoftmax_axis_2_expanded/", + "test_logsoftmax_default_axis/", + "test_logsoftmax_default_axis_expanded/", + "test_logsoftmax_example_1_expanded/", + "test_logsoftmax_large_number_expanded/", + "test_logsoftmax_negative_axis_expanded/", + "test_loop11/", + "test_loop13_seq/", "test_matmulinteger/", "test_maxpool_2d_same_lower/", "test_maxpool_2d_same_upper/", "test_maxpool_with_argmax_2d_precomputed_pads/", "test_maxpool_with_argmax_2d_precomputed_strides/", "test_maxunpool_export_with_output_shape/", + "test_momentum/", + "test_momentum_multiple/", "test_mvn/", + "test_nesterov_momentum/", + "test_nllloss_NC/", + "test_nllloss_NC_expanded/", + "test_nllloss_NCd1/", + "test_nllloss_NCd1_expanded/", + "test_nllloss_NCd1_ii/", + "test_nllloss_NCd1_ii_expanded/", + "test_nllloss_NCd1_mean_weight_negative_ii/", + "test_nllloss_NCd1_mean_weight_negative_ii_expanded/", + "test_nllloss_NCd1_weight/", + "test_nllloss_NCd1_weight_expanded/", + "test_nllloss_NCd1_weight_ii/", + "test_nllloss_NCd1_weight_ii_expanded/", + "test_nllloss_NCd1d2/", + "test_nllloss_NCd1d2_expanded/", + "test_nllloss_NCd1d2_no_weight_reduction_mean_ii/", + "test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded/", + "test_nllloss_NCd1d2_reduction_mean/", + "test_nllloss_NCd1d2_reduction_mean_expanded/", + "test_nllloss_NCd1d2_reduction_sum/", + "test_nllloss_NCd1d2_reduction_sum_expanded/", + "test_nllloss_NCd1d2_with_weight/", + "test_nllloss_NCd1d2_with_weight_expanded/", + "test_nllloss_NCd1d2_with_weight_reduction_mean/", + "test_nllloss_NCd1d2_with_weight_reduction_mean_expanded/", + "test_nllloss_NCd1d2_with_weight_reduction_sum/", + "test_nllloss_NCd1d2_with_weight_reduction_sum_expanded/", + "test_nllloss_NCd1d2_with_weight_reduction_sum_ii/", + "test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded/", + "test_nllloss_NCd1d2d3_none_no_weight_negative_ii/", + "test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded/", + "test_nllloss_NCd1d2d3_sum_weight_high_ii/", + "test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded/", + "test_nllloss_NCd1d2d3d4d5_mean_weight/", + "test_nllloss_NCd1d2d3d4d5_mean_weight_expanded/", + "test_nllloss_NCd1d2d3d4d5_none_no_weight/", + "test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded/", + "test_pow_types_float/", + "test_pow_types_float32_int32/", + "test_pow_types_float32_int64/", + "test_pow_types_float32_uint32/", + "test_pow_types_float32_uint64/", + "test_pow_types_int/", + "test_pow_types_int32_float32/", + "test_pow_types_int32_int32/", + "test_pow_types_int64_float32/", + "test_pow_types_int64_int64/", "test_qlinearmatmul_2D/", "test_qlinearmatmul_3D/", + "test_reduce_sum_default_axes_keepdims_example/", + "test_reduce_sum_default_axes_keepdims_random/", + "test_reduce_sum_do_not_keepdims_example/", + "test_reduce_sum_do_not_keepdims_random/", + "test_reduce_sum_empty_axes_input_noop_example/", + "test_reduce_sum_empty_axes_input_noop_random/", + "test_reduce_sum_keepdims_example/", + "test_reduce_sum_keepdims_random/", + "test_reduce_sum_negative_axes_keepdims_example/", + "test_reduce_sum_negative_axes_keepdims_random/", + "test_resize_downsample_sizes_cubic/", + "test_resize_downsample_sizes_linear_pytorch_half_pixel/", + "test_resize_downsample_sizes_nearest/", "test_resize_tf_crop_and_resize/", - ## For these three tests, ONNX 1.6.0 has incorrect graphs, they pass with ONNX 1.7.0 + "test_resize_upsample_sizes_cubic/", + "test_resize_upsample_sizes_nearest/", "test_resize_upsample_sizes_nearest_ceil_half_pixel/", "test_resize_upsample_sizes_nearest_floor_align_corners/", "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/", @@ -4465,8 +4581,94 @@ def verify_eyelike(indata): "test_round/", "test_scan9_sum/", "test_scan_sum/", + "test_sce_NCd1_mean_weight_negative_ii/", + "test_sce_NCd1_mean_weight_negative_ii_expanded/", + "test_sce_NCd1_mean_weight_negative_ii_log_prob/", + "test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded/", + "test_sce_NCd1d2d3_none_no_weight_negative_ii/", + "test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded/", + "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob/", + "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded/", + "test_sce_NCd1d2d3_sum_weight_high_ii/", + "test_sce_NCd1d2d3_sum_weight_high_ii_expanded/", + "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob/", + "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded/", + "test_sce_NCd1d2d3d4d5_mean_weight/", + "test_sce_NCd1d2d3d4d5_mean_weight_expanded/", + "test_sce_NCd1d2d3d4d5_mean_weight_log_prob/", + "test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded/", + "test_sce_NCd1d2d3d4d5_none_no_weight/", + "test_sce_NCd1d2d3d4d5_none_no_weight_expanded/", + "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob/", + "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded/", + "test_sce_mean/", + "test_sce_mean_3d/", + "test_sce_mean_3d_expanded/", + "test_sce_mean_3d_log_prob/", + "test_sce_mean_3d_log_prob_expanded/", + "test_sce_mean_expanded/", + "test_sce_mean_log_prob/", + "test_sce_mean_log_prob_expanded/", + "test_sce_mean_no_weight_ii/", + "test_sce_mean_no_weight_ii_3d/", + "test_sce_mean_no_weight_ii_3d_expanded/", + "test_sce_mean_no_weight_ii_3d_log_prob/", + "test_sce_mean_no_weight_ii_3d_log_prob_expanded/", + "test_sce_mean_no_weight_ii_4d/", + "test_sce_mean_no_weight_ii_4d_expanded/", + "test_sce_mean_no_weight_ii_4d_log_prob/", + "test_sce_mean_no_weight_ii_4d_log_prob_expanded/", + "test_sce_mean_no_weight_ii_expanded/", + "test_sce_mean_no_weight_ii_log_prob/", + "test_sce_mean_no_weight_ii_log_prob_expanded/", + "test_sce_mean_weight/", + "test_sce_mean_weight_expanded/", + "test_sce_mean_weight_ii/", + "test_sce_mean_weight_ii_3d/", + "test_sce_mean_weight_ii_3d_expanded/", + "test_sce_mean_weight_ii_3d_log_prob/", + "test_sce_mean_weight_ii_3d_log_prob_expanded/", + "test_sce_mean_weight_ii_4d/", + "test_sce_mean_weight_ii_4d_expanded/", + "test_sce_mean_weight_ii_4d_log_prob/", + "test_sce_mean_weight_ii_4d_log_prob_expanded/", + "test_sce_mean_weight_ii_expanded/", + "test_sce_mean_weight_ii_log_prob/", + "test_sce_mean_weight_ii_log_prob_expanded/", + "test_sce_mean_weight_log_prob/", + "test_sce_mean_weight_log_prob_expanded/", + "test_sce_none/", + "test_sce_none_expanded/", + "test_sce_none_log_prob/", + "test_sce_none_log_prob_expanded/", + "test_sce_none_weights/", + "test_sce_none_weights_expanded/", + "test_sce_none_weights_log_prob/", + "test_sce_none_weights_log_prob_expanded/", + "test_sce_sum/", + "test_sce_sum_expanded/", + "test_sce_sum_log_prob/", + "test_sce_sum_log_prob_expanded/", + "test_sequence_insert_at_back/", + "test_sequence_insert_at_front/", "test_simple_rnn_defaults/", "test_simple_rnn_with_initial_bias/", + "test_softmax_axis_0/", + "test_softmax_axis_0_expanded/", + "test_softmax_axis_1/", + "test_softmax_axis_1_expanded/", + "test_softmax_axis_2_expanded/", + "test_softmax_default_axis/", + "test_softmax_default_axis_expanded/", + "test_softmax_example_expanded/", + "test_softmax_large_number_expanded/", + "test_softmax_negative_axis_expanded/", + "test_split_variable_parts_1d/", + "test_split_variable_parts_2d/", + "test_split_variable_parts_default_axis/", + "test_split_zero_size_splits/", + "test_squeeze/", + "test_squeeze_negative_axes/", "test_strnormalizer_export_monday_casesensintive_lower/", "test_strnormalizer_export_monday_casesensintive_nochangecase/", "test_strnormalizer_export_monday_casesensintive_upper/", @@ -4480,9 +4682,22 @@ def verify_eyelike(indata): "test_tfidfvectorizer_tf_onlybigrams_levelempty/", "test_tfidfvectorizer_tf_onlybigrams_skip5/", "test_tfidfvectorizer_tf_uniandbigrams_skip5/", + "test_training_dropout/", + "test_training_dropout_default/", + "test_training_dropout_default_mask/", + "test_training_dropout_mask/", + "test_training_dropout_zero_ratio/", + "test_training_dropout_zero_ratio_mask/", "test_unique_sorted_with_axis/", "test_unique_sorted_with_axis_3d/", "test_unique_sorted_with_negative_axis/", + "test_unsqueeze_axis_0/", + "test_unsqueeze_axis_1/", + "test_unsqueeze_axis_2/", + "test_unsqueeze_negative_axes/", + "test_unsqueeze_three_axes/", + "test_unsqueeze_two_axes/", + "test_unsqueeze_unsorted_axes/", "test_upsample_nearest/", ] diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 9bbb6ca9364e..a57c402d212c 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -29,6 +29,13 @@ import tensorflow.compat.v1 as tf except ImportError: import tensorflow as tf + +# Only allow TF to run on half the GPU RAM to save the other half +# For TVM +gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5) +sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) +sess.close() + from tensorflow.python.framework import constant_op from tensorflow.python.framework import graph_util from tensorflow.python.ops import nn_ops @@ -318,6 +325,8 @@ def _test_pooling(input_shape, **kwargs): if is_gpu_available(): if len(input_shape) == 4: input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] + if isinstance(kwargs["padding"], list): + kwargs["padding"] = [kwargs["padding"][ii] for ii in (0, 3, 1, 2)] kwargs["data_format"] = "NCHW" _test_pooling_iteration(input_shape, **kwargs)