Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch.compile Graph break introduced due to new loss function api #34615

Closed
2 of 4 tasks
ChanderG opened this issue Nov 5, 2024 · 6 comments · May be fixed by #34616
Closed
2 of 4 tasks

Torch.compile Graph break introduced due to new loss function api #34615

ChanderG opened this issue Nov 5, 2024 · 6 comments · May be fixed by #34616
Labels

Comments

@ChanderG
Copy link

ChanderG commented Nov 5, 2024

System Info

  • transformers version: 4.47.0.dev0
  • Platform: Linux-5.14.0-284.73.1.el9_2.x86_64-x86_64-with-glibc2.31
  • Python version: 3.10.15
  • Huggingface_hub version: 0.26.2
  • Safetensors version: 0.4.5
  • Accelerate version: 1.0.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

PR #34191 introduces a new Loss API. Post this PR, Dynamo was broken, which was identified and fixed in this issue: #34402. Post this (on master), Dynamo runs without errors.

However, in this process, a new Graph Break has been introduced due to this line:

loss_type = re.findall(loss_groups, self.__class__.__name__)

This is due to the new regex check.

Since the dispatch function actually checks for an attr on the config, the fix for this is quite simple - set the loss_type at model init time itself.

Expected behavior

No additional graph breaks.

@ChanderG ChanderG added the bug label Nov 5, 2024
ChanderG added a commit to ChanderG/transformers that referenced this issue Nov 5, 2024
ensures no additional graph break introduced when torch.compile'ed

fixes huggingface#34615

Signed-off-by: ChanderG <mail@chandergovind.org>
@ArthurZucker
Copy link
Collaborator

Thanks , do you have a small reproducer? We would need to add this to our tests!

@ChanderG
Copy link
Author

Very straightforward use of compile actually, nothing out of the ordinary:

import argparse
from transformers import AutoModelForCausalLM
import torch

parser = argparse.ArgumentParser()
parser.add_argument("-c", "--compile", help="turn on compile", action="store_true")

args = parser.parse_args()
torch.set_default_device("cuda:0")

model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_name)
model = model.to(torch.bfloat16)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.95), weight_decay=0.0,)

if args.compile:
    model = torch.compile(model)

def inpgen(model, size):
    inp = {}
    inp["input_ids"] = torch.randint(0, model.config.vocab_size, size)
    inp["labels"] = torch.randint(0, model.config.vocab_size, size)
    inp["position_ids"] = torch.arange(0, size[1]).repeat(size[0], 1)
    return inp

def step(inp, model, optimizer):
    outputs = model(**inp)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

size = (8, 512)
for i in range(10):
    inp = inpgen(model, size)
    step(inp, model, optimizer)

@ChanderG
Copy link
Author

Running it like this:

TORCH_LOGS="graph_breaks" python test.py -c

results in no logs in 4.45, but leads to the following logs in 4.46.2:

V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks] Graph break: from user code at:
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks]   File "/opt/conda/envs/cg-test/lib/python3.10/site-packages/transformers/modeling_utils.py", line 5024, in loss_function
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks]     loss_type = re.findall(loss_groups, self.__class__.__name__)
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks] Traceback (most recent call last):
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks]   File "/opt/conda/envs/cg-test/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks]     return inner_fn(self, inst)
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks]   File "/opt/conda/envs/cg-test/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks]     self.call_function(fn, args, {})
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks]   File "/opt/conda/envs/cg-test/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks]     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks]   File "/opt/conda/envs/cg-test/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 727, in call_function
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks]     unimplemented(msg)
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks]   File "/opt/conda/envs/cg-test/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 297, in unimplemented
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks]     raise Unsupported(msg, case_name=case_name)
V1118 06:38:04.150000 6149 site-packages/torch/_dynamo/symbolic_convert.py:617] [1/0] [__graph_breaks] torch._dynamo.exc.Unsupported: 'skip function findall in file /opt/conda/envs/cg-test/lib/python3.10/re.py'

@ArthurZucker
Copy link
Collaborator

It should have been fixed by #34511

@ChanderG
Copy link
Author

@ArthurZucker Not really? As I mention in the first comment, the graph break is introduced in that very PR - the fix for #34402. Before that PR, Dynamo was broken, that PR fixed the error, but Graph Breaks have been introduced during this process.

@ArthurZucker
Copy link
Collaborator

Yep sorry, working on #34616 with the fixes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants