Skip to content

Commit

Permalink
feat: cache http client instance and inject default settings
Browse files Browse the repository at this point in the history
chore: call `raise_for_status` after fetching
  • Loading branch information
CNSeniorious000 committed Feb 5, 2024
1 parent 878a062 commit 7b99735
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 26 deletions.
12 changes: 2 additions & 10 deletions python/promplate/llm/openai/v1.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from copy import copy
from functools import cached_property
from importlib.metadata import version
from sys import version as py_version
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar

from openai import AsyncClient, Client # type: ignore

from ...prompt.chat import Message, ensure
from ...prompt.utils import get_user_agent
from ..base import *

P = ParamSpec("P")
Expand All @@ -33,14 +32,7 @@ def bind(self, **run_config):

@cached_property
def _user_agent(self):
return " ".join(
(
f"Promplate/{version('promplate')} ({self.__class__.__name__})",
f"OpenAI/{version('openai')}",
f"HTTPX/{version('httpx')}",
f"Python/{py_version.split()[0]}",
)
)
return get_user_agent(self, ("OpenAI", "openai"))

@property
def _config(self): # type: ignore
Expand Down
28 changes: 13 additions & 15 deletions python/promplate/prompt/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,31 +145,29 @@ async def aread(cls, path: str | Path, encoding="utf-8"):
obj.name = path.stem
return obj

_client = None
@classmethod
def _patch_kwargs(cls, kwargs: dict):
return {
"follow_redirects": True,
"base_url": "https://promplate.dev/",
"headers": {"User-Agent": get_user_agent(cls)},
} | kwargs

@classmethod
def fetch(cls, url: str, **kwargs):
if cls._client is None:
from httpx import Client

cls._client = Client(**kwargs)
from .utils import _get_client

response = cls._client.get(url)
obj = cls(response.text)
response = _get_client(cls._patch_kwargs(kwargs)).get(url)
obj = cls(response.raise_for_status().text)
obj.name = Path(url).stem
return obj

_aclient = None

@classmethod
async def afetch(cls, url: str, **kwargs):
if cls._aclient is None:
from httpx import AsyncClient

cls._aclient = AsyncClient(**kwargs)
from .utils import _get_aclient

response = await cls._aclient.get(url)
obj = cls(response.text)
response = await _get_aclient(cls._patch_kwargs(kwargs)).get(url)
obj = cls(response.raise_for_status().text)
obj.name = Path(url).stem
return obj

Expand Down
31 changes: 30 additions & 1 deletion python/promplate/prompt/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cached_property, wraps
from inspect import currentframe
from inspect import currentframe, isclass
from re import compile
from typing import Any, Callable, ParamSpec, TypeVar

Expand Down Expand Up @@ -88,3 +88,32 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
@cache_once
def get_builtins() -> dict[str, Any]:
return __builtins__ if isinstance(__builtins__, dict) else __builtins__.__dict__


@cache_once
def get_user_agent(self, *additional_packages: tuple[str, str]):
from importlib.metadata import version
from sys import version as py_version

return " ".join(
(
f"Promplate/{version('promplate')} ({self.__name__ if isclass(self) else self.__class__.__name__})",
*(f"{display_name}/{version(package)}" for display_name, package in additional_packages),
f"HTTPX/{version('httpx')}",
f"Python/{py_version.split()[0]}",
)
)


@cache_once
def _get_client(kwargs):
from httpx import Client

return Client(**kwargs)


@cache_once
def _get_aclient(kwargs):
from httpx import AsyncClient

return AsyncClient(**kwargs)

0 comments on commit 7b99735

Please sign in to comment.