|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from dataclasses import dataclass, field |
| 8 | + |
| 9 | +from torchtitan.config.job_config import FaultTolerance as BaseFaultTolerance |
| 10 | + |
| 11 | + |
| 12 | +@dataclass |
| 13 | +class FaultTolerance(BaseFaultTolerance): |
| 14 | + """ |
| 15 | + Extends fault tolerance to also support Streaming DiLoCo |
| 16 | + """ |
| 17 | + |
| 18 | + sync_steps: int = 5 |
| 19 | + """ |
| 20 | + Number of steps to wait before performing synchronization. This is only used when "semi_sync_method" |
| 21 | + is set. |
| 22 | + """ |
| 23 | + |
| 24 | + should_quantize: bool = False |
| 25 | + """ |
| 26 | + Whether to quantize the gradients before allreduce. |
| 27 | +
|
| 28 | + Disabled by default since the quantization does utilize the GPU |
| 29 | + and uses more collectives. Enabling this requires knowing about |
| 30 | + the tradeoffs between GPU utilization and communication. |
| 31 | +
|
| 32 | +
|
| 33 | + This is only used when "semi_sync_method" is set. |
| 34 | + """ |
| 35 | + |
| 36 | + fragment_sync_delay: int = 0 |
| 37 | + """ |
| 38 | + Controls the number of inner steps to wait before blocking on a |
| 39 | + model fragment's synchronization. This is the "tao" parameter in |
| 40 | + the Streaming DiLoCo paper. |
| 41 | +
|
| 42 | + By default, each model fragment will be synced at the same step |
| 43 | + at which the allreduce is issued. Enabling delay can improve |
| 44 | + communication and computation overlap, but at the cost of compromising |
| 45 | + model quality |
| 46 | +
|
| 47 | + This is only used when "semi_sync_method" is set. |
| 48 | + """ |
| 49 | + |
| 50 | + fragment_update_alpha: float = 0.0 |
| 51 | + """ |
| 52 | + Determines how to mix the local and global optimized parameters |
| 53 | +
|
| 54 | + By default, we just use the global parameters. This ensures all |
| 55 | + DDP replicas have the same parameters after syncrhonizing on |
| 56 | + the fragment. Tuning this can also affect the model quality. |
| 57 | +
|
| 58 | + This is only used when "semi_sync_method" is set. |
| 59 | + """ |
| 60 | + |
| 61 | + module_fqns_per_model_fragment: list[list[str]] = field(default_factory=list) |
| 62 | + """ |
| 63 | + Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model fragment. |
| 64 | + Each inner list represents one model fragment and contains the module names that belong to that fragment. |
| 65 | + e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']] |
| 66 | + will create 3 chunks: the first containing tok_embeddings and layers.0, |
| 67 | + the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4. |
| 68 | + """ |
| 69 | + |
| 70 | + num_fragments: int = 1 |
| 71 | + """ |
| 72 | + Number of fragments to split the model into. This is only used when "semi_sync_method" is "diloco". |
| 73 | + This is used to automatically split the model into fragments provided that the model |
| 74 | + implements FaultTolerantTrainSpec |
| 75 | + """ |
| 76 | + |
| 77 | + |
| 78 | +@dataclass |
| 79 | +class JobConfig: |
| 80 | + fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) |
0 commit comments