Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models traced with HFTracer cannot be TorchScripted or serialized #15974

Closed
2 of 4 tasks
jamesr66a opened this issue Mar 8, 2022 · 1 comment · Fixed by #17206
Closed
2 of 4 tasks

Models traced with HFTracer cannot be TorchScripted or serialized #15974

jamesr66a opened this issue Mar 8, 2022 · 1 comment · Fixed by #17206
Assignees

Comments

@jamesr66a
Copy link
Contributor

jamesr66a commented Mar 8, 2022

Environment info

  • transformers version: 4.17.0.dev0
  • Platform: Linux-5.4.0-1051-aws-x86_64-with-glibc2.27
  • Python version: 3.9.5
  • PyTorch version (GPU?): 1.11.0a0+git708f7b1 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help

@michaelbenayoun
@sgugger

Information

Model I am using (Bert, XLNet ...): BERT, but also happens e.g. for GPT-2

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

import torch
from transformers import BertConfig, BertModel
from transformers.utils import fx

bert = BertModel(BertConfig())
bert.eval()
bs, seq_length = 20, 512
bert_input = torch.zeros(bs, seq_length, dtype=torch.long).random_(bert.config.vocab_size)
orig_out = bert(bert_input)

# Set-up: fx trace the model

bert_traced = fx.symbolic_trace(bert)
traced_out = bert_traced(bert_input)
torch.testing.assert_allclose(traced_out['last_hidden_state'], orig_out['last_hidden_state'])

# Issue 1: TorchScript breakage. Leaf function patching breaks TorchScript tracing, in this
# instance the generated wrapper for `torch.ones`. I believe this is because TorchScript is
# unable to 

# scripted = torch.jit.script(bert_traced)
#
# The preceeding fails at pytorch/torch/_sources.py", line 22, in get_source_lines_and_file
#    sourcelines, file_lineno = inspect.getsourcelines(obj). When printing out the object that
#    is being resolved, `obj` is `<function _VariableFunctionsClass.ones at 0x7fbc8a6c9af0>`, the
#    torch.ones wrapper that is programmatically generated in transformers.utils.fx._function_to_leaf


# Issue 2: Serialized model does not have metadata needed to re-trace on load path

import pickle, tempfile, os

with tempfile.TemporaryDirectory() as tmp_dir_name:
    pkl_file_name = os.path.join(tmp_dir_name, "bert_model.pkl")

    # with open(pkl_file_name, 'wb') as f:
    #     pickle.dump(bert_traced, f)

    # with open(pkl_file_name, 'rb') as f:
    #     loaded = pickle.load(f)
    # The previous fails with: torch.package.importer.ObjNotFoundError: 
    # <function _VariableFunctionsClass.ones at 0x7f4e46740ca0> was not 
    # found as transformers.utils.fx._VariableFunctionsClass.ones. This is
    # because the ones wrapper was programmatically generated and cannot
    # be resolved to a call target in a deserialization context, which
    # only has references to target by qualified name (by virtue of needing
    # to work across different processes).


# We can hack around this and replace the `torch.ones` wrapper with a wrapper
# that can be resolved by qualified name:

def ones_wrapper(*args, **kwargs):
    return torch.ones(*args, **kwargs)

for node in bert_traced.graph.nodes:
    if node.op == 'call_function' and node.target.__qualname__ == '_VariableFunctionsClass.ones':
        node.target = ones_wrapper

bert_traced.recompile()

# This leads us to Issue 3: module does not have enough metadata to do re-tracing
# on the deserialization path.

with tempfile.TemporaryDirectory() as tmp_dir_name:
    pkl_file_name = os.path.join(tmp_dir_name, "bert_model.pkl")

    with open(pkl_file_name, 'wb') as f:
        pickle.dump(bert_traced, f)

    # with open(pkl_file_name, 'rb') as f:
    #     loaded = pickle.load(f)
    #
    # The above fails with:
    #
    # Traceback (most recent call last):
    #   File "/transformers_issue.py", line 64, in <module>
    #     loaded = pickle.load(f)
    #   File "/pytorch/torch/fx/graph_module.py", line 105, in reduce_graph_module
    #     return _deserialize_graph_module(forward, body)
    #   File "/pytorch/torch/fx/graph_module.py", line 163, in _deserialize_graph_module
    #     graph = KeepModules().trace(com)
    #   File "/transformers/src/transformers/utils/fx.py", line 467, in trace
    #     self.record(root, input_names, method_names=method_names)
    #   File "/transformers/src/transformers/utils/fx.py", line 418, in record
    #     inputs.update(self._generate_dummy_input(model, input_name, shape))
    #   File "/transformers/src/transformers/utils/fx.py", line 361, in _generate_dummy_input
    #     device = model.device
    #   File "/pytorch/torch/nn/modules/module.py", line 1186, in __getattr__
    #     raise AttributeError("'{}' object has no attribute '{}'".format(
    # AttributeError: 'CodeOnlyModule' object has no attribute 'device'

# We can patch HF transformers to customize the serialization/deserialization process
# to include metadata like `device` and the input shapes that were generated during
# initial symbolic tracing: https://gist.github.com/jamesr66a/7304d8818c04abd49df7a70a2ae51c02

# The following should now pass:

with tempfile.TemporaryDirectory() as tmp_dir_name:
    pkl_file_name = os.path.join(tmp_dir_name, "bert_model.pkl")

    with open(pkl_file_name, 'wb') as f:
        pickle.dump(bert_traced, f)

    with open(pkl_file_name, 'rb') as f:
        loaded = pickle.load(f)

loaded_outs = loaded(bert_input)
torch.testing.assert_allclose(loaded_outs['last_hidden_state'], orig_out['last_hidden_state'])

Expected behavior

torch.jit.script or pickle.dump/load serialization/deserialization should work out-of-the box. I believe that a) switching leaf function to reference functions that can be resolved by qualified name and b) customizing HFTracer serialization to preserve the metadata needed during serialization should fix this issue

@github-actions
Copy link

github-actions bot commented Apr 7, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants