Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT2] MinMax #3166

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
15 changes: 15 additions & 0 deletions nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer:
from nncf.torch.model_transformer import PTModelTransformer

return PTModelTransformer(model)

if model_backend == BackendType.TORCH2:
from nncf.experimental.torch2.model_transformer import PT2ModelTransformer

return PT2ModelTransformer(model)

if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.model_transformer import FXModelTransformer

Expand Down Expand Up @@ -107,6 +113,10 @@ def create(model: TModel) -> Engine:
from nncf.torch.engine import PTEngine

return PTEngine(model)
if model_backend == BackendType.TORCH2:
from nncf.experimental.torch2.engine import PT2Engine

return PT2Engine(model)
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific engine because {} is not supported!".format(model_backend.value)
)
Expand Down Expand Up @@ -159,6 +169,11 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator:
from nncf.torch.statistics.aggregator import PTStatisticsAggregator

return PTStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH2:
from nncf.experimental.torch2.statistics.aggregator import PT2StatisticsAggregator

return PT2StatisticsAggregator(dataset)

if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.statistics.aggregator import FXStatisticsAggregator

Expand Down
6 changes: 3 additions & 3 deletions nncf/common/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TypeVar
from typing import Generic, TypeVar

from nncf.common.graph.transformations.layout import TransformationLayout

TModel = TypeVar("TModel")


class ModelTransformer:
class ModelTransformer(Generic[TModel]):
"""
Applies transformations to the model.
"""
Expand All @@ -29,7 +29,7 @@ def __init__(self, model: TModel):
"""
self._model = model

def transform(self, transformation_layout: TransformationLayout) -> TModel: # type:ignore
def transform(self, transformation_layout: TransformationLayout) -> TModel:
"""
Applies transformations to the model.

Expand Down
4 changes: 2 additions & 2 deletions nncf/common/graph/patterns/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam
Dict[HWFusedPatternNames, Callable[[], GraphPattern]], OPENVINO_HW_FUSED_PATTERNS.registry_dict
)
return registry
if backend in (BackendType.TORCH, BackendType.TORCH_FX):
if backend in (BackendType.TORCH, BackendType.TORCH_FX, BackendType.TORCH2):
from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS

registry = cast(Dict[HWFusedPatternNames, Callable[[], GraphPattern]], PT_HW_FUSED_PATTERNS.registry_dict)
Expand Down Expand Up @@ -77,7 +77,7 @@ def _get_backend_ignored_patterns_map(
Dict[IgnoredPatternNames, Callable[[], GraphPattern]], OPENVINO_IGNORED_PATTERNS.registry_dict
)
return registry
if backend in (BackendType.TORCH, BackendType.TORCH_FX):
if backend in (BackendType.TORCH, BackendType.TORCH_FX, BackendType.TORCH2):
from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS

registry = cast(Dict[IgnoredPatternNames, Callable[[], GraphPattern]], PT_IGNORED_PATTERNS.registry_dict)
Expand Down
8 changes: 4 additions & 4 deletions nncf/common/graph/transformations/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from typing import List

from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.graph.transformations.commands import Command


class TransformationLayout:
Expand All @@ -27,13 +27,13 @@ def __init__(self) -> None:
"""
Initialize Transformation Layout.
"""
self._transformations: List[TransformationCommand] = []
self._transformations: List[Command] = []

@property
def transformations(self) -> List[TransformationCommand]:
def transformations(self) -> List[Command]:
return self._transformations

def register(self, transformation: TransformationCommand) -> None:
def register(self, transformation: Command) -> None:
"""
Registers the transformation command in the transformation layout.

Expand Down
6 changes: 3 additions & 3 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.tensor_statistics.statistics_serializer import dump_statistics
from nncf.common.tensor_statistics.statistics_serializer import load_statistics
Expand All @@ -31,6 +30,7 @@
from nncf.data.dataset import Dataset
from nncf.data.dataset import ModelInput
from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic
from nncf.tensor import Tensor

TensorType = TypeVar("TensorType")
TModel = TypeVar("TModel")
Expand Down Expand Up @@ -167,7 +167,7 @@ def register_statistic_points(self, statistic_points: StatisticPointsContainer)
self.stat_subset_size = max(self.stat_subset_size, tensor_collector.num_samples)

@abstractmethod
def _register_statistics(self, outputs: Dict[str, NNCFTensor], statistic_points: StatisticPointsContainer) -> None:
def _register_statistics(self, outputs: Dict[str, Tensor], statistic_points: StatisticPointsContainer) -> None:
"""
Process prepared raw model outputs and statistic points for the further usage.

Expand Down Expand Up @@ -205,7 +205,7 @@ def _get_merged_statistic_points(

@staticmethod
@abstractmethod
def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]:
def _process_outputs(outputs: Any) -> Dict[str, Tensor]:
"""
Post-process model outputs for the further statistics collection.

Expand Down
16 changes: 16 additions & 0 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

class BackendType(Enum):
TORCH = "Torch"
TORCH2 = "Torch2"
TORCH_FX = "TorchFX"
TENSORFLOW = "Tensorflow"
ONNX = "ONNX"
Expand Down Expand Up @@ -49,6 +50,20 @@ def is_torch_model(model: TModel) -> bool:
return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module)


@result_verifier
def is_torch2_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of GraphModelWrapper, otherwise False.

:param model: A target model.
:return: True if the model is an instance of GraphModelWrapper, otherwise False.
"""

