Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModuleExtension, a new type of extension for PyTorch frontend #22867

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
from openvino.frontend.pytorch.py_pytorch_frontend import ConversionExtensionPytorch as ConversionExtension
from openvino.frontend.pytorch.py_pytorch_frontend import OpExtensionPytorch as OpExtension
from openvino.frontend.pytorch.module_extension import ModuleExtension
except ImportError as err:
raise ImportError("OpenVINO PyTorch frontend is not available, please make sure the frontend is built."
"{}".format(err))
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# flake8: noqa
# mypy: ignore-errors

import typing

class ModuleExtension:
def __init__(self, module, target_op, evaluate=None, convert=None):
"""
Creates an extension that replaces entire PyTorch module by a single operation.
This functionality works with PyTorch models only. A module can be indentified by
module type (e.g. torch.nn.Linear), module instance in the model or module name.

Args:
module (str, torch.nn.Module, type(torch.nn.Module)): PyTorch module to replace

target_op (str): a target operation that will be used as a replacer for the module,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if I want to decompose existing torch.nn.Module into a set of operations? what is target_op here?

Copy link
Contributor Author

@slyalin slyalin Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only one custom operation is supported. It will be substituted instead of the target_op argument. To extend it furhter, I think we can give access to all custom/existing operations passing a kind of builder instead of a single target_op. It would be a more scalable approach, but I haven't drafted it yet. It wouldn't use target_op parameter in ModuleExtension.init, and in convert parameter you will write something like ops.PagedAttentionExtension(args), where ops is an op builder, and PageAttentionExtension is the name of the custom op, so you would able to instantiate many ops of different type. It is not so straightforward, because evaluate is still defined at the level of each op, but we can experiment with that approach.

could be a name of the extension operation or existing PyTorch operation
(with prim:: or aten:: prefix following TorchScript syntax).

evaluate (callable with args module, *args, **kwargs): a callable that will replace a target
module in model execution it is responsible for producing valid output for
the module to allow correct model tracing. By default it calls original module
forward with the same arguments. The provided code will not be a part of the final
traced model, it is used only to produce valid results in the tracing.

convert (callable with args target_op, *args, **kwargs): a callable that will be traced and become
a part of the final model instead of the target module. It accepts target_op as
the first parameter, target_op is callable that will appear as a single node in the
graph, the type of the node is target_op provided as another argument above.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please provide description why do I need to set evaluate and convert ? from the description is not entirely clear what is the difference between them as why would I need them at all?
I've already provided OpenVINO operation to replace original torch.nn.Module - can it be used during tracing automatically without excess convert and evaluate?

"""
self.module = module
self.target_op = target_op
self.evaluate = evaluate
if self.evaluate is None:
self.evaluate = lambda module, *args, **kwargs: module(*args, **kwargs)
self.convert = convert
if self.convert is None:
self.convert = lambda target_op, *args, **kwargs: target_op(*args, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# flake8: noqa
# mypy: ignore-errors


def patch_model(model, module_patcher, orig_forward_name):
for name, m in model.named_modules():
# TODO: Use one way to identify a patched module, currently GPTQ model patching uses different name of attribute
if hasattr(m, orig_forward_name):
# already patched, skipping with a warning because it is unexpected
print(f'[ WARNING ] Unexpectedly found already patched module {name} while applying ModuleExtension during PyTorch model conversion. '
'Result of the conversion maybe broken. Depending on the exact issue it may lead to broken original model.')
continue
module_patcher(m, name)


def unpatch_model(model, orig_forward_name):
for _, m in model.named_modules():
if hasattr(m, orig_forward_name):
try:
m.forward = getattr(m, orig_forward_name)
delattr(m, orig_forward_name)
except Exception as error:
print('[ WARNING ] Exception raised during model unpatching. Depending on the exact issue it may lead to broken original model.')
print('Original exception details:')
print(error)
92 changes: 89 additions & 3 deletions src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,41 @@
from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, prepare_example_inputs_and_model, convert_quantized_tensor, graph_has_ops
from openvino.runtime import opset11 as ops
from openvino.frontend.pytorch import gptq
from openvino.frontend.pytorch import patch_model
from openvino.frontend.pytorch.module_extension import ModuleExtension

import typing
import torch


class no_jit_trace:
def __enter__(self):
self.state = torch._C._get_tracing_state()
torch._C._set_tracing_state(None)

def __exit__(self, *args):
torch._C._set_tracing_state(self.state)
self.state = None

class TorchScriptPythonDecoder (Decoder):
def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True, skip_freeze=False, constant_cache=None):
def __init__(
self,
pt_module,
graph_element=None,
example_input=None,
alias_db=None,
shared_memory=True,
skip_freeze=False,
constant_cache=None,
module_extensions=None):
Decoder.__init__(self)
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
self.m_decoders = []
self._input_signature = None
self._shared_memory = shared_memory
self._input_is_list = False
self.constant_cache = constant_cache if constant_cache is not None else dict()
self.module_extensions = module_extensions
if graph_element is None:
try:
pt_module = self._get_scripted_model(
Expand Down Expand Up @@ -76,6 +97,47 @@ def _get_preserved_attributes(model) -> list:
preserved_attributes.append(name)
return preserved_attributes

def _patch_modules(self, model, orig_forward_name):
def module_patcher(module, name):
extension = None
if module in self.module_extensions:
extension = self.module_extensions[module]
elif module.__class__ in self.module_extensions:
extension = self.module_extensions[module.__class__]
elif name in self.module_extensions:
extension = self.module_extensions[name]

if extension:
# The Trampoline class is instantiated for every module replacement, so we can use class members individually for each module.
class Trampoline(torch.autograd.Function):
target_extension = extension
original_module = module
stashed_args = None
stashed_kwargs = None
@staticmethod
def forward(*args, **kwargs):
with no_jit_trace():
# `module`` is going to be passed to a user-defined function `evaluate`
# `module` is patched: forward function was replaced, and we are acutally in this patched function right in this code
# if we pass `module` as-is to the user code below, and it happens to call forward it will lead to infinite recursion or fail
# so we need to temporary patch the module back to the original forward and then return it back again
patched_forward = module.forward # stash the current forward to be able to return it back
module.forward = getattr(module, orig_forward_name) # set original forward for the module
results = extension.evaluate(module, *Trampoline.stashed_args, **Trampoline.stashed_kwargs) # call user code
module.forward = patched_forward # return patched forward back
return results
def new_forward(*args, **kwargs):
Trampoline.stashed_args = args
Trampoline.stashed_kwargs = kwargs
return extension.convert(module, Trampoline.apply, *args, **kwargs)
setattr(module, orig_forward_name, module.forward)
module.forward = new_forward

patch_model.patch_model(model, module_patcher, '_openvino_module_extension_patch_orig_forward')

def _unpatch_modules(self, model, orig_forward_name):
patch_model.unpatch_model(model, orig_forward_name)

def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False):
import torch
import inspect
Expand All @@ -95,6 +157,12 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False)
else:
input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model(
example_inputs, input_params, pt_module)

