Skip to content

Commit

Permalink
refactor out training code
Browse files Browse the repository at this point in the history
  • Loading branch information
sauyon committed Mar 14, 2023
1 parent 6e1af19 commit cd93d95
Showing 1 changed file with 41 additions and 140 deletions.
181 changes: 41 additions & 140 deletions src/bentoml/_internal/marshal/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, max_latency: float):
self.o_a = min(2, max_latency * 2.0 / 30)
self.o_b = min(1, max_latency * 1.0 / 30)

self.wait = 0.01 # the avg wait time before outbound called
self.wait = 0 # the avg wait time before outbound called

self._refresh_tb = TokenBucket(2) # to limit params refresh interval
self.outbound_counter = 0
Expand Down Expand Up @@ -168,52 +168,18 @@ async def _func(data: t.Any) -> t.Any:

return _func

async def controller(self):
"""
A standalone coroutine to wait/dispatch calling.
"""
logger.debug("Starting dispatcher optimizer training...")
# warm up the model
while self.optimizer.outbound_counter <= self.optimizer.N_SKIPPED_SAMPLE:
try:
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
now = time.time()
w0 = now - self._queue[0][0]

# only cancel requests if there are more than enough for training
if (
n
> self.optimizer.N_SKIPPED_SAMPLE
- self.optimizer.outbound_counter
+ 6
and w0 >= self.max_latency_in_ms
):
# we're being very conservative and only canceling requests if they have already timed out
self._queue.popleft()[2].cancel()
continue
# don't try to be smart here, just serve the first few requests
if self._sema.is_locked():
await asyncio.sleep(self.tick_interval)
continue

n_call_out = 1
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)
async def train_optimizer(
self, num_required_reqs: int, num_reqs_to_train: int, batch_size: int, wait: 0,
):
if self.max_batch_size < batch_size:
batch_size = self.max_batch_size

logger.debug("Dispatcher finished warming up model.")
if batch_size > 1:
wait = min(self.max_latency * 0.95, (batch_size*2+1)*(self.optimizer.o_a + self.optimizer.o_b))

while self.optimizer.outbound_counter <= self.optimizer.N_SKIPPED_SAMPLE + 1:
try:
# step 1: attempt to serve a single request immediately
req_count = 0
try:
while req_count < num_reqs_to_train:
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

Expand All @@ -222,117 +188,52 @@ async def controller(self):
w0 = now - self._queue[0][0]

# only cancel requests if there are more than enough for training
if n > 6 and w0 >= self.max_latency_in_ms:
if n > num_required_reqs - req_count and w0 >= self.max_latency_in_ms:
# we're being very conservative and only canceling requests if they have already timed out
self._queue.popleft()[2].cancel()
continue
if batch_size > 1: # only wait if batch_size
if n < batch_size and (batch_size * a + b) + w0 <= wait:
await asyncio.sleep(self.tick_interval)
continue
if self._sema.is_locked():
await asyncio.sleep(self.tick_interval)
continue

n_call_out = 1
n_call_out = min(n, batch_size)
req_count += 1
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)

logger.debug("Dispatcher finished optimizer training request 1.")
self.optimizer.trigger_refresh()
async def controller(self):
"""
A standalone coroutine to wait/dispatch calling.
"""
logger.debug("Starting dispatcher optimizer training...")
# warm up the model
self.train_optimizer(
self.optimizer.N_SKIPPED_SAMPLE, self.optimizer.N_SKIPPED_SAMPLE + 6, 1
)

if self.max_batch_size >= 2:
# we will attempt to keep the second request served within this time
step_2_wait = min(
self.max_latency_in_ms * 0.95,
5 * (self.optimizer.o_a + self.optimizer.o_b),
)
logger.debug("Dispatcher finished warming up model.")

# step 2: attempt to serve 2 requests
while (
self.optimizer.outbound_counter <= self.optimizer.N_SKIPPED_SAMPLE + 2
):
try:
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
dt = self.tick_interval
now = time.time()
w0 = now - self._queue[0][0]
a = self.optimizer.o_a
b = self.optimizer.o_b

# only cancel requests if there are more than enough for training
if n > 5 and w0 >= self.max_latency_in_ms:
# we're being very conservative and only canceling requests if they have already timed out
self._queue.popleft()[2].cancel()
continue
if n < 2 and (2 * a + b) + w0 <= step_2_wait:
await asyncio.sleep(self.tick_interval)
continue
if self._sema.is_locked():
await asyncio.sleep(self.tick_interval)
continue
await self.train_optimizer(1, 6, 1)
self.optimizer.trigger_refresh()
logger.debug("Dispatcher finished optimizer training request 1.")

n_call_out = min(n, 2)
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)

logger.debug("Dispatcher finished optimizer training request 2.")
self.optimizer.trigger_refresh()

if self.max_batch_size >= 3:
# step 3: attempt to serve 3 requests

# we will attempt to keep the second request served within this time
step_3_wait = min(
self.max_latency_in_ms * 0.95,
7 * (self.optimizer.o_a + self.optimizer.o_b),
)
while (
self.optimizer.outbound_counter <= self.optimizer.N_SKIPPED_SAMPLE + 3
):
try:
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
dt = self.tick_interval
now = time.time()
w0 = now - self._queue[0][0]
a = self.optimizer.o_a
b = self.optimizer.o_b

# only cancel requests if there are more than enough for training
if n > 3 and w0 >= self.max_latency_in_ms:
# we're being very conservative and only canceling requests if they have already timed out
self._queue.popleft()[2].cancel()
continue
if n < 3 and (3 * a + b) + w0 <= step_3_wait:
await asyncio.sleep(self.tick_interval)
continue
await self.train_optimizer(1, 5, 2)
self.optimizer.trigger_refresh()
logger.debug("Dispatcher finished optimizer training request 2.")

n_call_out = min(n, 3)
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)

logger.debug("Dispatcher finished optimizer training request 3.")
self.optimizer.trigger_refresh()
await self.train_optimizer(1, 3, 3)
self.optimizer.trigger_refresh()
logger.debug("Dispatcher finished optimizer training request 3.")

if self.optimizer.o_a + self.optimizer.o_b >= self.max_latency_in_ms:
logger.warning(
Expand Down

0 comments on commit cd93d95

Please sign in to comment.