Skip to content

Commit

Permalink
feat (core): support topK setting (#2375)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Jul 22, 2024
1 parent f088548 commit a5b5884
Show file tree
Hide file tree
Showing 27 changed files with 196 additions and 88 deletions.
13 changes: 13 additions & 0 deletions .changeset/polite-readers-wink.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
'@ai-sdk/amazon-bedrock': patch
'@ai-sdk/google-vertex': patch
'@ai-sdk/anthropic': patch
'@ai-sdk/provider': patch
'@ai-sdk/mistral': patch
'@ai-sdk/cohere': patch
'@ai-sdk/google': patch
'@ai-sdk/openai': patch
'ai': patch
---

feat (core): support topK setting
7 changes: 7 additions & 0 deletions content/docs/03-ai-sdk-core/25-settings.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ E.g. 0.1 would mean that only tokens with the top 10% probability mass are consi

It is recommended to set either `temperature` or `topP`, but not both.

### `topK`

Only sample from the top K options for each subsequent token.

Used to remove "long tail" low probability responses.
Recommended for advanced use cases only. You usually only need to use `temperature`.

### `presencePenalty`

The presence penalty affects the likelihood of the model to repeat information that is already in the prompt.
Expand Down
6 changes: 6 additions & 0 deletions content/docs/07-reference/ai-sdk-core/01-generate-text.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ console.log(text);
description:
'Nucleus sampling. The value is passed through to the provider. The range depends on the provider and model. It is recommended to set either `temperature` or `topP`, but not both.',
},
{
name: 'topK',
type: 'number',
isOptional: true,
description: `Only sample from the top K options for each subsequent token. Used to remove "long tail" low probability responses. Recommended for advanced use cases only. You usually only need to use temperature.`,
},
{
name: 'presencePenalty',
type: 'number',
Expand Down
6 changes: 6 additions & 0 deletions content/docs/07-reference/ai-sdk-core/02-stream-text.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,12 @@ for await (const textPart of textStream) {
description:
'Nucleus sampling. The value is passed through to the provider. The range depends on the provider and model. It is recommended to set either `temperature` or `topP`, but not both.',
},
{
name: 'topK',
type: 'number',
isOptional: true,
description: `Only sample from the top K options for each subsequent token. Used to remove "long tail" low probability responses. Recommended for advanced use cases only. You usually only need to use temperature.`,
},
{
name: 'presencePenalty',
type: 'number',
Expand Down
6 changes: 6 additions & 0 deletions content/docs/07-reference/ai-sdk-core/03-generate-object.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,12 @@ console.log(JSON.stringify(object, null, 2));
description:
'Nucleus sampling. The value is passed through to the provider. The range depends on the provider and model. It is recommended to set either `temperature` or `topP`, but not both.',
},
{
name: 'topK',
type: 'number',
isOptional: true,
description: `Only sample from the top K options for each subsequent token. Used to remove "long tail" low probability responses. Recommended for advanced use cases only. You usually only need to use temperature.`,
},
{
name: 'presencePenalty',
type: 'number',
Expand Down
6 changes: 6 additions & 0 deletions content/docs/07-reference/ai-sdk-core/04-stream-object.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@ for await (const partialObject of partialObjectStream) {
description:
'Nucleus sampling. The value is passed through to the provider. The range depends on the provider and model. It is recommended to set either `temperature` or `topP`, but not both.',
},
{
name: 'topK',
type: 'number',
isOptional: true,
description: `Only sample from the top K options for each subsequent token. Used to remove "long tail" low probability responses. Recommended for advanced use cases only. You usually only need to use temperature.`,
},
{
name: 'presencePenalty',
type: 'number',
Expand Down
6 changes: 6 additions & 0 deletions content/docs/07-reference/ai-sdk-rsc/01-stream-ui.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ A helper function to create a streamable UI from LLM providers. This function is
description:
'Nucleus sampling. The value is passed through to the provider. The range depends on the provider and model. It is recommended to set either `temperature` or `topP`, but not both.',
},
{
name: 'topK',
type: 'number',
isOptional: true,
description: `Only sample from the top K options for each subsequent token. Used to remove "long tail" low probability responses. Recommended for advanced use cases only. You usually only need to use temperature.`,
},
{
name: 'presencePenalty',
type: 'number',
Expand Down
18 changes: 0 additions & 18 deletions content/providers/01-ai-sdk-providers/05-anthropic.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,6 @@ Some models have multi-modal capabilities.
const model = anthropic('claude-3-haiku-20240307');
```

Anthropic Messages` models support also some model specific settings that are not part of the [standard call settings](/docs/ai-sdk-core/settings).
You can pass them as an options argument:

```ts
const model = anthropic('claude-3-haiku-20240307', {
topK: 0.2,
});
```

The following optional settings are available for Anthropic models:

- **topK** _number_

Only sample from the top K options for each subsequent token.

Used to remove "long tail" low probability responses.
Recommended for advanced use cases only. You usually only need to use temperature.

### Example

You can use Anthropic language models to generate text with the `generateText` function:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,14 @@ You can pass them as an options argument:

```ts
const model = google('models/gemini-1.5-pro-latest', {
topK: 0.2,
safetySettings: [
{ category: 'HARM_CATEGORY_UNSPECIFIED', threshold: 'BLOCK_LOW_AND_ABOVE' },
],
});
```

The following optional settings are available for Google Generative AI models:

- **topK** _number_

Optional. The maximum number of tokens to consider when sampling.

Models use nucleus sampling or combined Top-k and nucleus sampling.
Top-k sampling considers the set of topK most probable tokens.
Models running with nucleus sampling don't allow topK setting.

- **cachedContent** _string_

Optional. The name of the cached content used as context to serve the prediction.
Expand Down
12 changes: 3 additions & 9 deletions content/providers/01-ai-sdk-providers/11-google-vertex.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,14 @@ You can pass them as an options argument:

```ts
const model = vertex('gemini-1.5-pro', {
topK: 0.2,
safetySettings: [
{ category: 'HARM_CATEGORY_UNSPECIFIED', threshold: 'BLOCK_LOW_AND_ABOVE' },
],
});
```

The following optional settings are available for Google Vertex models:

- **topK** _number_

Optional. The maximum number of tokens to consider when sampling.

Models use nucleus sampling or combined Top-k and nucleus sampling.
Top-k sampling considers the set of topK most probable tokens.
Models running with nucleus sampling don't allow topK setting.

- **safetySettings** _Array\<\{ category: string; threshold: string \}\>_

Optional. Safety settings for the model.
Expand Down
8 changes: 8 additions & 0 deletions packages/amazon-bedrock/src/bedrock-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ export class BedrockChatLanguageModel implements LanguageModelV1 {
maxTokens,
temperature,
topP,
topK,
frequencyPenalty,
presencePenalty,
seed,
Expand Down Expand Up @@ -92,6 +93,13 @@ export class BedrockChatLanguageModel implements LanguageModelV1 {
});
}

if (topK != null) {
warnings.push({
type: 'unsupported-setting',
setting: 'topK',
});
}

const { system, messages } = await convertToBedrockChatMessages({ prompt });

const baseArgs: ConverseCommandInput = {
Expand Down
3 changes: 2 additions & 1 deletion packages/anthropic/src/anthropic-messages-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
maxTokens,
temperature,
topP,
topK,
frequencyPenalty,
presencePenalty,
seed,
Expand Down Expand Up @@ -95,7 +96,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
model: this.modelId,

// model specific settings:
top_k: this.settings.topK,
top_k: topK ?? this.settings.topK,

// standardized settings:
max_tokens: maxTokens ?? 4096, // 4096: max model output tokens
Expand Down
2 changes: 2 additions & 0 deletions packages/anthropic/src/anthropic-messages-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Only sample from the top K options for each subsequent token.
Used to remove "long tail" low probability responses.
Recommended for advanced use cases only. You usually only need to use temperature.
@deprecated use the topK setting on the request instead.
*/
topK?: number;
}
2 changes: 2 additions & 0 deletions packages/cohere/src/cohere-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ export class CohereChatLanguageModel implements LanguageModelV1 {
maxTokens,
temperature,
topP,
topK,
frequencyPenalty,
presencePenalty,
stopSequences,
Expand All @@ -78,6 +79,7 @@ export class CohereChatLanguageModel implements LanguageModelV1 {
max_tokens: maxTokens,
temperature,
p: topP,
k: topK,
seed,
stop_sequences: stopSequences,

Expand Down
11 changes: 7 additions & 4 deletions packages/core/core/generate-object/generate-object.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ This function does not stream the output. If you want to stream the output, use
@param messages - A list of messages. You can either use `prompt` or `messages` but not both.
@param maxTokens - Maximum number of tokens to generate.
@param temperature - Temperature setting.
@param temperature - Temperature setting.
The value is passed through to the provider. The range depends on the provider and model.
It is recommended to set either `temperature` or `topP`, but not both.
@param topP - Nucleus sampling.
The value is passed through to the provider. The range depends on the provider and model.
It is recommended to set either `temperature` or `topP`, but not both.
@param presencePenalty - Presence penalty setting.
@param topK - Only sample from the top K options for each subsequent token.
Used to remove "long tail" low probability responses.
Recommended for advanced use cases only. You usually only need to use temperature.
@param presencePenalty - Presence penalty setting.
It affects the likelihood of the model to repeat information that is already in the prompt.
The value is passed through to the provider. The range depends on the provider and model.
@param frequencyPenalty - Frequency penalty setting.
Expand All @@ -54,7 +57,7 @@ If set and supported by the model, calls will generate deterministic results.
@param abortSignal - An optional abort signal that can be used to cancel the call.
@param headers - Additional HTTP headers to be sent with the request. Only applicable for HTTP-based providers.
@returns
@returns
A result object that contains the generated object, the finish reason, the token usage, and additional information.
*/
export async function generateObject<T>({
Expand Down Expand Up @@ -336,7 +339,7 @@ Response headers.
};

/**
Logprobs for the completion.
Logprobs for the completion.
`undefined` if the mode does not support logprobs or if was not enabled
*/
readonly logprobs: LogProbs | undefined;
Expand Down
15 changes: 9 additions & 6 deletions packages/core/core/generate-object/stream-object.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,16 @@ This function streams the output. If you do not want to stream the output, use `
@param messages - A list of messages. You can either use `prompt` or `messages` but not both.
@param maxTokens - Maximum number of tokens to generate.
@param temperature - Temperature setting.
@param temperature - Temperature setting.
The value is passed through to the provider. The range depends on the provider and model.
It is recommended to set either `temperature` or `topP`, but not both.
@param topP - Nucleus sampling.
The value is passed through to the provider. The range depends on the provider and model.
It is recommended to set either `temperature` or `topP`, but not both.
@param presencePenalty - Presence penalty setting.
@param topK - Only sample from the top K options for each subsequent token.
Used to remove "long tail" low probability responses.
Recommended for advanced use cases only. You usually only need to use temperature.
@param presencePenalty - Presence penalty setting.
It affects the likelihood of the model to repeat information that is already in the prompt.
The value is passed through to the provider. The range depends on the provider and model.
@param frequencyPenalty - Frequency penalty setting.
Expand Down Expand Up @@ -452,8 +455,8 @@ The generated object (typed according to the schema). Resolved when the response

/**
Stream of partial objects. It gets more complete as the stream progresses.
Note that the partial object is not validated.
Note that the partial object is not validated.
If you want to be certain that the actual content matches your schema, you need to implement your own validation for partial results.
*/
get partialObjectStream(): AsyncIterableStream<DeepPartial<T>> {
Expand Down Expand Up @@ -482,7 +485,7 @@ If you want to be certain that the actual content matches your schema, you need
}

/**
Text stream of the JSON representation of the generated object. It contains text chunks.
Text stream of the JSON representation of the generated object. It contains text chunks.
When the stream is finished, the object is valid JSON that can be parsed.
*/
get textStream(): AsyncIterableStream<string> {
Expand Down Expand Up @@ -524,7 +527,7 @@ Only errors that stop the stream, such as network errors, are thrown.

/**
Writes text delta output to a Node.js response-like object.
It sets a `Content-Type` header to `text/plain; charset=utf-8` and
It sets a `Content-Type` header to `text/plain; charset=utf-8` and
writes each text delta as a separate chunk.
@param response A Node.js response-like object (ServerResponse).
Expand Down
11 changes: 7 additions & 4 deletions packages/core/core/generate-text/generate-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ It is recommended to set either `temperature` or `topP`, but not both.
@param topP - Nucleus sampling.
The value is passed through to the provider. The range depends on the provider and model.
It is recommended to set either `temperature` or `topP`, but not both.
@param presencePenalty - Presence penalty setting.
@param topK - Only sample from the top K options for each subsequent token.
Used to remove "long tail" low probability responses.
Recommended for advanced use cases only. You usually only need to use temperature.
@param presencePenalty - Presence penalty setting.
It affects the likelihood of the model to repeat information that is already in the prompt.
The value is passed through to the provider. The range depends on the provider and model.
@param frequencyPenalty - Frequency penalty setting.
Expand Down Expand Up @@ -109,8 +112,8 @@ The tool choice strategy. Default: 'auto'.
/**
Maximal number of automatic roundtrips for tool calls.
An automatic tool call roundtrip is another LLM call with the
tool call results when all tool calls of the last assistant
An automatic tool call roundtrip is another LLM call with the
tool call results when all tool calls of the last assistant
message have results.
A maximum number is required to prevent infinite loops in the
Expand Down Expand Up @@ -387,7 +390,7 @@ Warnings from the model provider (e.g. unsupported settings)

/**
The response messages that were generated during the call. It consists of an assistant message,
potentially containing tool calls.
potentially containing tool calls.
When there are tool results, there is an additional tool message with the tool results that are available.
If there are tools that do not have execute functions, they are not included in the tool results and
need to be added separately.
Expand Down
Loading

0 comments on commit a5b5884

Please sign in to comment.