Skip to content

Commit

Permalink
Sampling: Format and space out
Browse files Browse the repository at this point in the history
Make the code more readable.

Signed-off-by: kingbri <bdashore3@proton.me>
  • Loading branch information
bdashore3 committed Oct 25, 2024
1 parent 0936d1a commit f389178
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ def validate_params(self):

@field_validator("top_k", mode="before")
def convert_top_k(cls, v):
"""Fixes instance if Top-K is -1."""

if v == -1:
logger.warning("Provided a top-k value of -1. Converting to 0 instead.")
return 0
Expand All @@ -313,20 +315,25 @@ def convert_top_k(cls, v):
@field_validator("stop", "banned_strings", mode="before")
def convert_str_to_list(cls, v):
"""Convert single string to list of strings."""

if isinstance(v, str):
return [v]

return v

@field_validator("banned_tokens", "allowed_tokens", mode="before")
def convert_tokens_to_int_list(cls, v):
"""Convert comma-separated string of numbers to a list of integers."""

if isinstance(v, str):
return [int(x) for x in v.split(",") if x.isdigit()]

return v

@field_validator("dry_sequence_breakers", mode="before")
def parse_json_if_needed(cls, v):
"""Parse dry_sequence_breakers string to JSON array."""

if isinstance(v, str) and not v.startswith("["):
v = f"[{v}]"
try:
Expand All @@ -337,6 +344,7 @@ def parse_json_if_needed(cls, v):
@field_validator("mirostat", mode="before")
def convert_mirostat(cls, v, values):
"""Mirostat is enabled if mirostat_mode == 2."""

return values.get("mirostat_mode") == 2


Expand Down

0 comments on commit f389178

Please sign in to comment.