Skip to content

Commit bc486d4

Browse files
felipemello1Felipe Melloebsmothers
authored
[bug] fix sharding multimodal (#1889)
Co-authored-by: Felipe Mello <felipemello@fb.com> Co-authored-by: ebsmothers <ebs@meta.com>
1 parent 74139c9 commit bc486d4

File tree

7 files changed

+124
-110
lines changed

7 files changed

+124
-110
lines changed

recipes/configs/llama3_2_vision/11B_full.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ device: cuda
6767

6868
# Memory management
6969
enable_activation_checkpointing: True
70-
custom_sharded_layers: ['tok_embeddings', 'output']
70+
custom_sharded_layers: ['decoder.tok_embeddings']
7171
dtype: bf16
7272

7373
# Logging

recipes/full_finetune_distributed.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,11 @@ def setup(self, cfg: DictConfig) -> None:
225225
self._optimizer = self._setup_optimizer(
226226
cfg_optimizer=cfg.optimizer,
227227
optimizer_in_bwd=self._optimizer_in_bwd,
228-
opt_state_dict=checkpoint_dict[training.OPT_KEY]
229-
if self._resume_from_checkpoint
230-
else None,
228+
opt_state_dict=(
229+
checkpoint_dict[training.OPT_KEY]
230+
if self._resume_from_checkpoint
231+
else None
232+
),
231233
)
232234

233235
# initialize loss
@@ -350,10 +352,10 @@ def _setup_model(
350352
self,
351353
cfg_model: DictConfig,
352354
enable_activation_checkpointing: bool,
353-
custom_sharded_layers: Optional[List[str]],
354355
fsdp_cpu_offload: bool,
355356
reshard_after_forward: bool,
356357
model_state_dict: Dict[str, Any],
358+
custom_sharded_layers: Optional[List[str]] = None,
357359
ac_mode: Optional[str] = None,
358360
ac_option: Optional[int] = None,
359361
) -> nn.Module:
@@ -396,29 +398,13 @@ def _setup_model(
396398
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
397399
)
398400

399-
# For FSDP sharding, we can condition on either the module or its name
400-
# Shard conditions should be callables taking name (relative to model root)
401-
# and the module itself and returning a bool on whether to shard the given module
402-
fsdp_shard_conditions = []
403-
404-
# Shard transformer decoder layers (or AC-wrapped versions)
405-
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
406-
# But directly using the name is more concise
407-
def _is_layer_fqn(s: str) -> bool:
408-
"""
409-
Return True for layers.i and False for all other module names
410-
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
411-
"""
412-
s_list = s.split(".")
413-
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1])
414-
415-
fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)]
416-
417-
# If wrapping any layers separately, we can add another shard condition
418-
# A layer will be sharded if any of the fsdp_shard_conditions are met
419-
if custom_sharded_layers:
420-
fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers]
421-
401+
# For FSDP sharding
402+
fsdp_shard_conditions = [
403+
partial(
404+
training.get_shard_conditions,
405+
names_to_match=custom_sharded_layers,
406+
)
407+
]
422408
training.shard_model(
423409
model=model,
424410
shard_conditions=fsdp_shard_conditions,

recipes/lora_dpo_distributed.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import time
99

1010
from functools import partial
11-
from typing import Any, Dict, Optional, Tuple
11+
from typing import Any, Dict, List, Optional, Tuple
1212
from warnings import warn
1313

1414
import torch
@@ -290,6 +290,7 @@ def _setup_model(
290290
fsdp_cpu_offload: bool,
291291
reshard_after_forward: bool,
292292
base_model_state_dict: Dict[str, Any],
293+
custom_sharded_layers: Optional[List[str]] = None,
293294
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
294295
) -> nn.Module:
295296
"""
@@ -323,28 +324,16 @@ def _setup_model(
323324
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
324325
)
325326

326-
# For FSDP sharding, we can condition on either the module or its name
327-
# Shard conditions should be callables taking name (relative to model root)
328-
# and the module itself and returning a bool on whether to shard the given module
329-
330-
# Shard transformer decoder layers (or AC-wrapped versions)
331-
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
332-
# But directly using the name is more concise
333-
def _is_layer_name(name: str, module: nn.Module) -> bool:
334-
"""
335-
Return True for layers.i and False for all other module names
336-
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
337-
"""
338-
name_list = name.split(".")
339-
return (
340-
len(name_list) == 2
341-
and name_list[0] == "layers"
342-
and str.isdigit(name_list[1])
327+
# For FSDP sharding
328+
fsdp_shard_conditions = [
329+
partial(
330+
training.get_shard_conditions,
331+
names_to_match=custom_sharded_layers,
343332
)
344-
333+
]
345334
training.shard_model(
346335
model=model,
347-
shard_conditions=[_is_layer_name],
336+
shard_conditions=fsdp_shard_conditions,
348337
cpu_offload=fsdp_cpu_offload,
349338
reshard_after_forward=reshard_after_forward,
350339
)

recipes/lora_finetune_distributed.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import time
1010

1111
from functools import partial
12-
from typing import Any, Dict, Optional, Tuple, Union
12+
from typing import Any, Dict, List, Optional, Tuple, Union
1313
from warnings import warn
1414

1515
import torch
@@ -408,6 +408,7 @@ def _setup_model(
408408
fsdp_cpu_offload: bool,
409409
reshard_after_forward: bool,
410410
base_model_state_dict: Dict[str, Any],
411+
custom_sharded_layers: Optional[List[str]] = None,
411412
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
412413
) -> nn.Module:
413414
"""
@@ -445,28 +446,16 @@ def _setup_model(
445446
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
446447
)
447448

448-
# For FSDP sharding, we can condition on either the module or its name
449-
# Shard conditions should be callables taking name (relative to model root)
450-
# and the module itself and returning a bool on whether to shard the given module
451-
452-
# Shard transformer decoder layers (or AC-wrapped versions)
453-
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
454-
# But directly using the name is more concise
455-
def _is_layer_name(name: str, module: nn.Module) -> bool:
456-
"""
457-
Return True for layers.i and False for all other module names
458-
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
459-
"""
460-
name_list = name.split(".")
461-
return (
462-
len(name_list) == 2
463-
and name_list[0] == "layers"
464-
and str.isdigit(name_list[1])
449+
# For FSDP sharding
450+
fsdp_shard_conditions = [
451+
partial(
452+
training.get_shard_conditions,
453+
names_to_match=custom_sharded_layers,
465454
)
466-
455+
]
467456
training.shard_model(
468457
model=model,
469-
shard_conditions=[_is_layer_name],
458+
shard_conditions=fsdp_shard_conditions,
470459
cpu_offload=fsdp_cpu_offload,
471460
reshard_after_forward=reshard_after_forward,
472461
)
@@ -624,13 +613,15 @@ def _setup_data(
624613
sampler=sampler,
625614
# dropping last avoids shape issues with compile + flex attention
626615
drop_last=True,
627-
collate_fn=partial(
628-
collate_fn,
629-
padding_idx=self._tokenizer.pad_id,
630-
ignore_idx=self._loss_fn.ignore_index,
631-
)
632-
if not packed
633-
else padded_collate_packed,
616+
collate_fn=(
617+
partial(
618+
collate_fn,
619+
padding_idx=self._tokenizer.pad_id,
620+
ignore_idx=self._loss_fn.ignore_index,
621+
)
622+
if not packed
623+
else padded_collate_packed
624+
),
634625
)
635626

636627
if self._is_rank_zero:

recipes/qat_distributed.py

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,11 @@ def setup(self, cfg: DictConfig) -> None:
233233

234234
self._optimizer = self._setup_optimizer(
235235
cfg_optimizer=cfg.optimizer,
236-
opt_state_dict=checkpoint_dict[training.OPT_KEY]
237-
if self._resume_from_checkpoint
238-
else None,
236+
opt_state_dict=(
237+
checkpoint_dict[training.OPT_KEY]
238+
if self._resume_from_checkpoint
239+
else None
240+
),
239241
)
240242

241243
# initialize loss
@@ -363,10 +365,10 @@ def _setup_model(
363365
self,
364366
cfg_model: DictConfig,
365367
enable_activation_checkpointing: bool,
366-
custom_sharded_layers: Optional[List[str]],
367368
fsdp_cpu_offload: bool,
368369
reshard_after_forward: bool,
369370
model_state_dict: Dict[str, Any],
371+
custom_sharded_layers: Optional[List[str]] = None,
370372
ac_mode: Optional[str] = None,
371373
ac_option: Optional[int] = None,
372374
quantizer_cfg: Optional[DictConfig] = None,
@@ -420,29 +422,13 @@ def _setup_model(
420422
self._quantizer_mode = quantizer_mode
421423
model = quantizer.prepare(model)
422424

423-
# For FSDP sharding, we can condition on either the module or its name
424-
# Shard conditions should be callables taking name (relative to model root)
425-
# and the module itself and returning a bool on whether to shard the given module
426-
fsdp_shard_conditions = []
427-
428-
# Shard transformer decoder layers (or AC-wrapped versions)
429-
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
430-
# But directly using the name is more concise
431-
def _is_layer_fqn(s: str) -> bool:
432-
"""
433-
Return True for layers.i and False for all other module names
434-
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
435-
"""
436-
s_list = s.split(".")
437-
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1])
438-
439-
fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)]
440-
441-
# If wrapping any layers separately, we can add another shard condition
442-
# A layer will be sharded if any of the fsdp_shard_conditions are met
443-
if custom_sharded_layers:
444-
fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers]
445-
425+
# For FSDP sharding
426+
fsdp_shard_conditions = [
427+
partial(
428+
training.get_shard_conditions,
429+
names_to_match=custom_sharded_layers,
430+
)
431+
]
446432
training.shard_model(
447433
model=model,
448434
shard_conditions=fsdp_shard_conditions,
@@ -525,14 +511,16 @@ def _setup_data(
525511
sampler=sampler,
526512
# dropping last avoids shape issues with compile + flex attention
527513
drop_last=True,
528-
collate_fn=partial(
529-
padded_collate_sft,
530-
padding_idx=self._tokenizer.pad_id,
531-
ignore_idx=self._loss_fn.ignore_index,
532-
)
533-
if not packed
534-
else partial(
535-
padded_collate_packed,
514+
collate_fn=(
515+
partial(
516+
padded_collate_sft,
517+
padding_idx=self._tokenizer.pad_id,
518+
ignore_idx=self._loss_fn.ignore_index,
519+
)
520+
if not packed
521+
else partial(
522+
padded_collate_packed,
523+
)
536524
),
537525
)
538526

torchtune/training/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
get_full_finetune_fsdp_wrap_policy,
1212
get_full_model_state_dict,
1313
get_full_optimizer_state_dict,
14+
get_shard_conditions,
1415
get_world_size_and_rank,
1516
init_distributed,
1617
is_distributed,
@@ -106,6 +107,7 @@
106107
"get_world_size_and_rank",
107108
"set_torch_num_threads",
108109
"shard_model",
110+
"get_shard_conditions",
109111
"prepare_model_for_fsdp_with_meta_device",
110112
"validate_no_params_on_meta_device",
111113
"contains_fsdp",

torchtune/training/_distributed.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,55 @@ def llama3_wrap(module: nn.Module, recurse: bool, **kwargs):
583583
return llama3_wrap
584584

585585

586+
def get_shard_conditions(
587+
name: str,
588+
module: nn.Module,
589+
names_to_match: Optional[List[str]] = None,
590+
*args,
591+
**kwargs,
592+
) -> bool:
593+
"""
594+
Returs True for layers named {}.layers.i or layers that exactly match names_to_match, otherwise,
595+
returns False. This is a helper function for sharding a model with FSDP.
596+
In :func:`~torchtune.training.shard_model`, we iterate over the model's named modules
597+
and apply fully_shard using this condition.
598+
599+
As part of our sharding strategy, we want each layer to be sharded separately, as this is
600+
generally efficient. We may also want to shard certain modules that are not layers, such as
601+
the embedding module.
602+
603+
#TODO: a more robust way would be to shard on the module type, not the name.
604+
605+
Args:
606+
name (str): Name of the module.
607+
module (nn.Module): Module to be sharded.
608+
names_to_match (Optional[List[str]]): List of names to match, if any.
609+
*args: Variable length argument list to be passed to the Embedding module.
610+
**kwargs: Arbitrary keyword arguments to be passed to the Embedding module.
611+
612+
Returns:
613+
bool: True if the module name matches the condition, False otherwise.
614+
615+
Examples:
616+
>>> names_to_match = ["embedding"]
617+
>>> layer_names = ["layers.0", "decoder.layers.1", "encoder.layers.2.attention",
618+
"my_wrapper.layer.1.something", "embedding"]
619+
>>> matches = []
620+
>>> for name in layer_names:
621+
>>> if shard_condition_is_layer_or_match(name, None): matches.append(name)
622+
>>> print(matches)
623+
>>> ["layers.0", "decoder.layers.1", "embedding"]
624+
"""
625+
if names_to_match and name in names_to_match:
626+
return True
627+
628+
name_list = name.split(".")
629+
if len(name_list) >= 2:
630+
return name_list[-2] == "layers" and str.isdigit(name_list[-1])
631+
632+
return False
633+
634+
586635
def shard_model(
587636
model: TransformerDecoder,
588637
shard_conditions: List[Callable[[str, nn.Module], bool]],
@@ -608,16 +657,25 @@ def shard_model(
608657
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
609658
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
610659
660+
Raises:
661+
ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
611662
"""
612663
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward}
613664
if cpu_offload:
614665
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
615666

616667
# Shard the model with FSDP, iterating in reverse to start with
617668
# lowest-level modules first
669+
num_layers_sharded = 0
618670
for n, m in reversed(list(model.named_modules())):
619671
if any([shard_condition(n, m) for shard_condition in shard_conditions]):
620672
fully_shard(m, **fsdp_kwargs)
673+
num_layers_sharded += 1
674+
675+
if num_layers_sharded == 0:
676+
raise ValueError(
677+
"No layer modules were sharded. Please check if shard conditions are working as expected."
678+
)
621679

622680
# Finally shard the entire model to account for any stragglers
623681
fully_shard(model, **fsdp_kwargs)

0 commit comments

Comments
 (0)