Skip to content

Commit

Permalink
Feat(batching): refactor; troubleshoot back pressure
Browse files Browse the repository at this point in the history
  • Loading branch information
bojiang committed May 6, 2020
1 parent bc3f6ae commit c6573e3
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 148 deletions.
295 changes: 153 additions & 142 deletions bentoml/marshal/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,176 +1,187 @@
import asyncio
import logging
import traceback
import time
import random
import collections
from typing import Callable
from bentoml.utils import cached_property
import numpy as np

from bentoml.utils.alg import FixedBucket, TokenBucket

class Bucket:
'''
Fixed size container.
'''

def __init__(self, size):
self._data = [None] * size
self._cur = 0
self._size = size
self._flag_full = False
logger = logging.getLogger(__name__)
logger.setLevel(logging.ERROR)

def put(self, v):
self._data[self._cur] = v
self._cur += 1
if self._cur == self._size:
self._cur = 0
self._flag_full = True

@property
def data(self):
if not self._flag_full:
return self._data[: self._cur]
return self._data
class NonBlockSema:
def __init__(self, count):
self.sema = count

def __len__(self):
if not self._flag_full:
return self._cur
return self._size
def aquire(self):
if self.sema < 1:
return False
self.sema -= 1
return True

def is_locked(self):
return self.sema < 1

def release(self):
self.sema += 1


class Optimizer:
N_OUTBOUND_SAMPLE = 500
N_OUTBOUND_WAIT_SAMPLE = 20
N_OUTBOUND_SAMPLE = 50
INTERVAL_REFRESH_PARAMS = 5
N_DATA_DROP_FIRST = 2

def __init__(self):
self.outbound_stat = Bucket(self.N_OUTBOUND_SAMPLE)
self.outbound_wait_stat = Bucket(self.N_OUTBOUND_WAIT_SAMPLE)
self.outbound_a = 0.0001
self.outbound_b = 0
self.outbound_wait = 0.01

async def adaptive_wait(self, parade, max_time):
dt = 0.001
decay = 0.9
while True:
now = time.time()
w0 = now - parade.time_first
wn = now - parade.time_last
n = parade.length
a = max(self.outbound_a, 0)

if w0 >= max_time:
print("warning: max latency reached")
break
if max(n - 1, 1) * (wn + dt + a) <= self.outbound_wait * decay:
await asyncio.sleep(dt)
continue
break

def log_outbound_time(self, info):
if info[0] < 5: # skip all small batch
self.o_stat = FixedBucket(self.N_OUTBOUND_SAMPLE)
self.o_a = 2
self.o_b = 0.1
self.o_w = 0.01
self._refresh_tb = TokenBucket(2)
self._outbound_init_counter = 0
self._o_a = self.o_a
self._o_b = self.o_b
self._o_w = self.o_w

def log_outbound(self, n, wait, duration):
# drop first N_DATA_DROP_FIRST datas
if self._outbound_init_counter <= self.N_DATA_DROP_FIRST:
self._outbound_init_counter += 1
return
self.outbound_stat.put(info)
if random.random() < 0.05:
x = tuple((i, 1) for i, _ in self.outbound_stat.data)
y = tuple(i for _, i in self.outbound_stat.data)
self.outbound_a, self.outbound_b = np.linalg.lstsq(x, y, rcond=None)[0]

def log_outbound_wait(self, info):
self.outbound_wait_stat.put(info)
self.outbound_wait = (
sum(self.outbound_wait_stat.data) * 1.0 / len(self.outbound_wait_stat)
)

self.o_stat.put((n, duration, wait))

class Parade:
STATUSES = (STATUS_OPEN, STATUS_CLOSED, STATUS_RETURNED,) = range(3)

def __init__(self, max_size, outbound_sema, optimizer):
self.max_size = max_size
self.outbound_sema = outbound_sema
self.batch_input = [None] * max_size
self.batch_output = [None] * max_size
self.length = 0
self.returned = asyncio.Condition()
self.status = self.STATUS_OPEN
self.optimizer = optimizer
self.time_first = None
self.time_last = None

