diff --git a/examples/post_training_quantization/tensorflow/mobilenet_v2/main.py b/examples/post_training_quantization/tensorflow/mobilenet_v2/main.py index 329042df782..8e175f7dd3f 100644 --- a/examples/post_training_quantization/tensorflow/mobilenet_v2/main.py +++ b/examples/post_training_quantization/tensorflow/mobilenet_v2/main.py @@ -144,6 +144,10 @@ def transform_fn(data_item): calibration_dataset = nncf.Dataset(val_dataset, transform_fn) tf_quantized_model = nncf.quantize(tf_model, calibration_dataset) +# Removes auxiliary layers and operations added during the quantization process, +# resulting in a clean, fully quantized model ready for deployment. +tf_quantized_model = nncf.strip(tf_quantized_model) + ############################################################################### # Benchmark performance, calculate compression rate and validate accuracy diff --git a/examples/quantization_aware_training/tensorflow/mobilenet_v2/README.md b/examples/quantization_aware_training/tensorflow/mobilenet_v2/README.md new file mode 100644 index 00000000000..8c42efa055e --- /dev/null +++ b/examples/quantization_aware_training/tensorflow/mobilenet_v2/README.md @@ -0,0 +1,31 @@ +# Quantization-Aware Training of MobileNet v2 TensorFlow Model + +This example demonstrates how to use Post-Training Quantization API from Neural Network Compression Framework (NNCF) to quantize and train TensorFlow models on the example of [MobileNet v2](https://huggingface.co/alexsu52/mobilenet_v2_imagenette) quantization aware training, pretrained on [Imagenette](https://github.com/fastai/imagenette) dataset. + +The example includes the following steps: + +- Loading the [Imagenette](https://github.com/fastai/imagenette) dataset (~340 Mb) and the [MobileNet v2 TensorFlow model](https://huggingface.co/alexsu52/mobilenet_v2_imagenette) pretrained on this dataset. +- Quantizing the model using NNCF Post-Training Quantization algorithm. +- Fine tuning quantized model for two epoch to improve quantized model metrics. +- Output of the following characteristics of the quantized model: + - Accuracy drop of the quantized model (INT8) over the pre-trained model (FP32) + - Compression rate of the quantized model file size relative to the pre-trained model file size + - Performance speed up of the quantized model (INT8) + +## Install requirements + +At this point it is assumed that you have already installed NNCF. You can find information on installation NNCF [here](https://github.com/openvinotoolkit/nncf#user-content-installation). + +To work with the example you should install the corresponding Python package dependencies: + +```bash +pip install -r requirements.txt +``` + +## Run Example + +It's pretty simple. The example does not require additional preparation. It will do the preparation itself, such as loading the dataset and model, etc. + +```bash +python main.py +``` diff --git a/examples/quantization_aware_training/tensorflow/mobilenet_v2/main.py b/examples/quantization_aware_training/tensorflow/mobilenet_v2/main.py new file mode 100644 index 00000000000..bd0f6870b02 --- /dev/null +++ b/examples/quantization_aware_training/tensorflow/mobilenet_v2/main.py @@ -0,0 +1,197 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import subprocess +from pathlib import Path +from typing import List + +import openvino as ov +import tensorflow as tf +import tensorflow_datasets as tfds +from rich.progress import track + +import nncf + +ROOT = Path(__file__).parent.resolve() +WEIGHTS_URL = "https://huggingface.co/alexsu52/mobilenet_v2_imagenette/resolve/main/tf_model.h5" +DATASET_CLASSES = 10 + + +def validate(model: ov.Model, val_loader: tf.data.Dataset) -> tf.Tensor: + compiled_model = ov.compile_model(model, device_name="CPU") + output = compiled_model.outputs[0] + + metric = tf.keras.metrics.CategoricalAccuracy(name="acc@1") + for images, labels in track(val_loader): + pred = compiled_model(images.numpy())[output] + metric.update_state(labels, pred) + + return metric.result() + + +def run_benchmark(model_path: Path, shape: List[int]) -> float: + command = [ + "benchmark_app", + "-m", model_path.as_posix(), + "-d", "CPU", + "-api", "async", + "-t", "15", + "-shape", str(shape), + ] # fmt: skip + cmd_output = subprocess.check_output(command, text=True) # nosec + print(*cmd_output.splitlines()[-8:], sep="\n") + match = re.search(r"Throughput\: (.+?) FPS", cmd_output) + return float(match.group(1)) + + +def get_model_size(ir_path: Path, m_type: str = "Mb") -> float: + xml_size = ir_path.stat().st_size + bin_size = ir_path.with_suffix(".bin").stat().st_size + for t in ["bytes", "Kb", "Mb"]: + if m_type == t: + break + xml_size /= 1024 + bin_size /= 1024 + model_size = xml_size + bin_size + print(f"Model graph (xml): {xml_size:.3f} Mb") + print(f"Model weights (bin): {bin_size:.3f} Mb") + print(f"Model size: {model_size:.3f} Mb") + return model_size + + +############################################################################### +# Create a Tensorflow model and dataset + + +def center_crop(image: tf.Tensor, image_size: int, crop_padding: int) -> tf.Tensor: + shape = tf.shape(image) + image_height = shape[0] + image_width = shape[1] + + padded_center_crop_size = tf.cast( + ((image_size / (image_size + crop_padding)) * tf.cast(tf.minimum(image_height, image_width), tf.float32)), + tf.int32, + ) + + offset_height = ((image_height - padded_center_crop_size) + 1) // 2 + offset_width = ((image_width - padded_center_crop_size) + 1) // 2 + + image = tf.image.crop_to_bounding_box( + image, + offset_height=offset_height, + offset_width=offset_width, + target_height=padded_center_crop_size, + target_width=padded_center_crop_size, + ) + image = tf.image.resize(image, [image_size, image_size], method=tf.image.ResizeMethod.BILINEAR) + return image + + +def preprocess_for_eval(image, label): + image = center_crop(image, 224, 32) + image = tf.keras.applications.mobilenet_v2.preprocess_input(image) + image = tf.image.convert_image_dtype(image, tf.float32) + label = tf.one_hot(label, DATASET_CLASSES) + return image, label + + +def preprocess_for_train(image, label): + image = tf.image.resize_with_crop_or_pad(image, 256, 256) + image = tf.image.random_crop(image, [224, 224, 3]) + image = tf.image.convert_image_dtype(image, tf.float32) + label = tf.one_hot(label, DATASET_CLASSES) + return image, label + + +train_dataset = tfds.load("imagenette/320px-v2", split="train", shuffle_files=True, as_supervised=True) +train_dataset = train_dataset.map(preprocess_for_train).shuffle(1024).batch(128) + +val_dataset = tfds.load("imagenette/320px-v2", split="validation", shuffle_files=False, as_supervised=True) +val_dataset = val_dataset.map(preprocess_for_eval).batch(128) + +weights_path = tf.keras.utils.get_file("mobilenet_v2_imagenette_weights.h5", WEIGHTS_URL, cache_subdir="models") +tf_model = tf.keras.applications.MobileNetV2(weights=weights_path, classes=DATASET_CLASSES) + +############################################################################### +# Quantize a Tensorflow model +# +# The transformation function transforms a data item into model input data. +# +# To validate the transform function use the following code: +# >> for data_item in val_loader: +# >> model(transform_fn(data_item)) + + +def transform_fn(data_item): + images, _ = data_item + return images + + +# The calibration dataset is a small, no label, representative dataset +# (~100-500 samples) that is used to estimate the range, i.e. (min, max) of all +# floating point activation tensors in the model, to initialize the quantization +# parameters. +# +# The easiest way to define a calibration dataset is to use a training or +# validation dataset and a transformation function to remove labels from the data +# item and prepare model input data. The quantize method uses a small subset +# (default: 300 samples) of the calibration dataset. + +calibration_dataset = nncf.Dataset(val_dataset, transform_fn) +tf_quantized_model = nncf.quantize(tf_model, calibration_dataset) + +tf_quantized_model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), + loss=tf.keras.losses.CategoricalCrossentropy(), + metrics=[tf.keras.metrics.CategoricalAccuracy()], +) + +tf_quantized_model.fit(train_dataset, epochs=3, verbose=1) + +# Removes auxiliary layers and operations added during the quantization process, +# resulting in a clean, fully quantized model ready for deployment. +stripped_model = nncf.strip(tf_quantized_model) + +############################################################################### +# Benchmark performance, calculate compression rate and validate accuracy + +ov_model = ov.convert_model(tf_model, share_weights=False) +ov_quantized_model = ov.convert_model(stripped_model, share_weights=False) + +fp32_ir_path = ROOT / "mobilenet_v2_fp32.xml" +ov.save_model(ov_model, fp32_ir_path, compress_to_fp16=False) +print(f"[1/7] Save FP32 model: {fp32_ir_path}") +fp32_model_size = get_model_size(fp32_ir_path) + +int8_ir_path = ROOT / "mobilenet_v2_int8.xml" +ov.save_model(ov_quantized_model, int8_ir_path) +print(f"[2/7] Save INT8 model: {int8_ir_path}") +int8_model_size = get_model_size(int8_ir_path) + +print("[3/7] Benchmark FP32 model:") +fp32_fps = run_benchmark(fp32_ir_path, shape=[1, 224, 224, 3]) +print("[4/7] Benchmark INT8 model:") +int8_fps = run_benchmark(int8_ir_path, shape=[1, 224, 224, 3]) + +print("[5/7] Validate OpenVINO FP32 model:") +fp32_top1 = validate(ov_model, val_dataset) +print(f"Accuracy @ top1: {fp32_top1:.3f}") + +print("[6/7] Validate OpenVINO INT8 model:") +int8_top1 = validate(ov_quantized_model, val_dataset) +print(f"Accuracy @ top1: {int8_top1:.3f}") + +print("[7/7] Report:") +print(f"Accuracy drop: {fp32_top1 - int8_top1:.3f}") +print(f"Model compression rate: {fp32_model_size / int8_model_size:.3f}") +# https://docs.openvino.ai/latest/openvino_docs_optimization_guide_dldt_optimization_guide.html +print(f"Performance speed up (throughput mode): {int8_fps / fp32_fps:.3f}") diff --git a/examples/quantization_aware_training/tensorflow/mobilenet_v2/requirements.txt b/examples/quantization_aware_training/tensorflow/mobilenet_v2/requirements.txt new file mode 100644 index 00000000000..e1e3c89fe46 --- /dev/null +++ b/examples/quantization_aware_training/tensorflow/mobilenet_v2/requirements.txt @@ -0,0 +1,5 @@ +tensorflow~=2.12.0; python_version < '3.9' +tensorflow~=2.15.1; python_version >= '3.9' +tensorflow-datasets +tqdm +openvino==2024.6 diff --git a/nncf/common/strip.py b/nncf/common/strip.py index b642e84b598..c306a3f4ca9 100644 --- a/nncf/common/strip.py +++ b/nncf/common/strip.py @@ -40,5 +40,9 @@ def strip(model: TModel, do_copy: bool = True) -> TModel: from nncf.torch.strip import strip as strip_pt return strip_pt(model, do_copy) # type: ignore + elif model_backend == BackendType.TENSORFLOW: + from nncf.tensorflow.strip import strip as strip_tf + + return strip_tf(model, do_copy) # type: ignore raise nncf.UnsupportedBackendError(f"Method `strip` does not support for {model_backend.value} backend.") diff --git a/nncf/tensorflow/quantization/quantize_model.py b/nncf/tensorflow/quantization/quantize_model.py index 85435c67380..02da5a28ab4 100644 --- a/nncf/tensorflow/quantization/quantize_model.py +++ b/nncf/tensorflow/quantization/quantize_model.py @@ -176,7 +176,6 @@ def quantize_impl( ] ) - compression_ctrl, compressed_model = create_compressed_model(model=model, config=nncf_config) - stripped_model = compression_ctrl.strip_model(compressed_model) + _, compressed_model = create_compressed_model(model=model, config=nncf_config) - return stripped_model + return compressed_model diff --git a/nncf/tensorflow/strip.py b/nncf/tensorflow/strip.py new file mode 100644 index 00000000000..72a1e1123ca --- /dev/null +++ b/nncf/tensorflow/strip.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import tensorflow as tf + +from nncf.common.utils.backend import copy_model +from nncf.tensorflow.graph.model_transformer import TFModelTransformer +from nncf.tensorflow.graph.transformations.commands import TFOperationWithWeights +from nncf.tensorflow.graph.transformations.commands import TFRemovalCommand +from nncf.tensorflow.graph.transformations.layout import TFTransformationLayout +from nncf.tensorflow.graph.utils import collect_wrapped_layers +from nncf.tensorflow.layers.operation import NNCFOperation +from nncf.tensorflow.quantization.quantizers import AsymmetricQuantizer +from nncf.tensorflow.quantization.quantizers import SymmetricQuantizer +from nncf.tensorflow.quantization.utils import apply_overflow_fix_to_layer +from nncf.tensorflow.sparsity.magnitude.operation import BinaryMask +from nncf.tensorflow.sparsity.rb.operation import RBSparsifyingWeight +from nncf.tensorflow.sparsity.utils import apply_mask + + +def strip(model: tf.keras.Model, do_copy: bool = True) -> tf.keras.Model: + """ + Implementation of the nncf.strip() function for the TF backend + + :param model: The compressed model. + :param do_copy: If True (default), will return a copy of the currently associated model object. If False, + will return the currently associated model object "stripped" in-place. + :return: The stripped model. + """ + # Check to understand if the model is after NNCF or not. + wrapped_layers = collect_wrapped_layers(model) + if not wrapped_layers: + return model + + if do_copy: + model = copy_model(model) + wrapped_layers = collect_wrapped_layers(model) + + op_to_priority: Dict[NNCFOperation, int] = { + SymmetricQuantizer: 1, + AsymmetricQuantizer: 1, + BinaryMask: 2, + RBSparsifyingWeight: 2, + } + + key_fn = lambda op: op_to_priority.get(op, 0) + + transformation_layout = TFTransformationLayout() + for wrapped_layer in wrapped_layers: + for weight_attr, ops in wrapped_layer.weights_attr_ops.items(): + for op in sorted(ops.values(), key=key_fn, reverse=True): + # quantization + if isinstance(op, (SymmetricQuantizer, AsymmetricQuantizer)) and op.half_range: + apply_overflow_fix_to_layer(wrapped_layer, weight_attr, op) + # sparsity, pruning + if isinstance(op, (BinaryMask, RBSparsifyingWeight)): + apply_mask(wrapped_layer, weight_attr, op) + transformation_layout.register( + TFRemovalCommand( + target_point=TFOperationWithWeights( + wrapped_layer.name, weights_attr_name=weight_attr, operation_name=op.name + ) + ) + ) + if transformation_layout.transformations: + model = TFModelTransformer(model).transform(transformation_layout) + + return model diff --git a/tests/tensorflow/pruning/test_strip.py b/tests/tensorflow/pruning/test_strip.py index 37dc596dfde..e97cea2c27c 100644 --- a/tests/tensorflow/pruning/test_strip.py +++ b/tests/tensorflow/pruning/test_strip.py @@ -12,6 +12,7 @@ import pytest import tensorflow as tf +import nncf from tests.tensorflow.helpers import TFTensorListComparator from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test from tests.tensorflow.helpers import get_empty_config @@ -58,3 +59,46 @@ def test_do_copy(do_copy, enable_quantization): # Transform model for pruning creates copy of the model in both cases assert id(inference_model) != id(compression_model) + + +@pytest.mark.parametrize("enable_quantization", (True, False), ids=("with_quantization", "no_quantization")) +def test_strip_api(enable_quantization): + input_shape = (1, 8, 8, 3) + model = get_concat_test_model(input_shape) + + config = get_empty_config(input_sample_sizes=input_shape) + config.update( + {"compression": [{"algorithm": "filter_pruning", "pruning_init": 0.5, "params": {"prune_first_conv": True}}]} + ) + if enable_quantization: + config["compression"].append({"algorithm": "quantization", "preset": "mixed"}) + + compressed_model, _ = create_compressed_model_and_algo_for_test(model, config) + input_tensor = tf.ones(input_shape) + + x_a = compressed_model(input_tensor) + + stripped_model = nncf.strip(compressed_model) + x_b = stripped_model(input_tensor) + + TFTensorListComparator.check_equal(x_a, x_b) + + +@pytest.mark.parametrize("do_copy", (True, False)) +@pytest.mark.parametrize("enable_quantization", (True, False), ids=("with_quantization", "no_quantization")) +def test_strip_api_do_copy(do_copy, enable_quantization): + input_shape = (1, 8, 8, 3) + model = get_concat_test_model(input_shape) + + config = get_empty_config(input_sample_sizes=input_shape) + config.update( + {"compression": [{"algorithm": "filter_pruning", "pruning_init": 0.5, "params": {"prune_first_conv": True}}]} + ) + if enable_quantization: + config["compression"].append({"algorithm": "quantization", "preset": "mixed"}) + + compressed_model, _ = create_compressed_model_and_algo_for_test(model, config, force_no_init=True) + stripped_model = nncf.strip(compressed_model, do_copy=do_copy) + + # Transform model for pruning creates copy of the model in both cases + assert id(stripped_model) != id(compressed_model) diff --git a/tests/tensorflow/quantization/test_strip.py b/tests/tensorflow/quantization/test_strip.py index 915aeba9f4a..89d538882f4 100644 --- a/tests/tensorflow/quantization/test_strip.py +++ b/tests/tensorflow/quantization/test_strip.py @@ -12,6 +12,7 @@ import pytest import tensorflow as tf +import nncf from tests.tensorflow.helpers import TFTensorListComparator from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test from tests.tensorflow.helpers import get_basic_two_conv_test_model @@ -52,3 +53,39 @@ def test_do_copy(do_copy): assert id(inference_model) != id(compression_model) else: assert id(inference_model) == id(compression_model) + + +def test_strip_api(): + model = get_basic_two_conv_test_model() + config = get_basic_quantization_config() + config["compression"] = { + "algorithm": "quantization", + "preset": "mixed", + } + + compressed_model, _ = create_compressed_model_and_algo_for_test(model, config) + + input_tensor = tf.ones([1, 4, 4, 1]) + x_a = compressed_model(input_tensor) + + stripped_model = nncf.strip(compressed_model) + x_b = stripped_model(input_tensor) + + TFTensorListComparator.check_equal(x_a, x_b) + + +@pytest.mark.parametrize("do_copy", (True, False)) +def test_strip_api_do_copy(do_copy): + model = get_basic_two_conv_test_model() + config = get_basic_quantization_config() + config["compression"] = { + "algorithm": "quantization", + "preset": "mixed", + } + compressed_model, _ = create_compressed_model_and_algo_for_test(model, config, force_no_init=True) + stripped_model = nncf.strip(compressed_model, do_copy=do_copy) + + if do_copy: + assert id(stripped_model) != id(compressed_model) + else: + assert id(stripped_model) == id(compressed_model) diff --git a/tests/tensorflow/sparsity/test_strip.py b/tests/tensorflow/sparsity/test_strip.py index acdbfca9801..b818e70a445 100644 --- a/tests/tensorflow/sparsity/test_strip.py +++ b/tests/tensorflow/sparsity/test_strip.py @@ -12,6 +12,7 @@ import pytest import tensorflow as tf +import nncf from tests.tensorflow.helpers import TFTensorListComparator from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test from tests.tensorflow.helpers import get_basic_conv_test_model @@ -65,3 +66,52 @@ def test_do_copy(do_copy, enable_quantization): # Transform model for sparsity creates copy of the model in both cases assert id(inference_model) != id(compression_model) + + +@pytest.mark.parametrize("enable_quantization", (True, False), ids=("with_quantization", "no_quantization")) +def test_strip_api(enable_quantization): + input_shape = (1, 4, 4, 1) + model = get_basic_conv_test_model() + config = get_empty_config(input_sample_sizes=input_shape) + + config.update({"compression": [{"algorithm": "magnitude_sparsity"}]}) + if enable_quantization: + config["compression"].append( + { + "algorithm": "quantization", + "preset": "mixed", + "initializer": { + "batchnorm_adaptation": { + "num_bn_adaptation_samples": 0, + } + }, + } + ) + + compressed_model, _ = create_compressed_model_and_algo_for_test(model, config) + + input_tensor = tf.ones(input_shape) + x_a = compressed_model(input_tensor) + + stripped_model = nncf.strip(compressed_model) + x_b = stripped_model(input_tensor) + + TFTensorListComparator.check_equal(x_a, x_b) + + +@pytest.mark.parametrize("do_copy", (True, False)) +@pytest.mark.parametrize("enable_quantization", (True, False), ids=("with_quantization", "no_quantization")) +def test_strip_api_do_copy(do_copy, enable_quantization): + input_shape = (1, 4, 4, 1) + model = get_basic_conv_test_model(input_shape=input_shape[1:]) + + config = get_empty_config(input_sample_sizes=input_shape) + config.update({"compression": [{"algorithm": "magnitude_sparsity"}]}) + if enable_quantization: + config["compression"].append({"algorithm": "quantization", "preset": "mixed"}) + + compressed_model, _ = create_compressed_model_and_algo_for_test(model, config, force_no_init=True) + stripped_model = nncf.strip(compressed_model, do_copy=do_copy) + + # Transform model for sparsity creates copy of the model in both cases + assert id(compressed_model) != id(stripped_model)