Skip to content

Commit bd9c07a

Browse files
authored
Provider registry (#50)
* Change the AIProvider to a registry, to enable extensions to register new providers * Remove 'llm-models/utils.ts' module, replaced by the provider registry * Make instructions usable from the provider registry * Use the provider registry in the settings panel * lint * Variable renaming for consistency and docstring * Rename the token module 'tokens.ts' * Prevent removing provider from the registry * Fix the settings generator script
1 parent c7d428d commit bd9c07a

12 files changed

+326
-174
lines changed

schema/ai-provider.json schema/provider-registry.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"title": "AI provider",
3-
"description": "Provider settings",
3+
"description": "Provider registry settings",
44
"jupyter.lab.setting-icon": "@jupyterlite/ai:jupyternaut-lite",
55
"jupyter.lab.setting-icon-label": "JupyterLite AI Chat",
66
"type": "object",

scripts/settings-generator.js

+3-3
Original file line numberDiff line numberDiff line change
@@ -140,19 +140,19 @@ Object.entries(providers).forEach(([name, desc], index) => {
140140
});
141141

142142
// Build the index.ts file
143-
const indexContent = [];
143+
const indexContent = ["import { IDict } from '../../tokens';", ''];
144144
Object.keys(providers).forEach(name => {
145145
indexContent.push(`import ${name} from './_generated/${name}.json';`);
146146
});
147147

148-
indexContent.push('', 'const ProviderSettings: { [name: string]: any } = {');
148+
indexContent.push('', 'const ProviderSettings: IDict<any> = {');
149149

150150
Object.keys(providers).forEach((name, index) => {
151151
indexContent.push(
152152
` ${name}` + (index < Object.keys(providers).length - 1 ? ',' : '')
153153
);
154154
});
155-
indexContent.push('};', '', 'export default ProviderSettings;', '');
155+
indexContent.push('};', '', 'export { ProviderSettings };', '');
156156
fs.writeFile(
157157
path.join(schemasDir, 'index.ts'),
158158
indexContent.join('\n'),

src/chat-handler.ts

+16-13
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ import {
1717
SystemMessage
1818
} from '@langchain/core/messages';
1919
import { UUID } from '@lumino/coreutils';
20-
import { getErrorMessage } from './llm-models';
2120
import { chatSystemPrompt } from './provider';
22-
import { IAIProvider } from './token';
21+
import { IAIProviderRegistry } from './tokens';
2322
import { jupyternautLiteIcon } from './icons';
2423

2524
/**
@@ -37,17 +36,21 @@ export type ConnectionMessage = {
3736
export class ChatHandler extends ChatModel {
3837
constructor(options: ChatHandler.IOptions) {
3938
super(options);
40-
this._aiProvider = options.aiProvider;
41-
this._prompt = chatSystemPrompt({ provider_name: this._aiProvider.name });
39+
this._providerRegistry = options.providerRegistry;
40+
this._prompt = chatSystemPrompt({
41+
provider_name: this._providerRegistry.currentName
42+
});
4243

43-
this._aiProvider.providerChanged.connect(() => {
44-
this._errorMessage = this._aiProvider.chatError;
45-
this._prompt = chatSystemPrompt({ provider_name: this._aiProvider.name });
44+
this._providerRegistry.providerChanged.connect(() => {
45+
this._errorMessage = this._providerRegistry.chatError;
46+
this._prompt = chatSystemPrompt({
47+
provider_name: this._providerRegistry.currentName
48+
});
4649
});
4750
}
4851

4952
get provider(): BaseChatModel | null {
50-
return this._aiProvider.chatModel;
53+
return this._providerRegistry.currentChatModel;
5154
}
5255

5356
/**
@@ -95,7 +98,7 @@ export class ChatHandler extends ChatModel {
9598
};
9699
this.messageAdded(msg);
97100

98-
if (this._aiProvider.chatModel === null) {
101+
if (this._providerRegistry.currentChatModel === null) {
99102
const errorMsg: IChatMessage = {
100103
id: UUID.uuid4(),
101104
body: `**${this._errorMessage ? this._errorMessage : this._defaultErrorMessage}**`,
@@ -134,7 +137,7 @@ export class ChatHandler extends ChatModel {
134137
let content = '';
135138

136139
try {
137-
for await (const chunk of await this._aiProvider.chatModel.stream(
140+
for await (const chunk of await this._providerRegistry.currentChatModel.stream(
138141
messages
139142
)) {
140143
content += chunk.content ?? chunk;
@@ -144,7 +147,7 @@ export class ChatHandler extends ChatModel {
144147
this._history.messages.push(botMsg);
145148
return true;
146149
} catch (reason) {
147-
const error = getErrorMessage(this._aiProvider.name, reason);
150+
const error = this._providerRegistry.formatErrorMessage(reason);
148151
const errorMsg: IChatMessage = {
149152
id: UUID.uuid4(),
150153
body: `**${error}**`,
@@ -171,7 +174,7 @@ export class ChatHandler extends ChatModel {
171174
super.messageAdded(message);
172175
}
173176

174-
private _aiProvider: IAIProvider;
177+
private _providerRegistry: IAIProviderRegistry;
175178
private _personaName = 'AI';
176179
private _prompt: string;
177180
private _errorMessage: string = '';
@@ -181,6 +184,6 @@ export class ChatHandler extends ChatModel {
181184

182185
export namespace ChatHandler {
183186
export interface IOptions extends ChatModel.IOptions {
184-
aiProvider: IAIProvider;
187+
providerRegistry: IAIProviderRegistry;
185188
}
186189
}

src/completion-provider.ts

+7-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import {
55
} from '@jupyterlab/completer';
66

77
import { IBaseCompleter } from './llm-models';
8-
import { IAIProvider } from './token';
8+
import { IAIProviderRegistry } from './tokens';
99

1010
/**
1111
* The generic completion provider to register to the completion provider manager.
@@ -14,10 +14,10 @@ export class CompletionProvider implements IInlineCompletionProvider {
1414
readonly identifier = '@jupyterlite/ai';
1515

1616
constructor(options: CompletionProvider.IOptions) {
17-
this._aiProvider = options.aiProvider;
17+
this._providerRegistry = options.providerRegistry;
1818
this._requestCompletion = options.requestCompletion;
1919

20-
this._aiProvider.providerChanged.connect(() => {
20+
this._providerRegistry.providerChanged.connect(() => {
2121
if (this.completer) {
2222
this.completer.requestCompletion = this._requestCompletion;
2323
}
@@ -28,14 +28,14 @@ export class CompletionProvider implements IInlineCompletionProvider {
2828
* Get the current completer name.
2929
*/
3030
get name(): string {
31-
return this._aiProvider.name;
31+
return this._providerRegistry.currentName;
3232
}
3333

3434
/**
3535
* Get the current completer.
3636
*/
3737
get completer(): IBaseCompleter | null {
38-
return this._aiProvider.completer;
38+
return this._providerRegistry.currentCompleter;
3939
}
4040

4141
async fetch(
@@ -45,13 +45,13 @@ export class CompletionProvider implements IInlineCompletionProvider {
4545
return this.completer?.fetch(request, context);
4646
}
4747

48-
private _aiProvider: IAIProvider;
48+
private _providerRegistry: IAIProviderRegistry;
4949
private _requestCompletion: () => void;
5050
}
5151

5252
export namespace CompletionProvider {
5353
export interface IOptions {
54-
aiProvider: IAIProvider;
54+
providerRegistry: IAIProviderRegistry;
5555
requestCompletion: () => void;
5656
}
5757
}

src/index.ts

+25-21
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
2121

2222
import { ChatHandler } from './chat-handler';
2323
import { CompletionProvider } from './completion-provider';
24-
import { AIProvider } from './provider';
24+
import { AIProviders } from './llm-models';
25+
import { AIProviderRegistry } from './provider';
2526
import { aiSettingsRenderer } from './settings/panel';
2627
import { renderSlashCommandOption } from './slash-commands';
27-
import { IAIProvider } from './token';
28+
import { IAIProviderRegistry } from './tokens';
2829

2930
const autocompletionRegistryPlugin: JupyterFrontEndPlugin<IAutocompletionRegistry> =
3031
{
@@ -57,11 +58,11 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
5758
id: '@jupyterlite/ai:chat',
5859
description: 'LLM chat extension',
5960
autoStart: true,
60-
requires: [IAIProvider, IRenderMimeRegistry, IAutocompletionRegistry],
61+
requires: [IAIProviderRegistry, IRenderMimeRegistry, IAutocompletionRegistry],
6162
optional: [INotebookTracker, ISettingRegistry, IThemeManager],
6263
activate: async (
6364
app: JupyterFrontEnd,
64-
aiProvider: IAIProvider,
65+
providerRegistry: IAIProviderRegistry,
6566
rmRegistry: IRenderMimeRegistry,
6667
autocompletionRegistry: IAutocompletionRegistry,
6768
notebookTracker: INotebookTracker | null,
@@ -77,8 +78,8 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
7778
}
7879

7980
const chatHandler = new ChatHandler({
80-
aiProvider: aiProvider,
81-
activeCellManager: activeCellManager
81+
providerRegistry,
82+
activeCellManager
8283
});
8384

8485
let sendWithShiftEnter = false;
@@ -135,47 +136,47 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
135136
const completerPlugin: JupyterFrontEndPlugin<void> = {
136137
id: '@jupyterlite/ai:completer',
137138
autoStart: true,
138-
requires: [IAIProvider, ICompletionProviderManager],
139+
requires: [IAIProviderRegistry, ICompletionProviderManager],
139140
activate: (
140141
app: JupyterFrontEnd,
141-
aiProvider: IAIProvider,
142+
providerRegistry: IAIProviderRegistry,
142143
manager: ICompletionProviderManager
143144
): void => {
144145
const completer = new CompletionProvider({
145-
aiProvider,
146+
providerRegistry,
146147
requestCompletion: () => app.commands.execute('inline-completer:invoke')
147148
});
148149
manager.registerInlineProvider(completer);
149150
}
150151
};
151152

152-
const aiProviderPlugin: JupyterFrontEndPlugin<IAIProvider> = {
153-
id: '@jupyterlite/ai:ai-provider',
153+
const providerRegistryPlugin: JupyterFrontEndPlugin<IAIProviderRegistry> = {
154+
id: '@jupyterlite/ai:provider-registry',
154155
autoStart: true,
155156
requires: [IFormRendererRegistry, ISettingRegistry],
156157
optional: [IRenderMimeRegistry],
157-
provides: IAIProvider,
158+
provides: IAIProviderRegistry,
158159
activate: (
159160
app: JupyterFrontEnd,
160161
editorRegistry: IFormRendererRegistry,
161162
settingRegistry: ISettingRegistry,
162163
rmRegistry?: IRenderMimeRegistry
163-
): IAIProvider => {
164-
const aiProvider = new AIProvider();
164+
): IAIProviderRegistry => {
165+
const providerRegistry = new AIProviderRegistry();
165166

166167
editorRegistry.addRenderer(
167-
'@jupyterlite/ai:ai-provider.AIprovider',
168-
aiSettingsRenderer({ rmRegistry })
168+
'@jupyterlite/ai:provider-registry.AIprovider',
169+
aiSettingsRenderer({ providerRegistry, rmRegistry })
169170
);
170171
settingRegistry
171-
.load(aiProviderPlugin.id)
172+
.load(providerRegistryPlugin.id)
172173
.then(settings => {
173174
const updateProvider = () => {
174175
// Update the settings to the AI providers.
175176
const providerSettings = (settings.get('AIprovider').composite ?? {
176177
provider: 'None'
177178
}) as ReadonlyPartialJSONObject;
178-
aiProvider.setProvider(
179+
providerRegistry.setProvider(
179180
providerSettings.provider as string,
180181
providerSettings
181182
);
@@ -186,17 +187,20 @@ const aiProviderPlugin: JupyterFrontEndPlugin<IAIProvider> = {
186187
})
187188
.catch(reason => {
188189
console.error(
189-
`Failed to load settings for ${aiProviderPlugin.id}`,
190+
`Failed to load settings for ${providerRegistryPlugin.id}`,
190191
reason
191192
);
192193
});
193194

194-
return aiProvider;
195+
// Initialize the registry with the default providers
196+
AIProviders.forEach(provider => providerRegistry.add(provider));
197+
198+
return providerRegistry;
195199
}
196200
};
197201

198202
export default [
199-
aiProviderPlugin,
203+
providerRegistryPlugin,
200204
autocompletionRegistryPlugin,
201205
chatPlugin,
202206
completerPlugin

src/llm-models/index.ts

+49-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,50 @@
1+
import { ChatAnthropic } from '@langchain/anthropic';
2+
import { ChromeAI } from '@langchain/community/experimental/llms/chrome_ai';
3+
import { ChatMistralAI } from '@langchain/mistralai';
4+
import { ChatOpenAI } from '@langchain/openai';
5+
6+
import { AnthropicCompleter } from './anthropic-completer';
7+
import { CodestralCompleter } from './codestral-completer';
8+
import { ChromeCompleter } from './chrome-completer';
9+
import { OpenAICompleter } from './openai-completer';
10+
11+
import { instructions } from '../settings/instructions';
12+
import { ProviderSettings } from '../settings/schemas';
13+
14+
import { IAIProvider } from '../tokens';
15+
116
export * from './base-completer';
2-
export * from './codestral-completer';
3-
export * from './utils';
17+
18+
const AIProviders: IAIProvider[] = [
19+
{
20+
name: 'Anthropic',
21+
chatModel: ChatAnthropic,
22+
completer: AnthropicCompleter,
23+
settingsSchema: ProviderSettings.Anthropic,
24+
errorMessage: (error: any) => error.error.error.message
25+
},
26+
{
27+
name: 'ChromeAI',
28+
// TODO: fix
29+
// @ts-expect-error: missing properties
30+
chatModel: ChromeAI,
31+
completer: ChromeCompleter,
32+
instructions: instructions.ChromeAI,
33+
settingsSchema: ProviderSettings.ChromeAI
34+
},
35+
{
36+
name: 'MistralAI',
37+
chatModel: ChatMistralAI,
38+
completer: CodestralCompleter,
39+
instructions: instructions.MistralAI,
40+
settingsSchema: ProviderSettings.MistralAI
41+
},
42+
{
43+
name: 'OpenAI',
44+
chatModel: ChatOpenAI,
45+
completer: OpenAICompleter,
46+
settingsSchema: ProviderSettings.OpenAI
47+
}
48+
];
49+
50+
export { AIProviders };

0 commit comments

Comments
 (0)