Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion pkg/model/provider/anthropic/beta_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ func (c *Client) createBetaStream(
}
}

sys := extractBetaSystemBlocks(messages)

if used, err := countAnthropicTokensBeta(ctx, client, anthropic.Model(c.ModelConfig.Model), converted, sys, allTools); err == nil {
configuredMaxTokens := maxTokens
maxTokens = clampMaxTokens(anthropicContextLimit(c.ModelConfig.Model), used, maxTokens)
if maxTokens < configuredMaxTokens {
slog.Warn("Anthropic Beta API max_tokens clamped to", "max_tokens", maxTokens)
}
}

params := anthropic.BetaMessageNewParams{
Model: anthropic.Model(c.ModelConfig.Model),
MaxTokens: maxTokens,
Expand All @@ -47,7 +57,7 @@ func (c *Client) createBetaStream(
}

// Populate proper Anthropic system prompt from input messages
if sys := extractBetaSystemBlocks(messages); len(sys) > 0 {
if len(sys) > 0 {
params.System = sys
}

Expand Down Expand Up @@ -186,3 +196,42 @@ func contentArrayBeta(m map[string]any) []any {
}
return nil
}

// countAnthropicTokensBeta calls Anthropic's Count Tokens API for the provided Beta API payload
// and returns the number of input tokens.
func countAnthropicTokensBeta(
ctx context.Context,
client anthropic.Client,
model anthropic.Model,
messages []anthropic.BetaMessageParam,
system []anthropic.BetaTextBlockParam,
anthropicTools []anthropic.BetaToolUnionParam,
) (int64, error) {
params := anthropic.BetaMessageCountTokensParams{
Model: model,
Messages: messages,
}
if len(system) > 0 {
params.System = anthropic.BetaMessageCountTokensParamsSystemUnion{
OfBetaTextBlockArray: system,
}
}
if len(anthropicTools) > 0 {
// Convert BetaToolUnionParam to BetaMessageCountTokensParamsToolUnion
toolParams := make([]anthropic.BetaMessageCountTokensParamsToolUnion, len(anthropicTools))
for i, tool := range anthropicTools {
if tool.OfTool != nil {
toolParams[i] = anthropic.BetaMessageCountTokensParamsToolUnion{
OfTool: tool.OfTool,
}
}
}
params.Tools = toolParams
}

result, err := client.Beta.Messages.CountTokens(ctx, params)
if err != nil {
return 0, err
}
return result.InputTokens, nil
}
305 changes: 305 additions & 0 deletions pkg/model/provider/anthropic/beta_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
package anthropic

import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/cagent/pkg/chat"
)

// TestCountAnthropicTokensBeta_Success tests successful token counting for beta API
func TestCountAnthropicTokensBeta_Success(t *testing.T) {
// Setup mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/messages/count_tokens", r.URL.Path)
assert.Equal(t, "application/json", r.Header.Get("content-type"))
assert.NotEmpty(t, r.Header.Get("x-api-key"))

// Verify request body contains expected fields
var payload map[string]any
err := json.NewDecoder(r.Body).Decode(&payload)
assert.NoError(t, err)
assert.Equal(t, "claude-3-5-sonnet-20241022", payload["model"])
assert.NotNil(t, payload["messages"])

// Return mock response
w.Header().Set("content-type", "application/json")
err = json.NewEncoder(w).Encode(map[string]int64{"input_tokens": 150})
assert.NoError(t, err)
}))
defer server.Close()

// Create test data
messages := []anthropic.BetaMessageParam{
{
Role: anthropic.BetaMessageParamRoleUser,
Content: []anthropic.BetaContentBlockParamUnion{
{OfText: &anthropic.BetaTextBlockParam{Text: "Hello"}},
},
},
}
system := []anthropic.BetaTextBlockParam{
{Text: "You are helpful"},
}

// Create client with test server URL
client := anthropic.NewClient(
option.WithAPIKey("test-key"),
option.WithBaseURL(server.URL),
)

// Call function
tokens, err := countAnthropicTokensBeta(t.Context(), client, "claude-3-5-sonnet-20241022", messages, system, nil)

