Skip to content

Commit

Permalink
Merged PR 1908: Fix performance regression in pipeline parallel bench…
Browse files Browse the repository at this point in the history
…mark (#28)

* Merged PR 1908: Fix performance regression in pipeline parallel benchmark

When we run pipeline parallel engine step in a non-blocking mode, the engine steps runs in a busy loop within the benchmarking script. this causes other threads to stall due to GIL.
  • Loading branch information
AgrawalAmey authored Jul 23, 2024
1 parent c909495 commit c94dbf4
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 22 deletions.
3 changes: 2 additions & 1 deletion examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

parallel_config = ParallelConfig(
tensor_parallel_size=1,
pipeline_parallel_size=1,
pipeline_parallel_size=4,
)

scheduler_config = SarathiSchedulerConfig(
Expand Down Expand Up @@ -78,6 +78,7 @@ def generate(
if output.finished:
outputs.append(output)
pbar.update(1)

pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
Expand Down
1 change: 0 additions & 1 deletion sarathi/benchmark/benchmark_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
import os
import time
Expand Down
4 changes: 3 additions & 1 deletion sarathi/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ async def step_async(self) -> List[RequestOutput]:
"""
Simple wrapper around the synchronous `step` method to make it
"""
return await asyncio.get_event_loop().run_in_executor(None, self.engine.step)
return await asyncio.get_event_loop().run_in_executor(
None, self.engine.step, False
)


class AsyncLLMEngine(LLMEngine):
Expand Down
40 changes: 30 additions & 10 deletions sarathi/engine/pipeline_parallel_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from dataclasses import dataclass
from queue import Empty, Queue
from threading import Event, Thread
from typing import List
from typing import List, Tuple

from sarathi.config import SystemConfig
from sarathi.core.datatypes.request_output import RequestOutput
from sarathi.core.datatypes.scheduler_output import SchedulerOutputs
from sarathi.core.datatypes.sequence import SequenceMetadata
from sarathi.core.datatypes.sequence import SamplerOutputs, SequenceMetadata
from sarathi.engine.base_llm_engine import BaseLLMEngine
from sarathi.logger import init_logger
from sarathi.utils.threading_utils import exit_on_error
from sarathi.utils.threading_utils import exit_on_error, synchronized

logger = init_logger(__name__)

Expand Down Expand Up @@ -59,6 +59,8 @@ def __init__(
target=self._scheduler_timer_loop, daemon=True
)

self.pending_step_outputs: List[Tuple[SchedulerOutputs, SamplerOutputs]] = []

def _validate_parallel_config(self) -> None:
assert self.config.parallel_config.pipeline_parallel_size > 1

Expand All @@ -84,6 +86,20 @@ def _get_worker_impl(self):

return PipelineParallelWorker

@synchronized
def _append_pending_step_output(
self, scheduler_outputs: SchedulerOutputs, sampler_outputs: SamplerOutputs
) -> None:
self.pending_step_outputs.append((scheduler_outputs, sampler_outputs))

@synchronized
def _get_pending_step_outputs(
self,
) -> List[Tuple[SchedulerOutputs, SamplerOutputs]]:
pending_step_outputs = self.pending_step_outputs
self.pending_step_outputs = []
return pending_step_outputs

@exit_on_error
def _schedule_loop(self) -> None:
while True:
Expand All @@ -110,15 +126,17 @@ def _schedule_loop(self) -> None:
)
)

end_time = time.perf_counter()

if not scheduler_outputs.is_empty():
self.microbatch_watch_event.set()
self._run_workers(
"enqueue",
scheduler_outputs=scheduler_outputs,
pending_step_outputs=self._get_pending_step_outputs(),
ignore_output=True,
)

end_time = time.perf_counter()
self.metrics_store.on_schedule(seq_metadata_list, start_time, end_time)

@exit_on_error
Expand Down Expand Up @@ -146,11 +164,8 @@ def _output_loop(self) -> None:
"get_output",
)

# this needs to be optimized
self._run_workers(
"on_sampling_completed",
scheduler_outputs=scheduler_stage_output.scheduler_outputs,
sampler_outputs=sampler_outputs,
self._append_pending_step_output(
scheduler_stage_output.scheduler_outputs, sampler_outputs
)

all_request_outputs = self._on_step_completed(
Expand All @@ -161,16 +176,21 @@ def _output_loop(self) -> None:
scheduler_stage_output.start_time,
)
self.schedule_event.set()

self.output_queue.put(all_request_outputs)

def step(self) -> List[RequestOutput]:
def step(self, block: bool = True) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration of the engine.
This version does everything asynchronously and returns the results
"""
if not self.has_started_execution_loops:
self.start_execution_loops()

if block:
return self.output_queue.get()

try:
return self.output_queue.get(block=False)
except Empty:
Expand Down
16 changes: 10 additions & 6 deletions sarathi/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,17 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex)


def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower()


def get_ip() -> str:
return socket.gethostbyname(socket.gethostname())
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(0)
try:
s.connect(("10.254.254.254", 1))
ip = s.getsockname()[0]
except Exception:
ip = "127.0.0.1"
finally:
s.close()
return ip


def get_open_port() -> int:
Expand Down
1 change: 0 additions & 1 deletion sarathi/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from sarathi.config import ModelConfig, ParallelConfig, SystemConfig
from sarathi.logger import init_logger
from sarathi.model_executor.attention import get_attention_wrapper
from sarathi.utils import in_wsl

logger = init_logger(__name__)

Expand Down
10 changes: 8 additions & 2 deletions sarathi/worker/pipeline_parallel_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from queue import Queue
from threading import Thread
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import torch
import torch.distributed
Expand Down Expand Up @@ -52,7 +52,11 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None:
def enqueue(
self,
scheduler_outputs: SchedulerOutputs,
pending_step_outputs: List[Tuple[SchedulerOutputs, SamplerOutputs]] = [],
) -> None:
for pending_step_output in pending_step_outputs:
self.on_sampling_completed(pending_step_output[0], pending_step_output[1])

self.execution_queue.put(scheduler_outputs)

def on_step_completed(
Expand All @@ -79,8 +83,10 @@ def _execution_loop(self) -> None:
if not self.is_tensor_parallel_rank_zero:
continue

if self.is_first_pipeline_stage or self.is_last_pipeline_stage:
if self.is_last_pipeline_stage:
self.output_queue.put(output)
elif self.is_first_pipeline_stage:
self.output_queue.put(None)

def get_output(self) -> Optional[SamplerOutputs]:
return self.output_queue.get()
Expand Down

0 comments on commit c94dbf4

Please sign in to comment.