# name of attribute in a patched module where the original forward method is kept
orig_forward_name = '_openvino_module_extension_patch_orig_forward'
if self.module_extensions:
self._patch_modules(pt_module, orig_forward_name)

gptq_patched = False

if gptq.detect_gptq_model(pt_module):
Expand All @@ -115,6 +183,8 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False)
finally:
if gptq_patched:
gptq.unpatch_model(pt_module)
if self.module_extensions:
self._unpatch_modules(pt_module, orig_forward_name)

if not freeze_by_default and graph_has_ops(scripted.inlined_graph, ["prim::Uninitialized", "prim::unchecked_cast", "aten::append"]):
# freeze models with unsupported ops
Expand Down Expand Up @@ -230,7 +300,8 @@ def visit_subgraph(self, node_visitor) -> None:
node,
alias_db=self.alias_db,
shared_memory=self._shared_memory,
constant_cache=self.constant_cache)
constant_cache=self.constant_cache,
module_extensions=self.module_extensions)
self.m_decoders.append(decoder)
node_visitor(decoder)

Expand All @@ -253,13 +324,28 @@ def get_subgraph_decoder(self, index: int):
decoder = TorchScriptPythonDecoder(self.pt_module,
self.get_subgraphs()[index],
alias_db=self.alias_db,
shared_memory=self._shared_memory)
shared_memory=self._shared_memory,
module_extensions=self.module_extensions)
self.m_decoders.append(decoder)
return decoder

