Skip to content

Commit

Permalink
Merge pull request #10 from copilot-extensions/sgoedecke/update-api-a…
Browse files Browse the repository at this point in the history
…nd-handle-wrong-model-names

v2 API, o1 models metadata, and support typos in model names
  • Loading branch information
sgoedecke authored Sep 16, 2024
2 parents 96014d6 + 3032705 commit 601cc27
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 26 deletions.
14 changes: 8 additions & 6 deletions src/functions/describe-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ export class describeModel extends Tool {
properties: {
model: {
type: "string",
description:
'The model to describe. Looks like "registry/model-name". For example, `azureml/Phi-3-medium-128k-instruct` or `azure-openai/gpt-4o',
description: [
'The model to describe. Looks like "model-name". For example, `Phi-3-medium-128k-instruct` or `gpt-4o`.',
'The list of models is available in the context window of the chat, in the `<-- LIST OF MODELS -->` section.',
'If the model name is not found in the list of models, pick the closest matching model from the list.',
].join("\n"),
},
},
required: ["model"],
Expand All @@ -30,12 +33,11 @@ export class describeModel extends Tool {
const systemMessage = [
"The user is asking about the AI model with the following details:",
`\tModel Name: ${model.name}`,
`\tModel Version: ${model.model_version}`,
`\tModel Version: ${model.version}`,
`\tPublisher: ${model.publisher}`,
`\tModel Family: ${model.model_family}`,
`\tModel Registry: ${model.model_registry}`,
`\tModel Registry: ${model.registryName}`,
`\tLicense: ${model.license}`,
`\tTask: ${model.task}`,
`\tTask: ${model.inferenceTasks.join(", ")}`,
`\tDescription: ${model.description}`,
`\tSummary: ${model.summary}`,
"\n",
Expand Down
1 change: 1 addition & 0 deletions src/functions/execute-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Example Queries (IMPORTANT: Phrasing doesn't have to match):
"The name of the model to execute. It is ONLY the name of the model, not the publisher or registry.",
"For example: `gpt-4o`, or `cohere-command-r-plus`.",
"The list of models is available in the context window of the chat, in the `<-- LIST OF MODELS -->` section.",
"If the model name is not found in the list of models, pick the closest matching model from the list.",
].join("\n"),
},
instruction: {
Expand Down
2 changes: 1 addition & 1 deletion src/functions/list-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export class listModels extends Tool {
"That list of models is as follows:",
JSON.stringify(
models.map((model) => ({
name: model.friendly_name,
name: model.displayName,
publisher: model.publisher,
description: model.summary,
}))
Expand Down
29 changes: 22 additions & 7 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ const server = createServer(async (request, response) => {

// List of functions that are available to be called
const modelsAPI = new ModelsAPI(apiKey);

const functions = [listModels, describeModel, executeModel, recommendModel];

// Use the Copilot API to determine which function to execute
Expand All @@ -66,6 +67,7 @@ const server = createServer(async (request, response) => {
// Prepend a system message that includes the list of models, so that
// tool calls can better select the right model to use.
const models = await modelsAPI.listModels();

const toolCallMessages = [
{
role: "system" as const,
Expand All @@ -75,13 +77,28 @@ const server = createServer(async (request, response) => {
"Here is a list of some of the models available to the user:",
"<-- LIST OF MODELS -->",
JSON.stringify(
models.map((model) => ({
friendly_name: model.friendly_name,
[...models.map((model) => ({
friendly_name: model.displayName,
name: model.name,
publisher: model.publisher,
registry: model.model_registry,
registry: model.registryName,
description: model.summary,
}))
})),
{
friendly_name: "OpenAI o1-mini",
name: "o1-mini",
publisher: "openai",
model_registry: "azure-openai",
description: "Smaller, faster, and 80% cheaper than o1-preview, performs well at code generation and small context operations."
},
{
friendly_name: "OpenAI o1-preview",
name: "o1-preview",
publisher: "openai",
model_registry: "azure-openai",
description: "Focused on advanced reasoning and solving complex problems, including math and science tasks. Ideal for applications that require deep contextual understanding and agentic workflows."
},
]
),
"<-- END OF LIST OF MODELS -->",
].join("\n"),
Expand Down Expand Up @@ -148,13 +165,11 @@ const server = createServer(async (request, response) => {
console.timeEnd("function-exec");

try {
// We should keep all optional parameters out of this call, so it can work for any model.
const stream = await modelsAPI.inference.chat.completions.create({
model: functionCallRes.model,
messages: functionCallRes.messages,
stream: true,
stream_options: {
include_usage: false,
},
});

console.time("streaming");
Expand Down
48 changes: 36 additions & 12 deletions src/models-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@ import OpenAI from "openai";

// Model is the structure of a model in the model catalog.
export interface Model {
id: string;
name: string;
friendly_name: string;
model_version: number;
displayName: string;
version: string;
publisher: string;
model_family: string;
model_registry: string;
registryName: string;
license: string;
task: string;
description: string;
inferenceTasks: string[];
description?: string;
summary: string;
}

Expand Down Expand Up @@ -44,19 +42,23 @@ export class ModelsAPI {
}

async getModel(modelName: string): Promise<Model> {
const modelFromIndex = await this.getModelFromIndex(modelName);

const modelRes = await fetch(
"https://modelcatalog.azure-api.net/v1/model/" + modelName
`https://eastus.api.azureml.ms/asset-gallery/v1.0/${modelFromIndex.registryName}/models/${modelFromIndex.name}/version/${modelFromIndex.version}`,
);
if (!modelRes.ok) {
throw new Error(`Failed to fetch ${modelName} from the model catalog.`);
throw new Error(`Failed to fetch ${modelName} details from the model catalog.`);
}
const model = (await modelRes.json()) as Model;
return model;
}

async getModelSchema(modelName: string): Promise<ModelSchema> {
const modelFromIndex = await this.getModelFromIndex(modelName);

const modelSchemaRes = await fetch(
`https://modelcatalogcachev2-ebendjczf0c5dzca.b02.azurefd.net/widgets/en/Serverless/${modelName.toLowerCase()}.json`
`https://modelcatalogcachev2-ebendjczf0c5dzca.b02.azurefd.net/widgets/en/Serverless/${modelFromIndex.registryName.toLowerCase()}/${modelFromIndex.name.toLowerCase()}.json`
);
if (!modelSchemaRes.ok) {
throw new Error(
Expand All @@ -73,14 +75,36 @@ export class ModelsAPI {
}

const modelsRes = await fetch(
"https://modelcatalog.azure-api.net/v1/models"
"https://eastus.api.azureml.ms/asset-gallery/v1.0/models",
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
filters: [
{ field: "freePlayground", values: ["true"], operator: "eq" },
{ field: "labels", values: ["latest"], operator: "eq" },
],
order: [{ field: "displayName", direction: "Asc" }],
}),
}
);
if (!modelsRes.ok) {
throw new Error("Failed to fetch models from the model catalog");
}

const models = (await modelsRes.json()) as Model[];
const models = (await modelsRes.json()).summaries as Model[];
this._models = models;
return models;
}

async getModelFromIndex(modelName: string): Promise<Model> {
this._models = this._models || (await this.listModels());
const modelFromIndex = this._models.find((model) => model.name === modelName);
if (!modelFromIndex) {
throw new Error(`Failed to fetch ${modelName} from the model catalog.`);
}
return modelFromIndex;
}
}

0 comments on commit 601cc27

Please sign in to comment.