Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/core/condense/__tests__/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ jest.mock("../../../api/transform/image-cleaning", () => ({
maybeRemoveImageBlocks: jest.fn((messages: ApiMessage[], _apiHandler: ApiHandler) => [...messages]),
}))

const taskId = "test-task-id"

describe("getMessagesSinceLastSummary", () => {
it("should return all messages when there is no summary", () => {
const messages: ApiMessage[] = [
Expand Down Expand Up @@ -106,7 +108,7 @@ describe("summarizeConversation", () => {
{ role: "assistant", content: "Hi there", ts: 2 },
]

const result = await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt)
const result = await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt, taskId)
expect(result.messages).toEqual(messages)
expect(result.cost).toBe(0)
expect(result.summary).toBe("")
Expand All @@ -125,7 +127,7 @@ describe("summarizeConversation", () => {
{ role: "user", content: "Tell me more", ts: 7 },
]

const result = await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt)
const result = await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt, taskId)
expect(result.messages).toEqual(messages)
expect(result.cost).toBe(0)
expect(result.summary).toBe("")
Expand All @@ -144,7 +146,7 @@ describe("summarizeConversation", () => {
{ role: "user", content: "Tell me more", ts: 7 },
]

const result = await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt)
const result = await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt, taskId)

// Check that the API was called correctly
expect(mockApiHandler.createMessage).toHaveBeenCalled()
Expand Down Expand Up @@ -202,7 +204,7 @@ describe("summarizeConversation", () => {
return messages.map(({ role, content }: { role: string; content: any }) => ({ role, content }))
})

const result = await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt)
const result = await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt, taskId)

// Should return original messages when summary is empty
expect(result.messages).toEqual(messages)
Expand All @@ -225,7 +227,7 @@ describe("summarizeConversation", () => {
{ role: "user", content: "Tell me more", ts: 7 },
]

await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt)
await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt, taskId)

