@@ -68,6 +68,8 @@ class ParallelConfig:
6868 """Number of pipeline parallel groups."""
6969 tensor_parallel_size : int = 1
7070 """Number of tensor parallel groups."""
71+ context_parallel_size : int = 1
72+ """Number of context parallel groups."""
7173 data_parallel_size : int = 1
7274 """Number of data parallel groups. MoE layers will be sharded according to
7375 the product of the tensor parallel size and data parallel size."""
@@ -185,7 +187,7 @@ class is dynamically inherited by the worker class. This is used to inject
185187 calls."""
186188
187189 world_size : int = field (init = False )
188- """world_size is TPxPP , it affects the number of workers we create."""
190+ """world_size is TPxCPxPP , it affects the number of workers we create."""
189191
190192 rank : int = 0
191193 """Global rank in distributed setup."""
@@ -335,6 +337,7 @@ def compute_hash(self):
335337 factors : list [Any ] = []
336338 factors .append (self .pipeline_parallel_size )
337339 factors .append (self .tensor_parallel_size )
340+ factors .append (self .context_parallel_size )
338341 factors .append (self .enable_expert_parallel )
339342 factors .append (self .data_parallel_size )
340343 factors .append (envs .VLLM_ALL2ALL_BACKEND )
@@ -374,7 +377,7 @@ def __post_init__(self) -> None:
374377
375378 # Continue with the rest of the initialization
376379 self .world_size = self .pipeline_parallel_size * \
377- self .tensor_parallel_size
380+ self .tensor_parallel_size * self . context_parallel_size
378381
379382 if self .distributed_executor_backend == "external_launcher" :
380383 logger .info ("Using external launcher for distributed inference." )
0 commit comments