You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Model I am using (Bert, XLNet ...): BERT, but also happens e.g. for GPT-2
The problem arises when using:
the official example scripts: (give details below)
my own modified scripts: (give details below)
The tasks I am working on is:
an official GLUE/SQUaD task: (give the name)
my own task or dataset: (give details below)
To reproduce
Steps to reproduce the behavior:
import torch
from transformers import BertConfig, BertModel
from transformers.utils import fx
bert = BertModel(BertConfig())
bert.eval()
bs, seq_length = 20, 512
bert_input = torch.zeros(bs, seq_length, dtype=torch.long).random_(bert.config.vocab_size)
orig_out = bert(bert_input)
# Set-up: fx trace the model
bert_traced = fx.symbolic_trace(bert)
traced_out = bert_traced(bert_input)
torch.testing.assert_allclose(traced_out['last_hidden_state'], orig_out['last_hidden_state'])
# Issue 1: TorchScript breakage. Leaf function patching breaks TorchScript tracing, in this
# instance the generated wrapper for `torch.ones`. I believe this is because TorchScript is
# unable to
# scripted = torch.jit.script(bert_traced)
#
# The preceeding fails at pytorch/torch/_sources.py", line 22, in get_source_lines_and_file
# sourcelines, file_lineno = inspect.getsourcelines(obj). When printing out the object that
# is being resolved, `obj` is `<function _VariableFunctionsClass.ones at 0x7fbc8a6c9af0>`, the
# torch.ones wrapper that is programmatically generated in transformers.utils.fx._function_to_leaf
# Issue 2: Serialized model does not have metadata needed to re-trace on load path
import pickle, tempfile, os
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "bert_model.pkl")
# with open(pkl_file_name, 'wb') as f:
# pickle.dump(bert_traced, f)
# with open(pkl_file_name, 'rb') as f:
# loaded = pickle.load(f)
# The previous fails with: torch.package.importer.ObjNotFoundError:
# <function _VariableFunctionsClass.ones at 0x7f4e46740ca0> was not
# found as transformers.utils.fx._VariableFunctionsClass.ones. This is
# because the ones wrapper was programmatically generated and cannot
# be resolved to a call target in a deserialization context, which
# only has references to target by qualified name (by virtue of needing
# to work across different processes).
# We can hack around this and replace the `torch.ones` wrapper with a wrapper
# that can be resolved by qualified name:
def ones_wrapper(*args, **kwargs):
return torch.ones(*args, **kwargs)
for node in bert_traced.graph.nodes:
if node.op == 'call_function' and node.target.__qualname__ == '_VariableFunctionsClass.ones':
node.target = ones_wrapper
bert_traced.recompile()
# This leads us to Issue 3: module does not have enough metadata to do re-tracing
# on the deserialization path.
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "bert_model.pkl")
with open(pkl_file_name, 'wb') as f:
pickle.dump(bert_traced, f)
# with open(pkl_file_name, 'rb') as f:
# loaded = pickle.load(f)
#
# The above fails with:
#
# Traceback (most recent call last):
# File "/transformers_issue.py", line 64, in <module>
# loaded = pickle.load(f)
# File "/pytorch/torch/fx/graph_module.py", line 105, in reduce_graph_module
# return _deserialize_graph_module(forward, body)
# File "/pytorch/torch/fx/graph_module.py", line 163, in _deserialize_graph_module
# graph = KeepModules().trace(com)
# File "/transformers/src/transformers/utils/fx.py", line 467, in trace
# self.record(root, input_names, method_names=method_names)
# File "/transformers/src/transformers/utils/fx.py", line 418, in record
# inputs.update(self._generate_dummy_input(model, input_name, shape))
# File "/transformers/src/transformers/utils/fx.py", line 361, in _generate_dummy_input
# device = model.device
# File "/pytorch/torch/nn/modules/module.py", line 1186, in __getattr__
# raise AttributeError("'{}' object has no attribute '{}'".format(
# AttributeError: 'CodeOnlyModule' object has no attribute 'device'
# We can patch HF transformers to customize the serialization/deserialization process
# to include metadata like `device` and the input shapes that were generated during
# initial symbolic tracing: https://gist.github.com/jamesr66a/7304d8818c04abd49df7a70a2ae51c02
# The following should now pass:
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "bert_model.pkl")
with open(pkl_file_name, 'wb') as f:
pickle.dump(bert_traced, f)
with open(pkl_file_name, 'rb') as f:
loaded = pickle.load(f)
loaded_outs = loaded(bert_input)
torch.testing.assert_allclose(loaded_outs['last_hidden_state'], orig_out['last_hidden_state'])
Expected behavior
torch.jit.script or pickle.dump/load serialization/deserialization should work out-of-the box. I believe that a) switching leaf function to reference functions that can be resolved by qualified name and b) customizing HFTracer serialization to preserve the metadata needed during serialization should fix this issue
The text was updated successfully, but these errors were encountered:
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Environment info
transformers
version: 4.17.0.dev0Who can help
@michaelbenayoun
@sgugger
Information
Model I am using (Bert, XLNet ...): BERT, but also happens e.g. for GPT-2
The problem arises when using:
The tasks I am working on is:
To reproduce
Steps to reproduce the behavior:
Expected behavior
torch.jit.script
orpickle.dump/load
serialization/deserialization should work out-of-the box. I believe that a) switching leaf function to reference functions that can be resolved by qualified name and b) customizing HFTracer serialization to preserve the metadata needed during serialization should fix this issueThe text was updated successfully, but these errors were encountered: