Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/api/providers/__tests__/vscode-lm.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,75 @@ describe("VsCodeLmHandler", () => {
await expect(promise).rejects.toThrow("VSCode LM completion error: Completion failed")
})
})

describe("countTokens", () => {
beforeEach(() => {
const mockModel = { ...mockLanguageModelChat }
;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel])

// Override the default client with our test client
handler["client"] = mockLanguageModelChat
// Set up cancellation token
handler["currentRequestCancellation"] = new vscode.CancellationTokenSource()
})

it("should count tokens for string input", async () => {
mockLanguageModelChat.countTokens.mockResolvedValue(10)

const result = await handler.countTokens([{ type: "text", text: "Hello world" }])

expect(result).toBe(10)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Hello world", expect.any(Object))
})

it("should handle special case when LanguageModelChatMessage returns tokenCount of 4", async () => {
// First call returns 4 (triggering the special case)
// Second call returns the actual count after string conversion
mockLanguageModelChat.countTokens.mockResolvedValueOnce(4).mockResolvedValueOnce(25)

// Use the mocked vscode.LanguageModelChatMessage.User to create a proper message
const mockMessage = vscode.LanguageModelChatMessage.User("This is a test message")

const result = await handler["internalCountTokens"](mockMessage)

expect(result).toBe(25)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledTimes(2)
// First call with the message object
expect(mockLanguageModelChat.countTokens).toHaveBeenNthCalledWith(1, mockMessage, expect.any(Object))
// Second call with the extracted string
expect(mockLanguageModelChat.countTokens).toHaveBeenNthCalledWith(
2,
"This is a test message",
expect.any(Object),
)
})

it("should not recalculate when tokenCount is not 4", async () => {
mockLanguageModelChat.countTokens.mockResolvedValue(10)

// Use the mocked vscode.LanguageModelChatMessage.User to create a proper message
const mockMessage = vscode.LanguageModelChatMessage.User("This is a test message")

const result = await handler["internalCountTokens"](mockMessage)

expect(result).toBe(10)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledTimes(1)
})

it("should handle image blocks", async () => {
// The countTokens method converts to string, so it won't trigger the special case
mockLanguageModelChat.countTokens.mockResolvedValue(7)

const result = await handler.countTokens([
{ type: "text", text: "Hello" },
{ type: "image", source: { type: "base64", media_type: "image/png", data: "base64data" } },
{ type: "text", text: " world" },
])

expect(result).toBe(7)
// Should only be called once since it's a string, not a LanguageModelChatMessage
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledTimes(1)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Hello[IMAGE] world", expect.any(Object))
})
})
})
29 changes: 28 additions & 1 deletion src/api/providers/vscode-lm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,40 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan

if (typeof text === "string") {
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
} else if (text instanceof vscode.LanguageModelChatMessage) {
} else if (text && typeof text === "object" && "content" in text) {
// Handle LanguageModelChatMessage-like objects
// For chat messages, ensure we have content
if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) {
console.debug("Roo Code <Language Model API>: Empty chat message content")
return 0
}
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)

// Special handling: if tokenCount is exactly 4 for a LanguageModelChatMessage,
// convert to string and recalculate
if (tokenCount === 4) {
console.debug(
"Roo Code <Language Model API>: Token count is 4 for LanguageModelChatMessage, converting to string and recalculating",
)

// Convert message content to string
let messageText = ""
if (Array.isArray(text.content)) {
for (const part of text.content) {
if (part && typeof part === "object" && "value" in part && typeof part.value === "string") {
messageText += part.value
}
}
} else if (typeof text.content === "string") {
messageText = text.content
}

// Recalculate using string
if (messageText) {
tokenCount = await this.client.countTokens(messageText, this.currentRequestCancellation.token)
console.debug(`Roo Code <Language Model API>: Recalculated token count: ${tokenCount}`)
}
}
} else {
console.warn("Roo Code <Language Model API>: Invalid input type for token counting")
return 0
Expand Down