Skip to content

Commit

Permalink
Merge branch 'main' into model-config
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelDoyle authored Jun 4, 2024
2 parents 59b8f13 + ee7a49c commit fdb4274
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 18 deletions.
11 changes: 4 additions & 7 deletions plugins/anthropic/src/claude.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ import {
import Anthropic from '@anthropic-ai/sdk';
import z from 'zod';

const API_NAME_MAP: Record<string, string> = {
'claude-3-opus': 'claude-3-opus-20240229',
'claude-3-sonnet': 'claude-3-sonnet-20240229',
'claude-3-haiku': 'claude-3-haiku-20240307',
};

const AnthropicConfigSchema = GenerationCommonConfigSchema.extend({
tool_choice: z
.union([
Expand Down Expand Up @@ -73,6 +67,7 @@ export const claude3Opus = modelRef({
},
},
configSchema: AnthropicConfigSchema,
version: 'claude-3-opus-20240229',
});

export const claude3Sonnet = modelRef({
Expand All @@ -89,6 +84,7 @@ export const claude3Sonnet = modelRef({
},
},
configSchema: AnthropicConfigSchema,
version: 'claude-3-sonnet-20240229',
});

export const claude3Haiku = modelRef({
Expand All @@ -105,6 +101,7 @@ export const claude3Haiku = modelRef({
},
},
configSchema: AnthropicConfigSchema,
version: 'claude-3-haiku-20240307',
});

export const SUPPORTED_CLAUDE_MODELS: Record<
Expand Down Expand Up @@ -404,7 +401,7 @@ export function toAnthropicRequestBody(
const model = SUPPORTED_CLAUDE_MODELS[modelName];
if (!model) throw new Error(`Unsupported model: ${modelName}`);
const { system, messages } = toAnthropicMessages(request.messages);
const mappedModelName = API_NAME_MAP[modelName] || modelName;
const mappedModelName = request.config?.version || model.version || modelName;
const body: Anthropic.Beta.Tools.MessageCreateParams = {
system,
messages,
Expand Down
2 changes: 1 addition & 1 deletion plugins/cohere/src/command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ export function toCohereRequestBody(
// Note: these types are the same in the Cohere API (not on the surface, e.g. one uses ChatRequestToolResultsItem and the other uses ChatStreamRequestToolResultsItem, but when the types are unwrapped they are exactly the same)
const model = SUPPORTED_COMMAND_MODELS[modelName];
if (!model) throw new Error(`Unsupported model: ${modelName}`);
const mappedModelName = request.config?.version || modelName;
const mappedModelName = request.config?.version || model.version || modelName;
const messageHistory = toCohereMessageHistory(request.messages);
const body: Cohere.ChatRequest = {
message: messageHistory[0].message,
Expand Down
13 changes: 5 additions & 8 deletions plugins/groq/src/groq_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ export const llama3x8b = modelRef({
},
},
configSchema: GroqConfigSchema,
version: 'llama3-8b-8192',
});

// Worst at JSON mode
Expand All @@ -85,6 +86,7 @@ export const llama3x70b = modelRef({
},
},
configSchema: GroqConfigSchema,
version: 'llama3-70b-8192',
});

// Best at JSON mode
Expand All @@ -102,6 +104,7 @@ export const mixtral8x7b = modelRef({
},
},
configSchema: GroqConfigSchema,
version: 'mixtral-8x7b-32768',
});

// Runner up at JSON mode
Expand All @@ -119,6 +122,7 @@ export const gemma7b = modelRef({
},
},
configSchema: GroqConfigSchema,
version: 'gemma-7b-it',
});

export const SUPPORTED_GROQ_MODELS = {
Expand All @@ -128,13 +132,6 @@ export const SUPPORTED_GROQ_MODELS = {
'gemma-7b': gemma7b,
};

export const DEFAULT_MODEL_VERSION = {
'llama-3-8b': 'llama3-8b-8192',
'llama-3-70b': 'llama3-70b-8192',
'mixtral-8-7b': 'mixtral-8x7b-32768',
'gemma-7b': 'gemma-7b-it',
};

/**
* Converts a Genkit message role to a Groq role.
*
Expand Down Expand Up @@ -395,7 +392,7 @@ export function toGroqRequestBody(
const body: ChatCompletionCreateParamsBase = {
messages: toGroqMessages(request.messages),
tools: request.tools?.map(toGroqTool),
model: request.config?.version || DEFAULT_MODEL_VERSION[modelName],
model: request.config?.version || model.version || modelName,
temperature: request.config?.temperature,
max_tokens: request.config?.maxOutputTokens,
top_p: request.config?.topP,
Expand Down
2 changes: 1 addition & 1 deletion plugins/mistral/src/mistral_llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ export function toMistralRequestBody(
const model = SUPPORTED_MISTRAL_MODELS[modelName];
if (!model) throw new Error(`Unsupported model: ${modelName}`);
const mistralMessages = toMistralMessages(request.messages);
const mappedModelName = request.config?.version || modelName;
const mappedModelName = request.config?.version || model.version || modelName;

let responseFormat;
if (request.output?.format !== 'json') {
Expand Down
2 changes: 1 addition & 1 deletion plugins/openai/src/gpt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ export function toOpenAiRequestBody(
request.messages,
request.config?.visualDetailLevel
);
const mappedModelName = request.config?.version || modelName;
const mappedModelName = request.config?.version || model.version || modelName;
const body = {
model: mappedModelName,
messages: openAiMessages,
Expand Down

0 comments on commit fdb4274

Please sign in to comment.