Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed ZeRO stage3+Qwen2/Qwen2-57B-A14B-Instruct: RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) #32312

Closed
2 of 4 tasks
RobertXWL opened this issue Jul 30, 2024 · 9 comments

Comments

@RobertXWL
Copy link

RobertXWL commented Jul 30, 2024

System Info

transformers==4.40.0
trl>=0.8.6
deepspeed==0.9.3
gpu: A100-SXM4-80GB*32

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

dz_zero3.json

{
    "train_micro_batch_size_per_gpu": "auto",
    "bf16": {
      "enabled": "auto"
    },
    "fp16": {
      "enabled": "auto",
      "loss_scale": 0,
      "initial_scale_power": 16,
      "loss_scale_window": 1000,
      "hysteresis": 2,
      "min_loss_scale": 1
    },
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    }
}

train.py

def train(args):
    model_checkpoint = args.model_checkpoint
    print(f"model_checkpoint:  {model_checkpoint} @@@###")
    Targs = TrainingArguments(
        report_to="tensorboard",
        output_dir=args.output_dir,
        evaluation_strategy=args.eval_save_strategy,
        save_strategy=args.eval_save_strategy,
        save_steps=args.save_steps,
        eval_steps=args.eval_steps,
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        num_train_epochs=args.num_train_epochs,
        weight_decay=args.weight_decay,
        load_best_model_at_end=False,
        overwrite_output_dir=True,
        deepspeed=args.deepspeed,
        bf16=args.bf16,
    )
    ds = load_dataset("json", data_files={"train": args.train_dataset_path, "test": args.test_dataset_path})
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
    ds = change_data(ds, args.data_type)

    def model_init():
        return AutoModelForSequenceClassification.from_pretrained(model_checkpoint, return_dict=True)

    def preprocess_function(examples):
        return tokenizer(examples["text"], truncation=True, max_length=4096)

    encoded_ds = ds.map(preprocess_function, batched=True, num_proc=100)
    encoded_ds = encoded_ds.rename_column('label', 'labels')
    encoded_ds.set_format(type='torch', columns=['text', 'input_ids', 'attention_mask', 'labels'])
    print(f"First train sample device: {next(iter(encoded_ds['train']))['input_ids'].device}")
    print(f"First test sample device: {next(iter(encoded_ds['test']))['input_ids'].device}")

    # encoded_ds.set_format("torch", device="cuda") 
    num_labels = {0: 108, 1: 48, 2: 27, 3: 12}
    current_time = datetime.now().strftime("%m%d")
    args.output_dir = f"{args.output_dir}/{current_time}_{args.sub_name}"
        if not args.use_param_search:
        model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels[args.data_type])

    def compute_metrics(p):
        metric = load_metric("accuracy")
        preds = np.argmax(p.predictions, axis=1)
        return metric.compute(predictions=preds, references=p.label_ids)

    if args.use_param_search:
        trainer = Trainer(
            args=Targs,
            train_dataset=encoded_ds["train"],
            eval_dataset=encoded_ds["test"],
            tokenizer=tokenizer,
            model_init=model_init,
            compute_metrics=compute_metrics,
        )
        best_run = trainer.hyperparameter_search(n_trials=4, direction="maximize")
        for n, v in best_run.hyperparameters.items():
            setattr(trainer.args, n, v)
        trainer.train()
    else:
        trainer = Trainer(
            model=model,
            args=Targs,
            train_dataset=encoded_ds["train"],
            eval_dataset=encoded_ds["test"],
            tokenizer=tokenizer,
            compute_metrics=compute_metrics,
        )
        trainer.train()

error message

