Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelWrapper #3068

Draft
wants to merge 4 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TypeVar
import os
from typing import Any, Dict, Optional, Tuple, TypeVar

import nncf
from nncf.common.engine import Engine
Expand All @@ -26,13 +27,20 @@

class NNCFGraphFactory:
@staticmethod
def create(model: TModel) -> NNCFGraph:
def create(
model: TModel, input_args: Optional[Tuple[Any, ...]] = None, input_kwargs: Optional[Dict[str, Any]] = None
) -> NNCFGraph:
"""
Factory method to create backend-specific NNCFGraph instance based on the input model.

:param model: backend-specific model instance
:return: backend-specific NNCFGraph instance
"""
if input_args is None:
input_args = ()
if input_kwargs is None:
input_kwargs = {}

model_backend = get_backend(model)
if model_backend == BackendType.ONNX:
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
Expand All @@ -47,7 +55,13 @@ def create(model: TModel) -> NNCFGraph:

return GraphConverter.create_nncf_graph(model)
if model_backend == BackendType.TORCH:
return model.nncf.get_graph()
if os.getenv("NNCF_EXPERIMENTAL_TORCH_TRACING") is None:
return model.nncf.get_graph()
else:
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph

return build_nncf_graph(model, *input_args, **input_kwargs)

raise nncf.UnsupportedBackendError(
"Cannot create backend-specific graph because {} is not supported!".format(model_backend.value)
)
Expand Down
96 changes: 96 additions & 0 deletions nncf/common/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, TypeVar

from nncf.common.factory import NNCFGraphFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend

TModel = TypeVar("TModel")


@dataclass
class ModelAttributes:
"""
A class to store model attributes.

:param example_input_args: Example input arguments for the model.
:param example_input_kwargs: Example input keyword arguments for the model.
"""

example_input_args: Optional[Tuple[Any]] = None
example_input_kwargs: Optional[Dict[str, Any]] = None


class ModelWrapper:
"""
A wrapper class for the original model.

:param _model: The original model to be wrapped.
:param _graph: The graph representation of the model.
:param _attributes: The storage of the model attributes.
:param _backend: The backend of the model.
"""

def __init__(
self, model: TModel, *, graph: Optional[NNCFGraph] = None, attributes: Optional[ModelAttributes] = None
) -> None:
self._model = model
self._graph = graph
self._attributes = attributes or ModelAttributes()
self._backend = get_backend(model)

@property
def model(self) -> TModel:
"""
Retrieves the original model.
"""
return self._model

@property
def graph(self) -> NNCFGraph:
"""
Returns the NNCFGraph representation of the model.

If the graph has not been created yet, it will be created using the model,
example input arguments, and example input keyword arguments stored in the state.
"""
if self._graph is None:
self._graph = NNCFGraphFactory.create(
self.model, self.attributes.example_input_args, self.attributes.example_input_kwargs
)
return self._graph

@property
def attributes(self) -> ModelAttributes:
"""
Retrieves the model attributes.
"""
return self._attributes

@property
def backend(self) -> BackendType:
"""
Retrieves the model backend.
"""
return self._backend

def unwrap(self) -> Tuple[TModel, NNCFGraph]:
"""
Retrieves the model and graph.

:return: A tuple of the model and graph.
"""
return self.model, self.graph
25 changes: 12 additions & 13 deletions nncf/experimental/torch/fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
# limitations under the License.

from copy import deepcopy
from typing import Optional
from typing import Optional, cast

import torch
import torch.fx
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat
Expand All @@ -22,8 +20,8 @@
from torch.fx.passes.infra.pass_manager import PassManager

import nncf
from nncf.common.factory import NNCFGraphFactory
from nncf.common.logging import nncf_logger
from nncf.common.model import ModelWrapper
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
from nncf.experimental.torch.fx.quantization.backend_parameters import is_weight_compression_needed
Expand All @@ -46,7 +44,7 @@


def quantize_impl(
model: torch.fx.GraphModule,
model: GraphModule,
calibration_dataset: Dataset,
mode: Optional[QuantizationMode] = None,
preset: Optional[QuantizationPreset] = None,
Expand All @@ -56,7 +54,7 @@ def quantize_impl(
model_type: Optional[ModelType] = None,
ignored_scope: Optional[IgnoredScope] = None,
advanced_parameters: Optional[AdvancedQuantizationParameters] = None,
) -> torch.fx.GraphModule:
) -> GraphModule:
"""
Implementation of the `quantize()` method for the Torch FX backend.
"""
Expand Down Expand Up @@ -86,9 +84,9 @@ def quantize_impl(

# To make it easier for bias correction algorithms.
apply_quantization_transformations(copied_model)

nncf_graph = NNCFGraphFactory.create(copied_model)
quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset)
model_wrapper = ModelWrapper(copied_model)
quantized_model_wrapper = quantization_algorithm.apply(model_wrapper, dataset=calibration_dataset)
quantized_model = cast(GraphModule, quantized_model_wrapper.model)

