Skip to content

Commit

Permalink
Improves completion api with new input parameters for LLMs (#23)
Browse files Browse the repository at this point in the history
* #17 Extend LLM input options

* Add new options to test script, enable togetherai testing and small changes for clarity.

* Remove cohere command (non-r) from tests. Change default test_options to be more conservative to prevent long repetition and have better compatibility for testing.
  • Loading branch information
LeonRuggiero authored Nov 5, 2024
1 parent dd7622d commit a585389
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 36 deletions.
35 changes: 30 additions & 5 deletions drivers/src/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { transformAsyncIterator } from "@llumiverse/core/async";
import { ClaudeMessagesPrompt, formatClaudePrompt } from "@llumiverse/core/formatters";
import { AwsCredentialIdentity, Provider } from "@smithy/types";
import mnemonist from "mnemonist";
import { AI21RequestPayload, AmazonRequestPayload, ClaudeRequestPayload, CohereCommandRPayload, CohereRequestPayload, LLama2RequestPayload, MistralPayload } from "./payloads.js";
import { AI21JurassicRequestPayload, AmazonRequestPayload, ClaudeRequestPayload, CohereCommandRPayload, CohereRequestPayload, LLama3RequestPayload, MistralPayload } from "./payloads.js";
import { forceUploadFile } from "./s3.js";

const { LRUCache } = mnemonist;
Expand Down Expand Up @@ -347,7 +347,8 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
prompt,
temperature: options.temperature,
max_gen_len: options.max_tokens,
} as LLama2RequestPayload
top_p: options.top_p
} as LLama3RequestPayload
} else if (contains(options.model, "claude")) {

const maxToken = () => {
Expand All @@ -365,42 +366,66 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
...(prompt as ClaudeMessagesPrompt),
temperature: options.temperature,
max_tokens: maxToken(),
top_p: options.top_p,
top_k: options.top_k,
stop_sequences: typeof options.stop_sequence === 'string' ?
[options.stop_sequence] : options.stop_sequence,
} as ClaudeRequestPayload;
} else if (contains(options.model, "ai21")) {
return {
prompt: prompt,
temperature: options.temperature,
maxTokens: options.max_tokens,
} as AI21RequestPayload;
topP: options.top_p,
stopSequences: typeof options.stop_sequence === 'string' ?
[options.stop_sequence] : options.stop_sequence,
presencePenalty: {scale: options.presence_penalty},
frequencyPenalty: {scale: options.frequency_penalty},
} as AI21JurassicRequestPayload;
} else if (contains(options.model, "command-r-plus")) {
return {
message: prompt as string,
max_tokens: options.max_tokens,
temperature: options.temperature,
p: options.top_p,
k: options.top_k,
frequency_penalty: options.frequency_penalty,
presence_penalty: options.presence_penalty,
stop_sequences: typeof options.stop_sequence === 'string' ?
[options.stop_sequence] : options.stop_sequence,
} as CohereCommandRPayload;

}
else if (contains(options.model, "cohere")) {
return {
prompt: prompt,
temperature: options.temperature,
max_tokens: options.max_tokens,
p: options.top_p,
k: options.top_k,
stop_sequences: typeof options.stop_sequence === 'string' ?
[options.stop_sequence] : options.stop_sequence,
} as CohereRequestPayload;
} else if (contains(options.model, "amazon")) {
const stop_seq: string[] = (typeof options.stop_sequence === 'string' ?
[options.stop_sequence] : options.stop_sequence) ?? [];
return {
inputText: "User: " + (prompt as string) + "\nBot:", // see https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html#model-parameters-titan-request-response
textGenerationConfig: {
temperature: options.temperature,
topP: options.top_p,
maxTokenCount: options.max_tokens,
//stopSequences: ["\n"],
stopSequences: ["\n", ...stop_seq],
},
} as AmazonRequestPayload;
} else if (contains(options.model, "mistral")) {
return {
prompt: prompt,
temperature: options.temperature,
max_tokens: options.max_tokens,
top_k: options.top_k,
top_p: options.top_p,
stop: typeof options.stop_sequence === 'string' ?
[options.stop_sequence] : options.stop_sequence,
} as MistralPayload;
} else {
throw new Error("Cannot prepare payload for unknown provider: " + options.model);
Expand Down
45 changes: 35 additions & 10 deletions drivers/src/bedrock/payloads.ts
Original file line number Diff line number Diff line change
@@ -1,48 +1,73 @@
import { ClaudeMessagesPrompt } from "@llumiverse/core/formatters";

export interface LLama2RequestPayload {
//Overall documentation:
//https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html

//Docs at: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
export interface LLama3RequestPayload {
prompt: string;
temperature: number;
temperature?: number;
top_p?: number;
max_gen_len: number;
max_gen_len?: number;
}

//Docs at: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
export interface ClaudeRequestPayload extends ClaudeMessagesPrompt {
anthropic_version: "bedrock-2023-05-31";
max_tokens: number;
prompt: string;
temperature?: number;
top_p?: number;
top_k?: number;
stop_sequences?: [string];
}
export interface AI21RequestPayload {

//Docs at: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html
export interface AI21JurassicRequestPayload {
prompt: string;
temperature: number;
maxTokens: number;
topP?: number;
stopSequences?: [string]
presencePenalty?: {
scale : number
}
frequencyPenalty?: {
scale : number
}
}

//Docs at: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
export interface CohereRequestPayload {
prompt: string;
temperature: number;
max_tokens?: number;
p?: number;
k?: number;
stop_sequences: [string],
}

//Docs at: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
export interface AmazonRequestPayload {
inputText: string;
textGenerationConfig: {
temperature: number;
topP: number;
maxTokenCount: number;
stopSequences: [string];
textGenerationConfig?: {
temperature?: number;
topP?: number;
maxTokenCount?: number;
stopSequences?: [string];
};
}

//Docs at: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-text-completion.html
export interface MistralPayload {
prompt: string;
temperature: number;
max_tokens: number;
top_p?: number;
top_k?: number;
stop?: [string]
}

//Docs at: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
export interface CohereCommandRPayload {

message: string,
Expand Down
12 changes: 11 additions & 1 deletion drivers/src/groq/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, OpenAITextMess
messages: messages,
max_tokens: options.max_tokens,
temperature: options.temperature,
top_p: options.top_p,
//top_logprobs: options.top_logprobs, //Logprobs output currently not supported
//logprobs: options.top_logprobs ? true : false,
presence_penalty: options.presence_penalty,
frequency_penalty: options.frequency_penalty,
response_format: this.getResponseFormat(options),
});

Expand All @@ -92,8 +97,13 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, OpenAITextMess
messages: messages,
max_tokens: options.max_tokens,
temperature: options.temperature,
top_p: options.top_p,
//top_logprobs: options.top_logprobs, //Logprobs output currently not supported
//logprobs: options.top_logprobs ? true : false,
presence_penalty: options.presence_penalty,
frequency_penalty: options.frequency_penalty,
response_format: this.getResponseFormat(options),
stream: true
stream: true,
});

return transformAsyncIterator(res, (res) => ({
Expand Down
7 changes: 6 additions & 1 deletion drivers/src/mistral/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,11 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, Open
messages: messages,
maxTokens: options.max_tokens,
temperature: options.temperature,
topP: options.top_p,
responseFormat: this.getResponseFormat(options),
stream: true
stream: true,
stopSequences: typeof options.stop_sequence === 'string' ?
[options.stop_sequence] : options.stop_sequence,
}),
reader: 'sse'
});
Expand Down Expand Up @@ -181,6 +184,7 @@ function _makeChatCompletionRequest({
safePrompt,
toolChoice,
responseFormat,
stopSequences,
}: CompletionRequestParams) {
return {
model: model,
Expand All @@ -194,5 +198,6 @@ function _makeChatCompletionRequest({
safe_prompt: (safeMode || safePrompt) ?? undefined,
tool_choice: toolChoice ?? undefined,
response_format: responseFormat ?? undefined,
stop: stopSequences ?? undefined,
};
};
3 changes: 2 additions & 1 deletion drivers/src/mistral/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ export interface CompletionRequestParams {
safeMode?: boolean,
safePrompt?: boolean,
toolChoice?: ToolChoice,
responseFormat?: ResponseFormat
responseFormat?: ResponseFormat,
stopSequences?: string[],
}

// class MistralClient {
Expand Down
12 changes: 11 additions & 1 deletion drivers/src/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ export abstract class BaseOpenAIDriver extends AbstractDriver<
model: options.model,
messages: prompt,
temperature: options.temperature,
top_p: options.top_p,
//top_logprobs: options.top_logprobs, //Logprobs output currently not supported
//logprobs: options.top_logprobs ? true : false,
presence_penalty: options.presence_penalty,
frequency_penalty: options.frequency_penalty,
n: 1,
max_tokens: options.max_tokens,
tools: options.result_schema
Expand Down Expand Up @@ -146,13 +151,18 @@ export abstract class BaseOpenAIDriver extends AbstractDriver<
} as OpenAI.Chat.ChatCompletionTool,
]
: undefined;

