Skip to content

Commit

Permalink
Support LLaMA training through SPMD
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Jul 20, 2023
1 parent 6112b1c commit 4f806cd
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 4 deletions.
57 changes: 57 additions & 0 deletions examples/pytorch/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import torch
from datasets import load_dataset

import torch_xla.debug.profiler as xp
import transformers
from transformers import (
CONFIG_MAPPING,
Expand Down Expand Up @@ -139,6 +140,30 @@ class ModelArguments:
)
},
)
spmd_grad_chkpt: bool = field(
default=False,
metadata={
"help": (
"Apply gradient checkpointing to the model"
)
},
)
spmd_fsdp_sharding: bool = field(
default=False,
metadata={
"help": (
"Will apply XLA SPMD to run FSDP"
)
},
)
spmd_batch_sharding: bool = field(
default=False,
metadata={
"help": (
"Will apply XLA SPMD to shard the input along the batch dimension"
)
},
)

def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
Expand Down Expand Up @@ -238,6 +263,9 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

training_args.spmd_batch_sharding = model_args.spmd_batch_sharding or model_args.spmd_fsdp_sharding
training_args.spmd_fsdp_sharding = model_args.spmd_fsdp_sharding

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_clm", model_args, data_args)
Expand Down Expand Up @@ -285,6 +313,10 @@ def main():
# Set seed before initializing model.
set_seed(training_args.seed)

server = xp.start_server(9012)
logger.info('Profiling server started: {str(server)}')


# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
Expand Down Expand Up @@ -430,6 +462,31 @@ def main():
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))

import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
num_devices = xr.global_device_count()
device_ids = torch.arange(num_devices)
print('Using dtype', model_args.torch_dtype)
model = model.to(xm.xla_device(), dtype=getattr(torch, model_args.torch_dtype))

if model_args.spmd_grad_chkpt:
print("Applying gradient checkpointing")
from torch_xla.distributed.fsdp import checkpoint_module
for i, block in enumerate(model.model.layers):
# LLaMA-specific
model.model.layers[i] = checkpoint_module(block)

if model_args.spmd_fsdp_sharding:
print('Applying FSDP sharding to all parameters')
for name, param in model.named_parameters():
# Shard all parameters along a single axis
print('> Sharding tensor', name)
shape = (num_devices,) + (1,) * (len(param.shape) - 1)
mesh = xs.Mesh(device_ids, shape)
xs.mark_sharding(param, mesh, range(len(param.shape)))


# Preprocessing the datasets.
# First we tokenize all the texts.
if training_args.do_train:
Expand Down
38 changes: 34 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
import sys
import time
import warnings
import torch_xla.debug.profiler as xp
from collections.abc import Mapping
from pathlib import Path
from threading import Thread
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union


Expand Down Expand Up @@ -162,6 +164,7 @@
import datasets

if is_torch_tpu_available(check_device=False):
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

Expand Down Expand Up @@ -838,7 +841,8 @@ def get_train_dataloader(self) -> DataLoader:
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker

return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
# TODO(jonbolin): Disabling Accelerate on the dataloader (`Unknown device SPMD:0`)
return DataLoader(train_dataset, **dataloader_params)

def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
# Deprecated code
Expand Down Expand Up @@ -1444,6 +1448,21 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):

return model

def _xla_sharded_dataloader(self, dataloader):
if is_torch_tpu_available():
sharding_spec = None
if self.args.spmd_batch_sharding:
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla.distributed.parallel_loader as pl
num_devices = xr.global_device_count()
device_ids = np.arange(num_devices)
mesh = xs.Mesh(device_ids, (num_devices, 1))
sharding_spec = xs.ShardingSpec(mesh, (0, 1))
return pl.MpDeviceLoader(dataloader, self.args.device, input_sharding=sharding_spec, loader_prefetch_size=self.args.train_batch_size, device_prefetch_size=4)
else:
return dataloader

def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
Expand Down Expand Up @@ -1537,7 +1556,7 @@ def _inner_training_loop(
self._train_batch_size = batch_size
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
train_dataloader = self._xla_sharded_dataloader(self.get_train_dataloader())

# Setting up training control variables:
# number of training epochs: num_train_epochs
Expand Down Expand Up @@ -1771,7 +1790,13 @@ def _inner_training_loop(
rng_to_sync = True

step = -1
profile_step = int(os.environ.get('PROFILE_STEP', -1))
profile_epoch = int(os.environ.get('PROFILE_EPOCH', -1))
profile_duration = int(os.environ.get('PROFILE_DURATION_MS', 20000))
profile_logdir = os.environ.get('PROFILE_LOGDIR', None)
for step, inputs in enumerate(epoch_iterator):
if step == 0 and epoch == 0:
print('input sharding', {k: (v.shape, torch_xla._XLAC._get_xla_sharding_spec(v)) for k, v in inputs.items()})
total_batched_samples += 1
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
Expand All @@ -1792,6 +1817,10 @@ def _inner_training_loop(
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

if step == profile_step and epoch == profile_epoch:
trace = lambda: xp.trace('127.0.0.1:9012', profile_logdir or tempfile.mkdtemp(), profile_duration or 20000)
Thread(target=trace).start()

with self.accelerator.accumulate(model):
tr_loss_step = self.training_step(model, inputs)

Expand Down Expand Up @@ -2199,7 +2228,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
self.log(logs)

metrics = None
if self.control.should_evaluate:
# TODO(jonbolin): Disabling eval loop
if False: # self.control.should_evaluate:
if isinstance(self.eval_dataset, dict):
metrics = {}
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
Expand Down Expand Up @@ -2914,7 +2944,7 @@ def evaluate(
# memory metrics - must set up as early as possible
self._memory_tracker.start()

eval_dataloader = self.get_eval_dataloader(eval_dataset)
eval_dataloader = self._xla_sharded_dataloader(self.get_eval_dataloader(eval_dataset))
start_time = time.time()

eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
Expand Down

0 comments on commit 4f806cd

Please sign in to comment.