Skip to content

Commit

Permalink
remove to_gen_params functions
Browse files Browse the repository at this point in the history
  • Loading branch information
SecretiveShell committed Sep 21, 2024
1 parent c5a06db commit 035269c
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 149 deletions.
141 changes: 50 additions & 91 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@
import aiofiles
import json
import pathlib
from pydantic_core import ValidationError
from ruamel.yaml import YAML
from copy import deepcopy
from loguru import logger
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
from pydantic import (
AliasChoices,
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
)
from typing import Dict, List, Optional, Union

from common.utils import unwrap, prune_dict
Expand Down Expand Up @@ -178,6 +186,8 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", [])
)

mirostat: Optional[bool] = False

mirostat_mode: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
)
Expand Down Expand Up @@ -265,104 +275,53 @@ class BaseSamplerRequest(BaseModel):

model_config = ConfigDict(validate_assignment=True)

# TODO: Return back to adaptable class-based validation But that's just too much
# abstraction compared to simple if statements at the moment
@model_validator(mode="after")
def validate_params(self):
"""
Validates sampler parameters to be within sane ranges.
"""

if self.min_temp > self.max_temp:
raise ValueError("min temp cannot be more then max temp")

if self.min_tokens > self.max_tokens:
raise ValueError("min tokens cannot be more then max tokens")

def to_gen_params(self, **kwargs):
"""Converts samplers to internal generation params"""

# Add forced overrides if present
# FIXME: find a better way to register this
# Maybe make a function to assign values to the
# model if they do not exist post creation
apply_forced_sampler_overrides(self)

# Convert stop to an array of strings
if self.stop and isinstance(self.stop, str):
self.stop = [self.stop]

# Convert banned_strings to an array of strings
if self.banned_strings and isinstance(self.banned_strings, str):
self.banned_strings = [self.banned_strings]

# Convert string banned and allowed tokens to an integer list
if self.banned_tokens and isinstance(self.banned_tokens, str):
self.banned_tokens = [
int(x) for x in self.banned_tokens.split(",") if x.isdigit()
]

if self.allowed_tokens and isinstance(self.allowed_tokens, str):
self.allowed_tokens = [
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
]

# Convert sequence breakers into an array of strings
# NOTE: This sampler sucks to parse.
if self.dry_sequence_breakers and isinstance(self.dry_sequence_breakers, str):
if not self.dry_sequence_breakers.startswith("["):
self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]"

try:
self.dry_sequence_breakers = json.loads(self.dry_sequence_breakers)
except Exception:
self.dry_sequence_breakers = []

gen_params = {
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"generate_window": self.generate_window,
"stop": self.stop,
"banned_strings": self.banned_strings,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"skip_special_tokens": self.skip_special_tokens,
"token_healing": self.token_healing,
"logit_bias": self.logit_bias,
"banned_tokens": self.banned_tokens,
"allowed_tokens": self.allowed_tokens,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"min_temp": self.min_temp,
"max_temp": self.max_temp,
"temp_exponent": self.temp_exponent,
"smoothing_factor": self.smoothing_factor,
"top_k": self.top_k,
"top_p": self.top_p,
"top_a": self.top_a,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"skew": self.skew,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"repetition_penalty": self.repetition_penalty,
"penalty_range": self.penalty_range,
"dry_multiplier": self.dry_multiplier,
"dry_base": self.dry_base,
"dry_allowed_length": self.dry_allowed_length,
"dry_sequence_breakers": self.dry_sequence_breakers,
"dry_range": self.dry_range,
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta,
"cfg_scale": self.cfg_scale,
"negative_prompt": self.negative_prompt,
"json_schema": self.json_schema,
"regex_pattern": self.regex_pattern,
"grammar_string": self.grammar_string,
"speculative_ngram": self.speculative_ngram,
}

return {**gen_params, **kwargs}
if self.min_temp and self.max_temp and self.min_temp > self.max_temp:
raise ValidationError("min temp cannot be more then max temp")

if self.min_tokens and self.max_tokens and self.min_tokens > self.max_tokens:
raise ValidationError("min tokens cannot be more then max tokens")

return self

@field_validator("stop", "banned_strings", mode="before")
def convert_str_to_list(cls, v):
"""Convert single string to list of strings."""
if isinstance(v, str):
return [v]
return v

@field_validator("banned_tokens", "allowed_tokens", mode="before")
def convert_tokens_to_int_list(cls, v):
"""Convert comma-separated string of numbers to a list of integers."""
if isinstance(v, str):
return [int(x) for x in v.split(",") if x.isdigit()]
return v

@field_validator("dry_sequence_breakers", mode="before")
def parse_json_if_needed(cls, v):
"""Parse dry_sequence_breakers string to JSON array."""
if isinstance(v, str) and not v.startswith("["):
v = f"[{v}]"
try:
return json.loads(v) if isinstance(v, str) else v
except Exception:
return [] # Return empty list if parsing fails

@field_validator("mirostat", mode="before")
def convert_mirostat(cls, v, values):
"""Mirostat is enabled if mirostat_mode == 2."""
return values.get("mirostat_mode") == 2


