Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import { RequestyHandler } from "./providers/requesty"
import { HumanRelayHandler } from "./providers/human-relay"
import { FakeAIHandler } from "./providers/fake-ai"
import { XAIHandler } from "./providers/xai"
import { GroqHandler } from "./providers/groq"
import { ChutesHandler } from "./providers/chutes"

export interface SingleCompletionHandler {
completePrompt(prompt: string): Promise<string>
Expand Down Expand Up @@ -88,6 +90,10 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
return new FakeAIHandler(options)
case "xai":
return new XAIHandler(options)
case "groq":
return new GroqHandler(options)
case "chutes":
return new ChutesHandler(options)
default:
return new AnthropicHandler(options)
}
Expand Down
142 changes: 142 additions & 0 deletions src/api/providers/__tests__/chutes.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// npx jest src/api/providers/__tests__/chutes.test.ts

import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"

import { ChutesModelId, chutesDefaultModelId, chutesModels } from "../../../shared/api"

import { ChutesHandler } from "../chutes"

jest.mock("openai", () => {
const createMock = jest.fn()
return jest.fn(() => ({ chat: { completions: { create: createMock } } }))
})

describe("ChutesHandler", () => {
let handler: ChutesHandler
let mockCreate: jest.Mock

beforeEach(() => {
jest.clearAllMocks()
mockCreate = (OpenAI as unknown as jest.Mock)().chat.completions.create
handler = new ChutesHandler({})
})

test("should use the correct Chutes base URL", () => {
new ChutesHandler({})
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://llm.chutes.ai/v1" }))
})

test("should use the provided API key", () => {
const chutesApiKey = "test-chutes-api-key"
new ChutesHandler({ chutesApiKey })
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: chutesApiKey }))
})

test("should return default model when no model is specified", () => {
const model = handler.getModel()
expect(model.id).toBe(chutesDefaultModelId)
expect(model.info).toEqual(chutesModels[chutesDefaultModelId])
})

test("should return specified model when valid model is provided", () => {
const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
const handlerWithModel = new ChutesHandler({ apiModelId: testModelId })
const model = handlerWithModel.getModel()

expect(model.id).toBe(testModelId)
expect(model.info).toEqual(chutesModels[testModelId])
})

test("completePrompt method should return text from Chutes API", async () => {
const expectedResponse = "This is a test response from Chutes"
mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] })
const result = await handler.completePrompt("test prompt")
expect(result).toBe(expectedResponse)
})

test("should handle errors in completePrompt", async () => {
const errorMessage = "Chutes API error"
mockCreate.mockRejectedValueOnce(new Error(errorMessage))
await expect(handler.completePrompt("test prompt")).rejects.toThrow(`Chutes completion error: ${errorMessage}`)
})

test("createMessage should yield text content from stream", async () => {
const testContent = "This is test content from Chutes stream"

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: jest
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: testContent } }] },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "text", text: testContent })
})

test("createMessage should yield usage data from stream", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: jest
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 20 } },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
})

test("createMessage should pass correct parameters to Chutes client", async () => {
const modelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
const modelInfo = chutesModels[modelId]
const handlerWithModel = new ChutesHandler({ apiModelId: modelId })

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const systemPrompt = "Test system prompt for Chutes"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Chutes" }]

const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: modelId,
max_tokens: modelInfo.maxTokens,
temperature: 0.5,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
stream: true,
stream_options: { include_usage: true },
}),
)
})
})
142 changes: 142 additions & 0 deletions src/api/providers/__tests__/groq.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// npx jest src/api/providers/__tests__/groq.test.ts

import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"

import { GroqModelId, groqDefaultModelId, groqModels } from "../../../shared/api"

import { GroqHandler } from "../groq"

jest.mock("openai", () => {
const createMock = jest.fn()
return jest.fn(() => ({ chat: { completions: { create: createMock } } }))
})

describe("GroqHandler", () => {
let handler: GroqHandler
let mockCreate: jest.Mock

beforeEach(() => {
jest.clearAllMocks()
mockCreate = (OpenAI as unknown as jest.Mock)().chat.completions.create
handler = new GroqHandler({})
})

test("should use the correct Groq base URL", () => {
new GroqHandler({})
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.groq.com/openai/v1" }))
})

test("should use the provided API key", () => {
const groqApiKey = "test-groq-api-key"
new GroqHandler({ groqApiKey })
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: groqApiKey }))
})

test("should return default model when no model is specified", () => {
const model = handler.getModel()
expect(model.id).toBe(groqDefaultModelId) // Use groqDefaultModelId
expect(model.info).toEqual(groqModels[groqDefaultModelId]) // Use groqModels
})

test("should return specified model when valid model is provided", () => {
const testModelId: GroqModelId = "llama-3.3-70b-versatile" // Use a valid Groq model ID and type
const handlerWithModel = new GroqHandler({ apiModelId: testModelId }) // Instantiate GroqHandler
const model = handlerWithModel.getModel()

expect(model.id).toBe(testModelId)
expect(model.info).toEqual(groqModels[testModelId]) // Use groqModels
})

test("completePrompt method should return text from Groq API", async () => {
const expectedResponse = "This is a test response from Groq"
mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] })
const result = await handler.completePrompt("test prompt")
expect(result).toBe(expectedResponse)
})

test("should handle errors in completePrompt", async () => {
const errorMessage = "Groq API error"
mockCreate.mockRejectedValueOnce(new Error(errorMessage))
await expect(handler.completePrompt("test prompt")).rejects.toThrow(`Groq completion error: ${errorMessage}`)
})

test("createMessage should yield text content from stream", async () => {
const testContent = "This is test content from Groq stream"

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: jest
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: testContent } }] },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "text", text: testContent })
})

test("createMessage should yield usage data from stream", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: jest
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 20 } },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
})

test("createMessage should pass correct parameters to Groq client", async () => {
const modelId: GroqModelId = "llama-3.1-8b-instant"
const modelInfo = groqModels[modelId]
const handlerWithModel = new GroqHandler({ apiModelId: modelId })

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const systemPrompt = "Test system prompt for Groq"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Groq" }]

const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: modelId,
max_tokens: modelInfo.maxTokens,
temperature: 0.5,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
stream: true,
stream_options: { include_usage: true },
}),
)
})
})
Loading
Loading