Skip to content

Commit

Permalink
Fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
anticorrelator committed Oct 16, 2023
1 parent a75d419 commit eaeee66
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
33 changes: 17 additions & 16 deletions scripts/rate_limit/openai_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from collections import defaultdict, deque
from json import JSONDecodeError
from typing import Any, Deque, Dict

import httpx
import openai
Expand All @@ -16,10 +17,10 @@
TOTAL_TOKENS = 0
ERRORS = 0
stop_event = asyncio.Event()
request_times = deque()
tokens_processed = deque()
request_times: Deque[float] = deque()
tokens_processed: Deque[int] = deque()
log = defaultdict(list)
error_log = defaultdict(list)
error_log: Dict[str, Any] = defaultdict(list)

API_URL = "https://api.openai.com/v1/chat/completions"
openai.api_key = os.environ["OPENAI_API_KEY"]
Expand All @@ -31,7 +32,7 @@
MAX_CONCURRENT_REQUESTS = 20

MAX_QUEUE_SIZE = 40
request_queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
request_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)

prompt = "hello!"
payload_template = {
Expand All @@ -44,25 +45,25 @@
rate_limiter = OpenAIRateLimiter(openai.api_key)


def request_time(bucket_size):
def request_time(bucket_size: int) -> float:
global request_times
recent_request_times = (request_times.popleft() for _ in range(bucket_size))
return sum(recent_request_times) / bucket_size # seconds


def effective_rate():
def effective_rate() -> float:
elapsed_time = time.time() - START_TIME
global COMPLETED_RESPONSES
return 60 * COMPLETED_RESPONSES / elapsed_time # requests per minute


def effective_token_rate():
def effective_token_rate() -> float:
elapsed_time = time.time() - START_TIME
global TOTAL_TOKENS
return 60 * TOTAL_TOKENS / elapsed_time # requests per minute


def print_rate_info():
def print_rate_info() -> None:
info_interval = 20
if len(request_times) > info_interval:
elapsed_time = time.time() - START_TIME
Expand All @@ -83,7 +84,7 @@ def print_rate_info():
print(info_str)


def print_error(response):
def print_error(response: httpx.Response) -> None:
if response.status_code != 200:
elapsed_time = time.time() - START_TIME

Expand All @@ -96,7 +97,7 @@ def print_error(response):
error_log["error_payload"].append("no json payload")


def initial_token_cost(payload) -> int:
def initial_token_cost(payload: Dict[str, Any]) -> int:
"""Return the number of tokens used by a list of messages.
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
Expand All @@ -118,15 +119,15 @@ def initial_token_cost(payload) -> int:
return token_count


def response_token_cost(response):
def response_token_cost(response: httpx.Response) -> int:
if response.status_code == 200:
return response.json()["usage"]["completion_tokens"]
return int(response.json()["usage"]["completion_tokens"])
else:
return 0


@rate_limiter.alimit("gpt-4", initial_token_cost, response_token_cost)
async def openai_request(payload):
async def openai_request(payload: Dict[str, Any]) -> httpx.Response:
async with httpx.AsyncClient() as client:
response = await client.post(API_URL, headers=HEADERS, json=payload, timeout=None)
if response.status_code != 200:
Expand All @@ -147,21 +148,21 @@ async def openai_request(payload):
return response


async def producer():
async def producer() -> None:
while not stop_event.is_set():
await request_queue.put(payload_template)
await asyncio.sleep(0.001)


async def consumer():
async def consumer() -> None:
while not stop_event.is_set():
global request_queue
payload = await request_queue.get()
await openai_request(payload)
request_queue.task_done()


async def main(timeout_duration):
async def main(timeout_duration: int) -> None:
producer_task = asyncio.create_task(producer())
[asyncio.create_task(consumer()) for _ in range(MAX_CONCURRENT_REQUESTS)]

Expand Down
9 changes: 5 additions & 4 deletions src/phoenix/utilities/ratelimits.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from collections import defaultdict
from functools import wraps
from typing import Any, Callable, Dict, TypeVar, Union
from typing import Any, Awaitable, Callable, Dict, TypeVar, Union, cast

if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
Expand All @@ -13,6 +13,7 @@
Numeric = Union[int, float]
P = ParamSpec("P")
T = TypeVar("T")
A = TypeVar("A", bound=Callable[..., Awaitable[Any]])


class UnavailableTokensError(Exception):
Expand Down Expand Up @@ -181,8 +182,8 @@ def alimit(
model_name: str,
input_cost_fn: Callable[..., Numeric],
response_cost_fn: Callable[..., Numeric],
) -> Callable[[Callable[P, T]], Callable[P, T]]:
def rate_limit_decorator(fn: Callable[P, T]) -> Callable[P, T]:
) -> Callable[[A], A]:
def rate_limit_decorator(fn: A) -> A:
@wraps(fn)
async def wrapper(*args: Any, **kwargs: Any) -> T:
key = self.key(model_name)
Expand All @@ -193,6 +194,6 @@ async def wrapper(*args: Any, **kwargs: Any) -> T:
self._store.spend_rate_limits(key, {"tokens": response_cost_fn(result)})
return result

return wrapper
return cast(A, wrapper)

return rate_limit_decorator

0 comments on commit eaeee66

Please sign in to comment.