Skip to content

Latest commit

 

History

History
95 lines (69 loc) · 3.82 KB

llama_tuning.md

File metadata and controls

95 lines (69 loc) · 3.82 KB

Tuning Llama Model

Training Large Language Models (LLMs) on Google Tensor Processing Units (TPUs) with Single Program Multiple Data (SPMD) offers a multitude of benefits. TPUs provide competitive processing power, enabling good training times and allowing researchers to experiment with larger models and datasets efficiently. SPMD architecture optimizes resource utilization by distributing tasks across multiple TPUs, enhancing parallelism and scalability. The easiest approach to tune a model with SPMD is using Fully Sharded Data Parallel (FSDP). Pytorch/XLA most recent and performant implementation is FSDP v2, that allows to shard weights, activations and outputs.

This example shows how to tune Meta's LLama2 and Llama3 models on single host TPUs. For information on TPUs architecture, you can consult the documentation.

Prerequisites

We consider you have already created a single-host TPU VM, such as a v5litepod8 setup, and you have ssh access to the machine. You need to install few modules:

pip install datasets evaluate

Note that to work with the gated model, you will need to export the HF_TOKEN variable, or authenticate using the huggingface-cli login command (see here for details).

Instructions

To use FSDPv2, it needs to be enabled:

from optimum.tpu import fsdp_v2
fsdp_v2.use_fsdp_v2()

Then, the tokenizer and model need to be loaded. We will choose meta-llama/Meta-Llama-3-8B for this example

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Add custom token for padding Llama
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

To tune the model with the Abirate/english_quotes dataset, you can load it and obtain the quote column:

from datasets import load_dataset

data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

You then need to specify the FSDP training arguments to enable the sharding feature,the function will deduce the classes that should be sharded:

fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)

Now training can be done as simply as using the standard Trainer class:

from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
trainer = Trainer(
    model=model,
    train_dataset=data["train"],
    args=TrainingArguments(
        per_device_train_batch_size=24,
        num_train_epochs=10,
        max_steps=-1,
        output_dir="/tmp/output",
        optim="adafactor",
        logging_steps=1,
        dataloader_drop_last=True,  # Required by FSDP v2 and SPMD.
        **fsdp_training_args,
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

trainer.train()