66from typing import TYPE_CHECKING , Any , Literal , Optional , Union
77
88import torch
9- from pydantic import model_validator
9+ from pydantic import TypeAdapter , model_validator
1010from pydantic .dataclasses import dataclass
1111from torch .distributed import ProcessGroup , ReduceOp
1212from typing_extensions import Self
3232DistributedExecutorBackend = Literal ["ray" , "mp" , "uni" , "external_launcher" ]
3333
3434
35+ @config
36+ @dataclass
37+ class EPLBConfig :
38+ """Configuration for Expert Parallel Load Balancing (EP)."""
39+
40+ window_size : int = 1000
41+ """Window size for expert load recording."""
42+ step_interval : int = 3000
43+ """
44+ Interval for rearranging experts in expert parallelism.
45+
46+ Note that if this is greater than the EPLB window size, only the metrics
47+ of the last `lb_window_size` steps will be used for rearranging experts.
48+ """
49+
50+ num_redundant_experts : int = 0
51+ """Number of redundant experts to use for expert parallelism."""
52+
53+ log_balancedness : bool = False
54+ """
55+ Log the balancedness each step of expert parallelism.
56+ This is turned off by default since it will cause communication overhead.
57+ """
58+
59+ @classmethod
60+ def from_cli (cls , cli_value : str ) -> "EPLBConfig" :
61+ """Parse the CLI value for the compilation config.
62+ -O1, -O2, -O3, etc. is handled in FlexibleArgumentParser.
63+ """
64+ return TypeAdapter (EPLBConfig ).validate_json (cli_value )
65+
66+
3567@config
3668@dataclass
3769class ParallelConfig :
@@ -75,22 +107,24 @@ class ParallelConfig:
75107 """Use expert parallelism instead of tensor parallelism for MoE layers."""
76108 enable_eplb : bool = False
77109 """Enable expert parallelism load balancing for MoE layers."""
78- num_redundant_experts : int = 0
79- """Number of redundant experts to use for expert parallelism."""
80- eplb_window_size : int = 1000
81- """Window size for expert load recording."""
82- eplb_step_interval : int = 3000
83- """
84- Interval for rearranging experts in expert parallelism.
85-
86- Note that if this is greater than the EPLB window size, only the metrics
87- of the last `eplb_window_size` steps will be used for rearranging experts.
88- """
89- eplb_log_balancedness : bool = False
90- """
91- Log the balancedness each step of expert parallelism.
92- This is turned off by default since it will cause communication overhead.
93- """
110+ eplb_config : EPLBConfig = field (default_factory = EPLBConfig )
111+ """Expert parallelism configuration."""
112+ num_redundant_experts : Optional [int ] = None
113+ """`num_redundant_experts` is deprecated and has been replaced with
114+ `eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
115+ Please use `eplb_config.num_redundant_experts` instead."""
116+ eplb_window_size : Optional [int ] = None
117+ """`eplb_window_size` is deprecated and has been replaced with
118+ `eplb_config.window_size`. This will be removed in v0.12.0.
119+ Please use `eplb_config.window_size` instead."""
120+ eplb_step_interval : Optional [int ] = None
121+ """`eplb_step_interval` is deprecated and has been replaced with
122+ `eplb_config.step_interval`. This will be removed in v0.12.0.
123+ Please use `eplb_config.step_interval` instead."""
124+ eplb_log_balancedness : Optional [bool ] = None
125+ """`eplb_log_balancedness` is deprecated and has been replaced with
126+ `eplb_config.log_balancedness`. This will be removed in v0.12.0.
127+ Please use `eplb_config.log_balancedness` instead."""
94128
95129 max_parallel_loading_workers : Optional [int ] = None
96130 """Maximum number of parallel loading workers when loading model
@@ -237,6 +271,38 @@ def compute_hash(self):
237271 return hashlib .sha256 (str (factors ).encode ()).hexdigest ()
238272
239273 def __post_init__ (self ) -> None :
274+ # Forward deprecated fields to their new location
275+ if self .num_redundant_experts is not None :
276+ self .eplb_config .num_redundant_experts = (
277+ self .num_redundant_experts )
278+ logger .warning_once (
279+ "num_redundant_experts is deprecated and has been replaced "
280+ "with eplb_config.num_redundant_experts. This will be removed "
281+ "in v0.12.0. Changing this field after initialization will "
282+ "have no effect." )
283+ if self .eplb_window_size is not None :
284+ self .eplb_config .window_size = self .eplb_window_size
285+ logger .warning_once (
286+ "eplb_window_size is deprecated and has been replaced "
287+ "with eplb_config.window_size. This will be removed "
288+ "in v0.12.0. Changing this field after initialization will "
289+ "have no effect." )
290+ if self .eplb_step_interval is not None :
291+ self .eplb_config .step_interval = self .eplb_step_interval
292+ logger .warning_once (
293+ "eplb_step_interval is deprecated and has been replaced "
294+ "with eplb_config.step_interval. This will be removed "
295+ "in v0.12.0. Changing this field after initialization will "
296+ "have no effect." )
297+ if self .eplb_log_balancedness is not None :
298+ self .eplb_config .log_balancedness = self .eplb_log_balancedness
299+ logger .warning_once (
300+ "eplb_log_balancedness is deprecated and has been replaced "
301+ "with eplb_config.log_balancedness. This will be removed "
302+ "in v0.12.0. Changing this field after initialization will "
303+ "have no effect." )
304+
305+ # Continue with the rest of the initialization
240306 self .world_size = self .pipeline_parallel_size * \
241307 self .tensor_parallel_size
242308
@@ -275,10 +341,10 @@ def __post_init__(self) -> None:
275341 raise ValueError (
276342 "Expert parallelism load balancing is only supported on "
277343 "CUDA devices now." )
278- if self .num_redundant_experts < 0 :
344+ if self .eplb_config . num_redundant_experts < 0 :
279345 raise ValueError (
280346 "num_redundant_experts must be non-negative, but got "
281- f"{ self .num_redundant_experts } ." )
347+ f"{ self .eplb_config . num_redundant_experts } ." )
282348 if not self .enable_expert_parallel :
283349 raise ValueError (
284350 "enable_expert_parallel must be True to use EPLB." )
@@ -289,10 +355,10 @@ def __post_init__(self) -> None:
289355 f"TP={ self .tensor_parallel_size } ,DP={ self .data_parallel_size } ."
290356 )
291357 else :
292- if self .num_redundant_experts != 0 :
358+ if self .eplb_config . num_redundant_experts != 0 :
293359 raise ValueError (
294360 "num_redundant_experts should be used with EPLB."
295- f"{ self .num_redundant_experts } ." )
361+ f"{ self .eplb_config . num_redundant_experts } ." )
296362 if self .distributed_executor_backend is None and self .world_size > 1 :
297363 # We use multiprocessing by default if world_size fits on the
298364 # current node and we aren't in a ray placement group.
0 commit comments