Skip to content

Commit

Permalink
Merge pull request #74 from MichaelDoyle/anthropic
Browse files Browse the repository at this point in the history
Anthropic: Fix tool calls
  • Loading branch information
Dabolus authored Jun 4, 2024
2 parents 2092259 + 0e93913 commit e283a83
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 119 deletions.
8 changes: 4 additions & 4 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion plugins/anthropic/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"author": "TheFireCo",
"license": "Apache-2.0",
"dependencies": {
"@anthropic-ai/sdk": "^0.21.0",
"@anthropic-ai/sdk": "^0.22.0",
"zod": "^3.23.8"
},
"peerDependencies": {
Expand Down
231 changes: 122 additions & 109 deletions plugins/anthropic/src/claude.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
* limitations under the License.
*/

import { Message } from '@genkit-ai/ai';
import { Message as GenkitMessage } from '@genkit-ai/ai';
import {
GenerateResponseData,
GenerationCommonConfigSchema,
ModelAction,
defineModel,
Expand All @@ -30,8 +31,21 @@ import {
} from '@genkit-ai/ai/model';
import Anthropic from '@anthropic-ai/sdk';
import z from 'zod';
import {
type ImageBlockParam,
type TextBlock,
type TextBlockParam,
type MessageCreateParams,
type Tool,
type ToolResultBlockParam,
type ContentBlock,
type Message,
type MessageParam,
type MessageStreamEvent,
type ToolUseBlockParam,
} from '@anthropic-ai/sdk/resources/messages.mjs';

const AnthropicConfigSchema = GenerationCommonConfigSchema.extend({
export const AnthropicConfigSchema = GenerationCommonConfigSchema.extend({
tool_choice: z
.union([
z.object({
Expand Down Expand Up @@ -123,7 +137,7 @@ export const SUPPORTED_CLAUDE_MODELS: Record<
function toAnthropicRole(
role: Role,
toolMessageType?: 'tool_use' | 'tool_result'
): Anthropic.Beta.Tools.ToolsBetaMessageParam['role'] {
): MessageParam['role'] {
switch (role) {
case 'user':
return 'user';
Expand Down Expand Up @@ -167,34 +181,41 @@ const extractDataFromBase64Url = (
*/
export function toAnthropicToolResponseContent(
part: Part
): Anthropic.TextBlockParam | Anthropic.ImageBlockParam {
): TextBlockParam | ImageBlockParam {
if (!part.toolResponse) {
throw Error(
`Invalid genkit part provided to toAnthropicToolResponseContent: ${JSON.stringify(
part
)}.`
);
}
const isMedia = isMediaObject(part.toolResponse?.output);
const isString = typeof part.toolResponse?.output === 'string';
if (!isMedia && !isString) {
throw Error(
`Invalid genkit part provided to toAnthropicToolResponseContent: ${part}.`
let base64Data;
if (isMedia) {
base64Data = extractDataFromBase64Url(
(part.toolResponse?.output as Media).url
);
} else if (isString) {
base64Data = extractDataFromBase64Url(part.toolResponse?.output as string);
}
const base64Data = extractDataFromBase64Url(
isMedia
? (part.toolResponse?.output as Media).url
: (part.toolResponse?.output as string)
);
// @ts-expect-error TODO: improve these types
return base64Data
? {
type: 'image',
source: {
type: 'base64',
data: base64Data.data,
media_type:
(part.toolResponse?.output as Media)?.contentType ??
((part.toolResponse?.output as Media)
?.contentType as ImageBlockParam.Source['media_type']) ??
base64Data.contentType,
},
}
: {
type: 'text',
text: part.toolResponse?.output as string,
text: isString
? (part.toolResponse?.output as string)
: JSON.stringify(part.toolResponse?.output),
};
}

Expand All @@ -206,11 +227,7 @@ export function toAnthropicToolResponseContent(
*/
export function toAnthropicMessageContent(
part: Part
):
| Anthropic.TextBlock
| Anthropic.ImageBlockParam
| Anthropic.Beta.Tools.ToolUseBlockParam
| Anthropic.Beta.Tools.ToolResultBlockParam {
): TextBlock | ImageBlockParam | ToolUseBlockParam | ToolResultBlockParam {
if (part.text) {
return {
type: 'text',
Expand Down Expand Up @@ -262,20 +279,18 @@ export function toAnthropicMessageContent(
*/
export function toAnthropicMessages(messages: MessageData[]): {
system?: string;
messages: Anthropic.Beta.Tools.ToolsBetaMessageParam[];
messages: MessageParam[];
} {
const system =
messages[0]?.role === 'system' ? messages[0].content?.[0]?.text : undefined;
const messagesToIterate = system ? messages.slice(1) : messages;
const anthropicMsgs: Anthropic.Beta.Tools.ToolsBetaMessageParam[] = [];
const anthropicMsgs: MessageParam[] = [];
for (const message of messagesToIterate) {
const msg = new Message(message);
const msg = new GenkitMessage(message);
const content = msg.content.map(toAnthropicMessageContent);
const toolMessageType = content.find(
(c) => c.type === 'tool_use' || c.type === 'tool_result'
) as
| Anthropic.Beta.Tools.ToolUseBlockParam
| Anthropic.Beta.Tools.ToolResultBlockParam;
) as ToolUseBlockParam | ToolResultBlockParam;
const role = toAnthropicRole(message.role, toolMessageType?.type);
anthropicMsgs.push({
role: role,
Expand All @@ -290,19 +305,16 @@ export function toAnthropicMessages(messages: MessageData[]): {
* @param tool The Genkit ToolDefinition to convert.
* @returns The converted Anthropic Tool object.
*/
export function toAnthropicTool(
tool: ToolDefinition
): Anthropic.Beta.Tools.Tool {
export function toAnthropicTool(tool: ToolDefinition): Tool {
return {
name: tool.name,
description: tool.description,
input_schema:
tool.inputSchema as Anthropic.Beta.Tools.Messages.Tool.InputSchema,
input_schema: tool.inputSchema as Tool.InputSchema,
};
}

const finishReasonMap: Record<
NonNullable<Anthropic.Beta.Tools.ToolsBetaMessage['stop_reason']>,
NonNullable<Message['stop_reason']>,
CandidateData['finishReason']
> = {
end_turn: 'stop',
Expand All @@ -312,76 +324,88 @@ const finishReasonMap: Record<
};

/**
* Converts an Anthropic content block to a Genkit CandidateData object.
* @param choice The Anthropic content block to convert.
* @param index The index of the content block.
* @param stopReason The reason the content block generation stopped.
* @returns The converted Genkit CandidateData object.
* Converts an Anthropic content block to a Genkit Part object.
* @param contentBlock The Anthropic content block to convert.
* @returns The converted Genkit Part object.
*/
function fromAnthropicContentBlock(
choice: Anthropic.Beta.Tools.Messages.ToolsBetaContentBlock,
index: number,
stopReason: Anthropic.Beta.Tools.Messages.ToolsBetaMessage['stop_reason']
): CandidateData {
return {
index,
finishReason: (stopReason && finishReasonMap[stopReason]) || 'other',
message:
choice.type === 'text'
? {
role: 'model',
content: [{ text: choice.text }],
}
: {
role: 'tool',
content: [
{
toolRequest: {
ref: choice.id,
name: choice.name,
input: choice.input,
},
},
],
},
};
function fromAnthropicContentBlock(contentBlock: ContentBlock): Part {
return contentBlock.type === 'tool_use'
? {
toolRequest: {
ref: contentBlock.id,
name: contentBlock.name,
input: contentBlock.input,
},
}
: { text: contentBlock.text };
}

/**
* Converts an Anthropic message stream event to a Genkit CandidateData object.
* @param choice The Anthropic message stream event to convert.
* @returns The converted Genkit CandidateData object if the event is a content block start or delta, otherwise undefined.
* Converts an Anthropic message stream event to a Genkit Part object.
* @param event The Anthropic message stream event to convert.
* @returns The converted Genkit Part object if the event is a content block
* start or delta, otherwise undefined.
*/
function fromAnthropicContentBlockChunk(
choice: Anthropic.Beta.Tools.Messages.ToolsBetaMessageStreamEvent
): CandidateData | undefined {
event: MessageStreamEvent
): Part | undefined {
if (
choice.type !== 'content_block_start' &&
choice.type !== 'content_block_delta'
event.type !== 'content_block_start' &&
event.type !== 'content_block_delta'
) {
return;
}
const choiceField =
choice.type === 'content_block_start' ? 'content_block' : 'delta';
const eventField =
event.type === 'content_block_start' ? 'content_block' : 'delta';
return event[eventField].type === 'text'
? {
text: event[eventField].text,
}
: {
toolRequest: {
ref: event[eventField].id,
name: event[eventField].name,
input: event[eventField].input,
},
};
}

function fromAnthropicStopReason(
reason: Message['stop_reason']
): CandidateData['finishReason'] {
switch (reason) {
case 'max_tokens':
return 'length';
case 'end_turn':
// fall through
case 'stop_sequence':
// fall through
case 'tool_use':
return 'stop';
case null:
return 'unknown';
default:
return 'other';
}
}

export function fromAnthropicResponse(response: Message): GenerateResponseData {
return {
index: choice.index,
finishReason: 'unknown',
message: {
role: 'model',
content: [
choice[choiceField].type === 'text'
? {
text: choice[choiceField].text,
}
: {
toolRequest: {
ref: choice[choiceField].id,
name: choice[choiceField].name,
input: choice[choiceField].input,
},
},
],
candidates: [
{
index: 0,
finishReason: fromAnthropicStopReason(response.stop_reason),
message: {
role: 'model',
content: response.content.map(fromAnthropicContentBlock),
},
},
],
usage: {
inputTokens: response.usage.input_tokens,
outputTokens: response.usage.output_tokens,
},
custom: response,
};
}

Expand All @@ -397,12 +421,12 @@ export function toAnthropicRequestBody(
modelName: string,
request: GenerateRequest<typeof AnthropicConfigSchema>,
stream?: boolean
): Anthropic.Beta.Tools.Messages.MessageCreateParams {
): MessageCreateParams {
const model = SUPPORTED_CLAUDE_MODELS[modelName];
if (!model) throw new Error(`Unsupported model: ${modelName}`);
const { system, messages } = toAnthropicMessages(request.messages);
const mappedModelName = request.config?.version || model.version || modelName;
const body: Anthropic.Beta.Tools.MessageCreateParams = {
const body: MessageCreateParams = {
system,
messages,
tools: request.tools?.map(toAnthropicTool),
Expand Down Expand Up @@ -451,35 +475,24 @@ export function claudeModel(
configSchema: model.configSchema,
},
async (request, streamingCallback) => {
let response: Anthropic.Beta.Tools.ToolsBetaMessage;
let response: Message;
const body = toAnthropicRequestBody(name, request, !!streamingCallback);
if (streamingCallback) {
const stream = client.beta.tools.messages.stream(body);
const stream = client.messages.stream(body);
for await (const chunk of stream) {
const c = fromAnthropicContentBlockChunk(chunk);
if (c) {
streamingCallback({
index: c.index,
content: c.message.content,
index: 0,
content: [c],
});
}
}
response = await stream.finalMessage();
} else {
response = (await client.beta.tools.messages.create(
body
)) as Anthropic.Beta.Tools.ToolsBetaMessage;
response = (await client.messages.create(body)) as Message;
}
return {
candidates: response.content.map((content, index) =>
fromAnthropicContentBlock(content, index, response.stop_reason)
),
usage: {
inputTokens: response.usage.input_tokens,
outputTokens: response.usage.output_tokens,
},
custom: response,
};
return fromAnthropicResponse(response);
}
);
}
Loading

0 comments on commit e283a83

Please sign in to comment.