From eaeee667a326fbae7cb0abdebb73beaa00d84abf Mon Sep 17 00:00:00 2001 From: Dustin Ngo Date: Mon, 16 Oct 2023 19:09:07 -0400 Subject: [PATCH] Fix typing --- scripts/rate_limit/openai_testing.py | 33 ++++++++++++++-------------- src/phoenix/utilities/ratelimits.py | 9 ++++---- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/scripts/rate_limit/openai_testing.py b/scripts/rate_limit/openai_testing.py index 403c453a86..fc8a8bb21d 100644 --- a/scripts/rate_limit/openai_testing.py +++ b/scripts/rate_limit/openai_testing.py @@ -4,6 +4,7 @@ import time from collections import defaultdict, deque from json import JSONDecodeError +from typing import Any, Deque, Dict import httpx import openai @@ -16,10 +17,10 @@ TOTAL_TOKENS = 0 ERRORS = 0 stop_event = asyncio.Event() -request_times = deque() -tokens_processed = deque() +request_times: Deque[float] = deque() +tokens_processed: Deque[int] = deque() log = defaultdict(list) -error_log = defaultdict(list) +error_log: Dict[str, Any] = defaultdict(list) API_URL = "https://api.openai.com/v1/chat/completions" openai.api_key = os.environ["OPENAI_API_KEY"] @@ -31,7 +32,7 @@ MAX_CONCURRENT_REQUESTS = 20 MAX_QUEUE_SIZE = 40 -request_queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE) +request_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=MAX_QUEUE_SIZE) prompt = "hello!" payload_template = { @@ -44,25 +45,25 @@ rate_limiter = OpenAIRateLimiter(openai.api_key) -def request_time(bucket_size): +def request_time(bucket_size: int) -> float: global request_times recent_request_times = (request_times.popleft() for _ in range(bucket_size)) return sum(recent_request_times) / bucket_size # seconds -def effective_rate(): +def effective_rate() -> float: elapsed_time = time.time() - START_TIME global COMPLETED_RESPONSES return 60 * COMPLETED_RESPONSES / elapsed_time # requests per minute -def effective_token_rate(): +def effective_token_rate() -> float: elapsed_time = time.time() - START_TIME global TOTAL_TOKENS return 60 * TOTAL_TOKENS / elapsed_time # requests per minute -def print_rate_info(): +def print_rate_info() -> None: info_interval = 20 if len(request_times) > info_interval: elapsed_time = time.time() - START_TIME @@ -83,7 +84,7 @@ def print_rate_info(): print(info_str) -def print_error(response): +def print_error(response: httpx.Response) -> None: if response.status_code != 200: elapsed_time = time.time() - START_TIME @@ -96,7 +97,7 @@ def print_error(response): error_log["error_payload"].append("no json payload") -def initial_token_cost(payload) -> int: +def initial_token_cost(payload: Dict[str, Any]) -> int: """Return the number of tokens used by a list of messages. Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb @@ -118,15 +119,15 @@ def initial_token_cost(payload) -> int: return token_count -def response_token_cost(response): +def response_token_cost(response: httpx.Response) -> int: if response.status_code == 200: - return response.json()["usage"]["completion_tokens"] + return int(response.json()["usage"]["completion_tokens"]) else: return 0 @rate_limiter.alimit("gpt-4", initial_token_cost, response_token_cost) -async def openai_request(payload): +async def openai_request(payload: Dict[str, Any]) -> httpx.Response: async with httpx.AsyncClient() as client: response = await client.post(API_URL, headers=HEADERS, json=payload, timeout=None) if response.status_code != 200: @@ -147,13 +148,13 @@ async def openai_request(payload): return response -async def producer(): +async def producer() -> None: while not stop_event.is_set(): await request_queue.put(payload_template) await asyncio.sleep(0.001) -async def consumer(): +async def consumer() -> None: while not stop_event.is_set(): global request_queue payload = await request_queue.get() @@ -161,7 +162,7 @@ async def consumer(): request_queue.task_done() -async def main(timeout_duration): +async def main(timeout_duration: int) -> None: producer_task = asyncio.create_task(producer()) [asyncio.create_task(consumer()) for _ in range(MAX_CONCURRENT_REQUESTS)] diff --git a/src/phoenix/utilities/ratelimits.py b/src/phoenix/utilities/ratelimits.py index 88c253cad4..13ba02bd59 100644 --- a/src/phoenix/utilities/ratelimits.py +++ b/src/phoenix/utilities/ratelimits.py @@ -3,7 +3,7 @@ import time from collections import defaultdict from functools import wraps -from typing import Any, Callable, Dict, TypeVar, Union +from typing import Any, Awaitable, Callable, Dict, TypeVar, Union, cast if sys.version_info < (3, 10): from typing_extensions import ParamSpec @@ -13,6 +13,7 @@ Numeric = Union[int, float] P = ParamSpec("P") T = TypeVar("T") +A = TypeVar("A", bound=Callable[..., Awaitable[Any]]) class UnavailableTokensError(Exception): @@ -181,8 +182,8 @@ def alimit( model_name: str, input_cost_fn: Callable[..., Numeric], response_cost_fn: Callable[..., Numeric], - ) -> Callable[[Callable[P, T]], Callable[P, T]]: - def rate_limit_decorator(fn: Callable[P, T]) -> Callable[P, T]: + ) -> Callable[[A], A]: + def rate_limit_decorator(fn: A) -> A: @wraps(fn) async def wrapper(*args: Any, **kwargs: Any) -> T: key = self.key(model_name) @@ -193,6 +194,6 @@ async def wrapper(*args: Any, **kwargs: Any) -> T: self._store.spend_rate_limits(key, {"tokens": response_cost_fn(result)}) return result - return wrapper + return cast(A, wrapper) return rate_limit_decorator