Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ROPE extension issue and device mismatch #840

Merged
merged 5 commits into from
Jul 31, 2024

Conversation

xyangk
Copy link
Contributor

@xyangk xyangk commented Jul 31, 2024

Environment:

==((====))==  Unsloth 2024.8: Fast Llama patching. Transformers = 4.43.3.
   \\   /|    GPU: NVIDIA A40. Max memory: 44.352 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.0+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.26.post1. FA2 = True]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth

Initial error:

  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/transformers/trainer.py", line 3318, in training_step
    loss = self.compute_loss(model, inputs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss
    outputs = model(**inputs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/llama.py", line 959, in PeftModelForCausalLM_fast_forward
    return self.base_model(
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 179, in forward
    return self.model.forward(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/llama.py", line 878, in _CausalLM_fast_forward
    outputs = self.model(
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/llama.py", line 715, in LlamaModel_fast_forward
    hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 115, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/_utils.py", line 645, in forward
    output = forward_function(hidden_states, *args)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/llama.py", line 467, in LlamaDecoderLayer_fast_forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/llama.py", line 349, in LlamaAttention_fast_forward
    self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/llama.py", line 1158, in extend_rope_embedding
    self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/llama.py", line 1110, in _set_cos_sin_cache
    freqs = torch.outer(t, self.inv_freq)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I examined self.inv_freq and found that although it was initialized on the CPU, it's now on the GPU, which caused this error:
CleanShot 2024-07-31 at 15 22 36@2x

To resolve this, I modified t to be on the same device as self.inv_freq, which solved the initial problem.

However, a new error then occurred:

  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/transformers/trainer.py", line 3318, in training_step
    loss = self.compute_loss(model, inputs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss
    outputs = model(**inputs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/accelerate/utils/operations.py", line 822, in forward
    return model_forward(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/accelerate/utils/operations.py", line 810, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/llama.py", line 958, in PeftModelForCausalLM_fast_forward
    return self.base_model(
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 179, in forward
    return self.model.forward(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/llama.py", line 877, in _CausalLM_fast_forward
    outputs = self.model(
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/llama.py", line 714, in LlamaModel_fast_forward
    hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 115, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/_utils.py", line 645, in forward
    output = forward_function(hidden_states, *args)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/llama.py", line 466, in LlamaDecoderLayer_fast_forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/models/llama.py", line 353, in LlamaAttention_fast_forward
    Q, K = fast_rope_embedding(Q, K, cos, sin)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/kernels/rope_embedding.py", line 135, in fast_rope_embedding
    Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/data/miniconda3/envs/py310/lib/python3.10/site-packages/unsloth/kernels/rope_embedding.py", line 81, in forward
    assert(seq_len <= cos.shape[0])
AssertionError

Upon checking the output, I discovered that my input sequence length is 34k, which exceeds the initial current_rope_size. It appears that extend_rope_embedding was not effective. The reason for this is that round() was used for rounding, preventing the rope size from increasing. To address this, I changed it to use ceiling rounding instead.

@Serega6678
Copy link

I am experiencing the same issue
Thank you @xyangk
Does the fine-tuning work for you now on long documents?

@xyangk
Copy link
Contributor Author

xyangk commented Jul 31, 2024

I am experiencing the same issue Thank you @xyangk Does the fine-tuning work for you now on long documents?

Yes, It works now.

@danielhanchen danielhanchen changed the base branch from main to nightly July 31, 2024 19:04
@danielhanchen danielhanchen merged commit 2de1427 into unslothai:nightly Jul 31, 2024
danielhanchen added a commit that referenced this pull request Jul 31, 2024
* bugs

* Update _utils.py

* flash-attn softcapping

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update mapper.py

* Update README.md

* Update _utils.py

* Fix ROPE extension issue and device mismatch (#840)

* When an exception has been assigned using as target, it is cleared at the end of the except clause.(https://docs.python.org/3/reference/compound_stmts.html#the-try-statement)

* Update loader.py

* round up to extend rope size

* inv_freq.device changed, make sure they are on the same device

---------

Co-authored-by: xiaoyang <xiaoyang@youzan.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Update gemma.py

---------

Co-authored-by: XiaoYang <xyangk@gmail.com>
Co-authored-by: xiaoyang <xiaoyang@youzan.com>
@danielhanchen
Copy link
Contributor

Thanks for this! I'll first merge this for now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants