diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 61b765a2..7d2d00ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,6 +34,7 @@ repos: setuptools, setuptools-git-versioning, transformers, + jinja2, # dev dependencies pytest, diff --git a/pyproject.toml b/pyproject.toml index a78b1fc5..8310f9ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ dependencies = [ "pyyaml>=6.0.0", "rich", "transformers", + "jinja2>=3.1.6", ] [project.optional-dependencies] diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index e3f23963..8ac1e64d 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -2,10 +2,13 @@ import json import time from collections.abc import AsyncGenerator +from functools import cached_property from pathlib import Path from typing import Any, Literal, Optional, Union import httpx +import jinja2 +from jinja2.sandbox import ImmutableSandboxedEnvironment from loguru import logger from PIL import Image @@ -123,6 +126,29 @@ def __init__( self.extra_query = extra_query self.extra_body = extra_body self._async_client: Optional[httpx.AsyncClient] = None + self._request_template_str = settings.openai.request_template + + def __getstate__(self) -> object: + state = self.__dict__.copy() + # Templates are not serializable + # so we delete it before pickling + state.pop("request_template", None) + return state + + @cached_property + def request_template(self) -> jinja2.Template: + # Thanks to HuggingFace Tokenizers for this implementation + def tojson(x, ensure_ascii=False): + # We override the built-in tojson filter because Jinja's + # default filter escapes HTML characters + return json.dumps(x, ensure_ascii=ensure_ascii) + + j2_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) + + # Define custom filter functions + j2_env.filters["tojson"] = tojson + + return j2_env.from_string(self._request_template_str) @property def target(self) -> str: @@ -155,6 +181,7 @@ def info(self) -> dict[str, Any]: "project": self.project, "text_completions_path": TEXT_COMPLETIONS_PATH, "chat_completions_path": CHAT_COMPLETIONS_PATH, + "request_template": self._request_template_str, } async def check_setup(self): @@ -422,29 +449,15 @@ def _completions_payload( max_output_tokens: Optional[int], **kwargs, ) -> dict: - payload = body or {} + payload = json.loads( + self.request_template.render( + model=self.model, + output_tokens=(max_output_tokens or self.max_output_tokens), + ) + ) + payload.update(body or {}) payload.update(orig_kwargs or {}) payload.update(kwargs) - payload["model"] = self.model - payload["stream"] = True - payload["stream_options"] = { - "include_usage": True, - } - - if max_output_tokens or self.max_output_tokens: - logger.debug( - "{} adding payload args for setting output_token_count: {}", - self.__class__.__name__, - max_output_tokens or self.max_output_tokens, - ) - payload["max_tokens"] = max_output_tokens or self.max_output_tokens - payload["max_completion_tokens"] = payload["max_tokens"] - - if max_output_tokens: - # only set stop and ignore_eos if max_output_tokens set at request level - # otherwise the instance value is just the max to enforce we stay below - payload["stop"] = None - payload["ignore_eos"] = True return payload diff --git a/src/guidellm/config.py b/src/guidellm/config.py index ed7e782b..e744281e 100644 --- a/src/guidellm/config.py +++ b/src/guidellm/config.py @@ -85,6 +85,12 @@ class OpenAISettings(BaseModel): project: Optional[str] = None base_url: str = "http://localhost:8000" max_output_tokens: int = 16384 + request_template: str = ( + '{"model": "{{ model }}", {% if output_tokens %} ' + '"max_tokens": {{ output_tokens }}, "stop": null, ' + '"ignore_eos": true, {% endif %} ' + '"stream": true, "stream_options": {"include_usage": true}}' + ) class Settings(BaseSettings):