From d6fc95604aecfda150bdbbff02adea44e25fbbf8 Mon Sep 17 00:00:00 2001 From: Amey Agrawal Date: Thu, 25 Jul 2024 15:53:47 -0400 Subject: [PATCH] minor --- examples/offline_inference.py | 2 +- sarathi/engine/pipeline_parallel_llm_engine.py | 9 ++++----- sarathi/worker/pipeline_parallel_worker.py | 1 - 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 27f1762..e68526d 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -35,7 +35,7 @@ parallel_config = ParallelConfig( tensor_parallel_size=1, - pipeline_parallel_size=1, + pipeline_parallel_size=4, ) scheduler_config = SarathiSchedulerConfig( diff --git a/sarathi/engine/pipeline_parallel_llm_engine.py b/sarathi/engine/pipeline_parallel_llm_engine.py index 0388221..1a8c579 100644 --- a/sarathi/engine/pipeline_parallel_llm_engine.py +++ b/sarathi/engine/pipeline_parallel_llm_engine.py @@ -6,7 +6,7 @@ import zmq -from sarathi.config import SchedulerType, SystemConfig +from sarathi.config import SystemConfig from sarathi.core.datatypes.request_output import RequestOutput from sarathi.core.datatypes.scheduler_output import SchedulerOutputs from sarathi.core.datatypes.sequence import SamplerOutputs, SequenceMetadata @@ -53,7 +53,7 @@ def __init__( self.scheduler_output_queue = Queue() self.output_queue = Queue() self.schedule_event = Event() - self.microbatch_watch_queue = Queue() + self.microbatch_watch_event = Event() self.schedule_thread = Thread(target=self._schedule_loop, daemon=True) self.microbatch_watch_thread = Thread( target=self._microbatch_watch_loop, daemon=True @@ -139,7 +139,7 @@ def _schedule_loop(self) -> None: end_time = time.perf_counter() if not scheduler_outputs.is_empty(): - self.microbatch_watch_queue.put(scheduler_outputs) + self.microbatch_watch_event.set() self.enqueue_socket.send_pyobj( StepInputs( scheduler_outputs, @@ -153,8 +153,7 @@ def _schedule_loop(self) -> None: @exit_on_error def _microbatch_watch_loop(self) -> None: while True: - scheduler_outputs = self.microbatch_watch_queue.get() - + self.microbatch_watch_event.wait() self.microbatch_socket.recv_pyobj() self.schedule_event.set() diff --git a/sarathi/worker/pipeline_parallel_worker.py b/sarathi/worker/pipeline_parallel_worker.py index 2a6cc2d..ead565b 100644 --- a/sarathi/worker/pipeline_parallel_worker.py +++ b/sarathi/worker/pipeline_parallel_worker.py @@ -6,7 +6,6 @@ import torch.distributed import zmq -from sarathi.config import SystemConfig from sarathi.core.datatypes.scheduler_output import SchedulerOutputs from sarathi.core.datatypes.sequence import SamplerOutputs from sarathi.logger import init_logger