def feed(self, data) -> Callable:
'''
feed data into this parade.
return:
the output index in parade.batch_output
'''
self.batch_input[self.length] = data
self.length += 1
if self.length == self.max_size:
self.status = self.STATUS_CLOSED
self.time_last = time.time()
return self.length - 1

async def start_wait(self, max_wait_time, call):
now = time.time()
self.time_first = now
self.time_last = now
try:
await self.optimizer.adaptive_wait(self, max_wait_time)
async with self.outbound_sema:
self.status = self.STATUS_CLOSED
_time_start = time.time()
self.optimizer.log_outbound_wait(_time_start - self.time_first)
self.batch_output = await call(self.batch_input[: self.length])
self.optimizer.log_outbound_time(
(self.length, time.time() - _time_start)
)
self.status = self.STATUS_RETURNED
finally:
# make sure parade is closed
if self.status == self.STATUS_OPEN:
self.status = self.STATUS_CLOSED
async with self.returned:
self.returned.notify_all()
if self._refresh_tb.consume(1, 1.0 / self.INTERVAL_REFRESH_PARAMS, 1):
self.trigger_refresh()

def trigger_refresh(self):
x = tuple((i, 1) for i, _, _ in self.o_stat.data)
y = tuple(i for _, i, _ in self.o_stat.data)
self._o_a, self._o_b = np.linalg.lstsq(x, y, rcond=None)[0]
self._o_w = sum(w for _, _, w in self.o_stat) * 1.0 / len(self.o_stat)

self.o_a, self.o_b = max(0.000001, self._o_a), max(0, self._o_b)
self.o_w = max(0, self._o_w)
logger.info(
"optimizer params updated: o_a: %.6f, o_b: %.6f, o_w: %.6f",
self._o_a,
self._o_b,
self._o_w,
)


class ParadeDispatcher:
def __init__(self, max_wait_time, max_size, shared_sema: callable = None):
def __init__(
self,
max_expected_time: int,
max_batch_size: int,
shared_sema: NonBlockSema = None,
fallback: Callable = None,
):
"""
params:
* max_wait_time: max_wait_time to wait for inbound tasks in milliseconds
* max_size: inbound tasks buffer size
* max_expected_time: max_expected_time for inbound tasks in milliseconds
* max_batch_size: max batch size of inbound tasks
* shared_sema: semaphore to limit concurrent tasks
* fallback: callable to return fallback result
"""
self.max_wait_time = max_wait_time
self.max_size = int(max_size)
self.shared_sema = shared_sema
self.max_expected_time = max_expected_time / 1000.0
self.callback = None
self._current_parade = None
self.fallback = fallback
self.optimizer = Optimizer()

@cached_property
def outbound_sema(self):
'''
semaphore should be created after process forked
'''
return self.shared_sema() if self.shared_sema else asyncio.Semaphore(1)

def get_parade(self):
if self._current_parade and self._current_parade.status == Parade.STATUS_OPEN:
return self._current_parade
self._current_parade = Parade(self.max_size, self.outbound_sema, self.optimizer)
asyncio.get_event_loop().create_task(
self._current_parade.start_wait(self.max_wait_time / 1000.0, self.callback)
)
return self._current_parade
self.max_batch_size = int(max_batch_size)
self._controller = None
self._queue = collections.deque() # TODO(hrmthw): maxlen
self._loop = asyncio.get_event_loop()
self._wake_event = asyncio.Condition()
self._sema = shared_sema if shared_sema else NonBlockSema(1)

self.tick_interval = 0.001

def __call__(self, callback):
self.callback = callback
self._controller = self._loop.create_task(self.controller())

async def _func(inputs):
parade = self.get_parade()
_id = parade.feed(inputs)
async with parade.returned:
await parade.returned.wait()
return parade.batch_output[_id]
async def _func(data):
try:
r = await self.inbound_call(data)
except asyncio.CancelledError:
return None if self.fallback is None else self.fallback()
return r

return _func

async def controller(self):
while True:
try:
async with self._wake_event: # block until request in queue
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
dt = self.tick_interval
decay = 0.95
now = time.time()
w0 = now - self._queue[0][0]
wn = now - self._queue[-1][0]
a = self.optimizer.o_a
b = self.optimizer.o_b

