Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
AgrawalAmey committed Jul 25, 2024
1 parent 20b0999 commit d6fc956
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

parallel_config = ParallelConfig(
tensor_parallel_size=1,
pipeline_parallel_size=1,
pipeline_parallel_size=4,
)

scheduler_config = SarathiSchedulerConfig(
Expand Down
9 changes: 4 additions & 5 deletions sarathi/engine/pipeline_parallel_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import zmq

from sarathi.config import SchedulerType, SystemConfig
from sarathi.config import SystemConfig
from sarathi.core.datatypes.request_output import RequestOutput
from sarathi.core.datatypes.scheduler_output import SchedulerOutputs
from sarathi.core.datatypes.sequence import SamplerOutputs, SequenceMetadata
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
self.scheduler_output_queue = Queue()
self.output_queue = Queue()
self.schedule_event = Event()
self.microbatch_watch_queue = Queue()
self.microbatch_watch_event = Event()
self.schedule_thread = Thread(target=self._schedule_loop, daemon=True)
self.microbatch_watch_thread = Thread(
target=self._microbatch_watch_loop, daemon=True
Expand Down Expand Up @@ -139,7 +139,7 @@ def _schedule_loop(self) -> None:
end_time = time.perf_counter()

if not scheduler_outputs.is_empty():
self.microbatch_watch_queue.put(scheduler_outputs)
self.microbatch_watch_event.set()
self.enqueue_socket.send_pyobj(
StepInputs(
scheduler_outputs,
Expand All @@ -153,8 +153,7 @@ def _schedule_loop(self) -> None:
@exit_on_error
def _microbatch_watch_loop(self) -> None:
while True:
scheduler_outputs = self.microbatch_watch_queue.get()

self.microbatch_watch_event.wait()
self.microbatch_socket.recv_pyobj()
self.schedule_event.set()

Expand Down
1 change: 0 additions & 1 deletion sarathi/worker/pipeline_parallel_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch.distributed
import zmq

from sarathi.config import SystemConfig
from sarathi.core.datatypes.scheduler_output import SchedulerOutputs
from sarathi.core.datatypes.sequence import SamplerOutputs
from sarathi.logger import init_logger
Expand Down

0 comments on commit d6fc956

Please sign in to comment.