Skip to content

Commit

Permalink
Add async support (#146)
Browse files Browse the repository at this point in the history
* Add async support

* Fix aiohttp requests
* Fix some syntax errors

* Close aiohttp session properly
* This is due to a lack of an async __del__ method

* Fix code per review

* Fix async tests and some mypy errors

* Run black

* Add todo for multipart form generation

* Fix more mypy

* Fix exception type

* Don't yield twice

Co-authored-by: Damien Deville <damien@openai.com>
  • Loading branch information
Andrew-Chen-Wang and ddeville authored Jan 5, 2023
1 parent ec4943f commit 0abf641
Show file tree
Hide file tree
Showing 30 changed files with 1,288 additions and 74 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,32 @@ image_resp = openai.Image.create(prompt="two dogs playing chess, oil painting",

```

## Async API

Async support is available in the API by prepending `a` to a network-bound method:

```python
import openai
openai.api_key = "sk-..." # supply your API key however you choose

async def create_completion():
completion_resp = await openai.Completion.acreate(prompt="This is a test", engine="davinci")

```

To make async requests more efficient, you can pass in your own
``aiohttp.ClientSession``, but you must manually close the client session at the end
of your program/event loop:

```python
import openai
from aiohttp import ClientSession

openai.aiosession.set(ClientSession())
# At the end of your program, close the http session
await openai.aiosession.get().close()
```

See the [usage guide](https://beta.openai.com/docs/guides/images) for more details.

## Requirements
Expand Down
11 changes: 10 additions & 1 deletion openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# Originally forked from the MIT-licensed Stripe Python bindings.

import os
from typing import Optional
from contextvars import ContextVar
from typing import Optional, TYPE_CHECKING

from openai.api_resources import (
Answer,
Expand All @@ -24,6 +25,9 @@
)
from openai.error import APIError, InvalidRequestError, OpenAIError

if TYPE_CHECKING:
from aiohttp import ClientSession

api_key = os.environ.get("OPENAI_API_KEY")
# Path of a file with an API key, whose contents can change. Supercedes
# `api_key` if set. The main use case is volume-mounted Kubernetes secrets,
Expand All @@ -44,6 +48,11 @@
debug = False
log = None # Set to either 'debug' or 'info', controls console logging

aiosession: ContextVar[Optional["ClientSession"]] = ContextVar(
"aiohttp-session", default=None
) # Acts as a global aiohttp ClientSession that reuses connections.
# This is user-supplied; otherwise, a session is remade for each request.

__all__ = [
"APIError",
"Answer",
Expand Down
255 changes: 234 additions & 21 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
import json
import platform
import sys
import threading
import warnings
from json import JSONDecodeError
from typing import Dict, Iterator, Optional, Tuple, Union, overload
from typing import AsyncGenerator, Dict, Iterator, Optional, Tuple, Union, overload
from urllib.parse import urlencode, urlsplit, urlunsplit

import aiohttp
import requests

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -49,6 +51,20 @@ def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]:
)


def _aiohttp_proxies_arg(proxy) -> Optional[str]:
"""Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request."""
if proxy is None:
return None
elif isinstance(proxy, str):
return proxy
elif isinstance(proxy, dict):
return proxy["https"] if "https" in proxy else proxy["http"]
else:
raise ValueError(
"'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
)


def _make_session() -> requests.Session:
if not openai.verify_ssl_certs:
warnings.warn("verify_ssl_certs is ignored; openai always verifies.")
Expand All @@ -63,18 +79,32 @@ def _make_session() -> requests.Session:
return s


def parse_stream_helper(line):
if line:
if line == b"data: [DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
return None
if hasattr(line, "decode"):
line = line.decode("utf-8")
if line.startswith("data: "):
line = line[len("data: ") :]
return line
return None


def parse_stream(rbody):
for line in rbody:
if line:
if line == b"data: [DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
continue
if hasattr(line, "decode"):
line = line.decode("utf-8")
if line.startswith("data: "):
line = line[len("data: ") :]
yield line
_line = parse_stream_helper(line)
if _line is not None:
yield _line


async def parse_stream_async(rbody: aiohttp.StreamReader):
async for line in rbody:
_line = parse_stream_helper(line)
if _line is not None:
yield _line


class APIRequestor:
Expand Down Expand Up @@ -186,6 +216,86 @@ def request(
resp, got_stream = self._interpret_response(result, stream)
return resp, got_stream, self.api_key

@overload
async def arequest(
self,
method,
url,
params,
headers,
files,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass

@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
*,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass

@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: Literal[False] = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[OpenAIResponse, bool, str]:
pass

@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: bool = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
pass

async def arequest(
self,
method,
url,
params=None,
headers=None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
result = await self.arequest_raw(
method.lower(),
url,
params=params,
supplied_headers=headers,
files=files,
request_id=request_id,
request_timeout=request_timeout,
)
resp, got_stream = await self._interpret_async_response(result, stream)
return resp, got_stream, self.api_key

def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
try:
error_data = resp["error"]
Expand Down Expand Up @@ -315,18 +425,15 @@ def _validate_headers(

return headers

def request_raw(
def _prepare_request_raw(
self,
method,
url,
*,
params=None,
supplied_headers: Dict[str, str] = None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> requests.Response:
supplied_headers,
method,
params,
files,
request_id: Optional[str],
) -> Tuple[str, Dict[str, str], Optional[bytes]]:
abs_url = "%s%s" % (self.api_base, url)
headers = self._validate_headers(supplied_headers)

Expand Down Expand Up @@ -355,6 +462,24 @@ def request_raw(
util.log_info("Request to OpenAI API", method=method, path=abs_url)
util.log_debug("Post details", data=data, api_version=self.api_version)

return abs_url, headers, data

def request_raw(
self,
method,
url,
*,
params=None,
supplied_headers: Optional[Dict[str, str]] = None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> requests.Response:
abs_url, headers, data = self._prepare_request_raw(
url, supplied_headers, method, params, files, request_id
)

if not hasattr(_thread_context, "session"):
_thread_context.session = _make_session()
try:
Expand Down Expand Up @@ -385,6 +510,71 @@ def request_raw(
)
return result

async def arequest_raw(
self,
method,
url,
*,
params=None,
supplied_headers: Optional[Dict[str, str]] = None,
files=None,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> aiohttp.ClientResponse:
abs_url, headers, data = self._prepare_request_raw(
url, supplied_headers, method, params, files, request_id
)

if isinstance(request_timeout, tuple):
timeout = aiohttp.ClientTimeout(
connect=request_timeout[0],
total=request_timeout[1],
)
else:
timeout = aiohttp.ClientTimeout(
total=request_timeout if request_timeout else TIMEOUT_SECS
)
user_set_session = openai.aiosession.get()

if files:
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
# For now we use the private `requests` method that is known to have worked so far.
data, content_type = requests.models.RequestEncodingMixin._encode_files( # type: ignore
files, data
)
headers["Content-Type"] = content_type
request_kwargs = {
"method": method,
"url": abs_url,
"headers": headers,
"data": data,
"proxy": _aiohttp_proxies_arg(openai.proxy),
"timeout": timeout,
}
try:
if user_set_session:
result = await user_set_session.request(**request_kwargs)
else:
async with aiohttp.ClientSession() as session:
result = await session.request(**request_kwargs)
util.log_info(
"OpenAI API response",
path=abs_url,
response_code=result.status,
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
request_id=result.headers.get("X-Request-Id"),
)
# Don't read the whole stream for debug logging unless necessary.
if openai.log == "debug":
util.log_debug(
"API response body", body=result.content, headers=result.headers
)
return result
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
raise error.Timeout("Request timed out") from e
except aiohttp.ClientError as e:
raise error.APIConnectionError("Error communicating with OpenAI") from e

def _interpret_response(
self, result: requests.Response, stream: bool
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
Expand All @@ -404,6 +594,29 @@ def _interpret_response(
False,
)

async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
return (
self._interpret_response_line(
line, result.status, result.headers, stream=True
)
async for line in parse_stream_async(result.content)
), True
else:
try:
await result.read()
except aiohttp.ClientError as e:
util.log_warn(e, body=result.content)
return (
self._interpret_response_line(
await result.read(), result.status, result.headers, stream=False
),
False,
)

def _interpret_response_line(
self, rbody, rcode, rheaders, stream: bool
) -> OpenAIResponse:
Expand Down
Loading

0 comments on commit 0abf641

Please sign in to comment.