From b40068bf5fa0ef0f9a1ad8c9d936933eaade60b3 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Fri, 8 Apr 2022 14:44:46 +0100 Subject: [PATCH 1/5] [CMSIS-NN] Moved TFLite model making to common area Change-Id: Ic4dbc1919ff0b481c05daf7e57cf9b055c714c9c --- python/tvm/relay/testing/tflite.py | 152 ++++++++++++++++++ .../contrib/test_cmsisnn/test_conv2d.py | 18 ++- tests/python/contrib/test_cmsisnn/utils.py | 1 - 3 files changed, 162 insertions(+), 9 deletions(-) create mode 100644 python/tvm/relay/testing/tflite.py diff --git a/python/tvm/relay/testing/tflite.py b/python/tvm/relay/testing/tflite.py new file mode 100644 index 000000000000..ef0141d8b5f1 --- /dev/null +++ b/python/tvm/relay/testing/tflite.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Common utilities for creating TFLite models""" +from distutils.version import LooseVersion +import numpy as np +import pytest +import tvm + +pytest.importorskip("tflite") +pytest.importorskip("tensorflow") +import tflite.Model +import tensorflow as tf + + +class TFLiteModel: + """Creates TFLite Model and facilitates reference data generation""" + + def __init__(self, dtype): + self.serial_model = None # This is what TFLite convert() provides + self.dtype = dtype # This is the dtype of graph inputs + self.shape_dict = {} + self.dtype_dict = {} + + @tf.function + def conv2d_single_function(self, ifm_tensor, args): + """Returns TFLite Conv2d layer""" + assert len(args) == 6, "Conv2D needs (ifm_shape, kernel_shape, strides, padding, dilation)" + _, kernel_shape, strides, padding, dilation, activation = args + op = tf.nn.conv2d( + ifm_tensor, + filters=tf.constant( + np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]), + dtype=tf.float32, + ), + strides=[1, strides[0], strides[1], 1], + padding=padding, + dilations=dilation, + ) + if activation == "RELU": + op = tf.nn.relu(op) + elif activation == "NONE": + pass + else: + assert False, "Unsupported activation {}".format(activation) + return op + + def create_tflite_model(self, op_type, *args): + """Returns TFLite serial graph, Relay module, Relay params based on op_type""" + concrete_func = None + input_shape = None + if op_type == "conv2d_single": + input_shape = args[0] + ifm_tensor = tf.TensorSpec(input_shape, dtype=tf.float32, name="input") + concrete_func = self.conv2d_single_function.get_concrete_function(ifm_tensor, args) + else: + assert False, "Unsupported op_type {}".format(op_type) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(input_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + self.serial_model = converter.convert() + self.shape_dict = {"input": input_shape} + self.dtype_dict = {"input": self.dtype} + + def convert_to_relay(self): + """Converts TFLite serialized graph into Relay""" + assert self.serial_model is not None, "TFLite model is empty!" + + tflite_model = tflite.Model.Model.GetRootAsModel(self.serial_model, 0) + relay_module, relay_params = tvm.relay.frontend.from_tflite( + tflite_model, self.shape_dict, self.dtype_dict + ) + return relay_module, relay_params + + def generate_randomized_input_data(self, seed, shape, dtype): + """Generates randomized input numpy arrays based on shape and dtype.""" + random_state = np.random.RandomState(seed) + random_data = None + if dtype == np.float32: + random_data = random_state.uniform(-1, 1, size).astype(dtype) + else: + low = np.iinfo(dtype).min + high = np.iinfo(dtype).max + 1 + random_data = random_state.randint(low, high, shape, dtype) + return random_data + + # pylint: disable=import-outside-toplevel + def generate_reference_data(self): + """ + This method uses TFLite reference kernels to generate reference output. + It returns randomized inputs and reference outputs. + """ + assert self.serial_model is not None, "TFLite model was not created." + + output_tolerance = None + if tf.__version__ < LooseVersion("2.5.0"): + output_tolerance = 1 + interpreter = tf.lite.Interpreter(model_content=self.serial_model) + else: + output_tolerance = 0 + interpreter = tf.lite.Interpreter( + model_content=self.serial_model, + experimental_op_resolver_type=tf.lite.experimental.OpResolverType.BUILTIN_REF, + experimental_preserve_all_tensors=False, + ) + + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # Generate predictable randomized input + seed = 0 + input_data = {} + for input_detail in input_details: + input_values = self.generate_randomized_input_data( + seed, input_detail["shape"], input_detail["dtype"] + ) + interpreter.set_tensor(input_detail["index"], input_values) + input_data.update({input_detail["name"]: input_values}) + + interpreter.invoke() + + # Obtain the expected output from interpreter + expected_output_data = {} + for output_detail in output_details: + expected_output_data.update( + {output_detail["name"]: interpreter.get_tensor(output_detail["index"])} + ) + + return input_data, expected_output_data, output_tolerance diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 6c8f53666e95..b4ef586a09ac 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -35,14 +35,12 @@ from utils import ( skip_if_no_reference_system, make_module, - create_conv2d_tflite_relay_models, get_range_for_dtype_str, get_same_padding, get_conv2d_qnn_params, make_qnn_relu, assert_partitioned_function, assert_no_external_function, - generate_ref_data_tflite, ) @@ -314,25 +312,29 @@ def test_conv2d_int8_tflite(ifm_shape, kernel_shape, strides, dilation, padding, interface_api = "c" use_unpacked_api = True test_runner = AOT_USMP_CORSTONE300_RUNNER - dtype = "int8" - tflite_model, relay_mod, params = create_conv2d_tflite_relay_models( - ifm_shape, kernel_shape, strides, dilation, padding, activation, dtype + + from tvm.relay.testing.tflite import TFLiteModel + + tfl_model = TFLiteModel(dtype) + tfl_model.create_tflite_model( + "conv2d_single", ifm_shape, kernel_shape, strides, padding, dilation, activation ) + relay_mod, relay_params = tfl_model.convert_to_relay() - cmsisnn_mod = cmsisnn.partition_for_cmsisnn(relay_mod, params) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(relay_mod, relay_params) # validate pattern matching assert_partitioned_function(relay_mod, cmsisnn_mod) # validate CMSIS-NN output against TFLite output - input_map, output_map, output_tolerance = generate_ref_data_tflite(tflite_model) + input_map, output_map, output_tolerance = tfl_model.generate_reference_data() compile_and_run( AOTTestModel( module=cmsisnn_mod, inputs=input_map, outputs=output_map, - params=params, + params=relay_params, output_tolerance=output_tolerance, ), test_runner, diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py index 6bd375db1ff2..34f2196ab7a9 100644 --- a/tests/python/contrib/test_cmsisnn/utils.py +++ b/tests/python/contrib/test_cmsisnn/utils.py @@ -226,7 +226,6 @@ def make_qnn_relu(expr, fused_activation_fn, scale, zero_point, dtype): if fused_activation_fn == "RELU": return tvm.relay.op.clip(expr, a_min=max(qmin, quantize(0.0)), a_max=qmax) - def generate_random_input_data(seed, shape, dtype): """ Generates randomized input numpy arrays based on shape and dtype From 1a7cfba5362798bf778a5d44fbaacce6e4027760 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Mon, 11 Apr 2022 11:59:25 +0100 Subject: [PATCH 2/5] Fixed lint issues with tensorflow import Change-Id: I7a520beec9c244e9c790d3e82733c2fb476f7e5e --- python/tvm/relay/testing/tflite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/testing/tflite.py b/python/tvm/relay/testing/tflite.py index ef0141d8b5f1..ed3903eae599 100644 --- a/python/tvm/relay/testing/tflite.py +++ b/python/tvm/relay/testing/tflite.py @@ -22,8 +22,8 @@ pytest.importorskip("tflite") pytest.importorskip("tensorflow") -import tflite.Model -import tensorflow as tf +import tflite.Model # pylint: disable=wrong-import-position +import tensorflow as tf # pylint: disable=wrong-import-position class TFLiteModel: From f2a3484235ea549eaaad9c4f3c548fcdee777b3a Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Wed, 20 Apr 2022 12:57:06 +0100 Subject: [PATCH 3/5] Resolved merge conflict with main Change-Id: Iefe58dd321efae6eae26cd54a31c5923d0f1e32b --- tests/python/contrib/test_cmsisnn/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py index 34f2196ab7a9..6bd375db1ff2 100644 --- a/tests/python/contrib/test_cmsisnn/utils.py +++ b/tests/python/contrib/test_cmsisnn/utils.py @@ -226,6 +226,7 @@ def make_qnn_relu(expr, fused_activation_fn, scale, zero_point, dtype): if fused_activation_fn == "RELU": return tvm.relay.op.clip(expr, a_min=max(qmin, quantize(0.0)), a_max=qmax) + def generate_random_input_data(seed, shape, dtype): """ Generates randomized input numpy arrays based on shape and dtype From 4d775952d860124c2cdd3d581ef54c7ec994fb9e Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Tue, 26 Apr 2022 11:38:00 +0100 Subject: [PATCH 4/5] Made TFLite layer creation explicit Change-Id: I7fbf6a5a2163c1fada49477f86d84f1bc09bd57c --- python/tvm/relay/testing/tflite.py | 81 ++++++----- .../contrib/test_cmsisnn/test_conv2d.py | 5 +- tests/python/contrib/test_cmsisnn/utils.py | 131 ------------------ 3 files changed, 47 insertions(+), 170 deletions(-) diff --git a/python/tvm/relay/testing/tflite.py b/python/tvm/relay/testing/tflite.py index ed3903eae599..2ff44819f48e 100644 --- a/python/tvm/relay/testing/tflite.py +++ b/python/tvm/relay/testing/tflite.py @@ -35,44 +35,53 @@ def __init__(self, dtype): self.shape_dict = {} self.dtype_dict = {} - @tf.function - def conv2d_single_function(self, ifm_tensor, args): - """Returns TFLite Conv2d layer""" - assert len(args) == 6, "Conv2D needs (ifm_shape, kernel_shape, strides, padding, dilation)" - _, kernel_shape, strides, padding, dilation, activation = args - op = tf.nn.conv2d( - ifm_tensor, - filters=tf.constant( - np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]), - dtype=tf.float32, - ), - strides=[1, strides[0], strides[1], 1], - padding=padding, - dilations=dilation, - ) - if activation == "RELU": - op = tf.nn.relu(op) - elif activation == "NONE": - pass - else: - assert False, "Unsupported activation {}".format(activation) - return op - - def create_tflite_model(self, op_type, *args): - """Returns TFLite serial graph, Relay module, Relay params based on op_type""" - concrete_func = None - input_shape = None - if op_type == "conv2d_single": - input_shape = args[0] - ifm_tensor = tf.TensorSpec(input_shape, dtype=tf.float32, name="input") - concrete_func = self.conv2d_single_function.get_concrete_function(ifm_tensor, args) - else: - assert False, "Unsupported op_type {}".format(op_type) + def create_conv2d_single(self, kernel_shape, strides, padding, dilation, activation): + @tf.function + def conv2d_single_function(ifm_tensor): + """Returns TFLite Conv2d layer""" + op = tf.nn.conv2d( + ifm_tensor, + filters=tf.constant( + np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]), + dtype=tf.float32, + ), + strides=[1, strides[0], strides[1], 1], + padding=padding, + dilations=dilation, + ) + if activation == "RELU": + op = tf.nn.relu(op) + elif activation == "NONE": + pass + else: + assert False, "Unsupported activation {}".format(activation) + return op + + return conv2d_single_function + + def create_tflite_model(self, tfl_function, shapes, ranges=None): + """Creates TFLite serial graph""" + tensor_specs = [] + for i, shape in enumerate(shapes): + input_name = "input_" + str(i) + self.shape_dict.update({input_name: shape}) + self.dtype_dict.update({input_name: self.dtype}) + tensor_specs.append(tf.TensorSpec(shape, dtype=tf.float32, name=input_name)) + concrete_func = tfl_function.get_concrete_function(*tensor_specs) + + if not ranges: + ranges = [(0, 1) for _ in shapes] def representative_dataset(): for _ in range(100): - data = np.random.rand(*tuple(input_shape)) - yield [data.astype(np.float32)] + inputs = [] + for i, shape in enumerate(shapes): + data = np.random.uniform( + low=ranges[i][0], high=ranges[i][1], size=tuple(shape) + ).astype("float32") + inputs.append(data) + + yield inputs converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] @@ -81,8 +90,6 @@ def representative_dataset(): converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 self.serial_model = converter.convert() - self.shape_dict = {"input": input_shape} - self.dtype_dict = {"input": self.dtype} def convert_to_relay(self): """Converts TFLite serialized graph into Relay""" diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index b4ef586a09ac..47245f60e15e 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -317,9 +317,10 @@ def test_conv2d_int8_tflite(ifm_shape, kernel_shape, strides, dilation, padding, from tvm.relay.testing.tflite import TFLiteModel tfl_model = TFLiteModel(dtype) - tfl_model.create_tflite_model( - "conv2d_single", ifm_shape, kernel_shape, strides, padding, dilation, activation + conv2d_function = tfl_model.create_conv2d_single( + kernel_shape, strides, padding, dilation, activation ) + tfl_model.create_tflite_model(conv2d_function, [ifm_shape]) relay_mod, relay_params = tfl_model.convert_to_relay() cmsisnn_mod = cmsisnn.partition_for_cmsisnn(relay_mod, relay_params) diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py index 6bd375db1ff2..83c67cd95b1c 100644 --- a/tests/python/contrib/test_cmsisnn/utils.py +++ b/tests/python/contrib/test_cmsisnn/utils.py @@ -225,134 +225,3 @@ def make_qnn_relu(expr, fused_activation_fn, scale, zero_point, dtype): ) if fused_activation_fn == "RELU": return tvm.relay.op.clip(expr, a_min=max(qmin, quantize(0.0)), a_max=qmax) - - -def generate_random_input_data(seed, shape, dtype): - """ - Generates randomized input numpy arrays based on shape and dtype - """ - random_state = np.random.RandomState(seed) - if dtype == np.float32: - return random_state.uniform(-1, 1, size).astype(dtype) - else: - low = np.iinfo(dtype).min - high = np.iinfo(dtype).max + 1 - return random_state.randint(low, high, shape, dtype) - - -def generate_ref_data_tflite(model): - """ - This method uses TFLite reference kernels to generate reference output. - Random input generator is used to get the input data. - It returns randomized inputs and reference outputs. - """ - import tensorflow as tf - from distutils.version import LooseVersion - - output_tolerance = None - if tf.__version__ < LooseVersion("2.5.0"): - output_tolerance = 1 - interpreter = tf.lite.Interpreter(model_content=model) - else: - from tensorflow.lite.python.interpreter import OpResolverType - - output_tolerance = 0 - interpreter = tf.lite.Interpreter( - model_content=model, - experimental_op_resolver_type=OpResolverType.BUILTIN_REF, - experimental_preserve_all_tensors=False, - ) - - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - # Generate predictable randomized input - seed = 0 - input_data = {} - for input_detail in input_details: - input_values = generate_random_input_data( - seed, input_detail["shape"], input_detail["dtype"] - ) - interpreter.set_tensor(input_detail["index"], input_values) - input_data.update({input_detail["name"]: input_values}) - - interpreter.invoke() - - # Obtain the expected output from interpreter - expected_output_data = {} - for output_detail in output_details: - expected_output_data.update( - {output_detail["name"]: interpreter.get_tensor(output_detail["index"])} - ) - - return input_data, expected_output_data, output_tolerance - - -def create_conv2d_tflite_model(ifm_shape, kernel_shape, strides, dilation, padding, activation): - """This method prepares TFlite graph with a single Conv2d layer""" - import tensorflow as tf - - class Model(tf.Module): - @tf.function - def tf_function(self, x): - # Use tf.nn API to create the model - tf_strides = [1, strides[0], strides[1], 1] - op = tf.nn.conv2d( - x, - filters=tf.constant( - np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]), - dtype=tf.float32, - ), - strides=tf_strides, - padding=padding, - dilations=dilation, - ) - if activation: - op = tf.nn.relu(op) - return op - - model = Model() - concrete_func = model.tf_function.get_concrete_function( - tf.TensorSpec(ifm_shape, dtype=tf.float32) - ) - - def representative_dataset(): - for _ in range(100): - data = np.random.rand(*tuple(ifm_shape)) - yield [data.astype(np.float32)] - - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = representative_dataset - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.int8 - converter.inference_output_type = tf.int8 - tflite_model = converter.convert() - return tflite_model - - -def create_conv2d_tflite_relay_models( - ifm_shape, kernel_shape, strides, dilation, padding, activation, dtype -): - """ - This method creates a conv2d TFLite layer and prepared TFLite model from it. - Converts that into the Relay module and params. - Returns TFLite model, Relay module and params. - """ - pytest.importorskip("tflite") - import tflite.Model - - serialized_tflite_model = create_conv2d_tflite_model( - ifm_shape, kernel_shape, strides, dilation, padding, activation - ) - - tflite_model = tflite.Model.Model.GetRootAsModel(serialized_tflite_model, 0) - - relay_module, params = relay.frontend.from_tflite( - tflite_model, - shape_dict={"input": ifm_shape}, - dtype_dict={"input": dtype}, - ) - - return serialized_tflite_model, relay_module, params From 13eb13fa3a1503d4f61f2c4a3b5608099ecf1517 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Tue, 26 Apr 2022 13:07:11 +0100 Subject: [PATCH 5/5] Lint fix: added a missing docstring Change-Id: If1fb8bb09c538c04e333ccab65a20cff247a504d --- python/tvm/relay/testing/tflite.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/relay/testing/tflite.py b/python/tvm/relay/testing/tflite.py index 2ff44819f48e..df40130cebaf 100644 --- a/python/tvm/relay/testing/tflite.py +++ b/python/tvm/relay/testing/tflite.py @@ -36,6 +36,8 @@ def __init__(self, dtype): self.dtype_dict = {} def create_conv2d_single(self, kernel_shape, strides, padding, dilation, activation): + """Returns tf.function that creates TFLite Conv2d layer""" + @tf.function def conv2d_single_function(ifm_tensor): """Returns TFLite Conv2d layer"""