@@ -583,6 +583,55 @@ def llama3_wrap(module: nn.Module, recurse: bool, **kwargs):
583583 return llama3_wrap
584584
585585
586+ def get_shard_conditions (
587+ name : str ,
588+ module : nn .Module ,
589+ names_to_match : Optional [List [str ]] = None ,
590+ * args ,
591+ ** kwargs ,
592+ ) -> bool :
593+ """
594+ Returs True for layers named {}.layers.i or layers that exactly match names_to_match, otherwise,
595+ returns False. This is a helper function for sharding a model with FSDP.
596+ In :func:`~torchtune.training.shard_model`, we iterate over the model's named modules
597+ and apply fully_shard using this condition.
598+
599+ As part of our sharding strategy, we want each layer to be sharded separately, as this is
600+ generally efficient. We may also want to shard certain modules that are not layers, such as
601+ the embedding module.
602+
603+ #TODO: a more robust way would be to shard on the module type, not the name.
604+
605+ Args:
606+ name (str): Name of the module.
607+ module (nn.Module): Module to be sharded.
608+ names_to_match (Optional[List[str]]): List of names to match, if any.
609+ *args: Variable length argument list to be passed to the Embedding module.
610+ **kwargs: Arbitrary keyword arguments to be passed to the Embedding module.
611+
612+ Returns:
613+ bool: True if the module name matches the condition, False otherwise.
614+
615+ Examples:
616+ >>> names_to_match = ["embedding"]
617+ >>> layer_names = ["layers.0", "decoder.layers.1", "encoder.layers.2.attention",
618+ "my_wrapper.layer.1.something", "embedding"]
619+ >>> matches = []
620+ >>> for name in layer_names:
621+ >>> if shard_condition_is_layer_or_match(name, None): matches.append(name)
622+ >>> print(matches)
623+ >>> ["layers.0", "decoder.layers.1", "embedding"]
624+ """
625+ if names_to_match and name in names_to_match :
626+ return True
627+
628+ name_list = name .split ("." )
629+ if len (name_list ) >= 2 :
630+ return name_list [- 2 ] == "layers" and str .isdigit (name_list [- 1 ])
631+
632+ return False
633+
634+
586635def shard_model (
587636 model : TransformerDecoder ,
588637 shard_conditions : List [Callable [[str , nn .Module ], bool ]],
@@ -608,16 +657,25 @@ def shard_model(
608657 the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
609658 from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
610659
660+ Raises:
661+ ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
611662 """
612663 fsdp_kwargs = {"reshard_after_forward" : reshard_after_forward }
613664 if cpu_offload :
614665 fsdp_kwargs ["offload_policy" ] = CPUOffloadPolicy ()
615666
616667 # Shard the model with FSDP, iterating in reverse to start with
617668 # lowest-level modules first
669+ num_layers_sharded = 0
618670 for n , m in reversed (list (model .named_modules ())):
619671 if any ([shard_condition (n , m ) for shard_condition in shard_conditions ]):
620672 fully_shard (m , ** fsdp_kwargs )
673+ num_layers_sharded += 1
674+
675+ if num_layers_sharded == 0 :
676+ raise ValueError (
677+ "No layer modules were sharded. Please check if shard conditions are working as expected."
678+ )
621679
622680 # Finally shard the entire model to account for any stragglers
623681 fully_shard (model , ** fsdp_kwargs )
0 commit comments