def get_op_type(self) -> str:
assert isinstance(
self.graph_element, torch.Node), "Function can be called only when self.graph_element is of type torch.Node"
if self.graph_element.kind() == "prim::PythonOp":
if hasattr(self.graph_element, 'pyobj') and callable(self.graph_element.pyobj) and hasattr(self.graph_element.pyobj(), '__self__'):
trampoline = self.graph_element.pyobj().__self__
if hasattr(trampoline, 'target_extension') and isinstance(trampoline.target_extension, ModuleExtension):
target_op = trampoline.target_extension.target_op
if callable(target_op):
target = target_op(trampoline.original_module)
elif isinstance(target_op, str):
target = target_op
# TODO: Support target as a callable that will play a role of ConversionExtension for an entire module instead of a single op.
# Without supporting target as a callable here, ConversionExtension functionality is still possible to implement
# by combining two extensions: ModuleExtension that use temporary name as a target op and another extension of type ConversionExtension
# that translates that particular temporary name to custom graph. But providing conversion code as a callable `target` is more convenient.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we allow target_op to be a ConversionExtension? We will register it as extension in FE and check here if it is a CE we will get the name of the op it used to match and return that name here?

return target
return self.graph_element.kind()

def get_schema(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "input_model.hpp"
#include "op_table.hpp"
#include "openvino/core/op_extension.hpp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include "openvino/core/op_extension.hpp"

#include "openvino/core/so_extension.hpp"
#include "openvino/frontend/pytorch/extension/conversion.hpp"
#include "openvino/op/util/multi_subgraph_base.hpp"
Expand Down
4 changes: 3 additions & 1 deletion tools/ovc/openvino/tools/ovc/convert_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

# pylint: disable=no-name-in-module,import-error
from openvino.frontend import FrontEndManager, OpConversionFailure, TelemetryExtension
from openvino.frontend.pytorch.module_extension import ModuleExtension
from openvino.runtime import get_version as get_rt_version
from openvino.runtime import Type, PartialShape

Expand Down Expand Up @@ -173,7 +174,8 @@ def prepare_ir(argv: argparse.Namespace):
moc_front_end.add_extension(TelemetryExtension("ovc", t.send_event, t.send_error, t.send_stack_trace))
if any_extensions_used(argv):
for extension in argv.extension:
moc_front_end.add_extension(extension)
if not isinstance(extension, ModuleExtension):
moc_front_end.add_extension(extension)
ov_model = moc_pipeline(argv, moc_front_end)
return ov_model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
# pylint: disable=no-name-in-module,import-error
from openvino.runtime import Tensor, PartialShape
from openvino.tools.ovc.error import Error
from openvino.frontend.pytorch.module_extension import ModuleExtension


def extract_module_extensions(args):
return {extension.module: extension for extension in args.get('extension', []) or [] if isinstance(extension, ModuleExtension)}


def get_pytorch_decoder(model, example_inputs, args):
Expand Down Expand Up @@ -37,7 +42,11 @@ def get_pytorch_decoder(model, example_inputs, args):
if hasattr(torch, "export") and isinstance(model, (torch.export.ExportedProgram)):
raise RuntimeError("Models received from torch.export are not yet supported by convert_model.")
else:
decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True))
decoder = TorchScriptPythonDecoder(
model,
example_input=inputs,
shared_memory=args.get("share_weights", True),
module_extensions=extract_module_extensions(args))
else:
decoder = model
args['input_model'] = decoder
Expand Down
Loading