if is_weight_compression_needed(advanced_parameters):
compress_post_quantize_transformation(quantized_model)
Expand Down Expand Up @@ -116,7 +114,7 @@ def quantize_impl(


def compress_weights_impl(
model: torch.fx.GraphModule,
model: GraphModule,
dataset: Dataset,
mode: CompressWeightsMode,
ratio: float,
Expand All @@ -131,7 +129,7 @@ def compress_weights_impl(
lora_correction: bool,
backup_mode: BackupMode,
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
) -> torch.fx.GraphModule:
) -> GraphModule:
"""
Implementation of the `compress_weights()` method for the Torch Fx backend.
"""
Expand All @@ -151,8 +149,9 @@ def compress_weights_impl(
backup_mode,
advanced_parameters,
)
graph = NNCFGraphFactory.create(model)
compressed_model = compression_algorithm.apply(model, graph, dataset=dataset)
model_wrapper = ModelWrapper(model)
compressed_model_wrapper = compression_algorithm.apply(model_wrapper, dataset=dataset)
compressed_model = compressed_model_wrapper.model
compressed_model = GraphModule(compressed_model, compressed_model.graph)
compressed_model = _disallow_eval_train(compressed_model)

Expand Down
14 changes: 8 additions & 6 deletions nncf/onnx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union, cast

import onnx

import nncf
from nncf.common.logging.logger import nncf_logger
from nncf.common.model import ModelWrapper
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
from nncf.onnx.graph.metatypes.groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
from nncf.parameters import DropType
from nncf.parameters import ModelType
from nncf.parameters import QuantizationMode
Expand Down Expand Up @@ -78,10 +78,12 @@ def quantize_impl(
advanced_parameters=advanced_parameters,
)

graph = GraphConverter.create_nncf_graph(model)
warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS)
quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset)

model_wrapper = ModelWrapper(model)
warning_model_no_batchwise_support(
model_wrapper.graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
)
quantized_model_wrapper = quantization_algorithm.apply(model_wrapper, dataset=calibration_dataset)
quantized_model = cast(onnx.ModelProto, quantized_model_wrapper.model)
return quantized_model


Expand Down
11 changes: 9 additions & 2 deletions nncf/openvino/quantization/quantize_ifmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from itertools import islice
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, cast

import openvino.runtime as ov

Expand All @@ -25,6 +25,7 @@
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.model import ModelWrapper
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.openvino.graph.metatypes.openvino_metatypes import OVIfMetatype
from nncf.openvino.graph.model_utils import remove_friendly_name_duplicates
Expand Down Expand Up @@ -155,7 +156,13 @@ def apply_algorithm_if_bodies(
"""
nncf_logger.info(f"Iteration [{current_model_num}/{len(graphs)}] ...")
parent_graph = graphs[graph_id]
quantized_model = algorithm.apply(parent_model, parent_graph, parent_statistic_points, parent_dataset)

model_wrapper = ModelWrapper(parent_model, graph=parent_graph)
quantized_model_wrapper = algorithm.apply(
model_wrapper, statistic_points=parent_statistic_points, dataset=parent_dataset
)
quantized_model = cast(ov.Model, quantized_model_wrapper.model)

if get_number_if_op(parent_model) == 0:
return quantized_model, current_model_num
model_transformer_fp32 = factory.ModelTransformerFactory.create(parent_model)
Expand Down
27 changes: 16 additions & 11 deletions nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@

from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union, cast

import openvino.runtime as ov
from openvino._offline_transformations import compress_quantize_weights_transformation

from nncf.common.factory import NNCFGraphFactory
from nncf.common.factory import StatisticsAggregatorFactory
from nncf.common.logging import nncf_logger
from nncf.common.model import ModelWrapper
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
from nncf.openvino.graph.metatypes.groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
from nncf.openvino.graph.metatypes.openvino_metatypes import OVIfMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype
from nncf.openvino.graph.model_utils import remove_friendly_name_duplicates
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
from nncf.openvino.graph.node_utils import get_number_if_op
from nncf.openvino.quantization.backend_parameters import BackendParameters
from nncf.openvino.quantization.backend_parameters import is_weight_compression_needed
Expand Down Expand Up @@ -166,9 +166,12 @@ def native_quantize_impl(
ignored_scope=ignored_scope,
advanced_parameters=advanced_parameters,
)
graph = GraphConverter.create_nncf_graph(model)
warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS)
quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset)
model_wrapper = ModelWrapper(model)
warning_model_no_batchwise_support(
model_wrapper.graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
)
quantized_model_wrapper = quantization_algorithm.apply(model_wrapper, dataset=calibration_dataset)
quantized_model = cast(ov.Model, quantized_model_wrapper.model)

if is_weight_compression_needed(advanced_parameters):
compress_quantize_weights_transformation(quantized_model)
Expand Down Expand Up @@ -383,7 +386,7 @@ def compress_weights_impl(
Implementation of the `compress_weights()` method for the OpenVINO backend.
"""
model = remove_friendly_name_duplicates(model)
graph = NNCFGraphFactory.create(model)
model_wrapper = ModelWrapper(model)
compression_algorithm = WeightCompression(
mode,
ratio,
Expand All @@ -405,18 +408,20 @@ def compress_weights_impl(
# If there is no such directory, then caches statistics
statistics_path = Path(advanced_parameters.statistics_path)
if not statistics_path.exists():
cache_weight_compression_statistics(model, graph, dataset, subset_size, statistics_path)
cache_weight_compression_statistics(
model_wrapper.model, model_wrapper.graph, dataset, subset_size, advanced_parameters.statistics_path
)
statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset)
compression_algorithm.set_backend_entity(model)
_, matmul_input_to_output_nodes_map = compression_algorithm.get_compression_nodes_info(graph)
_, matmul_input_to_output_nodes_map = compression_algorithm.get_compression_nodes_info(model_wrapper.graph)
register_statistics_for_algorithm(
statistics_aggregator,
model,
graph,
model_wrapper.model,
model_wrapper.graph,
compression_algorithm,
matmul_input_to_output_nodes_map,
)
statistics_aggregator.load_statistics_from_dir(statistics_path)
statistics_points = statistics_aggregator.statistic_points

return compression_algorithm.apply(model, graph, statistics_points, dataset)
return compression_algorithm.apply(model_wrapper, statistic_points=statistics_points, dataset=dataset).model
Loading