diff --git a/src/api/providers/__tests__/openai.spec.ts b/src/api/providers/__tests__/openai.spec.ts index b4b5f29204d..0d42c082a91 100644 --- a/src/api/providers/__tests__/openai.spec.ts +++ b/src/api/providers/__tests__/openai.spec.ts @@ -1,11 +1,12 @@ // npx vitest run api/providers/__tests__/openai.spec.ts -import { OpenAiHandler } from "../openai" +import { OpenAiHandler, getOpenAiModels } from "../openai" import { ApiHandlerOptions } from "../../../shared/api" import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" import { openAiModelInfoSaneDefaults } from "@roo-code/types" import { Package } from "../../../shared/package" +import axios from "axios" const mockCreate = vitest.fn() @@ -68,6 +69,13 @@ vitest.mock("openai", () => { } }) +// Mock axios for getOpenAiModels tests +vitest.mock("axios", () => ({ + default: { + get: vitest.fn(), + }, +})) + describe("OpenAiHandler", () => { let handler: OpenAiHandler let mockOptions: ApiHandlerOptions @@ -776,3 +784,143 @@ describe("OpenAiHandler", () => { }) }) }) + +describe("getOpenAiModels", () => { + beforeEach(() => { + vi.mocked(axios.get).mockClear() + }) + + it("should return empty array when baseUrl is not provided", async () => { + const result = await getOpenAiModels(undefined, "test-key") + expect(result).toEqual([]) + expect(axios.get).not.toHaveBeenCalled() + }) + + it("should return empty array when baseUrl is empty string", async () => { + const result = await getOpenAiModels("", "test-key") + expect(result).toEqual([]) + expect(axios.get).not.toHaveBeenCalled() + }) + + it("should trim whitespace from baseUrl", async () => { + const mockResponse = { + data: { + data: [{ id: "gpt-4" }, { id: "gpt-3.5-turbo" }], + }, + } + vi.mocked(axios.get).mockResolvedValueOnce(mockResponse) + + const result = await getOpenAiModels(" https://api.openai.com/v1 ", "test-key") + + expect(axios.get).toHaveBeenCalledWith("https://api.openai.com/v1/models", expect.any(Object)) + expect(result).toEqual(["gpt-4", "gpt-3.5-turbo"]) + }) + + it("should handle baseUrl with trailing spaces", async () => { + const mockResponse = { + data: { + data: [{ id: "model-1" }, { id: "model-2" }], + }, + } + vi.mocked(axios.get).mockResolvedValueOnce(mockResponse) + + const result = await getOpenAiModels("https://api.example.com/v1 ", "test-key") + + expect(axios.get).toHaveBeenCalledWith("https://api.example.com/v1/models", expect.any(Object)) + expect(result).toEqual(["model-1", "model-2"]) + }) + + it("should handle baseUrl with leading spaces", async () => { + const mockResponse = { + data: { + data: [{ id: "model-1" }], + }, + } + vi.mocked(axios.get).mockResolvedValueOnce(mockResponse) + + const result = await getOpenAiModels(" https://api.example.com/v1", "test-key") + + expect(axios.get).toHaveBeenCalledWith("https://api.example.com/v1/models", expect.any(Object)) + expect(result).toEqual(["model-1"]) + }) + + it("should return empty array for invalid URL after trimming", async () => { + const result = await getOpenAiModels(" not-a-valid-url ", "test-key") + expect(result).toEqual([]) + expect(axios.get).not.toHaveBeenCalled() + }) + + it("should include authorization header when apiKey is provided", async () => { + const mockResponse = { + data: { + data: [{ id: "model-1" }], + }, + } + vi.mocked(axios.get).mockResolvedValueOnce(mockResponse) + + await getOpenAiModels("https://api.example.com/v1", "test-api-key") + + expect(axios.get).toHaveBeenCalledWith( + "https://api.example.com/v1/models", + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: "Bearer test-api-key", + }), + }), + ) + }) + + it("should include custom headers when provided", async () => { + const mockResponse = { + data: { + data: [{ id: "model-1" }], + }, + } + vi.mocked(axios.get).mockResolvedValueOnce(mockResponse) + + const customHeaders = { + "X-Custom-Header": "custom-value", + } + + await getOpenAiModels("https://api.example.com/v1", "test-key", customHeaders) + + expect(axios.get).toHaveBeenCalledWith( + "https://api.example.com/v1/models", + expect.objectContaining({ + headers: expect.objectContaining({ + "X-Custom-Header": "custom-value", + Authorization: "Bearer test-key", + }), + }), + ) + }) + + it("should handle API errors gracefully", async () => { + vi.mocked(axios.get).mockRejectedValueOnce(new Error("Network error")) + + const result = await getOpenAiModels("https://api.example.com/v1", "test-key") + + expect(result).toEqual([]) + }) + + it("should handle malformed response data", async () => { + vi.mocked(axios.get).mockResolvedValueOnce({ data: null }) + + const result = await getOpenAiModels("https://api.example.com/v1", "test-key") + + expect(result).toEqual([]) + }) + + it("should deduplicate model IDs", async () => { + const mockResponse = { + data: { + data: [{ id: "gpt-4" }, { id: "gpt-4" }, { id: "gpt-3.5-turbo" }, { id: "gpt-4" }], + }, + } + vi.mocked(axios.get).mockResolvedValueOnce(mockResponse) + + const result = await getOpenAiModels("https://api.example.com/v1", "test-key") + + expect(result).toEqual(["gpt-4", "gpt-3.5-turbo"]) + }) +}) diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index f5e4e4c985e..85abcf1a690 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -416,7 +416,10 @@ export async function getOpenAiModels(baseUrl?: string, apiKey?: string, openAiH return [] } - if (!URL.canParse(baseUrl)) { + // Trim whitespace from baseUrl to handle cases where users accidentally include spaces + const trimmedBaseUrl = baseUrl.trim() + + if (!URL.canParse(trimmedBaseUrl)) { return [] } @@ -434,7 +437,7 @@ export async function getOpenAiModels(baseUrl?: string, apiKey?: string, openAiH config["headers"] = headers } - const response = await axios.get(`${baseUrl}/models`, config) + const response = await axios.get(`${trimmedBaseUrl}/models`, config) const modelsArray = response.data?.data?.map((model: any) => model.id) || [] return [...new Set(modelsArray)] } catch (error) {