-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ModuleExtension, a new type of extension for PyTorch frontend #22867
Changes from all commits
3f1bc39
bbd11ec
6d9c2fb
4a2f534
2d4a145
e1cb833
86e0263
2af4abc
a157c65
b4293f1
5c6218e
ac97fa8
0ead7f1
95b9f9e
c81019c
757caa1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
|
||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# flake8: noqa | ||
# mypy: ignore-errors | ||
|
||
import typing | ||
|
||
class ModuleExtension: | ||
def __init__(self, module, target_op, evaluate=None, convert=None): | ||
""" | ||
Creates an extension that replaces entire PyTorch module by a single operation. | ||
This functionality works with PyTorch models only. A module can be indentified by | ||
module type (e.g. torch.nn.Linear), module instance in the model or module name. | ||
|
||
Args: | ||
module (str, torch.nn.Module, type(torch.nn.Module)): PyTorch module to replace | ||
|
||
target_op (str): a target operation that will be used as a replacer for the module, | ||
could be a name of the extension operation or existing PyTorch operation | ||
(with prim:: or aten:: prefix following TorchScript syntax). | ||
|
||
evaluate (callable with args module, *args, **kwargs): a callable that will replace a target | ||
module in model execution it is responsible for producing valid output for | ||
the module to allow correct model tracing. By default it calls original module | ||
forward with the same arguments. The provided code will not be a part of the final | ||
traced model, it is used only to produce valid results in the tracing. | ||
|
||
convert (callable with args target_op, *args, **kwargs): a callable that will be traced and become | ||
a part of the final model instead of the target module. It accepts target_op as | ||
the first parameter, target_op is callable that will appear as a single node in the | ||
graph, the type of the node is target_op provided as another argument above. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please provide description why do I need to set |
||
""" | ||
self.module = module | ||
self.target_op = target_op | ||
self.evaluate = evaluate | ||
if self.evaluate is None: | ||
self.evaluate = lambda module, *args, **kwargs: module(*args, **kwargs) | ||
self.convert = convert | ||
if self.convert is None: | ||
self.convert = lambda target_op, *args, **kwargs: target_op(*args, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
|
||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# flake8: noqa | ||
# mypy: ignore-errors | ||
|
||
|
||
def patch_model(model, module_patcher, orig_forward_name): | ||
for name, m in model.named_modules(): | ||
# TODO: Use one way to identify a patched module, currently GPTQ model patching uses different name of attribute | ||
if hasattr(m, orig_forward_name): | ||
# already patched, skipping with a warning because it is unexpected | ||
print(f'[ WARNING ] Unexpectedly found already patched module {name} while applying ModuleExtension during PyTorch model conversion. ' | ||
'Result of the conversion maybe broken. Depending on the exact issue it may lead to broken original model.') | ||
continue | ||
module_patcher(m, name) | ||
|
||
|
||
def unpatch_model(model, orig_forward_name): | ||
for _, m in model.named_modules(): | ||
if hasattr(m, orig_forward_name): | ||
try: | ||
m.forward = getattr(m, orig_forward_name) | ||
delattr(m, orig_forward_name) | ||
except Exception as error: | ||
print('[ WARNING ] Exception raised during model unpatching. Depending on the exact issue it may lead to broken original model.') | ||
print('Original exception details:') | ||
print(error) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,20 +10,41 @@ | |
from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, prepare_example_inputs_and_model, convert_quantized_tensor, graph_has_ops | ||
from openvino.runtime import opset11 as ops | ||
from openvino.frontend.pytorch import gptq | ||
from openvino.frontend.pytorch import patch_model | ||
from openvino.frontend.pytorch.module_extension import ModuleExtension | ||
|
||
import typing | ||
import torch | ||
|
||
|
||
class no_jit_trace: | ||
def __enter__(self): | ||
self.state = torch._C._get_tracing_state() | ||
torch._C._set_tracing_state(None) | ||
|
||
def __exit__(self, *args): | ||
torch._C._set_tracing_state(self.state) | ||
self.state = None | ||
|
||
class TorchScriptPythonDecoder (Decoder): | ||
def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True, skip_freeze=False, constant_cache=None): | ||
def __init__( | ||
self, | ||
pt_module, | ||
graph_element=None, | ||
example_input=None, | ||
alias_db=None, | ||
shared_memory=True, | ||
skip_freeze=False, | ||
constant_cache=None, | ||
module_extensions=None): | ||
Decoder.__init__(self) | ||
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted | ||
self.m_decoders = [] | ||
self._input_signature = None | ||
self._shared_memory = shared_memory | ||
self._input_is_list = False | ||
self.constant_cache = constant_cache if constant_cache is not None else dict() | ||
self.module_extensions = module_extensions | ||
if graph_element is None: | ||
try: | ||
pt_module = self._get_scripted_model( | ||
|
@@ -76,6 +97,47 @@ def _get_preserved_attributes(model) -> list: | |
preserved_attributes.append(name) | ||
return preserved_attributes | ||
|
||
def _patch_modules(self, model, orig_forward_name): | ||
def module_patcher(module, name): | ||
extension = None | ||
if module in self.module_extensions: | ||
extension = self.module_extensions[module] | ||
elif module.__class__ in self.module_extensions: | ||
extension = self.module_extensions[module.__class__] | ||
elif name in self.module_extensions: | ||
extension = self.module_extensions[name] | ||
|
||
if extension: | ||
# The Trampoline class is instantiated for every module replacement, so we can use class members individually for each module. | ||
class Trampoline(torch.autograd.Function): | ||
target_extension = extension | ||
original_module = module | ||
stashed_args = None | ||
stashed_kwargs = None | ||
@staticmethod | ||
def forward(*args, **kwargs): | ||
with no_jit_trace(): | ||
# `module`` is going to be passed to a user-defined function `evaluate` | ||
# `module` is patched: forward function was replaced, and we are acutally in this patched function right in this code | ||
# if we pass `module` as-is to the user code below, and it happens to call forward it will lead to infinite recursion or fail | ||
# so we need to temporary patch the module back to the original forward and then return it back again | ||
patched_forward = module.forward # stash the current forward to be able to return it back | ||
module.forward = getattr(module, orig_forward_name) # set original forward for the module | ||
results = extension.evaluate(module, *Trampoline.stashed_args, **Trampoline.stashed_kwargs) # call user code | ||
module.forward = patched_forward # return patched forward back | ||
return results | ||
def new_forward(*args, **kwargs): | ||
Trampoline.stashed_args = args | ||
Trampoline.stashed_kwargs = kwargs | ||
return extension.convert(module, Trampoline.apply, *args, **kwargs) | ||
setattr(module, orig_forward_name, module.forward) | ||
module.forward = new_forward | ||
|
||
patch_model.patch_model(model, module_patcher, '_openvino_module_extension_patch_orig_forward') | ||
|
||
def _unpatch_modules(self, model, orig_forward_name): | ||
patch_model.unpatch_model(model, orig_forward_name) | ||
|
||
def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False): | ||
import torch | ||
import inspect | ||
|
@@ -95,6 +157,12 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False) | |
else: | ||
input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model( | ||
example_inputs, input_params, pt_module) | ||
|
||
# name of attribute in a patched module where the original forward method is kept | ||
orig_forward_name = '_openvino_module_extension_patch_orig_forward' | ||
if self.module_extensions: | ||
self._patch_modules(pt_module, orig_forward_name) | ||
|
||
gptq_patched = False | ||
|
||
if gptq.detect_gptq_model(pt_module): | ||
|
@@ -115,6 +183,8 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False) | |
finally: | ||
if gptq_patched: | ||
gptq.unpatch_model(pt_module) | ||
if self.module_extensions: | ||
self._unpatch_modules(pt_module, orig_forward_name) | ||
|
||
if not freeze_by_default and graph_has_ops(scripted.inlined_graph, ["prim::Uninitialized", "prim::unchecked_cast", "aten::append"]): | ||
# freeze models with unsupported ops | ||
|
@@ -230,7 +300,8 @@ def visit_subgraph(self, node_visitor) -> None: | |
node, | ||
alias_db=self.alias_db, | ||
shared_memory=self._shared_memory, | ||
constant_cache=self.constant_cache) | ||
constant_cache=self.constant_cache, | ||
module_extensions=self.module_extensions) | ||
self.m_decoders.append(decoder) | ||
node_visitor(decoder) | ||
|
||
|
@@ -253,13 +324,28 @@ def get_subgraph_decoder(self, index: int): | |
decoder = TorchScriptPythonDecoder(self.pt_module, | ||
self.get_subgraphs()[index], | ||
alias_db=self.alias_db, | ||
shared_memory=self._shared_memory) | ||
shared_memory=self._shared_memory, | ||
module_extensions=self.module_extensions) | ||
self.m_decoders.append(decoder) | ||
return decoder | ||
|
||
def get_op_type(self) -> str: | ||
assert isinstance( | ||
self.graph_element, torch.Node), "Function can be called only when self.graph_element is of type torch.Node" | ||
if self.graph_element.kind() == "prim::PythonOp": | ||
if hasattr(self.graph_element, 'pyobj') and callable(self.graph_element.pyobj) and hasattr(self.graph_element.pyobj(), '__self__'): | ||
trampoline = self.graph_element.pyobj().__self__ | ||
if hasattr(trampoline, 'target_extension') and isinstance(trampoline.target_extension, ModuleExtension): | ||
target_op = trampoline.target_extension.target_op | ||
if callable(target_op): | ||
target = target_op(trampoline.original_module) | ||
elif isinstance(target_op, str): | ||
target = target_op | ||
# TODO: Support target as a callable that will play a role of ConversionExtension for an entire module instead of a single op. | ||
# Without supporting target as a callable here, ConversionExtension functionality is still possible to implement | ||
# by combining two extensions: ModuleExtension that use temporary name as a target op and another extension of type ConversionExtension | ||
# that translates that particular temporary name to custom graph. But providing conversion code as a callable `target` is more convenient. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if we allow |
||
return target | ||
return self.graph_element.kind() | ||
|
||
def get_schema(self) -> str: | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -6,6 +6,7 @@ | |||
|
||||
#include "input_model.hpp" | ||||
#include "op_table.hpp" | ||||
#include "openvino/core/op_extension.hpp" | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
#include "openvino/core/so_extension.hpp" | ||||
#include "openvino/frontend/pytorch/extension/conversion.hpp" | ||||
#include "openvino/op/util/multi_subgraph_base.hpp" | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if I want to decompose existing
torch.nn.Module
into a set of operations? what istarget_op
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only one custom operation is supported. It will be substituted instead of the target_op argument. To extend it furhter, I think we can give access to all custom/existing operations passing a kind of builder instead of a single target_op. It would be a more scalable approach, but I haven't drafted it yet. It wouldn't use
target_op
parameter in ModuleExtension.init, and inconvert
parameter you will write something like ops.PagedAttentionExtension(args), whereops
is an op builder, and PageAttentionExtension is the name of the custom op, so you would able to instantiate many ops of different type. It is not so straightforward, because evaluate is still defined at the level of each op, but we can experiment with that approach.