Traceback (most recent call last):
File "/checkpoint/binary/train_package/./rm_cls.py", line 202, in <module>
train(script_args)
File "/checkpoint/binary/train_package/./rm_cls.py", line 159, in train
trainer.train()
File "/root/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
return inner_training_loop(
File "/root/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/root/.local/lib/python3.10/site-packages/transformers/trainer.py", line 3138, in training_step
loss = self.compute_loss(model, inputs)
File "/root/.local/lib/python3.10/site-packages/transformers/trainer.py", line 3161, in compute_loss
outputs = model(**inputs)
File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1736, in forward
loss = self.module(*inputs, **kwargs)
File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = forward_call(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 1528, in forward
transformer_outputs = self.model(
File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = forward_call(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 1219, in forward
layer_outputs = decoder_layer(
File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = forward_call(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 915, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = forward_call(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 753, in forward
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
File "/root/.local/lib/python3.10/site-packages/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 240, in apply_rotary_pos_emb
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

bug

I found that there is a device mismatch in the apply_rotary_pos_emb function in modeling_qwen2_moe.py, where both cos and sin are on the CPU, but q, k, and position_ids are on CUDA.

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

It's quite strange. I'm not sure what the issue is, but when I made the following temporary modifications, it was able to run normally.:

    if position_ids.device != cos.device:
        +cos = cos.to(position_ids.device)
        -# print(f"Moved cos to device: {cos.device}")
        +sin = sin.to(position_ids.device)
        -# print(f"Moved sin to device: {sin.device}")

Expected behavior

How can I resolve this device inconsistency issue? Is it due to my configuration?

@RobertXWL RobertXWL added the bug label Jul 30, 2024
@ArthurZucker
Copy link
Collaborator

cc @muellerzr and @SunMarc !

@orrzohar
Copy link
Contributor

Same issue encountered with LLaMA, ZeRO3, transformers==4.44.0, deepspeed==0.12.6, accelerate==0.33.0

@ArthurZucker
Copy link
Collaborator

I'll ping @muellerzr and @SunMarc internally on this, thanks for reporting

@irislin1006
Copy link

Hi @ArthurZucker, (just fyi), I would like to try to help me with this ticket 🤗

@irislin1006
Copy link

irislin1006 commented Sep 10, 2024

This issue has been resolved after I rebased on top of the master branch, where it was addressed by a merge 14 hours ago 65bb284#diff-b1a8dc3bcd0052fd5b8db79800ba0f5049656730dd5788990cf41d36badc354e .

Problem Summary

A device mismatch occurred in the apply_rotary_pos_emb function of the Qwen2 model. Specifically, the cos and sin tensors (used for rotary positional embeddings) were on the CPU, while the q, k, and position_ids tensors were on the GPU. This mismatch led to a runtime error during training.

Resolution

There's also a small issue in your train.py script. Ensure that the model is explicitly moved to the correct device with the following code:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Reference

Here is the environment I used to reproduce the error and diagnose the issue: https://github.com/irislin1006/deepSeedZero

@SunMarc
Copy link
Member

SunMarc commented Sep 10, 2024

Thanks for the help @irislin1006 ! Really appreciate that ! Does it solves your issue @RobertXWL ?

@LittleGreenYuan
Copy link

LittleGreenYuan commented Sep 19, 2024

I also encountered a similar situation, but during backward, I used the Qwen2 model for LoRA fine-tuning.
If the 'offload_optimizer' is set to 'auto', this issue will not occur (deepspeed will be converted to none during initialization).

"zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
           "device": "cpu",
           "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e5,
        "stage3_max_reuse_distance": 1e5,
        "stage3_gather_16bit_weights_on_model_save": true
    },
[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/ps/Code/YamlStarter/corellm/sftSCD2024/called_sft.py", line 210, in <module>
[rank1]:     trainer.train(
[rank1]:   File "/home/ps/anaconda3/envs/minillm/lib/python3.12/site-packages/transformers/trainer.py", line 1859, in train
[rank1]:     return inner_training_loop(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/ps/anaconda3/envs/minillm/lib/python3.12/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
[rank1]:     tr_loss_step = self.training_step(model, inputs)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/ps/anaconda3/envs/minillm/lib/python3.12/site-packages/transformers/trainer.py", line 3147, in training_step
[rank1]:     self.accelerator.backward(loss)
[rank1]:   File "/home/ps/anaconda3/envs/minillm/lib/python3.12/site-packages/accelerate/accelerator.py", line 2143, in backward
[rank1]:     self.deepspeed_engine_wrapped.backward(loss, **kwargs)
[rank1]:   File "/home/ps/anaconda3/envs/minillm/lib/python3.12/site-packages/accelerate/utils/deepspeed.py", line 175, in backward
[rank1]:     self.engine.step()
[rank1]:   File "/home/ps/anaconda3/envs/minillm/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2169, in step
[rank1]:     self._take_model_step(lr_kwargs)
[rank1]:   File "/home/ps/anaconda3/envs/minillm/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step
[rank1]:     self.optimizer.step()
[rank1]:   File "/home/ps/anaconda3/envs/minillm/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/ps/anaconda3/envs/minillm/lib/python3.12/site-packages/deepspeed/runtime/zero/stage3.py", line 2047, in step
[rank1]:     self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
[rank1]:   File "/home/ps/anaconda3/envs/minillm/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/ps/anaconda3/envs/minillm/lib/python3.12/site-packages/deepspeed/runtime/zero/stage3.py", line 2117, in unscale_and_clip_grads
[rank1]:     self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)
[rank1]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu!

@SunMarc
Copy link
Member

SunMarc commented Sep 27, 2024

Hey @orrzohar @RobertXWL, Is the solution that @irislin1006 working for you?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants