Skip to content

Commit

Permalink
[TorchFX] Deptwise convolution support (#2896)
Browse files Browse the repository at this point in the history
### Changes

Depthwise convolution subtypes are integrated to torchFX nncf graph
builder

### Reason for changes

To correctly quantize deptiwise convolutions

### Related tickets

#2766 

### Tests

`tests/torch/fx/test_models.py` is updated
  • Loading branch information
daniil-lyakhov authored Aug 23, 2024
1 parent 865798c commit c57df43
Show file tree
Hide file tree
Showing 8 changed files with 5,406 additions and 1,100 deletions.
49 changes: 47 additions & 2 deletions nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@

import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph import NNCFNode
from nncf.common.graph.layer_attributes import BaseLayerAttributes
from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes
from nncf.common.graph.layer_attributes import Dtype
from nncf.common.graph.operator_metatypes import UnknownMetatype
from nncf.common.logging import nncf_logger
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
from nncf.torch.dynamic_graph.layer_attributes_handlers import apply_args_defaults
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES

Expand All @@ -27,12 +31,48 @@ class GraphConverter:
Builds the NNCFGraph from an torch.fx.GraphModule instance.
"""

def _get_layer_attributes(
node: torch.fx.Node, metatype: om.OperatorMetatype, model: torch.fx.GraphModule
) -> BaseLayerAttributes:
"""
Collects layer attributes for the given node.
:param node: Given node.
:param metatype: Given node metatype.
:param model: Target GraphModule instance.
:return: Given node layer attributes.
"""
if metatype in [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype]:
conv_default_args = [(arg.name, arg.default_value) for arg in node.target._schema.arguments]
kwargs = apply_args_defaults(node.args, node.kwargs, conv_default_args)

weight_node = kwargs["weight"]
if weight_node.op != "get_attr":
# Convs with constant subgraphs or two inputs are not supported yet.
return None
weight = get_tensor_constant_from_node(weight_node, model)
return ConvolutionLayerAttributes(
weight_requires_grad=False,
in_channels=weight.shape[0],
out_channels=weight.shape[1],
kernel_size=list(weight.shape[2:]),
stride=kwargs["stride"],
dilations=kwargs["dilation"],
groups=kwargs["groups"],
padding_values=kwargs["padding"],
transpose=False,
)
return None

@staticmethod
def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMetatype]:
def _get_node_type_and_metatype(
node: torch.fx.Node, model: torch.fx.GraphModule
) -> Tuple[str, om.OperatorMetatype]:
"""
Retrieves node's type and metatype.
:param node: Given node.
:param model: Given GraphModule.
:return: Node's type and metatype.
"""
if node.op == "placeholder":
Expand All @@ -58,6 +98,11 @@ def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMe
node_metatype = UnknownMetatype
if node_metatype is UnknownMetatype:
nncf_logger.debug(f"Unknown metatype for node: {node}")

if node_metatype.get_subtypes():
layer_attrs = GraphConverter._get_layer_attributes(node, node_metatype, model)
node_subtype = node_metatype.determine_subtype(layer_attrs)
node_metatype = node_subtype or node_metatype
return node_type, node_metatype

@staticmethod
Expand All @@ -74,7 +119,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
nncf_graph = PTNNCFGraph()

for source_node in model.graph.nodes:
node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node)
node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node, model)

nncf_graph.add_nncf_node(
node_name=source_node.name,
Expand Down
2,138 changes: 1,047 additions & 1,091 deletions tests/torch/data/reference_graphs/fx/quantized/mobilenet_v3_small.dot

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

2,685 changes: 2,684 additions & 1 deletion tests/torch/data/reference_graphs/fx/reference_metatypes/swin_v2_s.json

Large diffs are not rendered by default.

278 changes: 277 additions & 1 deletion tests/torch/data/reference_graphs/fx/reference_metatypes/unet.json

Large diffs are not rendered by default.

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions tests/torch/fx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_ref_metatypes_from_json(
if not os.path.exists(json_parent_dir):
os.makedirs(json_parent_dir)
with safe_open(complete_path, "w") as file:
json.dump(model_metatypes, file)
json.dump(model_metatypes, file, indent=4)

with safe_open(complete_path, "r") as file:
return json.load(file)
Expand All @@ -115,7 +115,7 @@ def test_model(test_case: ModelCase):
check_graph(nncf_graph, dot_filename, FX_DIR_NAME)

# Check metatypes
model_metatypes = {n.node_name: n.metatype.name for n in nncf_graph.get_all_nodes()}
model_metatypes = {n.node_name: n.metatype.__name__ for n in nncf_graph.get_all_nodes()}
ref_metatypes = get_ref_metatypes_from_json(model_name, model_metatypes)
assert model_metatypes == ref_metatypes

Expand Down

0 comments on commit c57df43

Please sign in to comment.