You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
File "/anaconda3/envs/dev/lib/python3.9/site-packages/torch/_dynamo/variables/functions.py", line 152, in __init__
assert isinstance(
AssertionError: expected FunctionType found _lru_cache_wrapper <functools._lru_cache_wrapper object at 0x7f1091109a40>
from user code:
File "/anaconda3/envs/dev/lib/python3.9/site-packages/accelerate/utils/operations.py", line 820, in forward
return model_forward(*args, **kwargs)
File "/anaconda3/envs/dev/lib/python3.9/site-packages/accelerate/utils/operations.py", line 808, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/anaconda3/envs/dev/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
return func(*args, **kwargs)
File "/anaconda3/envs/dev/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1214, in forward
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Now, if you update the accelerate config to not use dynamo, it runs just fine:
System Info
transformers
version: 4.46.0- distributed_type: NO
- mixed_precision: bf16
- use_cpu: False
- debug: False
- num_processes: 1
- machine_rank: 0
- num_machines: 1
- gpu_ids: 0
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: True
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
- dynamo_config: {'dynamo_backend': 'INDUCTOR'}
Who can help?
@muellerzr @ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
#34191 introduced custom loss functions to the model classes. This appears to have broken training with accelerate + torch dynamo.
To reproduce, use
run_clm.py
with the following accelerate config:This produces an error from dynamo relating to the new
model_cls.loss_function
attribute added in #34191:transformers/src/transformers/models/llama/modeling_llama.py
Lines 1209 to 1211 in 239a256
Important part of the traceback:
Now, if you update the accelerate config to not use dynamo, it runs just fine:
Expected behavior
Accelerate should not throw the error when using torch dynamo.
The text was updated successfully, but these errors were encountered: