Skip to content

Commit

Permalink
quanitze operation expanded to take const argument (#6127)
Browse files Browse the repository at this point in the history
* quanitze operation expanded to take const argument

* amendments

used get_tensor_expr, added _test_forward_quantize_dequantize_const test
  • Loading branch information
d-smirnov authored Aug 28, 2020
1 parent 34647ed commit 4c9a391
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2764,7 +2764,7 @@ def convert_quantize(self, op):
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
in_expr = self.get_expr(input_tensor.tensor_idx)
in_expr = self.get_tensor_expr(input_tensor)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
Expand Down
27 changes: 27 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1907,11 +1907,38 @@ def representative_data_gen():
rtol=1e-5, atol=1e-2)


def _test_quantize_dequantize_const(data):
""" One iteration of quantize and dequantize """

# Keras model to force TFLite converter to insert 2 TFLite quantize ops.
# First TFLite quantize op converts float32 tensor to int8 tensor - Qnn quantize.
# Second TFLite quantize op converts int8 tensor to int8 tensor - Qnn requantize.
data_in = tf.keras.layers.Input(shape=data.shape[1:])
relu = tf.keras.layers.ReLU()(data_in)
add = tf.keras.layers.Add()([data, relu])
concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
input_name = data_in.name.split(":")[0]

# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
for i in range(1):
yield [data]

tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)

tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-2)


def test_forward_quantize_dequantize():
""" Quantize Dequantize """
data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32")
if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
_test_quantize_dequantize(data)
_test_quantize_dequantize_const(data)


#######################################################################
Expand Down

0 comments on commit 4c9a391

Please sign in to comment.