diff --git a/src/api/providers/__tests__/vscode-lm.spec.ts b/src/api/providers/__tests__/vscode-lm.spec.ts index afb349e5e09..75d3e300e8a 100644 --- a/src/api/providers/__tests__/vscode-lm.spec.ts +++ b/src/api/providers/__tests__/vscode-lm.spec.ts @@ -300,4 +300,132 @@ describe("VsCodeLmHandler", () => { await expect(promise).rejects.toThrow("VSCode LM completion error: Completion failed") }) }) + + describe("countTokens", () => { + beforeEach(() => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel]) + handler["client"] = mockLanguageModelChat + handler["currentRequestCancellation"] = new vscode.CancellationTokenSource() + }) + + it("should count tokens for text content", async () => { + const content: Anthropic.Messages.ContentBlockParam[] = [ + { type: "text", text: "Hello world" }, + { type: "text", text: "How are you?" }, + ] + + mockLanguageModelChat.countTokens.mockResolvedValue(15) + + const result = await handler.countTokens(content) + expect(result).toBe(15) + expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith( + "Hello worldHow are you?", + expect.any(Object), + ) + }) + + it("should handle image content with placeholder", async () => { + const content: Anthropic.Messages.ContentBlockParam[] = [ + { type: "text", text: "Look at this:" }, + { type: "image", source: { type: "base64", media_type: "image/png", data: "base64data" } }, + ] + + mockLanguageModelChat.countTokens.mockResolvedValue(10) + + const result = await handler.countTokens(content) + expect(result).toBe(10) + expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Look at this:[IMAGE]", expect.any(Object)) + }) + }) + + describe("internalCountTokens", () => { + beforeEach(() => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel]) + handler["client"] = mockLanguageModelChat + handler["currentRequestCancellation"] = new vscode.CancellationTokenSource() + }) + + it("should count tokens for string input", async () => { + mockLanguageModelChat.countTokens.mockResolvedValue(20) + + const result = await handler["internalCountTokens"]("Test string") + expect(result).toBe(20) + expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Test string", expect.any(Object)) + }) + + it("should handle LanguageModelChatMessage with normal token count", async () => { + const message = vscode.LanguageModelChatMessage.User("Hello") + mockLanguageModelChat.countTokens.mockResolvedValue(10) + + const result = await handler["internalCountTokens"](message) + expect(result).toBe(10) + expect(mockLanguageModelChat.countTokens).toHaveBeenCalledTimes(1) + expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith(message, expect.any(Object)) + }) + + it("should recalculate when LanguageModelChatMessage returns token count of 4", async () => { + const message = vscode.LanguageModelChatMessage.User( + "This is a longer message that should have more than 4 tokens", + ) + + // First call returns 4 (the problematic value) + // Second call returns the correct count after string conversion + mockLanguageModelChat.countTokens.mockResolvedValueOnce(4).mockResolvedValueOnce(25) + + const result = await handler["internalCountTokens"](message) + expect(result).toBe(25) + expect(mockLanguageModelChat.countTokens).toHaveBeenCalledTimes(2) + + // First call with the message object + expect(mockLanguageModelChat.countTokens).toHaveBeenNthCalledWith(1, message, expect.any(Object)) + + // Second call with the extracted text + expect(mockLanguageModelChat.countTokens).toHaveBeenNthCalledWith( + 2, + "This is a longer message that should have more than 4 tokens", + expect.any(Object), + ) + }) + + it("should handle LanguageModelChatMessage with array content when token count is 4", async () => { + const textPart = new vscode.LanguageModelTextPart("Part 1") + const textPart2 = new vscode.LanguageModelTextPart(" Part 2") + const message = { + role: "user", + content: [textPart, textPart2], + } + + mockLanguageModelChat.countTokens.mockResolvedValueOnce(4).mockResolvedValueOnce(15) + + const result = await handler["internalCountTokens"](message as any) + expect(result).toBe(15) + expect(mockLanguageModelChat.countTokens).toHaveBeenCalledTimes(2) + expect(mockLanguageModelChat.countTokens).toHaveBeenNthCalledWith(2, "Part 1 Part 2", expect.any(Object)) + }) + + it("should return 0 when no client is available", async () => { + handler["client"] = null + + const result = await handler["internalCountTokens"]("Test") + expect(result).toBe(0) + expect(mockLanguageModelChat.countTokens).not.toHaveBeenCalled() + }) + + it("should return 0 when no cancellation token is available", async () => { + handler["currentRequestCancellation"] = null + + const result = await handler["internalCountTokens"]("Test") + expect(result).toBe(0) + expect(mockLanguageModelChat.countTokens).not.toHaveBeenCalled() + }) + + it("should handle errors gracefully", async () => { + mockLanguageModelChat.countTokens.mockRejectedValue(new Error("Token counting failed")) + + const result = await handler["internalCountTokens"]("Test") + expect(result).toBe(0) + }) + }) }) diff --git a/src/api/providers/vscode-lm.ts b/src/api/providers/vscode-lm.ts index 6474371beeb..28095ec5fc8 100644 --- a/src/api/providers/vscode-lm.ts +++ b/src/api/providers/vscode-lm.ts @@ -225,13 +225,42 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan if (typeof text === "string") { tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token) - } else if (text instanceof vscode.LanguageModelChatMessage) { + } else if (text && typeof text === "object" && "role" in text && "content" in text) { // For chat messages, ensure we have content if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) { console.debug("Roo Code : Empty chat message content") return 0 } tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token) + + // Special handling: when tokenCount equals 4 for LanguageModelChatMessage, + // convert to string and recalculate + if (tokenCount === 4) { + console.debug( + "Roo Code : Token count is 4, converting message to string for recalculation", + ) + + // Extract text content from the message + let textContent = "" + if (typeof text.content === "string") { + textContent = text.content + } else if (Array.isArray(text.content)) { + // Handle array of content parts + for (const part of text.content) { + if (part instanceof vscode.LanguageModelTextPart) { + textContent += part.value + } else if (typeof part === "string") { + textContent += part + } + } + } + + // Recalculate tokens using the extracted text + if (textContent) { + tokenCount = await this.client.countTokens(textContent, this.currentRequestCancellation.token) + console.debug(`Roo Code : Recalculated token count: ${tokenCount}`) + } + } } else { console.warn("Roo Code : Invalid input type for token counting") return 0