Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion recipes/configs/llama3_3/70B_full_multinode.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ checkpointer:
model_type: LLAMA3
resume_from_checkpoint: False

# Parallelism tweaks
tensor_parallel_dim: 8 # 8-way TP
tensor_parallel_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan
data_parallel_shard_dim: -1 # -1 means to infer based on other parallel dims & world size
data_parallel_replicate_dim: 1

# Fine-tuning arguments
batch_size: 4
epochs: 1
Expand All @@ -53,7 +60,6 @@ loss:
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase effective batch size


# Training env
device: cuda

Expand Down
81 changes: 41 additions & 40 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.distributed import (
destroy_process_group,
init_device_mesh,
init_process_group,
)
from torch.distributed import destroy_process_group, init_process_group
from torch.distributed._tensor import DTensor
from torch.distributed.tensor.parallel import parallelize_module
from torch.optim import Optimizer
Expand Down Expand Up @@ -146,20 +142,31 @@ def __init__(self, cfg: DictConfig) -> None:
# Initialize distributed variables
self.world_size, self.rank = utils.get_world_size_and_rank()
self._is_rank_zero = self.rank == 0
self.tensor_parallel_plan = config.instantiate(
cfg.get("tensor_parallel_plan", None)
)
self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", 1)
if self.tensor_parallel_dim > 1 and self.tensor_parallel_plan is None:
self.tp_plan = config.instantiate(cfg.get("tensor_parallel_plan", None))
self.tp_degree = cfg.get("tensor_parallel_dim", 1)
if self.tp_degree > 1 and self.tp_plan is None:
raise ValueError(
"Tensor Parallel plan needs to be provided when tensor parallel is enabled."
)
if self.world_size % self.tensor_parallel_dim != 0:
raise ValueError(
f"world_size {self.world_size} must be divisible by tensor_parallel_dim {self.tensor_parallel_dim}"
data_shard = cfg.get("data_parallel_shard_dim", -1) # -1 means to infer
data_replicate = cfg.get("data_parallel_replicate_dim", 1)

# Set up n-d device mesh
self.parallel_dims = training.ParallelDims(
dp_replicate=data_replicate,
dp_shard=data_shard,
tp=self.tp_degree,
world_size=self.world_size,
)
self.world_mesh = self.parallel_dims.build_mesh(device_type=device_type)
if self.parallel_dims.dp_enabled:
dp_mesh = self.world_mesh["dp"]
self.dp_degree, self.dp_rank = (
dp_mesh.size(),
dp_mesh.get_local_rank(),
)

self.data_parallel_dim = self.world_size // self.tensor_parallel_dim
else:
self.dp_degree, self.dp_rank = 1, 0

# Logging attributes
self._output_dir = cfg.output_dir
Expand Down Expand Up @@ -538,26 +545,18 @@ def _setup_model(
if self._compile:
training.compile_model(model, verbose=self._is_rank_zero)

device_mesh = init_device_mesh(
self._device.type,
mesh_shape=(self.data_parallel_dim, self.tensor_parallel_dim),
mesh_dim_names=("dp", "tp"),
)
self.dp_size = device_mesh["dp"].size()
self.dp_rank = device_mesh["dp"].get_local_rank()

# Apply tensor parallelism to the model
if self.tensor_parallel_dim > 1:
if self.data_parallel_dim == 1 and self.fsdp_cpu_offload:
if self.parallel_dims.tp_enabled:
if not self.parallel_dims.dp_enabled and self.fsdp_cpu_offload:
raise ValueError(
"Tensor parallelism is not supported with FSDP CPU offloading when data parallelism is disabled."
)
# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel
model = training.prepare_mha_for_tp(model, device_mesh["tp"])
model = training.prepare_mha_for_tp(model, self.world_mesh["tp"])
parallelize_module(
model,
device_mesh["tp"],
parallelize_plan=self.tensor_parallel_plan,
self.world_mesh["tp"],
parallelize_plan=self.tp_plan,
)

# We currently have two versions of activation checkpointing in this recipe
Expand All @@ -580,19 +579,25 @@ def _setup_model(
)

# Apply Fully Sharded Data Parallelism to the model
if self.data_parallel_dim > 1:
if self.parallel_dims.dp_shard_enabled:
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
)
]

if self.parallel_dims.dp_replicate_enabled:
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
else:
dp_mesh_dim_names = ("dp_shard",)

