From e26e7f7cd237b696255d0e663513198647f8c2ae Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 5 Nov 2024 19:01:17 -0800 Subject: [PATCH] First WIP prototype of async mode, refs #507 --- llm/__init__.py | 6 + llm/cli.py | 56 ++++-- llm/default_plugins/openai_models.py | 60 ++++++- llm/models.py | 249 ++++++++++++++++++++++++--- 4 files changed, 328 insertions(+), 43 deletions(-) diff --git a/llm/__init__.py b/llm/__init__.py index 0ea6c242..de838418 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -4,6 +4,7 @@ NeedsKeyException, ) from .models import ( + AsyncModel, Attachment, Conversation, Model, @@ -26,6 +27,7 @@ __all__ = [ "hookimpl", + "get_async_model", "get_model", "get_key", "user_dir", @@ -143,6 +145,10 @@ def get_model_aliases() -> Dict[str, Model]: return model_aliases +def get_async_model(model_id: str) -> AsyncModel: + return get_model(model_id).get_async_model() + + class UnknownModelError(KeyError): pass diff --git a/llm/cli.py b/llm/cli.py index 941831c5..90eb78c5 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1,3 +1,4 @@ +import asyncio import click from click_default_group import DefaultGroup from dataclasses import asdict @@ -11,6 +12,7 @@ Template, UnknownModelError, encode, + get_async_model, get_default_model, get_default_embedding_model, get_embedding_models_with_aliases, @@ -193,6 +195,7 @@ def cli(): ) @click.option("--key", help="API key to use") @click.option("--save", help="Save prompt with this template name") +@click.option("async_", "--async", is_flag=True, help="Run prompt asynchronously") def prompt( prompt, system, @@ -209,6 +212,7 @@ def prompt( conversation_id, key, save, + async_, ): """ Execute a prompt @@ -325,7 +329,10 @@ def read_prompt(): # Now resolve the model try: - model = model_aliases[model_id] + if async_: + model = get_async_model(model_id) + else: + model = get_model(model_id) except KeyError: raise click.ClickException("'{}' is not a known model".format(model_id)) @@ -363,21 +370,48 @@ def read_prompt(): prompt_method = conversation.prompt try: - response = prompt_method( - prompt, attachments=resolved_attachments, system=system, **validated_options - ) - if should_stream: - for chunk in response: - print(chunk, end="") - sys.stdout.flush() - print("") + if async_: + + async def inner(): + if should_stream: + async for chunk in prompt_method( + prompt, + attachments=resolved_attachments, + system=system, + **validated_options, + ): + print(chunk, end="") + sys.stdout.flush() + print("") + else: + response = await prompt_method( + prompt, + attachments=resolved_attachments, + system=system, + **validated_options, + ) + print(response.text()) + + asyncio.run(inner()) else: - print(response.text()) + response = prompt_method( + prompt, + attachments=resolved_attachments, + system=system, + **validated_options, + ) + if should_stream: + for chunk in response: + print(chunk, end="") + sys.stdout.flush() + print("") + else: + print(response.text()) except Exception as ex: raise click.ClickException(str(ex)) # Log to the database - if (logs_on() or log) and not no_log: + if (logs_on() or log) and not no_log and not async_: log_path = logs_db_path() (log_path.parent).mkdir(parents=True, exist_ok=True) db = sqlite_utils.Database(log_path) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 5cbb02bb..777bd346 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -1,4 +1,4 @@ -from llm import EmbeddingModel, Model, hookimpl +from llm import AsyncModel, EmbeddingModel, Model, hookimpl import llm from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client import click @@ -254,6 +254,9 @@ class Chat(Model): default_max_tokens = None + def get_async_model(self): + return AsyncChat(self.model_id, self.key) + class Options(SharedOptions): json_object: Optional[bool] = Field( description="Output a valid JSON object {...}. Prompt must mention JSON.", @@ -297,10 +300,8 @@ def __init__( def __str__(self): return "OpenAI Chat: {}".format(self.model_id) - def execute(self, prompt, stream, response, conversation=None): + def build_messages(self, prompt, conversation): messages = [] - if prompt.system and not self.allows_system_prompt: - raise NotImplementedError("Model does not support system prompts") current_system = None if conversation is not None: for prev_response in conversation.responses: @@ -345,7 +346,12 @@ def execute(self, prompt, stream, response, conversation=None): {"type": "image_url", "image_url": {"url": url}} ) messages.append({"role": "user", "content": attachment_message}) + return messages + def execute(self, prompt, stream, response, conversation=None): + if prompt.system and not self.allows_system_prompt: + raise NotImplementedError("Model does not support system prompts") + messages = self.build_messages(prompt, conversation) kwargs = self.build_kwargs(prompt, stream) client = self.get_client() if stream: @@ -376,7 +382,7 @@ def execute(self, prompt, stream, response, conversation=None): yield completion.choices[0].message.content response._prompt_json = redact_data_urls({"messages": messages}) - def get_client(self): + def get_client(self, async_=False): kwargs = {} if self.api_base: kwargs["base_url"] = self.api_base @@ -396,7 +402,10 @@ def get_client(self): kwargs["default_headers"] = self.headers if os.environ.get("LLM_OPENAI_SHOW_RESPONSES"): kwargs["http_client"] = logging_client() - return openai.OpenAI(**kwargs) + if async_: + return openai.AsyncOpenAI(**kwargs) + else: + return openai.OpenAI(**kwargs) def build_kwargs(self, prompt, stream): kwargs = dict(not_nulls(prompt.options)) @@ -410,6 +419,45 @@ def build_kwargs(self, prompt, stream): return kwargs +class AsyncChat(AsyncModel, Chat): + needs_key = "openai" + key_env_var = "OPENAI_API_KEY" + + async def execute(self, prompt, stream, response, conversation=None): + if prompt.system and not self.allows_system_prompt: + raise NotImplementedError("Model does not support system prompts") + messages = self.build_messages(prompt, conversation) + kwargs = self.build_kwargs(prompt, stream) + client = self.get_client(async_=True) + if stream: + completion = await client.chat.completions.create( + model=self.model_name or self.model_id, + messages=messages, + stream=True, + **kwargs, + ) + chunks = [] + async for chunk in completion: + chunks.append(chunk) + try: + content = chunk.choices[0].delta.content + except IndexError: + content = None + if content is not None: + yield content + response.response_json = remove_dict_none_values(combine_chunks(chunks)) + else: + completion = await client.chat.completions.create( + model=self.model_name or self.model_id, + messages=messages, + stream=False, + **kwargs, + ) + response.response_json = remove_dict_none_values(completion.model_dump()) + yield completion.choices[0].message.content + response._prompt_json = redact_data_urls({"messages": messages}) + + class Completion(Chat): class Options(SharedOptions): logprobs: Optional[int] = Field( diff --git a/llm/models.py b/llm/models.py index 838e25b1..d41b17e9 100644 --- a/llm/models.py +++ b/llm/models.py @@ -8,7 +8,19 @@ import puremagic import re import time -from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union +from typing import ( + Any, + AsyncIterator, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + Set, + TypeVar, + Union, +) from abc import ABC, abstractmethod import json from pydantic import BaseModel @@ -144,13 +156,19 @@ def from_row(cls, row): ) -class Response(ABC): +ModelT = TypeVar("ModelT", bound=Union["Model", "AsyncModel"]) +ConversationT = TypeVar( + "ConversationT", bound=Optional[Union["Conversation", "AsyncConversation"]] +) + + +class _BaseResponse(ABC, Generic[ModelT, ConversationT]): def __init__( self, prompt: Prompt, - model: "Model", + model: ModelT, stream: bool, - conversation: Optional[Conversation] = None, + conversation: ConversationT = None, ): self.prompt = prompt self._prompt_json = None @@ -161,28 +179,9 @@ def __init__( self.response_json = None self.conversation = conversation self.attachments: List[Attachment] = [] - - def __iter__(self) -> Iterator[str]: - self._start = time.monotonic() - self._start_utcnow = datetime.datetime.utcnow() - if self._done: - yield from self._chunks - for chunk in self.model.execute( - self.prompt, - stream=self.stream, - response=self, - conversation=self.conversation, - ): - yield chunk - self._chunks.append(chunk) - if self.conversation: - self.conversation.responses.append(self) - self._end = time.monotonic() - self._done = True - - def _force(self): - if not self._done: - list(self) + self._start: Optional[float] = None + self._end: Optional[float] = None + self._start_utcnow: Optional[datetime.datetime] = None def __str__(self) -> str: return self.text() @@ -203,6 +202,30 @@ def datetime_utc(self) -> str: self._force() return self._start_utcnow.isoformat() + +class Response(_BaseResponse["Model", Optional["Conversation"]]): + def _force(self): + if not self._done: + list(self) + + def __iter__(self) -> Iterator[str]: + self._start = time.monotonic() + self._start_utcnow = datetime.datetime.utcnow() + if self._done: + yield from self._chunks + for chunk in self.model.execute( + self.prompt, + stream=self.stream, + response=self, + conversation=self.conversation, + ): + yield chunk + self._chunks.append(chunk) + if self.conversation: + self.conversation.responses.append(self) + self._end = time.monotonic() + self._done = True + def log_to_db(self, db): conversation = self.conversation if not conversation: @@ -257,6 +280,51 @@ def log_to_db(self, db): }, ) + +class AsyncResponse(_BaseResponse["AsyncModel", Optional["AsyncConversation"]]): + async def _force(self): + if not self._done: + async for _ in self: + pass + + async def __aiter__(self) -> AsyncIterator[str]: + self._start = time.monotonic() + self._start_utcnow = datetime.datetime.utcnow() + if self._done: + for chunk in self._chunks: + yield chunk + return + + async for chunk in self.model.execute( + self.prompt, + stream=self.stream, + response=self, + conversation=self.conversation, + ): + yield chunk + self._chunks.append(chunk) + if self.conversation: + self.conversation.responses.append(self) + self._end = time.monotonic() + self._done = True + + # Override base methods to make them async + async def text(self) -> str: + await self._force() + return "".join(self._chunks) + + async def json(self) -> Optional[Dict[str, Any]]: + await self._force() + return self.response_json + + async def duration_ms(self) -> int: + await self._force() + return int((self._end - self._start) * 1000) + + async def datetime_utc(self) -> str: + await self._force() + return self._start_utcnow.isoformat() + @classmethod def fake( cls, @@ -362,6 +430,135 @@ def get_key(self): raise NeedsKeyException(message) +ResponseT = TypeVar("ResponseT") +ConversationT = TypeVar("ConversationT") + + +class _BaseModel(ABC, _get_key_mixin, Generic[ResponseT, ConversationT]): + model_id: str + + # API key handling + key: Optional[str] = None + needs_key: Optional[str] = None + key_env_var: Optional[str] = None + + # Model characteristics + can_stream: bool = False + attachment_types: Set = set() + + class Options(_Options): + pass + + def _validate_attachments( + self, attachments: Optional[List[Attachment]] = None + ) -> None: + """Shared attachment validation logic""" + if attachments and not self.attachment_types: + raise ValueError( + "This model does not support attachments, but some were provided" + ) + for attachment in attachments or []: + attachment_type = attachment.resolve_type() + if attachment_type not in self.attachment_types: + raise ValueError( + "This model does not support attachments of type '{}', only {}".format( + attachment_type, ", ".join(self.attachment_types) + ) + ) + + def __str__(self) -> str: + return "{}: {}".format(self.__class__.__name__, self.model_id) + + def __repr__(self): + return "<{} '{}'>".format(self.__class__.__name__, self.model_id) + + +class Model(_BaseModel["Response", "Conversation"]): + def conversation(self) -> "Conversation": + return Conversation(model=self) + + @abstractmethod + def execute( + self, + prompt: Prompt, + stream: bool, + response: "Response", + conversation: Optional["Conversation"], + ) -> Iterator[str]: + """ + Execute a prompt and yield chunks of text, or yield a single big chunk. + Any additional useful information about the execution should be assigned to the response. + """ + pass + + def prompt( + self, + prompt: str, + *, + attachments: Optional[List[Attachment]] = None, + system: Optional[str] = None, + stream: bool = True, + **options + ) -> "Response": + self._validate_attachments(attachments) + return self.response( + Prompt( + prompt, + attachments=attachments, + system=system, + model=self, + options=self.Options(**options), + ), + stream=stream, + ) + + def response(self, prompt: Prompt, stream: bool = True) -> "Response": + return Response(prompt, self, stream) + + +class AsyncModel(_BaseModel["AsyncResponse", "AsyncConversation"]): + def conversation(self) -> "AsyncConversation": + return AsyncConversation(model=self) + + @abstractmethod + async def execute( + self, + prompt: Prompt, + stream: bool, + response: "AsyncResponse", + conversation: Optional["AsyncConversation"], + ) -> AsyncIterator[str]: + """ + Execute a prompt and yield chunks of text, or yield a single big chunk. + Any additional useful information about the execution should be assigned to the response. + """ + pass + + def prompt( + self, + prompt: str, + *, + attachments: Optional[List[Attachment]] = None, + system: Optional[str] = None, + stream: bool = True, + **options + ) -> "AsyncResponse": + self._validate_attachments(attachments) + return self.response( + Prompt( + prompt, + attachments=attachments, + system=system, + model=self, + options=self.Options(**options), + ), + stream=stream, + ) + + def response(self, prompt: Prompt, stream: bool = True) -> "AsyncResponse": + return AsyncResponse(prompt, self, stream) + + class Model(ABC, _get_key_mixin): model_id: str