From 6433d87f7289ff4d81d9fa27201b91d7a01b6afe Mon Sep 17 00:00:00 2001 From: Roo Code Date: Mon, 28 Jul 2025 12:28:40 +0000 Subject: [PATCH] fix: handle special case when VS Code LM API returns tokenCount of 4 - Add special handling in internalCountTokens method for LanguageModelChatMessage - When tokenCount equals 4, convert message to string and recalculate - Add comprehensive test coverage for the special case - Fixes #6290 --- src/api/providers/__tests__/vscode-lm.spec.ts | 71 +++++++++++++++++++ src/api/providers/vscode-lm.ts | 29 +++++++- 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/src/api/providers/__tests__/vscode-lm.spec.ts b/src/api/providers/__tests__/vscode-lm.spec.ts index afb349e5e09..f42dc8ca30d 100644 --- a/src/api/providers/__tests__/vscode-lm.spec.ts +++ b/src/api/providers/__tests__/vscode-lm.spec.ts @@ -300,4 +300,75 @@ describe("VsCodeLmHandler", () => { await expect(promise).rejects.toThrow("VSCode LM completion error: Completion failed") }) }) + + describe("countTokens", () => { + beforeEach(() => { + const mockModel = { ...mockLanguageModelChat } + ;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel]) + + // Override the default client with our test client + handler["client"] = mockLanguageModelChat + // Set up cancellation token + handler["currentRequestCancellation"] = new vscode.CancellationTokenSource() + }) + + it("should count tokens for string input", async () => { + mockLanguageModelChat.countTokens.mockResolvedValue(10) + + const result = await handler.countTokens([{ type: "text", text: "Hello world" }]) + + expect(result).toBe(10) + expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Hello world", expect.any(Object)) + }) + + it("should handle special case when LanguageModelChatMessage returns tokenCount of 4", async () => { + // First call returns 4 (triggering the special case) + // Second call returns the actual count after string conversion + mockLanguageModelChat.countTokens.mockResolvedValueOnce(4).mockResolvedValueOnce(25) + + // Use the mocked vscode.LanguageModelChatMessage.User to create a proper message + const mockMessage = vscode.LanguageModelChatMessage.User("This is a test message") + + const result = await handler["internalCountTokens"](mockMessage) + + expect(result).toBe(25) + expect(mockLanguageModelChat.countTokens).toHaveBeenCalledTimes(2) + // First call with the message object + expect(mockLanguageModelChat.countTokens).toHaveBeenNthCalledWith(1, mockMessage, expect.any(Object)) + // Second call with the extracted string + expect(mockLanguageModelChat.countTokens).toHaveBeenNthCalledWith( + 2, + "This is a test message", + expect.any(Object), + ) + }) + + it("should not recalculate when tokenCount is not 4", async () => { + mockLanguageModelChat.countTokens.mockResolvedValue(10) + + // Use the mocked vscode.LanguageModelChatMessage.User to create a proper message + const mockMessage = vscode.LanguageModelChatMessage.User("This is a test message") + + const result = await handler["internalCountTokens"](mockMessage) + + expect(result).toBe(10) + expect(mockLanguageModelChat.countTokens).toHaveBeenCalledTimes(1) + }) + + it("should handle image blocks", async () => { + // The countTokens method converts to string, so it won't trigger the special case + mockLanguageModelChat.countTokens.mockResolvedValue(7) + + const result = await handler.countTokens([ + { type: "text", text: "Hello" }, + { type: "image", source: { type: "base64", media_type: "image/png", data: "base64data" } }, + { type: "text", text: " world" }, + ]) + + expect(result).toBe(7) + // Should only be called once since it's a string, not a LanguageModelChatMessage + expect(mockLanguageModelChat.countTokens).toHaveBeenCalledTimes(1) + expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Hello[IMAGE] world", expect.any(Object)) + }) + }) }) diff --git a/src/api/providers/vscode-lm.ts b/src/api/providers/vscode-lm.ts index 6474371beeb..6d7edf6bef6 100644 --- a/src/api/providers/vscode-lm.ts +++ b/src/api/providers/vscode-lm.ts @@ -225,13 +225,40 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan if (typeof text === "string") { tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token) - } else if (text instanceof vscode.LanguageModelChatMessage) { + } else if (text && typeof text === "object" && "content" in text) { + // Handle LanguageModelChatMessage-like objects // For chat messages, ensure we have content if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) { console.debug("Roo Code : Empty chat message content") return 0 } tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token) + + // Special handling: if tokenCount is exactly 4 for a LanguageModelChatMessage, + // convert to string and recalculate + if (tokenCount === 4) { + console.debug( + "Roo Code : Token count is 4 for LanguageModelChatMessage, converting to string and recalculating", + ) + + // Convert message content to string + let messageText = "" + if (Array.isArray(text.content)) { + for (const part of text.content) { + if (part && typeof part === "object" && "value" in part && typeof part.value === "string") { + messageText += part.value + } + } + } else if (typeof text.content === "string") { + messageText = text.content + } + + // Recalculate using string + if (messageText) { + tokenCount = await this.client.countTokens(messageText, this.currentRequestCancellation.token) + console.debug(`Roo Code : Recalculated token count: ${tokenCount}`) + } + } } else { console.warn("Roo Code : Invalid input type for token counting") return 0