Skip to content

Commit

Permalink
Add ModelCapability enum and capability detection, fix text-only mode…
Browse files Browse the repository at this point in the history
…ls in plus mode
  • Loading branch information
logancyang committed Feb 12, 2025
1 parent 4feedf5 commit 784aa0b
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 3 deletions.
22 changes: 20 additions & 2 deletions src/LLMProviders/chainRunner.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { ModelCapability } from "@/aiParams";
import { getStandaloneQuestion } from "@/chainUtils";
import {
ABORT_REASON,
Expand All @@ -22,6 +23,7 @@ import {
ImageProcessor,
MessageContent,
} from "@/utils";
import { BaseChatModel } from "@langchain/core/language_models/chat_models";
import { Notice } from "obsidian";
import ChainManager from "./chainManager";
import { COPILOT_TOOL_NAMES, IntentAnalyzer } from "./intentAnalyzer";
Expand Down Expand Up @@ -334,6 +336,16 @@ class CopilotPlusChainRunner extends BaseChainRunner {
return content;
}

private hasCapability(model: BaseChatModel, capability: ModelCapability): boolean {
const modelName = (model as any).modelName || (model as any).model || "";
const customModel = this.chainManager.chatModelManager.findModelByName(modelName);
return customModel?.capabilities?.includes(capability) ?? false;
}

private isMultimodalModel(model: BaseChatModel): boolean {
return this.hasCapability(model, ModelCapability.VISION);
}

private async streamMultimodalResponse(
textContent: string,
userMessage: ChatMessage,
Expand Down Expand Up @@ -375,8 +387,14 @@ class CopilotPlusChainRunner extends BaseChainRunner {
messages.push({ role: "assistant", content: ai });
}

// Build message content with text and images
const content = await this.buildMessageContent(textContent, userMessage);
// Get the current chat model
const chatModelCurrent = this.chainManager.chatModelManager.getChatModel();
const isMultimodalCurrent = this.isMultimodalModel(chatModelCurrent);

// Build message content with text and images for multimodal models, or just text for text-only models
const content = isMultimodalCurrent
? await this.buildMessageContent(textContent, userMessage)
: textContent;

// Add current user message
messages.push({
Expand Down
5 changes: 5 additions & 0 deletions src/LLMProviders/chatModelManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -373,4 +373,9 @@ export default class ChatModelManager {
}
}
}

findModelByName(modelName: string): CustomModel | undefined {
const settings = getSettings();
return settings.activeModels.find((model) => model.name === modelName);
}
}
7 changes: 7 additions & 0 deletions src/aiParams.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ export interface SetChainOptions {
refreshIndex?: boolean;
}

export enum ModelCapability {
REASONING = "reasoning",
VISION = "vision",
WEB_SEARCH = "websearch",
}

export interface CustomModel {
name: string;
provider: string;
Expand All @@ -79,6 +85,7 @@ export interface CustomModel {
maxTokens?: number;
context?: number;
believerExclusive?: boolean;
capabilities?: ModelCapability[];
// Embedding models only (Jina at the moment)
dimensions?: number;
// OpenAI specific fields
Expand Down
9 changes: 8 additions & 1 deletion src/constants.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { CustomModel } from "@/aiParams";
import { CustomModel, ModelCapability } from "@/aiParams";
import { type CopilotSettings } from "@/settings/model";
import { ChainType } from "./chainFactory";

Expand Down Expand Up @@ -78,13 +78,15 @@ export const BUILTIN_CHAT_MODELS: CustomModel[] = [
enabled: true,
isBuiltIn: true,
core: true,
capabilities: [ModelCapability.VISION],
},
{
name: ChatModels.GPT_4o,
provider: ChatModelProviders.OPENAI,
enabled: true,
isBuiltIn: true,
core: true,
capabilities: [ModelCapability.VISION],
},
{
name: ChatModels.GPT_4o_mini,
Expand All @@ -98,19 +100,22 @@ export const BUILTIN_CHAT_MODELS: CustomModel[] = [
provider: ChatModelProviders.OPENAI,
enabled: true,
isBuiltIn: true,
capabilities: [ModelCapability.REASONING],
},
{
name: ChatModels.O3_mini,
provider: ChatModelProviders.OPENAI,
enabled: true,
isBuiltIn: true,
capabilities: [ModelCapability.REASONING],
},
{
name: ChatModels.CLAUDE_3_5_SONNET,
provider: ChatModelProviders.ANTHROPIC,
enabled: true,
isBuiltIn: true,
core: true,
capabilities: [ModelCapability.VISION],
},
{
name: ChatModels.CLAUDE_3_5_HAIKU,
Expand All @@ -135,12 +140,14 @@ export const BUILTIN_CHAT_MODELS: CustomModel[] = [
provider: ChatModelProviders.GOOGLE,
enabled: true,
isBuiltIn: true,
capabilities: [ModelCapability.VISION],
},
{
name: ChatModels.GEMINI_FLASH,
provider: ChatModelProviders.GOOGLE,
enabled: true,
isBuiltIn: true,
capabilities: [ModelCapability.VISION],
},
{
name: ChatModels.AZURE_OPENAI,
Expand Down

0 comments on commit 784aa0b

Please sign in to comment.