//TODO: OpenAI o1 support requires max_completions_tokens
const res = await this.service.chat.completions.create({
stream: false,
model: options.model,
messages: prompt,
temperature: options.temperature,
top_p: options.top_p,
//top_logprobs: options.top_logprobs, //Logprobs output currently not supported
//logprobs: options.top_logprobs ? true : false,
presence_penalty: options.presence_penalty,
frequency_penalty: options.frequency_penalty,
n: 1,
max_tokens: options.max_tokens,
tools: functions,
Expand Down
23 changes: 21 additions & 2 deletions drivers/src/togetherai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,26 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
}

async requestCompletion(prompt: string, options: ExecutionOptions): Promise<Completion<any>> {

const stop_seq = typeof options.stop_sequence == 'string' ?
[options.stop_sequence] : options.stop_sequence ?? [];

const res = await this.fetchClient.post('/v1/completions', {
payload: {
model: options.model,
prompt: prompt,
response_format: this.getResponseFormat(options),
max_tokens: options.max_tokens,
temperature: options.temperature,
top_p: options.top_p,
top_k: options.top_k,
//logprobs: options.top_logprobs, //Logprobs output currently not supported
frequency_penalty: options.frequency_penalty,
presence_penalty: options.presence_penalty,
stop: [
"</s>",
"[/INST]"
"[/INST]",
...stop_seq,
],
}
}) as TextCompletion;
Expand All @@ -59,17 +69,26 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
}

async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunk>> {
const stop_seq = typeof options.stop_sequence == 'string' ?
[options.stop_sequence] : options.stop_sequence ?? [];

const stream = await this.fetchClient.post('/v1/completions', {
payload: {
model: options.model,
prompt: prompt,
max_tokens: options.max_tokens,
temperature: options.temperature,
response_format: this.getResponseFormat(options),
top_p: options.top_p,
top_k: options.top_k,
//logprobs: options.top_logprobs, //Logprobs output currently not supported
frequency_penalty: options.frequency_penalty,
presence_penalty: options.presence_penalty,
stream: true,
stop: [
"</s>",
"[/INST]"
"[/INST]",
...stop_seq,
],
},
reader: 'sse'
Expand Down
5 changes: 5 additions & 0 deletions drivers/src/vertexai/models/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ function getGenerativeModel(driver: VertexAIDriver, options: ExecutionOptions, m
candidateCount: modelParams?.generationConfig?.candidateCount ?? 1,
temperature: options.temperature,
maxOutputTokens: options.max_tokens,
topP: options.top_p,
topK: options.top_k,
frequencyPenalty: options.frequency_penalty,
stopSequences: typeof options.stop_sequence === 'string' ?
[options.stop_sequence] : options.stop_sequence
},
});

Expand Down
Loading

0 comments on commit a585389

Please sign in to comment.