Skip to content

WIP: Configurable Jinja2 Template for OpenAI Backend Request Format #194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ repos:
setuptools,
setuptools-git-versioning,
transformers,
jinja2,

# dev dependencies
pytest,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ dependencies = [
"pyyaml>=6.0.0",
"rich",
"transformers",
"jinja2>=3.1.6",
]

[project.optional-dependencies]
Expand Down
55 changes: 34 additions & 21 deletions src/guidellm/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import json
import time
from collections.abc import AsyncGenerator
from functools import cached_property
from pathlib import Path
from typing import Any, Literal, Optional, Union

import httpx
import jinja2
from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
from PIL import Image

Expand Down Expand Up @@ -123,6 +126,29 @@ def __init__(
self.extra_query = extra_query
self.extra_body = extra_body
self._async_client: Optional[httpx.AsyncClient] = None
self._request_template_str = settings.openai.request_template

def __getstate__(self) -> object:
state = self.__dict__.copy()
# Templates are not serializable
# so we delete it before pickling
state.pop("request_template", None)
return state

@cached_property
def request_template(self) -> jinja2.Template:
# Thanks to HuggingFace Tokenizers for this implementation
def tojson(x, ensure_ascii=False):
# We override the built-in tojson filter because Jinja's
# default filter escapes HTML characters
return json.dumps(x, ensure_ascii=ensure_ascii)

j2_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)

# Define custom filter functions
j2_env.filters["tojson"] = tojson

return j2_env.from_string(self._request_template_str)

@property
def target(self) -> str:
Expand Down Expand Up @@ -155,6 +181,7 @@ def info(self) -> dict[str, Any]:
"project": self.project,
"text_completions_path": TEXT_COMPLETIONS_PATH,
"chat_completions_path": CHAT_COMPLETIONS_PATH,
"request_template": self._request_template_str,
}

async def check_setup(self):
Expand Down Expand Up @@ -422,29 +449,15 @@ def _completions_payload(
max_output_tokens: Optional[int],
**kwargs,
) -> dict:
payload = body or {}
payload = json.loads(
self.request_template.render(
model=self.model,
output_tokens=(max_output_tokens or self.max_output_tokens),
)
)
payload.update(body or {})
payload.update(orig_kwargs or {})
payload.update(kwargs)
payload["model"] = self.model
payload["stream"] = True
payload["stream_options"] = {
"include_usage": True,
}

if max_output_tokens or self.max_output_tokens:
logger.debug(
"{} adding payload args for setting output_token_count: {}",
self.__class__.__name__,
max_output_tokens or self.max_output_tokens,
)
payload["max_tokens"] = max_output_tokens or self.max_output_tokens
payload["max_completion_tokens"] = payload["max_tokens"]

if max_output_tokens:
# only set stop and ignore_eos if max_output_tokens set at request level
# otherwise the instance value is just the max to enforce we stay below
payload["stop"] = None
payload["ignore_eos"] = True

return payload

Expand Down
6 changes: 6 additions & 0 deletions src/guidellm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ class OpenAISettings(BaseModel):
project: Optional[str] = None
base_url: str = "http://localhost:8000"
max_output_tokens: int = 16384
request_template: str = (
'{"model": "{{ model }}", {% if output_tokens %} '
'"max_tokens": {{ output_tokens }}, "stop": null, '
'"ignore_eos": true, {% endif %} '
'"stream": true, "stream_options": {"include_usage": true}}'
)


class Settings(BaseSettings):
Expand Down