// Verify
require.NoError(t, err)
assert.Equal(t, int64(150), tokens)
}

// TestCountAnthropicTokensBeta_NoAPIKey tests error when API key is missing
func TestCountAnthropicTokensBeta_NoAPIKey(t *testing.T) {
messages := []anthropic.BetaMessageParam{}
system := []anthropic.BetaTextBlockParam{}

// Create client without base URL to trigger error
client := anthropic.NewClient(
option.WithAPIKey("test-key"),
// No base URL set
)

tokens, err := countAnthropicTokensBeta(t.Context(), client, "claude-3-5-sonnet-20241022", messages, system, nil)

require.Error(t, err)
assert.Equal(t, int64(0), tokens)
}

// TestCountAnthropicTokensBeta_ServerError tests error handling for server errors
func TestCountAnthropicTokensBeta_ServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()

messages := []anthropic.BetaMessageParam{}
system := []anthropic.BetaTextBlockParam{}

// Create client with test server URL
client := anthropic.NewClient(
option.WithAPIKey("test-key"),
option.WithBaseURL(server.URL),
)

tokens, err := countAnthropicTokensBeta(t.Context(), client, "claude-3-5-sonnet-20241022", messages, system, nil)
require.Error(t, err)
assert.Equal(t, int64(0), tokens)
}

// TestCountAnthropicTokensBeta_WithTools tests token counting includes tools
func TestCountAnthropicTokensBeta_WithTools(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var payload map[string]any
err := json.NewDecoder(r.Body).Decode(&payload)
assert.NoError(t, err)

// Verify tools are included in payload
assert.NotNil(t, payload["tools"])
tools, ok := payload["tools"].([]any)
assert.True(t, ok)
assert.Len(t, tools, 1)

w.Header().Set("content-type", "application/json")
err = json.NewEncoder(w).Encode(map[string]int64{"input_tokens": 200})
assert.NoError(t, err)
}))
defer server.Close()

messages := []anthropic.BetaMessageParam{}
system := []anthropic.BetaTextBlockParam{}
tools := []anthropic.BetaToolUnionParam{
{OfTool: &anthropic.BetaToolParam{
Name: "test_tool",
Description: anthropic.String("A test tool"),
}},
}

// Create client with test server URL
client := anthropic.NewClient(
option.WithAPIKey("test-key"),
option.WithBaseURL(server.URL),
)

tokens, err := countAnthropicTokensBeta(t.Context(), client, "claude-3-5-sonnet-20241022", messages, system, tools)

require.NoError(t, err)
assert.Equal(t, int64(200), tokens)
}

// TestClampMaxTokens_WithinLimit tests clamping when configured tokens are within limit
func TestClampMaxTokens_WithinLimit(t *testing.T) {
// Context limit: 200k, used: 50k, safety: 1k, remaining: 149k
// Configured: 8k (within limit)
result := clampMaxTokens(200000, 50000, 8000)
assert.Equal(t, int64(8000), result)
}

// TestClampMaxTokens_ExceedsLimit tests clamping when configured tokens exceed remaining
func TestClampMaxTokens_ExceedsLimit(t *testing.T) {
// Context limit: 200k, used: 190k, safety: 1024, remaining: 8976
// Configured: 16k (exceeds limit)
result := clampMaxTokens(200000, 190000, 16000)
assert.Equal(t, int64(8976), result)
}

// TestClampMaxTokens_MinimumOne tests clamping never returns less than 1
func TestClampMaxTokens_MinimumOne(t *testing.T) {
// Context limit: 200k, used: 199k, safety: 1k, remaining: 0 (would be negative)
result := clampMaxTokens(200000, 199000, 8000)
assert.Equal(t, int64(1), result)
}

// TestClampMaxTokens_ExactlyAtLimit tests clamping when used + safety equals limit
func TestClampMaxTokens_ExactlyAtLimit(t *testing.T) {
// Context limit: 200k, used: 199k, safety: 1k, remaining: 0
result := clampMaxTokens(200000, 199000, 1000)
assert.Equal(t, int64(1), result)
}

// TestAnthropicContextLimit_ReturnsCorrectLimit tests context limit function
func TestAnthropicContextLimit_ReturnsCorrectLimit(t *testing.T) {
limit := anthropicContextLimit("claude-3-5-sonnet-20241022")
assert.Equal(t, int64(200000), limit)
}

// TestExtractBetaSystemBlocks_SingleSystemMessage tests extracting system messages
func TestExtractBetaSystemBlocks_SingleSystemMessage(t *testing.T) {
msgs := []chat.Message{
{
Role: chat.MessageRoleSystem,
Content: "You are a helpful assistant",
},
}

blocks := extractBetaSystemBlocks(msgs)

require.Len(t, blocks, 1)
assert.Equal(t, "You are a helpful assistant", blocks[0].Text)
}

// TestExtractBetaSystemBlocks_MultipleSystemMessages tests extracting multiple system messages
func TestExtractBetaSystemBlocks_MultipleSystemMessages(t *testing.T) {
msgs := []chat.Message{
{
Role: chat.MessageRoleSystem,
Content: "You are helpful",
},
{
Role: chat.MessageRoleUser,
Content: "Hello",
},
{
Role: chat.MessageRoleSystem,
Content: "Be concise",
},
}

blocks := extractBetaSystemBlocks(msgs)

require.Len(t, blocks, 2)
assert.Equal(t, "You are helpful", blocks[0].Text)
assert.Equal(t, "Be concise", blocks[1].Text)
}

// TestExtractBetaSystemBlocks_SkipsEmptyText tests that empty system text is skipped
func TestExtractBetaSystemBlocks_SkipsEmptyText(t *testing.T) {
msgs := []chat.Message{
{
Role: chat.MessageRoleSystem,
Content: " \n\t ",
},
{
Role: chat.MessageRoleSystem,
Content: "Valid system prompt",
},
}

blocks := extractBetaSystemBlocks(msgs)

require.Len(t, blocks, 1)
assert.Equal(t, "Valid system prompt", blocks[0].Text)
}

// TestExtractBetaSystemBlocks_MultiContent tests extracting from multi-content system messages
func TestExtractBetaSystemBlocks_MultiContent(t *testing.T) {
msgs := []chat.Message{
{
Role: chat.MessageRoleSystem,
MultiContent: []chat.MessagePart{
{Type: chat.MessagePartTypeText, Text: "Part 1"},
{Type: chat.MessagePartTypeText, Text: "Part 2"},
},
},
}

blocks := extractBetaSystemBlocks(msgs)

require.Len(t, blocks, 2)
assert.Equal(t, "Part 1", blocks[0].Text)
assert.Equal(t, "Part 2", blocks[1].Text)
}

// TestConvertBetaMessages_UserMessage tests converting user messages
func TestConvertBetaMessages_UserMessage(t *testing.T) {
msgs := []chat.Message{
{
Role: chat.MessageRoleUser,
Content: "Hello, assistant!",
},
}

converted := convertBetaMessages(msgs)

require.Len(t, converted, 1)
assert.Equal(t, anthropic.BetaMessageParamRoleUser, converted[0].Role)
require.Len(t, converted[0].Content, 1)
}

// TestConvertBetaMessages_SkipsSystemMessages tests that system messages are skipped
func TestConvertBetaMessages_SkipsSystemMessages(t *testing.T) {
msgs := []chat.Message{
{
Role: chat.MessageRoleSystem,
Content: "System prompt",
},
{
Role: chat.MessageRoleUser,
Content: "User message",
},
}

converted := convertBetaMessages(msgs)

require.Len(t, converted, 1)
assert.Equal(t, anthropic.BetaMessageParamRoleUser, converted[0].Role)
}

// TestConvertBetaMessages_AssistantMessage tests converting assistant messages
func TestConvertBetaMessages_AssistantMessage(t *testing.T) {
msgs := []chat.Message{
{
Role: chat.MessageRoleAssistant,
Content: "I can help with that",
},
}

converted := convertBetaMessages(msgs)

require.Len(t, converted, 1)
assert.Equal(t, anthropic.BetaMessageParamRoleAssistant, converted[0].Role)
}
Loading
Loading