// Verify the final request message
const expectedFinalMessage = {
Expand Down Expand Up @@ -266,7 +268,7 @@ describe("summarizeConversation", () => {
// Override the mock for this test
mockApiHandler.createMessage = jest.fn().mockReturnValue(streamWithUsage) as any

const result = await summarizeConversation(messages, mockApiHandler, systemPrompt)
const result = await summarizeConversation(messages, mockApiHandler, systemPrompt, taskId)

// Verify that countTokens was called with the correct messages including system prompt
expect(mockApiHandler.countTokens).toHaveBeenCalled()
Expand Down
4 changes: 4 additions & 0 deletions src/core/condense/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import Anthropic from "@anthropic-ai/sdk"
import { ApiHandler } from "../../api"
import { ApiMessage } from "../task-persistence/apiMessages"
import { maybeRemoveImageBlocks } from "../../api/transform/image-cleaning"
import { telemetryService } from "../../services/telemetry/TelemetryService"

export const N_MESSAGES_TO_KEEP = 3

Expand Down Expand Up @@ -58,13 +59,16 @@ export type SummarizeResponse = {
* @param {ApiMessage[]} messages - The conversation messages
* @param {ApiHandler} apiHandler - The API handler to use for token counting.
* @param {string} systemPrompt - The system prompt for API requests, which should be considered in the context token count
* @param {string} taskId - The task ID for the conversation, used for telemetry
* @returns {SummarizeResponse} - The result of the summarization operation (see above)
*/
export async function summarizeConversation(
messages: ApiMessage[],
apiHandler: ApiHandler,
systemPrompt: string,
taskId: string,
): Promise<SummarizeResponse> {
telemetryService.captureContextCondensed(taskId)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add some way to distinguish between automatic and manual?

const response: SummarizeResponse = { messages, cost: 0, summary: "" }
const messagesToSummarize = getMessagesSinceLastSummary(messages.slice(0, -N_MESSAGES_TO_KEEP))
if (messagesToSummarize.length <= 1) {
Expand Down
38 changes: 31 additions & 7 deletions src/core/sliding-window/__tests__/sliding-window.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class MockApiHandler extends BaseProvider {

// Create a singleton instance for tests
const mockApiHandler = new MockApiHandler()
const taskId = "test-task-id"

/**
* Tests for the truncateConversation function
Expand All @@ -49,7 +50,7 @@ describe("truncateConversation", () => {
{ role: "user", content: "Third message" },
]

const result = truncateConversation(messages, 0.5)
const result = truncateConversation(messages, 0.5, taskId)

// With 2 messages after the first, 0.5 fraction means remove 1 message
// But 1 is odd, so it rounds down to 0 (to make it even)
Expand All @@ -70,7 +71,7 @@ describe("truncateConversation", () => {

// 4 messages excluding first, 0.5 fraction = 2 messages to remove
// 2 is already even, so no rounding needed
const result = truncateConversation(messages, 0.5)
const result = truncateConversation(messages, 0.5, taskId)

expect(result.length).toBe(3)
expect(result[0]).toEqual(messages[0])
Expand All @@ -91,7 +92,7 @@ describe("truncateConversation", () => {

// 6 messages excluding first, 0.3 fraction = 1.8 messages to remove
// 1.8 rounds down to 1, then to 0 to make it even
const result = truncateConversation(messages, 0.3)
const result = truncateConversation(messages, 0.3, taskId)

expect(result.length).toBe(7) // No messages removed
expect(result).toEqual(messages)
Expand All @@ -104,7 +105,7 @@ describe("truncateConversation", () => {
{ role: "user", content: "Third message" },
]

const result = truncateConversation(messages, 0)
const result = truncateConversation(messages, 0, taskId)

expect(result).toEqual(messages)
})
Expand All @@ -119,7 +120,7 @@ describe("truncateConversation", () => {

// 3 messages excluding first, 1.0 fraction = 3 messages to remove
// But 3 is odd, so it rounds down to 2 to make it even
const result = truncateConversation(messages, 1)
const result = truncateConversation(messages, 1, taskId)

expect(result.length).toBe(2)
expect(result[0]).toEqual(messages[0])
Expand Down Expand Up @@ -251,6 +252,7 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})

// Check the new return type
Expand Down Expand Up @@ -282,6 +284,7 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})

expect(result).toEqual({
Expand Down Expand Up @@ -311,6 +314,7 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})

const result2 = await truncateConversationIfNeeded({
Expand All @@ -322,6 +326,7 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})

expect(result1.messages).toEqual(result2.messages)
Expand All @@ -340,6 +345,7 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})

const result4 = await truncateConversationIfNeeded({
Expand All @@ -351,6 +357,7 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})

expect(result3.messages).toEqual(result4.messages)
Expand Down Expand Up @@ -384,6 +391,7 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})
expect(resultWithSmall).toEqual({
messages: messagesWithSmallContent,
Expand Down Expand Up @@ -416,6 +424,7 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})
expect(resultWithLarge.messages).not.toEqual(messagesWithLargeContent) // Should truncate
expect(resultWithLarge.summary).toBe("")
Expand All @@ -441,6 +450,7 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})
expect(resultWithVeryLarge.messages).not.toEqual(messagesWithVeryLargeContent) // Should truncate
expect(resultWithVeryLarge.summary).toBe("")
Expand Down Expand Up @@ -469,6 +479,7 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})
expect(result).toEqual({
messages: expectedResult,
Expand Down Expand Up @@ -510,10 +521,11 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: true,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})

// Verify summarizeConversation was called with the right parameters
expect(summarizeSpy).toHaveBeenCalledWith(messagesWithSmallContent, mockApiHandler, "System prompt")
expect(summarizeSpy).toHaveBeenCalledWith(messagesWithSmallContent, mockApiHandler, "System prompt", taskId)

// Verify the result contains the summary information
expect(result).toMatchObject({
Expand Down Expand Up @@ -557,6 +569,7 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: true,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})

// Verify summarizeConversation was called
Expand Down Expand Up @@ -594,6 +607,7 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 50, // This shouldn't matter since autoCondenseContext is false
systemPrompt: "System prompt",
taskId,
})

// Verify summarizeConversation was not called
Expand Down Expand Up @@ -645,10 +659,11 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: true,
autoCondenseContextPercent: 50, // Set threshold to 50% - our tokens are at 60%
systemPrompt: "System prompt",
taskId,
})

// Verify summarizeConversation was called with the right parameters
expect(summarizeSpy).toHaveBeenCalledWith(messagesWithSmallContent, mockApiHandler, "System prompt")
expect(summarizeSpy).toHaveBeenCalledWith(messagesWithSmallContent, mockApiHandler, "System prompt", taskId)

// Verify the result contains the summary information
expect(result).toMatchObject({
Expand Down Expand Up @@ -682,6 +697,7 @@ describe("truncateConversationIfNeeded", () => {
autoCondenseContext: true,
autoCondenseContextPercent: 50, // Set threshold to 50% - our tokens are at 40%
systemPrompt: "System prompt",
taskId,
})

// Verify summarizeConversation was not called
Expand Down Expand Up @@ -738,6 +754,7 @@ describe("getMaxTokens", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})
expect(result1).toEqual({
messages: messagesWithSmallContent,
Expand All @@ -756,6 +773,7 @@ describe("getMaxTokens", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})
expect(result2.messages).not.toEqual(messagesWithSmallContent)
expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction
Expand All @@ -782,6 +800,7 @@ describe("getMaxTokens", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})
expect(result1).toEqual({
messages: messagesWithSmallContent,
Expand All @@ -800,6 +819,7 @@ describe("getMaxTokens", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})
expect(result2.messages).not.toEqual(messagesWithSmallContent)
expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction
Expand All @@ -825,6 +845,7 @@ describe("getMaxTokens", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})
expect(result1.messages).toEqual(messagesWithSmallContent)

Expand All @@ -838,6 +859,7 @@ describe("getMaxTokens", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})
expect(result2).not.toEqual(messagesWithSmallContent)
expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction
Expand All @@ -861,6 +883,7 @@ describe("getMaxTokens", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})
expect(result1.messages).toEqual(messagesWithSmallContent)

Expand All @@ -874,6 +897,7 @@ describe("getMaxTokens", () => {
autoCondenseContext: false,
autoCondenseContextPercent: 100,
systemPrompt: "System prompt",
taskId,
})
expect(result2).not.toEqual(messagesWithSmallContent)
expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction
Expand Down
11 changes: 8 additions & 3 deletions src/core/sliding-window/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Anthropic } from "@anthropic-ai/sdk"
import { ApiHandler } from "../../api"
import { summarizeConversation, SummarizeResponse } from "../condense"
import { ApiMessage } from "../task-persistence/apiMessages"
import { telemetryService } from "../../services/telemetry/TelemetryService"

/**
* Default percentage of the context window to use as a buffer when deciding when to truncate
Expand Down Expand Up @@ -31,9 +32,11 @@ export async function estimateTokenCount(
*
* @param {ApiMessage[]} messages - The conversation messages.
* @param {number} fracToRemove - The fraction (between 0 and 1) of messages (excluding the first) to remove.
* @param {string} taskId - The task ID for the conversation, used for telemetry
* @returns {ApiMessage[]} The truncated conversation messages.
*/
export function truncateConversation(messages: ApiMessage[], fracToRemove: number): ApiMessage[] {
export function truncateConversation(messages: ApiMessage[], fracToRemove: number, taskId: string): ApiMessage[] {
telemetryService.captureSlidingWindowTruncation(taskId)
const truncatedMessages = [messages[0]]
const rawMessagesToRemove = Math.floor((messages.length - 1) * fracToRemove)
const messagesToRemove = rawMessagesToRemove - (rawMessagesToRemove % 2)
Expand Down Expand Up @@ -66,6 +69,7 @@ type TruncateOptions = {
autoCondenseContext: boolean
autoCondenseContextPercent: number
systemPrompt: string
taskId: string
}

type TruncateResponse = SummarizeResponse & { prevContextTokens: number }
Expand All @@ -86,6 +90,7 @@ export async function truncateConversationIfNeeded({
autoCondenseContext,
autoCondenseContextPercent,
systemPrompt,
taskId,
}: TruncateOptions): Promise<TruncateResponse> {
// Calculate the maximum tokens reserved for response
const reservedTokens = maxTokens || contextWindow * 0.2
Expand All @@ -108,7 +113,7 @@ export async function truncateConversationIfNeeded({
const contextPercent = (100 * prevContextTokens) / contextWindow
if (contextPercent >= autoCondenseContextPercent || prevContextTokens > allowedTokens) {
// Attempt to intelligently condense the context
const result = await summarizeConversation(messages, apiHandler, systemPrompt)
const result = await summarizeConversation(messages, apiHandler, systemPrompt, taskId)
if (result.summary) {
return { ...result, prevContextTokens }
}
Expand All @@ -117,7 +122,7 @@ export async function truncateConversationIfNeeded({

// Fall back to sliding window truncation if needed
if (prevContextTokens > allowedTokens) {
const truncatedMessages = truncateConversation(messages, 0.5)
const truncatedMessages = truncateConversation(messages, 0.5, taskId)
return { messages: truncatedMessages, prevContextTokens, summary: "", cost: 0 }
}
// No truncation or condensation needed
Expand Down
Loading
Loading