Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: 调整模型列表,将自定义模型放在前面显示 #5180

Merged
merged 4 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ export interface LLMModel {
displayName?: string;
available: boolean;
provider: LLMModelProvider;
sorted: number;
}

export interface LLMModelProvider {
id: string;
providerName: string;
providerType: string;
sorted: number;
}

export abstract class LLMApi {
Expand Down
4 changes: 4 additions & 0 deletions app/client/platforms/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -411,13 +411,17 @@ export class ChatGPTApi implements LLMApi {
return [];
}

//由于目前 OpenAI 的 disableListModels 默认为 true,所以当前实际不会运行到这场
let seq = 1000; //同 Constant.ts 中的排序保持一致
return chatModels.map((m) => ({
name: m.id,
available: true,
sorted: seq++,
provider: {
id: "openai",
providerName: "OpenAI",
providerType: "openai",
sorted: 1,
},
}));
}
Expand Down
19 changes: 19 additions & 0 deletions app/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -320,86 +320,105 @@ const tencentModels = [

const moonshotModes = ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"];

let seq = 1000; // 内置的模型序号生成器从1000开始
export const DEFAULT_MODELS = [
...openaiModels.map((name) => ({
name,
available: true,
sorted: seq++, // Global sequence sort(index)
provider: {
id: "openai",
providerName: "OpenAI",
providerType: "openai",
sorted: 1, // 这里是固定的,确保顺序与之前内置的版本一致
},
})),
...openaiModels.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "azure",
providerName: "Azure",
providerType: "azure",
sorted: 2,
},
})),
...googleModels.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "google",
providerName: "Google",
providerType: "google",
sorted: 3,
},
})),
...anthropicModels.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "anthropic",
providerName: "Anthropic",
providerType: "anthropic",
sorted: 4,
},
})),
...baiduModels.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "baidu",
providerName: "Baidu",
providerType: "baidu",
sorted: 5,
},
})),
...bytedanceModels.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "bytedance",
providerName: "ByteDance",
providerType: "bytedance",
sorted: 6,
},
})),
...alibabaModes.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "alibaba",
providerName: "Alibaba",
providerType: "alibaba",
sorted: 7,
},
})),
...tencentModels.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "tencent",
providerName: "Tencent",
providerType: "tencent",
sorted: 8,
},
})),
...moonshotModes.map((name) => ({
name,
available: true,
sorted: seq++,
provider: {
id: "moonshot",
providerName: "Moonshot",
providerType: "moonshot",
sorted: 9,
},
})),
] as const;
Expand Down
48 changes: 44 additions & 4 deletions app/utils/model.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,42 @@
import { DEFAULT_MODELS } from "../constant";
import { LLMModel } from "../client/api";

const CustomSeq = {
val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts
cache: new Map<string, number>(),
next: (id: string) => {
if (CustomSeq.cache.has(id)) {
return CustomSeq.cache.get(id) as number;
} else {
let seq = CustomSeq.val++;
CustomSeq.cache.set(id, seq);
return seq;
}
},
};

const customProvider = (providerName: string) => ({
id: providerName.toLowerCase(),
providerName: providerName,
providerType: "custom",
sorted: CustomSeq.next(providerName),
});

/**
* Sorts an array of models based on specified rules.
*
* First, sorted by provider; if the same, sorted by model
*/
const sortModelTable = (models: ReturnType<typeof collectModels>) =>
models.sort((a, b) => {
if (a.provider && b.provider) {
let cmp = a.provider.sorted - b.provider.sorted;
return cmp === 0 ? a.sorted - b.sorted : cmp;
} else {
return a.sorted - b.sorted;
}
});

export function collectModelTable(
models: readonly LLMModel[],
frostime marked this conversation as resolved.
Show resolved Hide resolved
customModels: string,
Expand All @@ -17,6 +47,7 @@ export function collectModelTable(
available: boolean;
name: string;
displayName: string;
sorted: number;
provider?: LLMModel["provider"]; // Marked as optional
isDefault?: boolean;
}
Expand Down Expand Up @@ -84,6 +115,7 @@ export function collectModelTable(
displayName: displayName || customModelName,
available,
provider, // Use optional chaining
sorted: CustomSeq.next(`${customModelName}@${provider?.id}`),
};
}
}
Expand All @@ -99,13 +131,16 @@ export function collectModelTableWithDefaultModel(
) {
let modelTable = collectModelTable(models, customModels);
if (defaultModel && defaultModel !== "") {
if (defaultModel.includes('@')) {
if (defaultModel.includes("@")) {
if (defaultModel in modelTable) {
modelTable[defaultModel].isDefault = true;
}
} else {
for (const key of Object.keys(modelTable)) {
if (modelTable[key].available && key.split('@').shift() == defaultModel) {
if (
modelTable[key].available &&
key.split("@").shift() == defaultModel
) {
modelTable[key].isDefault = true;
break;
}
Expand All @@ -123,7 +158,9 @@ export function collectModels(
customModels: string,
) {
const modelTable = collectModelTable(models, customModels);
const allModels = Object.values(modelTable);
let allModels = Object.values(modelTable);

allModels = sortModelTable(allModels);

return allModels;
}
Expand All @@ -138,7 +175,10 @@ export function collectModelsWithDefaultModel(
customModels,
defaultModel,
);
const allModels = Object.values(modelTable);
let allModels = Object.values(modelTable);

allModels = sortModelTable(allModels);

return allModels;
}

Expand Down
Loading