Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to fine-tune llama3.1 8b with custom dataset on multiple gpus ? #1202

Open
dataai1205 opened this issue Nov 11, 2024 · 0 comments
Open

Comments

@dataai1205
Copy link

Describe the bug

When finetune llama with custom dataset, this error occur: RuntimeError: chunk expects at least a 1-dimensional tensor
The same code works on single GPU but not works on multiple GPUs.
What is the possible reasons ?

Minimal reproducible example

import torch, multiprocessing
from datasets import load_dataset, Dataset
from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    set_seed,
    TrainingArguments
)
from trl import SFTTrainer
from peft.utils.other import fsdp_auto_wrap_policy
from accelerate import Accelerator
import os

accelerator = Accelerator()
set_seed(1234)
#use bf16 and FlashAttention if supported
if torch.cuda.is_bf16_supported():
    os.system('pip install flash_attn')
    compute_dtype = torch.bfloat16
    attn_implementation = 'flash_attention_2'
else:
    compute_dtype = torch.float16
    attn_implementation = 'sdpa'

# Define model and tokenizer

model_id = "meta-llama/Meta-Llama-3-8B-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation)

# Dataset mapping
------
------
------

if hasattr(accelerator.state, 'fsdp_plugin'):  
    fsdp_plugin = accelerator.state.fsdp_plugin  
    fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.model)  
else:  
    print("FSDP plugin is not available.")  
trainer.train()

output_dir="./fine-tuned-output"

model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

Output

Currently training with a batch size of: 2
The following columns in the training set don't have a corresponding argument in `PeftModelForCausalLM.forward` and have been ignored: text. If text are not expected by `PeftModelForCausalLM.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 489
  Num Epochs = 4
  Instantaneous batch size per device = 1
  Training with DataParallel so batch size has been adjusted to: 2
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 16
  Total optimization steps = 50
  Number of trainable parameters = 41,943,040
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[42], line 8
      6 else:  
      7     print("FSDP plugin is not available.")  
----> 8 trainer.train()
     10 output_dir="./fine-tuned-output"
     12 model.save_pretrained(output_dir)

File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2123, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2121         hf_hub_utils.enable_progress_bars()
   2122 else:
-> 2123     return inner_training_loop(
   2124         args=args,
   2125         resume_from_checkpoint=resume_from_checkpoint,
   2126         trial=trial,
   2127         ignore_keys_for_eval=ignore_keys_for_eval,
   2128     )

File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2481, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2475 context = (
   2476     functools.partial(self.accelerator.no_sync, model=model)
   2477     if i == len(batch_samples) - 1
   2478     else contextlib.nullcontext
   2479 )
   2480 with context():
-> 2481     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
   2483 if (
   2484     args.logging_nan_inf_filter
   2485     and not is_torch_xla_available()
   2486     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   2487 ):
   2488     # if loss is nan or inf simply add the average of previous logged losses
   2489     tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:3579, in Trainer.training_step(self, model, inputs, num_items_in_batch)
   3576     return loss_mb.reduce_mean().detach().to(self.args.device)
   3578 with self.compute_loss_context_manager():
-> 3579     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
   3581 del inputs
   3582 if (
   3583     self.args.torch_empty_cache_steps is not None
   3584     and self.state.global_step % self.args.torch_empty_cache_steps == 0
   3585 ):

File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:3633, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3631         loss_kwargs["num_items_in_batch"] = num_items_in_batch
   3632     inputs = {**inputs, **loss_kwargs}
-> 3633 outputs = model(**inputs)
   3634 # Save past state if it exists
   3635 # TODO: this needs to be fixed and made cleaner later.
   3636 if self.args.past_index >= 0:

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:175, in DataParallel.forward(self, *inputs, **kwargs)
    170     if t.device != self.src_device_obj:
    171         raise RuntimeError("module must have its parameters and buffers "
    172                            f"on device {self.src_device_obj} (device_ids[0]) but found one of "
    173                            f"them on device: {t.device}")
