Skip to content

Commit

Permalink
fix RAM OOM when load large models in tensor parallel mode. (vllm-pro…
Browse files Browse the repository at this point in the history
…ject#1395)

Co-authored-by: ran_lin <rlin@thoughtworks.com>
  • Loading branch information
boydfd and rlin-tw authored Nov 21, 2023
1 parent 6283390 commit 236d95a
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 7 deletions.
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,12 @@ def __init__(
pipeline_parallel_size: int,
tensor_parallel_size: int,
worker_use_ray: bool,
max_parallel_loading_workers: Optional[int] = None,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers

self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
Expand Down
10 changes: 9 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class EngineArgs:
worker_use_ray: bool = False
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
Expand Down Expand Up @@ -128,6 +129,12 @@ def add_cli_args(
type=int,
default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
parser.add_argument(
'--max-parallel-loading-workers',
type=int,
help='load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor '
'parallel and large models')
# KV cache arguments
parser.add_argument('--block-size',
type=int,
Expand Down Expand Up @@ -195,7 +202,8 @@ def create_engine_configs(
getattr(model_config.hf_config, 'sliding_window', None))
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray)
self.worker_use_ray,
self.max_parallel_loading_workers)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
Expand Down
45 changes: 39 additions & 6 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def _init_workers(self, distributed_init_method: str):
"init_model",
get_all_outputs=True,
)
self._run_workers(
"load_model",
get_all_outputs=True,
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
)

def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
Expand Down Expand Up @@ -182,6 +188,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
"init_model",
get_all_outputs=True,
)
self._run_workers(
"load_model",
get_all_outputs=True,
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
)

def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
Expand Down Expand Up @@ -682,26 +694,47 @@ def _check_stop(self, seq: Sequence,
seq.status = SequenceStatus.FINISHED_STOPPED
return

def _run_workers(
def _run_workers_in_batch(
self,
workers,
method: str,
*args,
get_all_outputs: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
):
all_outputs = []
for worker in self.workers:
for worker in workers:
if self.parallel_config.worker_use_ray:
executor = partial(worker.execute_method.remote, method)
else:
executor = getattr(worker, method)

output = executor(*args, **kwargs)
all_outputs.append(output)

if self.parallel_config.worker_use_ray:
all_outputs = ray.get(all_outputs)
return all_outputs

def _run_workers(
self,
method: str,
*args,
get_all_outputs: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
all_outputs = []
if max_concurrent_workers:
work_groups = [
self.workers[i:i + max_concurrent_workers]
for i in range(0, len(self.workers), max_concurrent_workers)
]
else:
work_groups = [self.workers]

for workers in work_groups:
all_outputs.extend(
self._run_workers_in_batch(workers, method, *args, **kwargs))

if get_all_outputs:
return all_outputs
Expand Down
2 changes: 2 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def init_model(self):

# Initialize the model.
set_random_seed(self.model_config.seed)

def load_model(self):
self.model = get_model(self.model_config)

@torch.inference_mode()
Expand Down

0 comments on commit 236d95a

Please sign in to comment.