training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
dp_mesh=device_mesh["dp"],
dp_mesh=self.world_mesh[dp_mesh_dim_names],
)

with training.set_default_dtype(self._dtype), self._device:
Expand Down Expand Up @@ -629,7 +634,7 @@ def _setup_model(
training.log_memory_stats(memory_stats)

# synchronize before training begins
torch.distributed.barrier()
torch.distributed.barrier(device_ids=[self._device.index])

return model

Expand Down Expand Up @@ -716,7 +721,7 @@ def _setup_data(
collate_fn = _get_component_from_path(collate_fn)

sampler = StatefulDistributedSampler(
ds, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle
ds, num_replicas=self.dp_degree, rank=self.dp_rank, shuffle=shuffle, seed=0
)
dataloader = StatefulDataLoader(
dataset=ds,
Expand All @@ -727,7 +732,7 @@ def _setup_data(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
pad_to_multiple_of=self.tensor_parallel_dim,
pad_to_multiple_of=self.tp_degree,
)
if not packed
else padded_collate_packed
Expand Down Expand Up @@ -811,22 +816,18 @@ def train(self) -> None:
if self._optimizer_in_bwd:
torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)

# We multiply by world_size to undo FSDP2 gradient normalization.
current_loss = current_loss * (self.dp_size / num_tokens)
current_loss = current_loss * (self.dp_degree / num_tokens)

current_loss.backward()

# Step with optimizer
# Optimizer step (if not fused in backward call)
if (idx + 1) % self._gradient_accumulation_steps == 0:
if not self._optimizer_in_bwd:
# Get total number of tokens across all ranks to normalize gradients
torch.distributed.all_reduce(num_tokens)
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, self.dp_size / num_tokens)
training.scale_grads(self._model, self.dp_degree / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down
2 changes: 2 additions & 0 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
is_distributed,
load_from_full_model_state_dict,
load_from_full_optimizer_state_dict,
ParallelDims,
prepare_mha_for_tp,
set_torch_num_threads,
shard_model,
Expand Down Expand Up @@ -91,6 +92,7 @@
"Checkpointer",
"update_state_dict_for_classifier",
"ADAPTER_CONFIG",
"ParallelDims",
"ADAPTER_KEY",
"EPOCHS_KEY",
"MAX_STEPS_KEY",
Expand Down
78 changes: 77 additions & 1 deletion torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import logging
import os
from dataclasses import dataclass
from itertools import chain
from typing import Any, Callable, cast, Dict, List, Optional, Tuple

Expand All @@ -24,7 +25,7 @@
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import ShardingStrategy
from torch.nn.modules.module import _IncompatibleKeys
from torch.optim import Optimizer
Expand All @@ -47,6 +48,81 @@
_DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE = False


@dataclass
class ParallelDims:
dp_replicate: int
dp_shard: int
tp: int
world_size: int

def __post_init__(self):
self._validate()

def _validate(self):
dp_replicate, dp_shard, tp = (
self.dp_replicate,
self.dp_shard,
self.tp,
)
for d in (dp_replicate, tp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"

assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
if dp_shard < 0:
self.dp_shard = dp_shard = self.world_size // (dp_replicate * tp)
assert dp_shard >= 1

assert dp_replicate * dp_shard * tp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"tp({tp}) != WORLD_SIZE({self.world_size})"
)

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.dp_replicate, self.dp_shard, self.tp],
["dp_replicate", "dp_shard", "tp"],
):
if d > 1:
dims.append(d)
names.append(name)

names = tuple(names)
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)

# Create all the submesh here to ensure all required process groups are
# initialized:
# Mesh for data loading (no communication on this mesh)
dp_mesh_dim_names = []

if self.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")
if self.dp_shard_enabled:
dp_mesh_dim_names.append("dp_shard")

if dp_mesh_dim_names != []:
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")

return mesh

@property
def dp_enabled(self):
return self.dp_replicate > 1 or self.dp_shard > 1

@property
def dp_replicate_enabled(self):
return self.dp_replicate > 1

@property
def dp_shard_enabled(self):
return self.dp_shard > 1

@property
def tp_enabled(self):
return self.tp > 1


def _get_sharding_strategy(strategy: str) -> ShardingStrategy:
"""Helper function to convert sharding strategy strings to ShardingStrategy enum."""
return getattr(ShardingStrategy, strategy)
Expand Down