if (w0 + a * n + b + dt) >= self.max_expected_time:
self._queue.popleft()[2].cancel()
continue
if self._sema.is_locked():
await asyncio.sleep(self.tick_interval)
continue
if n * (wn + dt + a) <= self.optimizer.o_w * decay:
await asyncio.sleep(self.tick_interval)
continue

n_call_out = min(self.max_batch_size, n,)
# call
self._sema.aquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
break
except Exception: # pylint: disable=broad-except
logger.error(traceback.format_exc())

async def inbound_call(self, data) -> asyncio.Future:
t = time.time()
future = self._loop.create_future()
input_info = (t, data, future)
self._queue.append(input_info)
async with self._wake_event:
self._wake_event.notify_all()
return await future

async def outbound_call(self, inputs_info):
_time_start = time.time()
_done = False
logger.info("outbound function called: %d", len(inputs_info))
try:
outputs = await self.callback(tuple(d for _, d, _ in inputs_info))
assert len(outputs) == len(inputs_info)
for (_, _, fut), out in zip(inputs_info, outputs):
if not fut.done():
fut.set_result(out)
_done = True
self.optimizer.log_outbound(
n=len(inputs_info),
wait=_time_start - inputs_info[-1][0],
duration=time.time() - _time_start,
)
except asyncio.CancelledError:
pass
except Exception: # pylint: disable=broad-except
logger.error(traceback.format_exc())
finally:
if not _done:
for _, _, fut in inputs_info:
if not fut.done():
fut.cancel()
self._sema.release()
15 changes: 9 additions & 6 deletions bentoml/marshal/marshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from bentoml.handlers import HANDLER_TYPES_BATCH_MODE_SUPPORTED
from bentoml.bundler import load_bento_service_metadata
from bentoml.utils.usage_stats import track_server
from bentoml.marshal.dispatcher import ParadeDispatcher
from bentoml.marshal.dispatcher import ParadeDispatcher, NonBlockSema

logger = logging.getLogger(__name__)
ZIPKIN_API_URL = config("tracing").get("zipkin_api_url")
Expand Down Expand Up @@ -141,6 +141,7 @@ def __init__(
self.setup_routes_from_pb(self.bento_service_metadata_pb)
if psutil.POSIX:
import resource

self.CONNECTION_LIMIT = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
else:
self.CONNECTION_LIMIT = 1024
Expand All @@ -157,14 +158,17 @@ def set_outbound_port(self, outbound_port):

def fetch_sema(self):
if self._outbound_sema is None:
self._outbound_sema = asyncio.Semaphore(self.outbound_workers)
self._outbound_sema = NonBlockSema(self.outbound_workers)
return self._outbound_sema

def add_batch_handler(self, api_name, max_latency, max_batch_size):

if api_name not in self.batch_handlers:
_func = ParadeDispatcher(
max_latency, max_batch_size, shared_sema=self.fetch_sema
max_latency,
max_batch_size,
shared_sema=self.fetch_sema(),
fallback=aiohttp.web.HTTPTooManyRequests,
)(partial(self._batch_handler_template, api_name=api_name))
self.batch_handlers[api_name] = _func

Expand Down Expand Up @@ -243,10 +247,10 @@ async def _batch_handler_template(self, requests, api_name):
raw = await resp.read()
merged = DataLoader.split_responses(raw)
except (aiohttp.ClientConnectorError, aiohttp.ServerDisconnectedError):
return (aiohttp.web.HTTPServiceUnavailable,) * len(requests)
return (aiohttp.web.HTTPServiceUnavailable(),) * len(requests)

if merged is None:
return (aiohttp.web.HTTPInternalServerError,) * len(requests)
return (aiohttp.web.HTTPInternalServerError(),) * len(requests)
return tuple(
aiohttp.web.Response(body=i.data, headers=i.headers, status=i.status)
for i in merged
Expand All @@ -260,7 +264,6 @@ def async_start(self, port):
marshal_proc = multiprocessing.Process(
target=self.fork_start_app, kwargs=dict(port=port), daemon=True,
)
# TODO: make sure child process dies when parent process is killed.
marshal_proc.start()
logger.info("Running micro batch service on :%d", port)

Expand Down
Loading

0 comments on commit c6573e3

Please sign in to comment.