Skip to content

Commit

Permalink
[minor] allow overriding args/kwargs behavior in Runtime (learning-at…
Browse files Browse the repository at this point in the history
…-home#587)

* allow overriding args/kwargs in Runtime
* switch stats time to time.perf_counter

---------

Co-authored-by: Your Name <you@example.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
  • Loading branch information
3 people authored Aug 25, 2023
1 parent 6f5c471 commit 33a9a41
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions hivemind/moe/server/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from queue import SimpleQueue
from selectors import EVENT_READ, DefaultSelector
from statistics import mean
from time import time
from typing import Dict, NamedTuple, Optional
from time import perf_counter
from typing import Any, Dict, NamedTuple, Optional, Tuple

import torch
from prefetch_generator import BackgroundGenerator

from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.utils import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -85,15 +86,11 @@ def run(self):

for pool, batch_index, batch in batch_iterator:
logger.debug(f"Processing batch {batch_index} from pool {pool.name}")

start = time()
start = perf_counter()
try:
outputs = pool.process_func(*batch)
batch_processing_time = time() - start

batch_size = outputs[0].size(0)
outputs, batch_size = self.process_batch(pool, batch_index, *batch)
batch_processing_time = perf_counter() - start
logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")

if self.stats_report_interval is not None:
self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)

Expand All @@ -108,6 +105,11 @@ def run(self):
if not self.shutdown_trigger.is_set():
self.shutdown()

def process_batch(self, pool: TaskPoolBase, batch_index: int, *batch: torch.Tensor) -> Tuple[Any, int]:
"""process one batch of tasks from a given pool, return a batch of results and total batch size"""
outputs = pool.process_func(*batch)
return outputs, outputs[0].size(0)

def shutdown(self):
"""Gracefully terminate a running runtime."""
logger.info("Shutting down")
Expand Down

0 comments on commit 33a9a41

Please sign in to comment.