-
Notifications
You must be signed in to change notification settings - Fork 811
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat(batching): refactor; troubleshoot back pressure
- Loading branch information
Showing
4 changed files
with
243 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,176 +1,187 @@ | ||
import asyncio | ||
import logging | ||
import traceback | ||
import time | ||
import random | ||
import collections | ||
from typing import Callable | ||
from bentoml.utils import cached_property | ||
import numpy as np | ||
|
||
from bentoml.utils.alg import FixedBucket, TokenBucket | ||
|
||
class Bucket: | ||
''' | ||
Fixed size container. | ||
''' | ||
|
||
def __init__(self, size): | ||
self._data = [None] * size | ||
self._cur = 0 | ||
self._size = size | ||
self._flag_full = False | ||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.ERROR) | ||
|
||
def put(self, v): | ||
self._data[self._cur] = v | ||
self._cur += 1 | ||
if self._cur == self._size: | ||
self._cur = 0 | ||
self._flag_full = True | ||
|
||
@property | ||
def data(self): | ||
if not self._flag_full: | ||
return self._data[: self._cur] | ||
return self._data | ||
class NonBlockSema: | ||
def __init__(self, count): | ||
self.sema = count | ||
|
||
def __len__(self): | ||
if not self._flag_full: | ||
return self._cur | ||
return self._size | ||
def aquire(self): | ||
if self.sema < 1: | ||
return False | ||
self.sema -= 1 | ||
return True | ||
|
||
def is_locked(self): | ||
return self.sema < 1 | ||
|
||
def release(self): | ||
self.sema += 1 | ||
|
||
|
||
class Optimizer: | ||
N_OUTBOUND_SAMPLE = 500 | ||
N_OUTBOUND_WAIT_SAMPLE = 20 | ||
N_OUTBOUND_SAMPLE = 50 | ||
INTERVAL_REFRESH_PARAMS = 5 | ||
N_DATA_DROP_FIRST = 2 | ||
|
||
def __init__(self): | ||
self.outbound_stat = Bucket(self.N_OUTBOUND_SAMPLE) | ||
self.outbound_wait_stat = Bucket(self.N_OUTBOUND_WAIT_SAMPLE) | ||
self.outbound_a = 0.0001 | ||
self.outbound_b = 0 | ||
self.outbound_wait = 0.01 | ||
|
||
async def adaptive_wait(self, parade, max_time): | ||
dt = 0.001 | ||
decay = 0.9 | ||
while True: | ||
now = time.time() | ||
w0 = now - parade.time_first | ||
wn = now - parade.time_last | ||
n = parade.length | ||
a = max(self.outbound_a, 0) | ||
|
||
if w0 >= max_time: | ||
print("warning: max latency reached") | ||
break | ||
if max(n - 1, 1) * (wn + dt + a) <= self.outbound_wait * decay: | ||
await asyncio.sleep(dt) | ||
continue | ||
break | ||
|
||
def log_outbound_time(self, info): | ||
if info[0] < 5: # skip all small batch | ||
self.o_stat = FixedBucket(self.N_OUTBOUND_SAMPLE) | ||
self.o_a = 2 | ||
self.o_b = 0.1 | ||
self.o_w = 0.01 | ||
self._refresh_tb = TokenBucket(2) | ||
self._outbound_init_counter = 0 | ||
self._o_a = self.o_a | ||
self._o_b = self.o_b | ||
self._o_w = self.o_w | ||
|
||
def log_outbound(self, n, wait, duration): | ||
# drop first N_DATA_DROP_FIRST datas | ||
if self._outbound_init_counter <= self.N_DATA_DROP_FIRST: | ||
self._outbound_init_counter += 1 | ||
return | ||
self.outbound_stat.put(info) | ||
if random.random() < 0.05: | ||
x = tuple((i, 1) for i, _ in self.outbound_stat.data) | ||
y = tuple(i for _, i in self.outbound_stat.data) | ||
self.outbound_a, self.outbound_b = np.linalg.lstsq(x, y, rcond=None)[0] | ||
|
||
def log_outbound_wait(self, info): | ||
self.outbound_wait_stat.put(info) | ||
self.outbound_wait = ( | ||
sum(self.outbound_wait_stat.data) * 1.0 / len(self.outbound_wait_stat) | ||
) | ||
|
||
self.o_stat.put((n, duration, wait)) | ||
|
||
class Parade: | ||
STATUSES = (STATUS_OPEN, STATUS_CLOSED, STATUS_RETURNED,) = range(3) | ||
|
||
def __init__(self, max_size, outbound_sema, optimizer): | ||
self.max_size = max_size | ||
self.outbound_sema = outbound_sema | ||
self.batch_input = [None] * max_size | ||
self.batch_output = [None] * max_size | ||
self.length = 0 | ||
self.returned = asyncio.Condition() | ||
self.status = self.STATUS_OPEN | ||
self.optimizer = optimizer | ||
self.time_first = None | ||
self.time_last = None | ||
|
||
def feed(self, data) -> Callable: | ||
''' | ||
feed data into this parade. | ||
return: | ||
the output index in parade.batch_output | ||
''' | ||
self.batch_input[self.length] = data | ||
self.length += 1 | ||
if self.length == self.max_size: | ||
self.status = self.STATUS_CLOSED | ||
self.time_last = time.time() | ||
return self.length - 1 | ||
|
||
async def start_wait(self, max_wait_time, call): | ||
now = time.time() | ||
self.time_first = now | ||
self.time_last = now | ||
try: | ||
await self.optimizer.adaptive_wait(self, max_wait_time) | ||
async with self.outbound_sema: | ||
self.status = self.STATUS_CLOSED | ||
_time_start = time.time() | ||
self.optimizer.log_outbound_wait(_time_start - self.time_first) | ||
self.batch_output = await call(self.batch_input[: self.length]) | ||
self.optimizer.log_outbound_time( | ||
(self.length, time.time() - _time_start) | ||
) | ||
self.status = self.STATUS_RETURNED | ||
finally: | ||
# make sure parade is closed | ||
if self.status == self.STATUS_OPEN: | ||
self.status = self.STATUS_CLOSED | ||
async with self.returned: | ||
self.returned.notify_all() | ||
if self._refresh_tb.consume(1, 1.0 / self.INTERVAL_REFRESH_PARAMS, 1): | ||
self.trigger_refresh() | ||
|
||
def trigger_refresh(self): | ||
x = tuple((i, 1) for i, _, _ in self.o_stat.data) | ||
y = tuple(i for _, i, _ in self.o_stat.data) | ||
self._o_a, self._o_b = np.linalg.lstsq(x, y, rcond=None)[0] | ||
self._o_w = sum(w for _, _, w in self.o_stat) * 1.0 / len(self.o_stat) | ||
|
||
self.o_a, self.o_b = max(0.000001, self._o_a), max(0, self._o_b) | ||
self.o_w = max(0, self._o_w) | ||
logger.info( | ||
"optimizer params updated: o_a: %.6f, o_b: %.6f, o_w: %.6f", | ||
self._o_a, | ||
self._o_b, | ||
self._o_w, | ||
) | ||
|
||
|
||
class ParadeDispatcher: | ||
def __init__(self, max_wait_time, max_size, shared_sema: callable = None): | ||
def __init__( | ||
self, | ||
max_expected_time: int, | ||
max_batch_size: int, | ||
shared_sema: NonBlockSema = None, | ||
fallback: Callable = None, | ||
): | ||
""" | ||
params: | ||
* max_wait_time: max_wait_time to wait for inbound tasks in milliseconds | ||
* max_size: inbound tasks buffer size | ||
* max_expected_time: max_expected_time for inbound tasks in milliseconds | ||
* max_batch_size: max batch size of inbound tasks | ||
* shared_sema: semaphore to limit concurrent tasks | ||
* fallback: callable to return fallback result | ||
""" | ||
self.max_wait_time = max_wait_time | ||
self.max_size = int(max_size) | ||
self.shared_sema = shared_sema | ||
self.max_expected_time = max_expected_time / 1000.0 | ||
self.callback = None | ||
self._current_parade = None | ||
self.fallback = fallback | ||
self.optimizer = Optimizer() | ||
|
||
@cached_property | ||
def outbound_sema(self): | ||
''' | ||
semaphore should be created after process forked | ||
''' | ||
return self.shared_sema() if self.shared_sema else asyncio.Semaphore(1) | ||
|
||
def get_parade(self): | ||
if self._current_parade and self._current_parade.status == Parade.STATUS_OPEN: | ||
return self._current_parade | ||
self._current_parade = Parade(self.max_size, self.outbound_sema, self.optimizer) | ||
asyncio.get_event_loop().create_task( | ||
self._current_parade.start_wait(self.max_wait_time / 1000.0, self.callback) | ||
) | ||
return self._current_parade | ||
self.max_batch_size = int(max_batch_size) | ||
self._controller = None | ||
self._queue = collections.deque() # TODO(hrmthw): maxlen | ||
self._loop = asyncio.get_event_loop() | ||
self._wake_event = asyncio.Condition() | ||
self._sema = shared_sema if shared_sema else NonBlockSema(1) | ||
|
||
self.tick_interval = 0.001 | ||
|
||
def __call__(self, callback): | ||
self.callback = callback | ||
self._controller = self._loop.create_task(self.controller()) | ||
|
||
async def _func(inputs): | ||
parade = self.get_parade() | ||
_id = parade.feed(inputs) | ||
async with parade.returned: | ||
await parade.returned.wait() | ||
return parade.batch_output[_id] | ||
async def _func(data): | ||
try: | ||
r = await self.inbound_call(data) | ||
except asyncio.CancelledError: | ||
return None if self.fallback is None else self.fallback() | ||
return r | ||
|
||
return _func | ||
|
||
async def controller(self): | ||
while True: | ||
try: | ||
async with self._wake_event: # block until request in queue | ||
await self._wake_event.wait_for(self._queue.__len__) | ||
|
||
n = len(self._queue) | ||
dt = self.tick_interval | ||
decay = 0.95 | ||
now = time.time() | ||
w0 = now - self._queue[0][0] | ||
wn = now - self._queue[-1][0] | ||
a = self.optimizer.o_a | ||
b = self.optimizer.o_b | ||
|
||
if (w0 + a * n + b + dt) >= self.max_expected_time: | ||
self._queue.popleft()[2].cancel() | ||
continue | ||
if self._sema.is_locked(): | ||
await asyncio.sleep(self.tick_interval) | ||
continue | ||
if n * (wn + dt + a) <= self.optimizer.o_w * decay: | ||
await asyncio.sleep(self.tick_interval) | ||
continue | ||
|
||
n_call_out = min(self.max_batch_size, n,) | ||
# call | ||
self._sema.aquire() | ||
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out)) | ||
self._loop.create_task(self.outbound_call(inputs_info)) | ||
except asyncio.CancelledError: | ||
break | ||
except Exception: # pylint: disable=broad-except | ||
logger.error(traceback.format_exc()) | ||
|
||
async def inbound_call(self, data) -> asyncio.Future: | ||
t = time.time() | ||
future = self._loop.create_future() | ||
input_info = (t, data, future) | ||
self._queue.append(input_info) | ||
async with self._wake_event: | ||
self._wake_event.notify_all() | ||
return await future | ||
|
||
async def outbound_call(self, inputs_info): | ||
_time_start = time.time() | ||
_done = False | ||
logger.info("outbound function called: %d", len(inputs_info)) | ||
try: | ||
outputs = await self.callback(tuple(d for _, d, _ in inputs_info)) | ||
assert len(outputs) == len(inputs_info) | ||
for (_, _, fut), out in zip(inputs_info, outputs): | ||
if not fut.done(): | ||
fut.set_result(out) | ||
_done = True | ||
self.optimizer.log_outbound( | ||
n=len(inputs_info), | ||
wait=_time_start - inputs_info[-1][0], | ||
duration=time.time() - _time_start, | ||
) | ||
except asyncio.CancelledError: | ||
pass | ||
except Exception: # pylint: disable=broad-except | ||
logger.error(traceback.format_exc()) | ||
finally: | ||
if not _done: | ||
for _, _, fut in inputs_info: | ||
if not fut.done(): | ||
fut.cancel() | ||
self._sema.release() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.