Skip to content

Commit

Permalink
Update AI error handling to use the result type
Browse files Browse the repository at this point in the history
  • Loading branch information
Caleb-T-Owens committed Jun 25, 2024
1 parent b954b54 commit b8b8c18
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 125 deletions.
16 changes: 13 additions & 3 deletions app/src/lib/ai/anthropicClient.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import { SHORT_DEFAULT_COMMIT_TEMPLATE, SHORT_DEFAULT_BRANCH_TEMPLATE } from '$lib/ai/prompts';
import { err, ok, type Result } from '$lib/result';
import { fetch, Body } from '@tauri-apps/api/http';
import type { AIClient, AnthropicModelName, Prompt } from '$lib/ai/types';

type AnthropicAPIResponse = { content: { text: string }[] };
type AnthropicAPIResponse = {
content: { text: string }[];
error: { type: string; message: string };
};

export class AnthropicAIClient implements AIClient {
defaultCommitTemplate = SHORT_DEFAULT_COMMIT_TEMPLATE;
Expand All @@ -13,7 +17,7 @@ export class AnthropicAIClient implements AIClient {
private modelName: AnthropicModelName
) {}

async evaluate(prompt: Prompt) {
async evaluate(prompt: Prompt): Promise<Result<string, string>> {
const body = Body.json({
messages: prompt,
max_tokens: 1024,
Expand All @@ -30,6 +34,12 @@ export class AnthropicAIClient implements AIClient {
body
});

return response.data.content[0].text;
if (response.ok && response.data?.content?.[0]?.text) {
return ok(response.data.content[0].text);
} else {
return err(
`Anthropic returned error code ${response.status} ${response.data?.error?.message}`
);
}
}
}
29 changes: 19 additions & 10 deletions app/src/lib/ai/butlerClient.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import { err, ok, type Result } from '$lib/result';
import type { AIClient, ModelKind, Prompt } from '$lib/ai/types';
import type { HttpClient } from '$lib/backend/httpClient';

Expand All @@ -12,16 +13,24 @@ export class ButlerAIClient implements AIClient {
private modelKind: ModelKind
) {}

async evaluate(prompt: Prompt) {
const response = await this.cloud.post<{ message: string }>('evaluate_prompt/predict.json', {
body: {
messages: prompt,
max_tokens: 400,
model_kind: this.modelKind
},
token: this.userToken
});
async evaluate(prompt: Prompt): Promise<Result<string, string>> {
try {
const response = await this.cloud.post<{ message: string }>('evaluate_prompt/predict.json', {
body: {
messages: prompt,
max_tokens: 400,
model_kind: this.modelKind
},
token: this.userToken
});

return response.message;
return ok(response.message);
} catch (e) {
if (e instanceof Error) {
return err(e.message);
} else {
return err('Failed to contant GitButler API');
}
}
}
}
20 changes: 12 additions & 8 deletions app/src/lib/ai/ollamaClient.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { LONG_DEFAULT_BRANCH_TEMPLATE, LONG_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import { MessageRole, type PromptMessage, type AIClient, type Prompt } from '$lib/ai/types';
import { err, isError, ok, type Result } from '$lib/result';
import { isNonEmptyObject } from '$lib/utils/typeguards';
import { fetch, Body, Response } from '@tauri-apps/api/http';

Expand Down Expand Up @@ -81,15 +82,19 @@ export class OllamaClient implements AIClient {
private modelName: string
) {}

async evaluate(prompt: Prompt) {
async evaluate(prompt: Prompt): Promise<Result<string, string>> {
const messages = this.formatPrompt(prompt);
const response = await this.chat(messages);

const responseResult = await this.chat(messages);
if (isError(responseResult)) return responseResult;
const response = responseResult.value;

const rawResponse = JSON.parse(response.message.content);
if (!isOllamaChatMessageFormat(rawResponse)) {
throw new Error('Invalid response: ' + response.message.content);
err('Invalid response: ' + response.message.content);
}

return rawResponse.result;
return ok(rawResponse.result);
}

/**
Expand Down Expand Up @@ -142,13 +147,12 @@ ${JSON.stringify(OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA, null, 2)}`
*
* @param messages - An array of LLMChatMessage objects representing the chat messages.
* @param options - Optional LLMRequestOptions object for specifying additional options.
* @throws Error if the response is invalid.
* @returns A Promise that resolves to an LLMResponse object representing the response from the LLM model.
*/
private async chat(
messages: Prompt,
options?: OllamaRequestOptions
): Promise<OllamaChatResponse> {
): Promise<Result<OllamaChatResponse, string>> {
const result = await this.fetchChat({
model: this.modelName,
stream: false,
Expand All @@ -158,9 +162,9 @@ ${JSON.stringify(OLLAMA_CHAT_MESSAGE_FORMAT_SCHEMA, null, 2)}`
});

if (!isOllamaChatResponse(result.data)) {
throw new Error('Invalid response\n' + JSON.stringify(result.data));
return err('Invalid response\n' + JSON.stringify(result.data));
}

return result.data;
return ok(result.data);
}
}
27 changes: 20 additions & 7 deletions app/src/lib/ai/openAIClient.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { SHORT_DEFAULT_BRANCH_TEMPLATE, SHORT_DEFAULT_COMMIT_TEMPLATE } from '$lib/ai/prompts';
import { err, ok, type Result } from '$lib/result';
import type { OpenAIModelName, Prompt, AIClient } from '$lib/ai/types';
import type OpenAI from 'openai';

Expand All @@ -11,13 +12,25 @@ export class OpenAIClient implements AIClient {
private openAI: OpenAI
) {}

async evaluate(prompt: Prompt) {
const response = await this.openAI.chat.completions.create({
messages: prompt,
model: this.modelName,
max_tokens: 400
});
async evaluate(prompt: Prompt): Promise<Result<string, string>> {
try {
const response = await this.openAI.chat.completions.create({
messages: prompt,
model: this.modelName,
max_tokens: 400
});

return response.choices[0].message.content || '';
if (response.choices[0]?.message.content) {
return ok(response.choices[0]?.message.content);
} else {
return err('Open AI generated an empty message');
}
} catch (e) {
if (e instanceof Error) {
return err(e.message);
} else {
return err('Failed to contact Open AI');
}
}
}
}
88 changes: 50 additions & 38 deletions app/src/lib/ai/service.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
type Prompt
} from '$lib/ai/types';
import { HttpClient } from '$lib/backend/httpClient';
import * as toasts from '$lib/utils/toasts';
import { err, ok, unwrap, type Result } from '$lib/result';
import { Hunk } from '$lib/vbranches/types';
import { plainToInstance } from 'class-transformer';
import { expect, test, describe, vi } from 'vitest';
Expand Down Expand Up @@ -56,8 +56,8 @@ class DummyAIClient implements AIClient {
defaultBranchTemplate = SHORT_DEFAULT_BRANCH_TEMPLATE;
constructor(private response = 'lorem ipsum') {}

async evaluate(_prompt: Prompt) {
return this.response;
async evaluate(_prompt: Prompt): Promise<Result<string, string>> {
return ok(this.response);
}
}

Expand Down Expand Up @@ -116,16 +116,14 @@ describe.concurrent('AIService', () => {
test('With default configuration, When a user token is provided. It returns ButlerAIClient', async () => {
const aiService = buildDefaultAIService();

expect(await aiService.buildClient('token')).toBeInstanceOf(ButlerAIClient);
expect(unwrap(await aiService.buildClient('token'))).toBeInstanceOf(ButlerAIClient);
});

test('With default configuration, When a user is undefined. It returns undefined', async () => {
const toastErrorSpy = vi.spyOn(toasts, 'error');
const aiService = buildDefaultAIService();

expect(await aiService.buildClient()).toBe(undefined);
expect(toastErrorSpy).toHaveBeenLastCalledWith(
"When using GitButler's API to summarize code, you must be logged in"
expect(await aiService.buildClient()).toStrictEqual(
err("When using GitButler's API to summarize code, you must be logged in")
);
});

Expand All @@ -137,21 +135,21 @@ describe.concurrent('AIService', () => {
});
const aiService = new AIService(gitConfig, cloud);

expect(await aiService.buildClient()).toBeInstanceOf(OpenAIClient);
expect(unwrap(await aiService.buildClient())).toBeInstanceOf(OpenAIClient);
});

test('When token is bring your own, When a openAI token is blank. It returns undefined', async () => {
const toastErrorSpy = vi.spyOn(toasts, 'error');
const gitConfig = new DummyGitConfigService({
...defaultGitConfig,
[GitAIConfigKey.OpenAIKeyOption]: KeyOption.BringYourOwn,
[GitAIConfigKey.OpenAIKey]: undefined
});
const aiService = new AIService(gitConfig, cloud);

expect(await aiService.buildClient()).toBe(undefined);
expect(toastErrorSpy).toHaveBeenLastCalledWith(
'When using OpenAI in a bring your own key configuration, you must provide a valid token'
expect(await aiService.buildClient()).toStrictEqual(
err(
'When using OpenAI in a bring your own key configuration, you must provide a valid token'
)
);
});

Expand All @@ -164,11 +162,10 @@ describe.concurrent('AIService', () => {
});
const aiService = new AIService(gitConfig, cloud);

expect(await aiService.buildClient()).toBeInstanceOf(AnthropicAIClient);
expect(unwrap(await aiService.buildClient())).toBeInstanceOf(AnthropicAIClient);
});

test('When ai provider is Anthropic, When token is bring your own, When an anthropic token is blank. It returns undefined', async () => {
const toastErrorSpy = vi.spyOn(toasts, 'error');
const gitConfig = new DummyGitConfigService({
...defaultGitConfig,
[GitAIConfigKey.ModelProvider]: ModelKind.Anthropic,
Expand All @@ -177,9 +174,10 @@ describe.concurrent('AIService', () => {
});
const aiService = new AIService(gitConfig, cloud);

expect(await aiService.buildClient()).toBe(undefined);
expect(toastErrorSpy).toHaveBeenLastCalledWith(
'When using Anthropic in a bring your own key configuration, you must provide a valid token'
expect(await aiService.buildClient()).toStrictEqual(
err(
'When using Anthropic in a bring your own key configuration, you must provide a valid token'
)
);
});
});
Expand All @@ -188,9 +186,13 @@ describe.concurrent('AIService', () => {
test('When buildModel returns undefined, it returns undefined', async () => {
const aiService = buildDefaultAIService();

vi.spyOn(aiService, 'buildClient').mockReturnValue((async () => undefined)());
vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => err<AIClient, string>('Failed to build'))()
);

expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toBe(undefined);
expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toStrictEqual(
err('Failed to build')
);
});

test('When the AI returns a single line commit message, it returns it unchanged', async () => {
Expand All @@ -199,10 +201,12 @@ describe.concurrent('AIService', () => {
const clientResponse = 'single line commit';

vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => new DummyAIClient(clientResponse))()
(async () => ok<AIClient, string>(new DummyAIClient(clientResponse)))()
);

expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toBe('single line commit');
expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toStrictEqual(
ok('single line commit')
);
});

test('When the AI returns a title and body that is split by a single new line, it replaces it with two', async () => {
Expand All @@ -211,10 +215,12 @@ describe.concurrent('AIService', () => {
const clientResponse = 'one\nnew line';

vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => new DummyAIClient(clientResponse))()
(async () => ok<AIClient, string>(new DummyAIClient(clientResponse)))()
);

expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toBe('one\n\nnew line');
expect(await aiService.summarizeCommit({ hunks: exampleHunks })).toStrictEqual(
ok('one\n\nnew line')
);
});

test('When the commit is in brief mode, When the AI returns a title and body, it takes just the title', async () => {
Expand All @@ -223,22 +229,26 @@ describe.concurrent('AIService', () => {
const clientResponse = 'one\nnew line';

vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => new DummyAIClient(clientResponse))()
(async () => ok<AIClient, string>(new DummyAIClient(clientResponse)))()
);

expect(await aiService.summarizeCommit({ hunks: exampleHunks, useBriefStyle: true })).toBe(
'one'
);
expect(
await aiService.summarizeCommit({ hunks: exampleHunks, useBriefStyle: true })
).toStrictEqual(ok('one'));
});
});

describe.concurrent('#summarizeBranch', async () => {
test('When buildModel returns undefined, it returns undefined', async () => {
const aiService = buildDefaultAIService();

vi.spyOn(aiService, 'buildClient').mockReturnValue((async () => undefined)());
vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => err<AIClient, string>('Failed to build client'))()
);

expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toBe(undefined);
expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toStrictEqual(
err('Failed to build client')
);
});

test('When the AI client returns a string with spaces, it replaces them with hypens', async () => {
Expand All @@ -247,10 +257,12 @@ describe.concurrent('AIService', () => {
const clientResponse = 'with spaces included';

vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => new DummyAIClient(clientResponse))()
(async () => ok<AIClient, string>(new DummyAIClient(clientResponse)))()
);

expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toBe('with-spaces-included');
expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toStrictEqual(
ok('with-spaces-included')
);
});

test('When the AI client returns multiple lines, it replaces them with hypens', async () => {
Expand All @@ -259,11 +271,11 @@ describe.concurrent('AIService', () => {
const clientResponse = 'with\nnew\nlines\nincluded';

vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => new DummyAIClient(clientResponse))()
(async () => ok<AIClient, string>(new DummyAIClient(clientResponse)))()
);

expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toBe(
'with-new-lines-included'
expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toStrictEqual(
ok('with-new-lines-included')
);
});

Expand All @@ -273,11 +285,11 @@ describe.concurrent('AIService', () => {
const clientResponse = 'with\nnew lines\nincluded';

vi.spyOn(aiService, 'buildClient').mockReturnValue(
(async () => new DummyAIClient(clientResponse))()
(async () => ok<AIClient, string>(new DummyAIClient(clientResponse)))()
);

expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toBe(
'with-new-lines-included'
expect(await aiService.summarizeBranch({ hunks: exampleHunks })).toStrictEqual(
ok('with-new-lines-included')
);
});
});
Expand Down
Loading

0 comments on commit b8b8c18

Please sign in to comment.