Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generation with HybridCache fails (affecting Gemma-2) #31664

Closed
sanchit-gandhi opened this issue Jun 27, 2024 · 3 comments · Fixed by #31661
Closed

Generation with HybridCache fails (affecting Gemma-2) #31664

sanchit-gandhi opened this issue Jun 27, 2024 · 3 comments · Fixed by #31661

Comments

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Jun 27, 2024

System Info

  • transformers version: 4.43.0.dev0
  • Platform: macOS-14.5-arm64-arm-64bit
  • Python version: 3.11.6
  • Huggingface_hub version: 0.23.3
  • Safetensors version: 0.4.2
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.8.1 (cpu)
  • Jax version: 0.4.24
  • JaxLib version: 0.4.24

Who can help?

@sanchit-gandhi

Reproduction

Generation of Gemma-2 currently fails on main, e.g. with the following toy codesnippet:

from transformers.models.gemma2 import Gemma2ForCausalLM, Gemma2Config
import torch

config = Gemma2Config(num_hidden_layers=1, vocab_size=128, hidden_size=16, intermediate_size=32, num_attention_heads=1, num_key_value_heads=1)
model = Gemma2ForCausalLM(config)

input_ids = torch.ones((1, 10), dtype=torch.int)
model.generate(input_ids)

Traceback:

Traceback (most recent call last):
  File "/Users/sanchitgandhi/transformers/debug_gemma2.py", line 8, in <module>
    model.generate(input_ids)
  File "/Users/sanchitgandhi/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sanchitgandhi/transformers/src/transformers/generation/utils.py", line 1914, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/Users/sanchitgandhi/transformers/src/transformers/generation/utils.py", line 2644, in _sample
    model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sanchitgandhi/transformers/src/transformers/generation/utils.py", line 1409, in _get_initial_cache_position
    model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: arange() received an invalid combination of arguments - got (NoneType, int, device=torch.device), but expected one of:
 * (Number end, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (Number start, Number end, *, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (Number start, Number end, Number step, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

Expected behavior

This is fixed by by handling the exception in generate (PR #31661).

cc @ArthurZucker

@ArthurZucker
Copy link
Collaborator

Thanks for fixing!

@fst813
Copy link

fst813 commented Jul 3, 2024

transformer version: 4.42.3
I have another error:

  File "/home/ss/train_frame/LLaMA-Factory/src/train.py", line 30, in <module>
    main()
  File "/home/ss/train_frame/LLaMA-Factory/src/train.py", line 21, in main
    run_exp()
  File "/home/ss/train_frame/LLaMA-Factory/src/llamafactory/train/tuner.py", line 93, in run_exp
    run_exe(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
  File "/home/ss/train_frame/LLaMA-Factory/src/llamafactory/train/tuner.py", line 47, in run_exe
    run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
  File "/home/ss/train_frame/LLaMA-Factory/src/llamafactory/train/sft/workflow.py", line 107, in run_sft
    predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/transformers/trainer_seq2seq.py", line 244, in predict
    return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/transformers/trainer.py", line 3717, in predict
    output = eval_loop(
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/transformers/trainer.py", line 3826, in evaluation_loop
    losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "/home/ss/train_frame/LLaMA-Factory/src/llamafactory/train/sft/trainer.py", line 99, in prediction_step
    loss, generated_tokens, _ = super().prediction_step(  # ignore the returned labels (may be truncated)
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/transformers/trainer_seq2seq.py", line 310, in prediction_step
    generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/transformers/generation/utils.py", line 1914, in generate
    result = self._sample(
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/transformers/generation/utils.py", line 2651, in _sample
    outputs = self(
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/utils/operations.py", line 822, in forward
    return model_forward(*args, **kwargs)
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/utils/operations.py", line 810, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/utils/operations.py", line 789, in convert_to_fp32
    return recursively_apply(_convert_to_fp32, tensor, test_type=_is_fp16_bf16_tensor)
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/utils/operations.py", line 118, in recursively_apply
    {
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/utils/operations.py", line 119, in <dictcomp>
    k: recursively_apply(
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/utils/operations.py", line 126, in recursively_apply
    return func(data, *args, **kwargs)
  File "/home/ss/anaconda3-new/envs/train/lib/python3.10/site-packages/accelerate/utils/operations.py", line 781, in _convert_to_fp32
    return tensor.float()
AttributeError: 'HybridCache' object has no attribute 'float'

@ArthurZucker
Copy link
Collaborator

The cache should never be used when training !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants