Skip to content

Commit

Permalink
Added support for min_p
Browse files Browse the repository at this point in the history
My small contribution to this great project.

Ref: ggerganov/llama.cpp#3841

Closes: abetlen#911
  • Loading branch information
tk-master committed Nov 16, 2023
1 parent 96a3776 commit 1b1a918
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
29 changes: 27 additions & 2 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,8 @@ def sample(
self,
top_k: int = 40,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
temp: float = 0.80,
repeat_penalty: float = 1.1,
frequency_penalty: float = 0.0,
Expand Down Expand Up @@ -1108,7 +1110,9 @@ def sample(
grammar=grammar,
)

if temp == 0.0:
if temp < 0.0:
id = self._ctx.sample_softmax(candidates=self._candidates)
elif temp == 0.0:
id = self._ctx.sample_token_greedy(candidates=self._candidates)
elif mirostat_mode == 1:
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
Expand All @@ -1130,8 +1134,9 @@ def sample(
else:
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
self._ctx.sample_typical(candidates=self._candidates, p=1.0, min_keep=1)
self._ctx.sample_typical(candidates=self._candidates, p=typical_p, min_keep=1)
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token(candidates=self._candidates)
if grammar is not None:
Expand All @@ -1143,6 +1148,8 @@ def generate(
tokens: Sequence[int],
top_k: int = 40,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
temp: float = 0.80,
repeat_penalty: float = 1.1,
reset: bool = True,
Expand Down Expand Up @@ -1200,6 +1207,8 @@ def generate(
token = self.sample(
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temp,
repeat_penalty=repeat_penalty,
frequency_penalty=frequency_penalty,
Expand Down Expand Up @@ -1298,6 +1307,8 @@ def _create_completion(
max_tokens: Optional[int] = 16,
temperature: float = 0.8,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[Union[str, List[str]]] = [],
Expand Down Expand Up @@ -1396,6 +1407,8 @@ def _create_completion(
prompt_tokens,
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temperature,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
Expand Down Expand Up @@ -1764,6 +1777,8 @@ def create_completion(
max_tokens: Optional[int] = 16,
temperature: float = 0.8,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[Union[str, List[str]]] = [],
Expand Down Expand Up @@ -1810,6 +1825,8 @@ def create_completion(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
logprobs=logprobs,
echo=echo,
stop=stop,
Expand Down Expand Up @@ -1841,6 +1858,8 @@ def __call__(
max_tokens: int = 128,
temperature: float = 0.8,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[Union[str, List[str]]] = [],
Expand Down Expand Up @@ -1887,6 +1906,8 @@ def __call__(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
logprobs=logprobs,
echo=echo,
stop=stop,
Expand Down Expand Up @@ -1916,6 +1937,8 @@ def create_chat_completion(
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
Expand Down Expand Up @@ -1962,6 +1985,8 @@ def create_chat_completion(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop,
seed=seed,
Expand Down
16 changes: 16 additions & 0 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __call__(
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
Expand Down Expand Up @@ -287,6 +289,8 @@ def basic_create_chat_completion(
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
Expand Down Expand Up @@ -330,6 +334,8 @@ def basic_create_chat_completion(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop,
seed=seed,
Expand Down Expand Up @@ -579,6 +585,8 @@ def functionary_chat_handler(
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
Expand Down Expand Up @@ -761,6 +769,8 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=["user:", "</s>"],
max_tokens=max_tokens,
Expand Down Expand Up @@ -831,6 +841,8 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
Expand Down Expand Up @@ -929,6 +941,8 @@ def __call__(
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[
Expand Down Expand Up @@ -1045,6 +1059,8 @@ def __call__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop,
max_tokens=max_tokens,
Expand Down
10 changes: 10 additions & 0 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,14 @@ async def get_event_publisher(
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
)

min_p_field = Field(
default=0.05,
ge=0.0,
le=1.0,
description="Sets a minimum base probability threshold for token selection.\n\n"
+ "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.",
)

stop_field = Field(
default=None,
description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
Expand Down Expand Up @@ -593,6 +601,7 @@ class CreateCompletionRequest(BaseModel):
max_tokens: int = max_tokens_field
temperature: float = temperature_field
top_p: float = top_p_field
min_p: float = min_p_field
echo: bool = Field(
default=False,
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
Expand Down Expand Up @@ -788,6 +797,7 @@ class CreateChatCompletionRequest(BaseModel):
)
temperature: float = temperature_field
top_p: float = top_p_field
min_p: float = min_p_field
stop: Optional[List[str]] = stop_field
stream: bool = stream_field
presence_penalty: Optional[float] = presence_penalty_field
Expand Down

0 comments on commit 1b1a918

Please sign in to comment.