--> 175 inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids)
    176 # for forward function without any inputs, empty list and dict will be created
    177 # so the module can be executed on one device which is the first one in device_ids
    178 if not inputs and not module_kwargs:

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:197, in DataParallel.scatter(self, inputs, kwargs, device_ids)
    191 def scatter(
    192     self,
    193     inputs: Tuple[Any, ...],
    194     kwargs: Optional[Dict[str, Any]],
    195     device_ids: Sequence[Union[int, torch.device]],
    196 ) -> Any:
--> 197     return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:73, in scatter_kwargs(inputs, kwargs, target_gpus, dim)
     71 r"""Scatter with support for kwargs dictionary."""
     72 scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else []
---> 73 scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
     74 if len(scattered_inputs) < len(scattered_kwargs):
     75     scattered_inputs.extend(() for _ in range(len(scattered_kwargs) - len(scattered_inputs)))

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:59, in scatter(inputs, target_gpus, dim)
     53 # After scatter_map is called, a scatter_map cell will exist. This cell
     54 # has a reference to the actual function scatter_map, which has references
     55 # to a closure that has a reference to the scatter_map cell (because the
     56 # fn is recursive). To avoid this reference cycle, we set the function to
     57 # None, clearing the cell
     58 try:
---> 59     res = scatter_map(inputs)
     60 finally:
     61     scatter_map = None  # type: ignore[assignment]

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:50, in scatter.<locals>.scatter_map(obj)
     48     return [list(i) for i in zip(*map(scatter_map, obj))]
     49 if isinstance(obj, dict) and len(obj) > 0:
---> 50     return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
     51 return [obj for _ in target_gpus]

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:46, in scatter.<locals>.scatter_map(obj)
     44     return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
     45 if isinstance(obj, tuple) and len(obj) > 0:
---> 46     return list(zip(*map(scatter_map, obj)))
     47 if isinstance(obj, list) and len(obj) > 0:
     48     return [list(i) for i in zip(*map(scatter_map, obj))]

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:42, in scatter.<locals>.scatter_map(obj)
     40 def scatter_map(obj):
     41     if isinstance(obj, torch.Tensor):
---> 42         return Scatter.apply(target_gpus, None, dim, obj)
     43     if _is_namedtuple(obj):
     44         return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]

File /opt/conda/lib/python3.11/site-packages/torch/autograd/function.py:553, in Function.apply(cls, *args, **kwargs)
    550 if not torch._C._are_functorch_transforms_active():
    551     # See NOTE: [functorch vjp and autograd interaction]
    552     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553     return super().apply(*args, **kwargs)  # type: ignore[misc]
    555 if not is_setup_ctx_defined:
    556     raise RuntimeError(
    557         "In order to use an autograd.Function with functorch transforms "
    558         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    559         "staticmethod. For more details, please see "
    560         "https://pytorch.org/docs/master/notes/extending.func.html"
    561     )

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:96, in Scatter.forward(ctx, target_gpus, chunk_sizes, dim, input)
     93 if torch.cuda.is_available() and ctx.input_device == -1:
     94     # Perform CPU to GPU copies in a background stream
     95     streams = [_get_stream(torch.device("cuda", device)) for device in target_gpus]
---> 96 outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
     97 # Synchronize with the copy stream
     98 if streams is not None:

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/comm.py:187, in scatter(tensor, devices, chunk_sizes, dim, streams, out)
    185 if out is None:
    186     devices = [_get_device_index(d) for d in devices]
--> 187     return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
    188 else:
    189     if devices is not None:

RuntimeError: chunk expects at least a 1-dimensional tensor

Runtime Environment

  • Model: [eg: llama-3.1 8b]: llama-3.1 8b
  • Using via huggingface?: [yes/no] yes
  • OS: [eg. Linux/Ubuntu, Windows]: Jupyter Lab (Ubuntu, CUDA12.0, Pytorch 2.1)
  • GPU VRAM: 151GB * 2
  • Number of GPUs: 2
  • GPU Make: [eg: Nvidia, AMD, Intel]: H200

Additional context
Add any other context about the problem or environment here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant