Skip to content

Commit

Permalink
feat: Add OpenAI Rate limiting (#1805)
Browse files Browse the repository at this point in the history
* Implement adaptive rate limiter for OpenAI

* Add adaptive rate limiter to Bedrock model

* Use a sensible default maximum request rate

* Ruff 🐶

* Mark test as xfail after llama_index update

* Do not retry on rate limit errors with tenacity

* Remove xfail after llama_index version lock

* Use events and locks instead of nesting asyncio.run

* Ensure that events are always set after rate limit handling

* Retry on httpx ReadTimeout errors

* Update rate limiters with verbose generation info

* Improve end of queue handling in AsyncExecutor

* improve types to remove the need for casts (#1817)

* Improve interrupt handling

* Exit early from queue.join on termination events

* Properly cancel running tasks

* Add pytest-asyncio to hatch env

* Do not await cancelled tasks

* Improve task_done marking logic

* Increase default concurrency

---------

Co-authored-by: Xander Song <axiomofjoy@gmail.com>
  • Loading branch information
anticorrelator and axiomofjoy authored Nov 29, 2023
1 parent 2ca3613 commit 115e044
Show file tree
Hide file tree
Showing 8 changed files with 699 additions and 59 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ dev = [
"ruff==0.1.5",
"pandas-stubs<=2.0.2.230605", # version 2.0.3.230814 is causing a dependency conflict.
"pytest",
"pytest-cov",
"pytest-asyncio",
"pytest-cov",
"pytest-lazy-fixture",
"strawberry-graphql[debug-server]==0.208.2",
"pre-commit",
Expand Down Expand Up @@ -93,6 +93,7 @@ artifacts = ["src/phoenix/server/static"]
dependencies = [
"pandas==1.4.0",
"pytest",
"pytest-asyncio",
"pytest-cov",
"pytest-lazy-fixture",
"arize",
Expand Down
119 changes: 81 additions & 38 deletions src/phoenix/experimental/evals/functions/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import signal
import traceback
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -48,10 +49,6 @@
_EXPLANATION = "explanation"


class EndOfQueue:
pass


class Unset:
pass

Expand Down Expand Up @@ -96,74 +93,120 @@ def __init__(
self.tqdm_bar_format = tqdm_bar_format
self.exit_on_error = exit_on_error

# An end of queue sentinel is used to signal to consumers that the queue is empty and that
# they should exit. This is necessary because some consumers may still be waiting for an
# item to be added to the queue when the producer finishes.
self.end_of_queue = EndOfQueue()

self._TERMINATE = False
self._TERMINATE = asyncio.Event()

def _signal_handler(self, signum: int, frame: Any) -> None:
self._TERMINATE = True
self._TERMINATE.set()
tqdm.write("Process was interrupted. The return value will be incomplete...")

async def producer(
self,
inputs: Sequence[Any],
queue: asyncio.Queue[Union[EndOfQueue, Tuple[int, Any]]],
queue: asyncio.Queue[Tuple[int, Any]],
done_producing: asyncio.Event,
) -> None:
for index, input in enumerate(inputs):
if self._TERMINATE:
break
await queue.put((index, input))
# adds an end of queue sentinel for each consumer, guaranteeing that any consumer that is
# currently waiting for an item will gracefully stop.
for _ in range(self.concurrency):
await queue.put(self.end_of_queue)
try:
for index, input in enumerate(inputs):
if self._TERMINATE.is_set():
break
await queue.put((index, input))
finally:
done_producing.set()

async def consumer(
self,
output: List[Any],
queue: asyncio.Queue[Union[EndOfQueue, Tuple[int, Any]]],
queue: asyncio.Queue[Tuple[int, Any]],
done_producing: asyncio.Event,
progress_bar: tqdm[Any],
) -> None:
termination_signal_task = None
while True:
item = await queue.get()
if item is self.end_of_queue:
return
if self._TERMINATE:
marked_done = False
try:
item = await asyncio.wait_for(queue.get(), timeout=1)
except asyncio.TimeoutError:
if done_producing.is_set() and queue.empty():
break
continue
if self._TERMINATE.is_set():
# discard any remaining items in the queue
queue.task_done()
marked_done = True
continue

item = cast(Tuple[int, Any], item)
index, payload = item
try:
result = await self.generate(payload)
output[index] = result
progress_bar.update()
except Exception as e:
tqdm.write(f"Exception in worker: {e}")
generate_task = asyncio.create_task(self.generate(payload))
termination_signal_task = asyncio.create_task(self._TERMINATE.wait())
done, pending = await asyncio.wait(
[generate_task, termination_signal_task],
timeout=60,
return_when=asyncio.FIRST_COMPLETED,
)
if generate_task in done:
output[index] = generate_task.result()
progress_bar.update()
elif self._TERMINATE.is_set():
# discard the pending task and remaining items in the queue
if not generate_task.done():
generate_task.cancel()
try:
# allow any cleanup to finish for the cancelled task
await generate_task
except asyncio.CancelledError:
# Handle the cancellation exception
pass
queue.task_done()
marked_done = True
continue
else:
tqdm.write(f"Worker timeout, requeuing: {payload}")
await queue.put(item)
except Exception:
tqdm.write(f"Exception in worker: {traceback.format_exc()}")
if self.exit_on_error:
self._TERMINATE = True
self._TERMINATE.set()
else:
progress_bar.update()
finally:
if not marked_done:
queue.task_done()
if termination_signal_task and not termination_signal_task.done():
termination_signal_task.cancel()

async def execute(self, inputs: Sequence[Any]) -> List[Any]:
signal.signal(signal.SIGINT, self._signal_handler)
outputs = [self.fallback_return_value] * len(inputs)
progress_bar = tqdm(total=len(inputs), bar_format=self.tqdm_bar_format)

queue: asyncio.Queue[Union[EndOfQueue, Tuple[int, Any]]] = asyncio.Queue(
maxsize=2 * self.concurrency
)
queue: asyncio.Queue[Tuple[int, Any]] = asyncio.Queue(maxsize=2 * self.concurrency)
done_producing = asyncio.Event()

producer = self.producer(inputs, queue)
producer = asyncio.create_task(self.producer(inputs, queue, done_producing))
consumers = [
asyncio.create_task(self.consumer(outputs, queue, progress_bar))
asyncio.create_task(self.consumer(outputs, queue, done_producing, progress_bar))
for _ in range(self.concurrency)
]

await asyncio.gather(producer, *consumers)
join_task = asyncio.create_task(queue.join())
termination_signal_task = asyncio.create_task(self._TERMINATE.wait())
done, pending = await asyncio.wait(
[join_task, termination_signal_task], return_when=asyncio.FIRST_COMPLETED
)
if termination_signal_task in done:
# Cancel all tasks
if not join_task.done():
join_task.cancel()
if not producer.done():
producer.cancel()
for task in consumers:
if not task.done():
task.cancel()

if not termination_signal_task.done():
termination_signal_task.cancel()
return outputs

def run(self, inputs: Sequence[Any]) -> List[Any]:
Expand Down Expand Up @@ -276,7 +319,7 @@ def llm_classify(
verbose: bool = False,
use_function_calling_if_available: bool = True,
provide_explanation: bool = False,
concurrency: int = 3,
concurrency: int = 20,
) -> pd.DataFrame:
"""Classifies each input row of the dataframe using an LLM. Returns a pandas.DataFrame
where the first column is named `label` and contains the classification labels. An optional
Expand Down
12 changes: 9 additions & 3 deletions src/phoenix/experimental/evals/models/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from abc import ABC, abstractmethod, abstractproperty
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Sequence, Type

from phoenix.experimental.evals.models.rate_limiters import RateLimiter

if TYPE_CHECKING:
from tiktoken import Encoding

Expand Down Expand Up @@ -44,16 +46,20 @@ def set_verbosity(
model: "BaseEvalModel", verbose: bool = False
) -> Generator["BaseEvalModel", None, None]:
try:
_verbose = model._verbose
_model_verbose_setting = model._verbose
_rate_limiter_verbose_setting = model._rate_limiter._verbose
model._verbose = verbose
model._rate_limiter._verbose = verbose
yield model
finally:
model._verbose = _verbose
model._verbose = _model_verbose_setting
model._rate_limiter._verbose = _rate_limiter_verbose_setting


@dataclass
class BaseEvalModel(ABC):
_verbose: bool = False
_rate_limiter: RateLimiter = field(default_factory=RateLimiter)

def _retry(
self,
Expand Down
20 changes: 12 additions & 8 deletions src/phoenix/experimental/evals/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from phoenix.experimental.evals.models.base import BaseEvalModel
from phoenix.experimental.evals.models.rate_limiters import RateLimiter

if TYPE_CHECKING:
from tiktoken import Encoding
Expand Down Expand Up @@ -52,6 +53,7 @@ def __post_init__(self) -> None:
self._init_environment()
self._init_client()
self._init_tiktoken()
self._init_rate_limiter()

def _init_environment(self) -> None:
try:
Expand All @@ -69,13 +71,6 @@ def _init_client(self) -> None:
import boto3 # type:ignore

self.client = boto3.client("bedrock-runtime")
self._retry_errors = [self.client.exceptions.ThrottlingException]
self.retry = self._retry(
error_types=self._retry_errors,
min_seconds=self.retry_min_seconds,
max_seconds=self.retry_max_seconds,
max_retries=self.max_retries,
)
except ImportError:
self._raise_import_error(
package_name="boto3",
Expand All @@ -90,6 +85,15 @@ def _init_tiktoken(self) -> None:
encoding = self._tiktoken.get_encoding("cl100k_base")
self._tiktoken_encoding = encoding

def _init_rate_limiter(self) -> None:
self._rate_limiter = RateLimiter(
rate_limit_error=self.client.exceptions.ThrottlingException,
max_rate_limit_retries=10,
initial_per_second_request_rate=2,
maximum_per_second_request_rate=20,
enforcement_window_minutes=1,
)

@property
def max_context_size(self) -> int:
context_size = self.max_content_size or MODEL_TOKEN_LIMIT_MAPPING.get(self.model_id, None)
Expand Down Expand Up @@ -130,7 +134,7 @@ def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
def _generate_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""

@self.retry
@self._rate_limiter.limit
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.invoke_model(**kwargs)

Expand Down
16 changes: 15 additions & 1 deletion src/phoenix/experimental/evals/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

from phoenix.experimental.evals.models.base import BaseEvalModel
from phoenix.experimental.evals.models.rate_limiters import RateLimiter

if TYPE_CHECKING:
from tiktoken import Encoding
Expand Down Expand Up @@ -104,9 +105,11 @@ def __post_init__(self) -> None:
self._init_environment()
self._init_open_ai()
self._init_tiktoken()
self._init_rate_limiter()

def _init_environment(self) -> None:
try:
import httpx
import openai
import openai._utils as openai_util

Expand All @@ -116,8 +119,8 @@ def _init_environment(self) -> None:
self._openai.APITimeoutError,
self._openai.APIError,
self._openai.APIConnectionError,
self._openai.RateLimitError,
self._openai.InternalServerError,
httpx.ReadTimeout,
]
self.retry = self._retry(
error_types=self._openai_retry_errors,
Expand Down Expand Up @@ -235,6 +238,15 @@ def _get_azure_options(self) -> AzureOptions:
options[option.name] = None
return AzureOptions(**options)

def _init_rate_limiter(self) -> None:
self._rate_limiter = RateLimiter(
rate_limit_error=self._openai.RateLimitError,
max_rate_limit_retries=10,
initial_per_second_request_rate=5,
maximum_per_second_request_rate=20,
enforcement_window_minutes=1,
)

@staticmethod
def _build_messages(
prompt: str, system_instruction: Optional[str] = None
Expand Down Expand Up @@ -289,6 +301,7 @@ async def _async_generate_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""

@self.retry
@self._rate_limiter.alimit
async def _completion_with_retry(**kwargs: Any) -> Any:
if self._model_uses_legacy_completion_api:
if "prompt" not in kwargs:
Expand All @@ -309,6 +322,7 @@ def _generate_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""

@self.retry
@self._rate_limiter.limit
def _completion_with_retry(**kwargs: Any) -> Any:
if self._model_uses_legacy_completion_api:
if "prompt" not in kwargs:
Expand Down
Loading

0 comments on commit 115e044

Please sign in to comment.