Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions src/api/providers/__tests__/vscode-lm.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,132 @@ describe("VsCodeLmHandler", () => {
await expect(promise).rejects.toThrow("VSCode LM completion error: Completion failed")
})
})

describe("countTokens", () => {
beforeEach(() => {
const mockModel = { ...mockLanguageModelChat }
;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel])
handler["client"] = mockLanguageModelChat
handler["currentRequestCancellation"] = new vscode.CancellationTokenSource()
})

it("should count tokens for text content", async () => {
const content: Anthropic.Messages.ContentBlockParam[] = [
{ type: "text", text: "Hello world" },
{ type: "text", text: "How are you?" },
]

mockLanguageModelChat.countTokens.mockResolvedValue(15)

const result = await handler.countTokens(content)
expect(result).toBe(15)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith(
"Hello worldHow are you?",
expect.any(Object),
)
})

it("should handle image content with placeholder", async () => {
const content: Anthropic.Messages.ContentBlockParam[] = [
{ type: "text", text: "Look at this:" },
{ type: "image", source: { type: "base64", media_type: "image/png", data: "base64data" } },
]

mockLanguageModelChat.countTokens.mockResolvedValue(10)

const result = await handler.countTokens(content)
expect(result).toBe(10)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Look at this:[IMAGE]", expect.any(Object))
})
})

describe("internalCountTokens", () => {
beforeEach(() => {
const mockModel = { ...mockLanguageModelChat }
;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel])
handler["client"] = mockLanguageModelChat
handler["currentRequestCancellation"] = new vscode.CancellationTokenSource()
})

it("should count tokens for string input", async () => {
mockLanguageModelChat.countTokens.mockResolvedValue(20)

const result = await handler["internalCountTokens"]("Test string")
expect(result).toBe(20)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Test string", expect.any(Object))
})

it("should handle LanguageModelChatMessage with normal token count", async () => {
const message = vscode.LanguageModelChatMessage.User("Hello")
mockLanguageModelChat.countTokens.mockResolvedValue(10)

const result = await handler["internalCountTokens"](message)
expect(result).toBe(10)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledTimes(1)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith(message, expect.any(Object))
})

it("should recalculate when LanguageModelChatMessage returns token count of 4", async () => {
const message = vscode.LanguageModelChatMessage.User(
"This is a longer message that should have more than 4 tokens",
)

// First call returns 4 (the problematic value)
// Second call returns the correct count after string conversion
mockLanguageModelChat.countTokens.mockResolvedValueOnce(4).mockResolvedValueOnce(25)

const result = await handler["internalCountTokens"](message)
expect(result).toBe(25)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledTimes(2)

// First call with the message object
expect(mockLanguageModelChat.countTokens).toHaveBeenNthCalledWith(1, message, expect.any(Object))

// Second call with the extracted text
expect(mockLanguageModelChat.countTokens).toHaveBeenNthCalledWith(
2,
"This is a longer message that should have more than 4 tokens",
expect.any(Object),
)
})

it("should handle LanguageModelChatMessage with array content when token count is 4", async () => {
const textPart = new vscode.LanguageModelTextPart("Part 1")
const textPart2 = new vscode.LanguageModelTextPart(" Part 2")
const message = {
role: "user",
content: [textPart, textPart2],
}

mockLanguageModelChat.countTokens.mockResolvedValueOnce(4).mockResolvedValueOnce(15)

const result = await handler["internalCountTokens"](message as any)
expect(result).toBe(15)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledTimes(2)
expect(mockLanguageModelChat.countTokens).toHaveBeenNthCalledWith(2, "Part 1 Part 2", expect.any(Object))
})

it("should return 0 when no client is available", async () => {
handler["client"] = null

const result = await handler["internalCountTokens"]("Test")
expect(result).toBe(0)
expect(mockLanguageModelChat.countTokens).not.toHaveBeenCalled()
})

it("should return 0 when no cancellation token is available", async () => {
handler["currentRequestCancellation"] = null

const result = await handler["internalCountTokens"]("Test")
expect(result).toBe(0)
expect(mockLanguageModelChat.countTokens).not.toHaveBeenCalled()
})

it("should handle errors gracefully", async () => {
mockLanguageModelChat.countTokens.mockRejectedValue(new Error("Token counting failed"))

const result = await handler["internalCountTokens"]("Test")
expect(result).toBe(0)
})
})
})
31 changes: 30 additions & 1 deletion src/api/providers/vscode-lm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,42 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan

if (typeof text === "string") {
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
} else if (text instanceof vscode.LanguageModelChatMessage) {
} else if (text && typeof text === "object" && "role" in text && "content" in text) {
// For chat messages, ensure we have content
if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) {
console.debug("Roo Code <Language Model API>: Empty chat message content")
return 0
}
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)

// Special handling: when tokenCount equals 4 for LanguageModelChatMessage,
// convert to string and recalculate
if (tokenCount === 4) {
console.debug(
"Roo Code <Language Model API>: Token count is 4, converting message to string for recalculation",
)

// Extract text content from the message
let textContent = ""
if (typeof text.content === "string") {
textContent = text.content
} else if (Array.isArray(text.content)) {
// Handle array of content parts
for (const part of text.content) {
if (part instanceof vscode.LanguageModelTextPart) {
textContent += part.value
} else if (typeof part === "string") {
textContent += part
}
}
}

// Recalculate tokens using the extracted text
if (textContent) {
tokenCount = await this.client.countTokens(textContent, this.currentRequestCancellation.token)
console.debug(`Roo Code <Language Model API>: Recalculated token count: ${tokenCount}`)
}
}
} else {
console.warn("Roo Code <Language Model API>: Invalid input type for token counting")
return 0
Expand Down