class SamplerOverridesContainer(BaseModel):
Expand Down
42 changes: 25 additions & 17 deletions endpoints/Kobold/types/generation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial
from typing import List, Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from common import model
from common.sampling import BaseSamplerRequest, get_default_sampler_value
from common.utils import flat_map, unwrap
Expand All @@ -11,29 +12,36 @@ class GenerateRequest(BaseSamplerRequest):
genkey: Optional[str] = None
use_default_badwordsids: Optional[bool] = False
dynatemp_range: Optional[float] = Field(
default_factory=get_default_sampler_value("dynatemp_range")
default_factory=partial(get_default_sampler_value, "dynatemp_range")
)

def to_gen_params(self, **kwargs):
# Exl2 uses -1 to include all tokens in repetition penalty
if self.penalty_range == 0:
self.penalty_range = -1

if self.dynatemp_range:
self.min_temp = self.temperature - self.dynatemp_range
self.max_temp = self.temperature + self.dynatemp_range

# Move badwordsids into banned tokens for generation
if self.use_default_badwordsids:
@field_validator("penalty_range")
@classmethod
def validate_penalty_range(cls, v):
return -1 if v == 0 else v

@field_validator("min_temp", "max_temp")
@classmethod
def validate_temp_range(cls, v, info):
if "dynatemp_range" in info.data and info.data["dynatemp_range"] is not None:
temperature = info.data.get("temperature", 0) # Assume 0 if not present
if info.field_name == "min_temp":
return temperature - info.data["dynatemp_range"]
elif info.field_name == "max_temp":
return temperature + info.data["dynatemp_range"]
return v

@field_validator("banned_tokens")
@classmethod
def validate_banned_tokens(cls, v, info):
if info.data.get("use_default_badwordsids"):
bad_words_ids = unwrap(
model.container.generation_config.bad_words_ids,
model.container.hf_config.get_badwordsids(),
)

if bad_words_ids:
self.banned_tokens += flat_map(bad_words_ids)

return super().to_gen_params(**kwargs)
return v + flat_map(bad_words_ids)
return v


class GenerateResponseResult(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion endpoints/Kobold/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def _stream_collector(data: GenerateRequest, request: Request):
logger.info(f"Received Kobold generation request {data.genkey}")

generator = model.container.generate_gen(
data.prompt, data.genkey, abort_event, **data.to_gen_params()
request_id=data.genkey, abort_event=abort_event, **data.model_dump()
)
async for generation in generator:
if disconnect_task.done():
Expand Down
15 changes: 3 additions & 12 deletions endpoints/OAI/types/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Common types for OAI."""

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from typing import Optional

from common.sampling import BaseSamplerRequest, get_default_sampler_value
Expand Down Expand Up @@ -54,17 +54,8 @@ class CommonCompletionRequest(BaseSamplerRequest):
description="Not parsed. Only used for OAI compliance.", default=None
)

@model_validator(mode="after")
def validate_params(self):
# Temperature
if self.n < 1:
raise ValueError(f"n must be greater than or equal to 1. Got {self.n}")

return super().validate_params()

def to_gen_params(self):
extra_gen_params = {
"stream": self.stream,
"logprobs": self.logprobs,
}

return super().to_gen_params(**extra_gen_params)
return self
16 changes: 4 additions & 12 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,21 +381,13 @@ async def generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
):
gen_tasks: List[asyncio.Task] = []
gen_params = data.to_gen_params()

try:
for n in range(0, data.n):
# Deepcopy gen params above the first index
# to ensure nested structures aren't shared
if n > 0:
task_gen_params = deepcopy(gen_params)
else:
task_gen_params = gen_params

for _ in range(0, data.n):
gen_tasks.append(
asyncio.create_task(
model.container.generate(
prompt, request.state.id, **task_gen_params
prompt, request.state.id, **data.model_dump()
)
)
)
Expand Down Expand Up @@ -433,9 +425,9 @@ async def generate_tool_calls(

# Copy to make sure the parent JSON schema doesn't get modified
# FIXME: May not be necessary depending on how the codebase evolves
tool_data = deepcopy(data)
tool_data = data.model_copy(deep=True)
tool_data.json_schema = tool_data.tool_call_schema
gen_params = tool_data.to_gen_params()
gen_params = tool_data.model_dump()

for idx, gen in enumerate(generations):
if gen["stop_str"] in tool_data.tool_call_start:
Expand Down
33 changes: 17 additions & 16 deletions endpoints/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,28 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None):
)

api_servers = config.network.api_servers
api_servers = (
api_servers
if api_servers
else [
"oai",
]
)

# Map for API id to server router
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}

# Include the OAI api by default
if api_servers:
for server in api_servers:
selected_server = router_mapping.get(server.lower())

if selected_server:
app.include_router(selected_server.setup())

logger.info(f"Starting {selected_server.api_name} API")
for path, url in selected_server.urls.items():
formatted_url = url.format(host=host, port=port)
logger.info(f"{path}: {formatted_url}")
else:
app.include_router(OAIRouter.setup())
for path, url in OAIRouter.urls.items():
formatted_url = url.format(host=host, port=port)
logger.info(f"{path}: {formatted_url}")
for server in api_servers:
selected_server = router_mapping.get(server.lower())

if selected_server:
app.include_router(selected_server.setup())

logger.info(f"Starting {selected_server.api_name} API")
for path, url in selected_server.urls.items():
formatted_url = url.format(host=host, port=port)
logger.info(f"{path}: {formatted_url}")

# Include core API request paths
app.include_router(CoreRouter)
Expand Down

0 comments on commit 035269c

Please sign in to comment.