Skip to content

Commit

Permalink
Improve ai error handling (#4180)
Browse files Browse the repository at this point in the history
* Introduce a result type

* Update AI error handling to use the result type

* Handle ollama json parse error

* Migrate using Error as the type that represents errors

* Remove now useless condition

* asdfasdf

* Use andThen

* Correct unit tests
  • Loading branch information
Caleb-T-Owens authored Jun 27, 2024
1 parent 518cc8b commit dd0b4ec
Show file tree
Hide file tree
Showing 11 changed files with 330 additions and 136 deletions.
18 changes: 14 additions & 4 deletions app/src/lib/ai/anthropicClient.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import { SHORT_DEFAULT_COMMIT_TEMPLATE, SHORT_DEFAULT_BRANCH_TEMPLATE } from '$lib/ai/prompts';
import { type AIClient, type AnthropicModelName, type Prompt } from '$lib/ai/types';
import { buildFailureFromAny, ok, type Result } from '$lib/result';
import { fetch, Body } from '@tauri-apps/api/http';
import type { AIClient, AnthropicModelName, Prompt } from '$lib/ai/types';

type AnthropicAPIResponse = { content: { text: string }[] };
type AnthropicAPIResponse = {
content: { text: string }[];
error: { type: string; message: string };
};

export class AnthropicAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
Expand All @@ -13,7 +17,7 @@ export class AnthropicAIClient implements AIClient {
private modelName: AnthropicModelName
) {}

async evaluate(prompt: Prompt) {
async evaluate(prompt: Prompt): Promise<Result<string, Error>> {
const body = Body.json({
messages: prompt,
max_tokens: 1024,
Expand All @@ -30,6 +34,12 @@ export class AnthropicAIClient implements AIClient {
body
});

return response.data.content[0].text;
if (response.ok && response.data?.content?.[0]?.text) {
return ok(response.data.content[0].text);
} else {
return buildFailureFromAny(
`Anthropic returned error code ${response.status} ${response.data?.error?.message}`
);
}
}
}
24 changes: 14 additions & 10 deletions app/src/lib/ai/butlerClient.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import { map, type Result } from '$lib/result';
import type { AIClient, ModelKind, Prompt } from '$lib/ai/types';
import type { HttpClient } from '$lib/backend/httpClient';

Expand All @@ -12,16 +13,19 @@ export class ButlerAIClient implements AIClient {
private modelKind: ModelKind
) {}

async evaluate(prompt: Prompt) {
const response = await this.cloud.post<{ message: string }>('evaluate_prompt/predict.json', {
body: {
messages: prompt,
max_tokens: 400,
model_kind: this.modelKind
},
token: this.userToken
});
async evaluate(prompt: Prompt): Promise<Result<string, Error>> {
const response = await this.cloud.postSafe<{ message: string }>(
'evaluate_prompt/predict.json',
{
body: {
messages: prompt,
max_tokens: 400,
model_kind: this.modelKind
},
token: this.userToken
}
);

return response.message;
return map(response, ({ message }) => message);
}
}
55 changes: 33 additions & 22 deletions app/src/lib/ai/ollamaClient.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { LONG_DEFAULT_BRANCH_TEMPLATE, LONG_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import { MessageRole, type PromptMessage, type AIClient, type Prompt } from '$lib/ai/types';
import { andThen, buildFailureFromAny, ok, wrap, wrapAsync, type Result } from '$lib/result';
import { isNonEmptyObject } from '$lib/utils/typeguards';
import { fetch, Body, Response } from '@tauri-apps/api/http';

Expand Down Expand Up @@ -81,15 +82,22 @@ export class OllamaClient implements AIClient {
private modelName: string
) {}

async evaluate(prompt: Prompt) {
async evaluate(prompt: Prompt): Promise<Result<string, Error>> {
const messages = this.formatPrompt(prompt);
const response = await this.chat(messages);
const rawResponse = JSON.parse(response.message.content);
if (!isOllamaChatMessageFormat(rawResponse)) {
throw new Error('Invalid response: ' + response.message.content);
}

return rawResponse.result;
const responseResult = await this.chat(messages);

return andThen(responseResult, (response) => {
const rawResponseResult = wrap<unknown, Error>(() => JSON.parse(response.message.content));

return andThen(rawResponseResult, (rawResponse) => {
if (!isOllamaChatMessageFormat(rawResponse)) {
return buildFailureFromAny('Invalid response: ' + response.message.content);
}

return ok(rawResponse.result);
});
});
}

/**
Expand Down Expand Up @@ -124,31 +132,32 @@ ${JSON.stringify(OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA, null, 2)}`
* @param request - The OllamaChatRequest object containing the request details.
* @returns A Promise that resolves to the Response object.
*/
private async fetchChat(request: OllamaChatRequest): Promise<Response<any>> {
private async fetchChat(request: OllamaChatRequest): Promise<Result<Response<any>, Error>> {
const url = new URL(OllamaAPEndpoint.Chat, this.endpoint);
const body = Body.json(request);
const result = await fetch(url.toString(), {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body
});
return result;
return await wrapAsync(
async () =>
await fetch(url.toString(), {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body
})
);
}

/**
* Sends a chat message to the LLM model and returns the response.
*
* @param messages - An array of LLMChatMessage objects representing the chat messages.
* @param options - Optional LLMRequestOptions object for specifying additional options.
* @throws Error if the response is invalid.
* @returns A Promise that resolves to an LLMResponse object representing the response from the LLM model.
*/
private async chat(
messages: Prompt,
options?: OllamaRequestOptions
): Promise<OllamaChatResponse> {
): Promise<Result<OllamaChatResponse, Error>> {
const result = await this.fetchChat({
model: this.modelName,
stream: false,
Expand All @@ -157,10 +166,12 @@ ${JSON.stringify(OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA, null, 2)}`
format: 'json'
});

if (!isOllamaChatResponse(result.data)) {
throw new Error('Invalid response\n' + JSON.stringify(result.data));
}
return andThen(result, (result) => {
if (!isOllamaChatResponse(result.data)) {
return buildFailureFromAny('Invalid response\n' + JSON.stringify(result.data));
}

return result.data;
return ok(result.data);
});
}
}
22 changes: 16 additions & 6 deletions app/src/lib/ai/openAIClient.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import { andThen, buildFailureFromAny, ok, wrapAsync, type Result } from '$lib/result';
import type { OpenAIModelName, Prompt, AIClient } from '$lib/ai/types';
import type OpenAI from 'openai';
import type { ChatCompletion } from 'openai/resources/index.mjs';

export class OpenAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
Expand All @@ -11,13 +13,21 @@ export class OpenAIClient implements AIClient {
private openAI: OpenAI
) {}

async evaluate(prompt: Prompt) {
const response = await this.openAI.chat.completions.create({
messages: prompt,
model: this.modelName,
max_tokens: 400
async evaluate(prompt: Prompt): Promise<Result<string, Error>> {
const responseResult = await wrapAsync<ChatCompletion, Error>(async () => {
return await this.openAI.chat.completions.create({
messages: prompt,
model: this.modelName,
max_tokens: 400
});
});

return response.choices[0].message.content || '';
return andThen(responseResult, (response) => {
if (response.choices[0]?.message.content) {
return ok(response.choices[0]?.message.content);
} else {
return buildFailureFromAny('Open AI generated an empty message');
}
});
}
}
Loading

0 comments on commit dd0b4ec

Please sign in to comment.