Skip to content

Commit

Permalink
Sampling: Add XTC support
Browse files Browse the repository at this point in the history
Matches with upstream.

Signed-off-by: kingbri <bdashore3@proton.me>
  • Loading branch information
bdashore3 committed Sep 24, 2024
1 parent f4791e7 commit 56ce82e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
15 changes: 15 additions & 0 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,13 @@ def check_unsupported_settings(self, **kwargs):
Meant for dev wheels!
"""

if unwrap(kwargs.get("xtc_probability"), 0.0) > 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "xtc_probability"
):
logger.warning(
"XTC is not supported by the currently " "installed ExLlamaV2 version."
)

return kwargs

async def generate_gen(
Expand Down Expand Up @@ -1003,6 +1010,14 @@ async def generate_gen(
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
gen_settings.skew = unwrap(kwargs.get("skew"), 0)

# XTC
xtc_probability = unwrap(kwargs.get("xtc_probability"), 0.0)
if xtc_probability > 0.0:
gen_settings.xtc_probability = xtc_probability

# 0.1 is the default for this value
gen_settings.xtc_threshold = unwrap(kwargs.get("xtc_threshold", 0.1))

# DynaTemp settings
max_temp = unwrap(kwargs.get("max_temp"), 1.0)
min_temp = unwrap(kwargs.get("min_temp"), 1.0)
Expand Down
10 changes: 10 additions & 0 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ class BaseSamplerRequest(BaseModel):
examples=[0.0],
)

xtc_probability: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("xtc_probability", 0.0),
)

xtc_threshold: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("xtc_threshold", 0.1)
)

frequency_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0)
)
Expand Down Expand Up @@ -366,6 +374,8 @@ def to_gen_params(self, **kwargs):
"min_p": self.min_p,
"tfs": self.tfs,
"skew": self.skew,
"xtc_probability": self.xtc_probability,
"xtc_threshold": self.xtc_threshold,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"repetition_penalty": self.repetition_penalty,
Expand Down
6 changes: 6 additions & 0 deletions sampler_overrides/sample_preset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ typical:
skew:
override: 0.0
force: false
xtc_probability:
override: 0.0
force: false
xtc_threshold:
override: 0.1
force: false

# MARK: Penalty settings
frequency_penalty:
Expand Down

0 comments on commit 56ce82e

Please sign in to comment.