diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 3f7a96544a65..8c8a4a1ddcd3 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -635,9 +635,11 @@ def _convert_pooling( _op.nn.global_max_pool2d(inexpr, **global_pool_params), keras_layer, etab, data_layout ) if pool_type == "GlobalAveragePooling2D": - return _convert_flatten( - _op.nn.global_avg_pool2d(inexpr, **global_pool_params), keras_layer, etab, data_layout - ) + global_avg_pool2d = _op.nn.global_avg_pool2d(inexpr, **global_pool_params) + keep_dims = len(keras_layer.input.shape) == len(keras_layer.output.shape) + if keep_dims: + return global_avg_pool2d + return _convert_flatten(global_avg_pool2d, keras_layer, etab, data_layout) pool_h, pool_w = keras_layer.pool_size stride_h, stride_w = keras_layer.strides params = { diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 9121721d8ea2..7267b725483d 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -963,6 +963,10 @@ def representative_data_gen(): input_node = subgraph.Tensors(model_input).Name().decode("utf-8") tflite_output = run_tflite_graph(tflite_model_quant, data) + if tf.__version__ < LooseVersion("2.9"): + input_node = data_in.name.replace(":0", "") + else: + input_node = "serving_default_" + data_in.name + ":0" tvm_output = run_tvm_graph(tflite_model_quant, data, input_node) tvm.testing.assert_allclose( np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2 @@ -1997,10 +2001,12 @@ def _test_abs(data, quantized, int_quant_dtype=tf.int8): # TFLite 2.6.x upgrade support if tf.__version__ < LooseVersion("2.6.1"): in_node = ["serving_default_input_int8"] - else: + elif tf.__version__ < LooseVersion("2.9"): in_node = ( ["serving_default_input_int16"] if int_quant_dtype == tf.int16 else ["tfl.quantize"] ) + else: + in_node = "serving_default_input" tvm_output = run_tvm_graph(tflite_model_quant, data, in_node) tvm.testing.assert_allclose( @@ -2028,8 +2034,10 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8): tf.math.rsqrt, data, int_quant_dtype=int_quant_dtype ) tflite_output = run_tflite_graph(tflite_model_quant, data) - in_node = ["tfl.quantize"] - + if tf.__version__ < LooseVersion("2.9"): + in_node = ["tfl.quantize"] + else: + in_node = "serving_default_input" tvm_output = run_tvm_graph(tflite_model_quant, data, in_node) tvm.testing.assert_allclose( np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2 @@ -2110,7 +2118,10 @@ def _test_cos(data, quantized, int_quant_dtype=tf.int8): tf.math.cos, data, int_quant_dtype=int_quant_dtype ) tflite_output = run_tflite_graph(tflite_model_quant, data) - in_node = ["tfl.quantize"] + if tf.__version__ < LooseVersion("2.9"): + in_node = ["tfl.quantize"] + else: + in_node = "serving_default_input" tvm_output = run_tvm_graph(tflite_model_quant, data, in_node) tvm.testing.assert_allclose( np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2 @@ -3024,7 +3035,6 @@ def _test_quantize_dequantize(data): add = tf.keras.layers.Add()([data_in, relu]) concat = tf.keras.layers.Concatenate(axis=0)([relu, add]) keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat) - input_name = data_in.name.split(":")[0] # To create quantized values with dynamic range of activations, needs representative dataset def representative_data_gen(): @@ -3034,7 +3044,11 @@ def representative_data_gen(): tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True) tflite_output = run_tflite_graph(tflite_model_quant, data) - tvm_output = run_tvm_graph(tflite_model_quant, data, input_name) + if tf.__version__ < LooseVersion("2.9"): + in_node = data_in.name.split(":")[0] + else: + in_node = "serving_default_" + data_in.name + ":0" + tvm_output = run_tvm_graph(tflite_model_quant, data, in_node) tvm.testing.assert_allclose( np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2 ) @@ -3051,7 +3065,6 @@ def _test_quantize_dequantize_const(data): add = tf.keras.layers.Add()([data, relu]) concat = tf.keras.layers.Concatenate(axis=0)([relu, add]) keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat) - input_name = data_in.name.split(":")[0] # To create quantized values with dynamic range of activations, needs representative dataset def representative_data_gen(): @@ -3061,7 +3074,11 @@ def representative_data_gen(): tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True) tflite_output = run_tflite_graph(tflite_model_quant, data) - tvm_output = run_tvm_graph(tflite_model_quant, data, input_name) + if tf.__version__ < LooseVersion("2.9"): + in_node = data_in.name.split(":")[0] + else: + in_node = "serving_default_" + data_in.name + ":0" + tvm_output = run_tvm_graph(tflite_model_quant, data, in_node) tvm.testing.assert_allclose( np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2 )