from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper

return isinstance(model, GraphModelWrapper)


@result_verifier
def is_torch_fx_model(model: TModel) -> bool:
"""
Expand Down Expand Up @@ -125,6 +140,7 @@ def get_backend(model: TModel) -> BackendType:
verify_map = {
is_torch_fx_model: BackendType.TORCH_FX,
is_torch_model: BackendType.TORCH,
is_torch2_model: BackendType.TORCH2,
is_tensorflow_model: BackendType.TENSORFLOW,
is_onnx_model: BackendType.ONNX,
is_openvino_model: BackendType.OPENVINO,
Expand Down
37 changes: 37 additions & 0 deletions nncf/experimental/torch2/commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

from torch import nn

from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TransformationType
from nncf.experimental.torch2.function_hook.hook_storage import RemovableHookHandle
from nncf.torch.graph.transformations.commands import PTTargetPoint


class PT2InsertionCommand(Command):
"""
Insertion operation to the models.
"""

def __init__(
self,
target_points: List[PTTargetPoint],
hook_module: nn.Module,
*,
handle_storage: Optional[List[RemovableHookHandle]] = None,
):
super().__init__(TransformationType.INSERT)
self.target_points = target_points
self.hook_module = hook_module
self.handle_storage = handle_storage
47 changes: 47 additions & 0 deletions nncf/experimental/torch2/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Tuple, Union

import torch

from nncf.common.engine import Engine
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper


class PT2Engine(Engine):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse PTEngine as it done for the TorchFX backend?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can , but what a reason?
I have try to keep one way dependence torch2 from torch

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was asked to remove the PTFXEngine by @alexsu52, await the same comment from him

"""
Engine for the Pytorch backend.
"""

def __init__(self, model: GraphModelWrapper):
"""
Constructor.

:param model: Pytorch module to infer.
"""

self._model = model.model
self._model.eval()

def infer(self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]]) -> Any:
"""
Runs Torch model on the provided input.

:param input_data: Inputs for the model.
:return: Model outputs.
"""

if isinstance(input_data, dict):
return self._model(**input_data)
if isinstance(input_data, tuple):
return self._model(*input_data)
return self._model(input_data)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

"""
This module implements selected functions from the `torch` module, excluding the `hand_function` mechanism.
This module implements selected functions from the `torch` module, excluding the `handle_torch_function` function.

It processes inner functions to handle exception hooks and graph analysis. The implementation is designed
to support custom handling of inner function exceptions for specific functions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import nncf
import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.layer_attributes import Dtype
from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph
Expand All @@ -28,6 +27,7 @@
from nncf.experimental.torch2.function_hook.graph.graph_utils import FunctionMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import NodeType
from nncf.torch.graph.graph import PTNNCFGraph


def get_node_type(type: NodeType, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> str:
Expand Down Expand Up @@ -91,14 +91,14 @@ def get_meta_type(node_type: str, meta: Union[ConstMeta, FunctionMeta, InOutMeta
return node_sub_meta_type or node_metatype


def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> PTNNCFGraph:
"""
Converts a graph to an NNCFGraph.
Converts a graph to an PTNNCFGraph.

:param nx_graph: The graph to convert.
:return: The converted NNCFGraph.
"""
nncf_graph = NNCFGraph()
nncf_graph = PTNNCFGraph()

map_nx_node_to_nncf_node: Dict[int, NNCFNode] = {}
for node, data in nx_graph.nodes(data=True):
Expand Down Expand Up @@ -136,7 +136,7 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
return nncf_graph


def build_nncf_graph(model: nn.Module, *args: Any, **kwargs: Any) -> NNCFGraph:
def build_nncf_graph(model: nn.Module, *args: Any, **kwargs: Any) -> PTNNCFGraph:
"""
Builds an NNCF graph from the given PyTorch model.

Expand All @@ -147,3 +147,35 @@ def build_nncf_graph(model: nn.Module, *args: Any, **kwargs: Any) -> NNCFGraph:
"""
graph = build_graph(model, *args, **kwargs)
return convert_to_nncf_graph(graph)


class GraphModelWrapper:
"""
A class that wraps a PyTorch model with examples inputs and provides an interface
to build a computational graph of the model.

:param model: The PyTorch model to be wrapped.
:param example_input: A tuple of example input for the model.
"""

def __init__(self, model: nn.Module, example_input: Any) -> None:
"""
Initialize the GraphModelWrapper.
"""
self.model = model
self.example_input = example_input

def build_nncf_graph(self) -> PTNNCFGraph:
"""
Constructs a computational graph of the given model.

This function builds a directed graph `PTNNCFGraph` representing the operations
and data flow within the model by leveraging hooks by using GraphBuilderMode.

:return: A PTNNCFGraph where nodes represent operations of model.
"""
if isinstance(self.example_input, dict):
return build_nncf_graph(self.model, **self.example_input)
if isinstance(self.example_input, tuple):
return build_nncf_graph(self.model, *self.example_input)
return build_nncf_graph(self.model, self.example_input)
Loading