Skip to content

Commit

Permalink
feat (core): support stopSequences setting. (#2371)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Jul 22, 2024
1 parent 644f658 commit 2b9da0f
Show file tree
Hide file tree
Showing 21 changed files with 99 additions and 9 deletions.
13 changes: 13 additions & 0 deletions .changeset/giant-shirts-relate.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
'@ai-sdk/amazon-bedrock': patch
'@ai-sdk/google-vertex': patch
'@ai-sdk/anthropic': patch
'@ai-sdk/provider': patch
'@ai-sdk/mistral': patch
'@ai-sdk/cohere': patch
'@ai-sdk/google': patch
'@ai-sdk/openai': patch
'ai': patch
---

feat (core): support stopSequences setting.
12 changes: 10 additions & 2 deletions content/docs/03-ai-sdk-core/25-settings.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ const result = await generateText({

<Note>
Some providers do not support all common settings. If you use a setting with a
provider that does not support it, a warning will be included in the AI
function result object.
provider that does not support it, a warning will be generated. You can check
the `warnings` property in the result object to see if any warnings were
generated.
</Note>

### `maxTokens`
Expand Down Expand Up @@ -62,6 +63,13 @@ The frequency penalty affects the likelihood of the model to repeatedly use the
The value is passed through to the provider. The range depends on the provider and model.
For most providers, `0` means no penalty.

### `stopSequences`

The stop sequences to use for stopping the text generation.

If set, the model will stop generating text when one of the stop sequences is generated.
Providers may have limits on the number of stop sequences.

### `seed`

It is the seed (integer) to use for random sampling.
Expand Down
7 changes: 7 additions & 0 deletions content/docs/07-reference/ai-sdk-core/01-generate-text.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,13 @@ console.log(text);
description:
'Frequency penalty setting. It affects the likelihood of the model to repeatedly use the same words or phrases. The value is passed through to the provider. The range depends on the provider and model.',
},
{
name: 'stopSequences',
type: 'string[]',
isOptional: true,
description:
'Sequences that will stop the generation of the text. If the model generates any of these sequences, it will stop generating further text.',
},
{
name: 'seed',
type: 'number',
Expand Down
7 changes: 7 additions & 0 deletions content/docs/07-reference/ai-sdk-core/02-stream-text.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,13 @@ for await (const textPart of textStream) {
description:
'Frequency penalty setting. It affects the likelihood of the model to repeatedly use the same words or phrases. The value is passed through to the provider. The range depends on the provider and model.',
},
{
name: 'stopSequences',
type: 'string[]',
isOptional: true,
description:
'Sequences that will stop the generation of the text. If the model generates any of these sequences, it will stop generating further text.',
},
{
name: 'seed',
type: 'number',
Expand Down
7 changes: 7 additions & 0 deletions content/docs/07-reference/ai-sdk-rsc/01-stream-ui.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ A helper function to create a streamable UI from LLM providers. This function is
description:
'Frequency penalty setting. It affects the likelihood of the model to repeatedly use the same words or phrases. The value is passed through to the provider. The range depends on the provider and model.',
},
{
name: 'stopSequences',
type: 'string[]',
isOptional: true,
description:
'Sequences that will stop the generation of the text. If the model generates any of these sequences, it will stop generating further text.',
},
{
name: 'seed',
type: 'number',
Expand Down
2 changes: 2 additions & 0 deletions packages/amazon-bedrock/src/bedrock-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ export class BedrockChatLanguageModel implements LanguageModelV1 {
frequencyPenalty,
presencePenalty,
seed,
stopSequences,
headers,
}: Parameters<LanguageModelV1['doGenerate']>[0]) {
const type = mode.type;
Expand Down Expand Up @@ -101,6 +102,7 @@ export class BedrockChatLanguageModel implements LanguageModelV1 {
maxTokens,
temperature,
topP,
stopSequences,
},
messages,
};
Expand Down
2 changes: 2 additions & 0 deletions packages/anthropic/src/anthropic-messages-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
frequencyPenalty,
presencePenalty,
seed,
stopSequences,
}: Parameters<LanguageModelV1['doGenerate']>[0]) {
const type = mode.type;

Expand Down Expand Up @@ -100,6 +101,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
max_tokens: maxTokens ?? 4096, // 4096: max model output tokens
temperature,
top_p: topP,
stop_sequences: stopSequences,

// prompt:
system: messagesPrompt.system,
Expand Down
2 changes: 2 additions & 0 deletions packages/cohere/src/cohere-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ export class CohereChatLanguageModel implements LanguageModelV1 {
topP,
frequencyPenalty,
presencePenalty,
stopSequences,
seed,
}: Parameters<LanguageModelV1['doGenerate']>[0]) {
const type = mode.type;
Expand All @@ -78,6 +79,7 @@ export class CohereChatLanguageModel implements LanguageModelV1 {
temperature,
p: topP,
seed,
stop_sequences: stopSequences,

// messages:
chat_history: history,
Expand Down
2 changes: 1 addition & 1 deletion packages/core/core/generate-object/generate-object.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export async function generateObject<T>({
headers,
experimental_telemetry: telemetry,
...settings
}: CallSettings &
}: Omit<CallSettings, 'stopSequences'> &
Prompt & {
/**
The language model to use.
Expand Down
2 changes: 1 addition & 1 deletion packages/core/core/generate-object/stream-object.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export async function streamObject<T>({
headers,
onFinish,
...settings
}: CallSettings &
}: Omit<CallSettings, 'stopSequences'> &
Prompt & {
/**
The language model to use.
Expand Down
2 changes: 2 additions & 0 deletions packages/core/core/generate-text/generate-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ The value is passed through to the provider. The range depends on the provider a
@param frequencyPenalty - Frequency penalty setting.
It affects the likelihood of the model to repeatedly use the same words or phrases.
The value is passed through to the provider. The range depends on the provider and model.
@param stopSequences - Stop sequences.
If set, the model will stop generating text when one of the stop sequences is generated.
@param seed - The seed (integer) to use for random sampling.
If set and supported by the model, calls will generate deterministic results.
Expand Down
2 changes: 2 additions & 0 deletions packages/core/core/generate-text/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ The value is passed through to the provider. The range depends on the provider a
@param frequencyPenalty - Frequency penalty setting.
It affects the likelihood of the model to repeatedly use the same words or phrases.
The value is passed through to the provider. The range depends on the provider and model.
@param stopSequences - Stop sequences.
If set, the model will stop generating text when one of the stop sequences is generated.
@param seed - The seed (integer) to use for random sampling.
If set and supported by the model, calls will generate deterministic results.
Expand Down
7 changes: 7 additions & 0 deletions packages/core/core/prompt/call-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ and 1 (maximum penalty, decrease repetition). 0 means no penalty.
*/
frequencyPenalty?: number;

/**
Stop sequences.
If set, the model will stop generating text when one of the stop sequences is generated.
Providers may have limits on the number of stop sequences.
*/
stopSequences?: string[];

/**
The seed (integer) to use for random sampling. If set and supported
by the model, calls will generate deterministic results.
Expand Down
5 changes: 5 additions & 0 deletions packages/core/core/prompt/prepare-call-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export function prepareCallSettings({
topP,
presencePenalty,
frequencyPenalty,
stopSequences,
seed,
maxRetries,
}: CallSettings): CallSettings {
Expand Down Expand Up @@ -105,6 +106,10 @@ export function prepareCallSettings({
topP,
presencePenalty,
frequencyPenalty,
stopSequences:
stopSequences != null && stopSequences.length > 0
? stopSequences
: undefined,
seed,
maxRetries: maxRetries ?? 2,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ describe('doGenerate', () => {
temperature: 0.5,
maxTokens: 100,
topP: 0.9,
stopSequences: ['abc', 'def'],
});

expect(mockVertexAI.lastModelParams).toStrictEqual({
Expand All @@ -200,6 +201,7 @@ describe('doGenerate', () => {
temperature: 0.5,
topK: 0.1,
topP: 0.9,
stopSequences: ['abc', 'def'],
},
tools: undefined,
safetySettings: undefined,
Expand Down
10 changes: 6 additions & 4 deletions packages/google-vertex/src/google-vertex-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ export class GoogleVertexLanguageModel implements LanguageModelV1 {
}

private async getArgs({
prompt,
mode,
frequencyPenalty,
presencePenalty,
seed,
prompt,
maxTokens,
temperature,
topP,
frequencyPenalty,
presencePenalty,
stopSequences,
seed,
headers,
}: LanguageModelV1CallOptions) {
const warnings: LanguageModelV1CallWarning[] = [];
Expand Down Expand Up @@ -97,6 +98,7 @@ export class GoogleVertexLanguageModel implements LanguageModelV1 {
maxOutputTokens: maxTokens,
temperature,
topP,
stopSequences,
};

const type = mode.type;
Expand Down
2 changes: 2 additions & 0 deletions packages/google/src/google-generative-ai-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 {
topP,
frequencyPenalty,
presencePenalty,
stopSequences,
seed,
}: Parameters<LanguageModelV1['doGenerate']>[0]) {
const type = mode.type;
Expand Down Expand Up @@ -96,6 +97,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 {
maxOutputTokens: maxTokens,
temperature,
topP,
stopSequences,
};

const { contents, systemInstruction } =
Expand Down
8 changes: 8 additions & 0 deletions packages/mistral/src/mistral-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ export class MistralChatLanguageModel implements LanguageModelV1 {
topP,
frequencyPenalty,
presencePenalty,
stopSequences,
seed,
}: Parameters<LanguageModelV1['doGenerate']>[0]) {
const type = mode.type;
Expand All @@ -80,6 +81,13 @@ export class MistralChatLanguageModel implements LanguageModelV1 {
});
}

if (stopSequences != null) {
warnings.push({
type: 'unsupported-setting',
setting: 'stopSequences',
});
}

const baseArgs = {
// model id:
model: this.modelId,
Expand Down
2 changes: 2 additions & 0 deletions packages/openai/src/openai-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 {
topP,
frequencyPenalty,
presencePenalty,
stopSequences,
seed,
}: Parameters<LanguageModelV1['doGenerate']>[0]) {
const type = mode.type;
Expand Down Expand Up @@ -96,6 +97,7 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 {
top_p: topP,
frequency_penalty: frequencyPenalty,
presence_penalty: presencePenalty,
stop: stopSequences,
seed,

// messages:
Expand Down
5 changes: 4 additions & 1 deletion packages/openai/src/openai-completion-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,16 @@ export class OpenAICompletionLanguageModel implements LanguageModelV1 {
topP,
frequencyPenalty,
presencePenalty,
stopSequences: userStopSequences,
seed,
}: Parameters<LanguageModelV1['doGenerate']>[0]) {
const type = mode.type;

const { prompt: completionPrompt, stopSequences } =
convertToOpenAICompletionPrompt({ prompt, inputFormat });

const stop = [...(stopSequences ?? []), ...(userStopSequences ?? [])];

const baseArgs = {
// model id:
model: this.modelId,
Expand Down Expand Up @@ -102,7 +105,7 @@ export class OpenAICompletionLanguageModel implements LanguageModelV1 {
prompt: completionPrompt,

// stop sequences:
stop: stopSequences,
stop: stop.length > 0 ? stop : undefined,
};

switch (type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ It is recommended to set either `temperature` or `topP`, but not both.
*/
temperature?: number;

/**
Stop sequences.
If set, the model will stop generating text when one of the stop sequences is generated.
Providers may have limits on the number of stop sequences.
*/
stopSequences?: string[];

/**
Nucleus sampling.
Expand Down

0 comments on commit 2b9da0f

Please sign in to comment.