Skip to content

Commit

Permalink
fixing OpenAI Assistants to allow streaming files
Browse files Browse the repository at this point in the history
  • Loading branch information
OvidijusParsiunas committed Apr 7, 2024
1 parent 64bb3f4 commit d0404db
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 22 deletions.
26 changes: 20 additions & 6 deletions component/src/services/openAI/openAIAssistantIO.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import {AssistantFunctionHandler, OpenAIAssistant, OpenAINewAssistant} from '../../types/openAI';
import {MessageStream} from '../../views/chat/messages/stream/messageStream';
import {OpenAIConverseBodyInternal} from '../../types/openAIInternal';
import {OpenAIAssistantFiles} from './utils/openAIAssistantFiles';
import {OpenAIAssistantUtils} from './utils/openAIAssistantUtils';
import {DirectConnection} from '../../types/directConnection';
import {MessageLimitUtils} from '../utils/messageLimitUtils';
import {MessageContentI} from '../../types/messagesInternal';
Expand Down Expand Up @@ -37,6 +38,8 @@ export class OpenAIAssistantIO extends DirectServiceIO {
private readonly shouldFetchHistory: boolean = false;
private waitingForStreamResponse = false;
private readonly isSSEStream: boolean = false;
private streamedMessageId: string | undefined;
private messageStream: MessageStream | undefined;
fetchHistory?: () => Promise<ResponseI[]>;

constructor(deepChat: DeepChat) {
Expand Down Expand Up @@ -110,9 +113,10 @@ export class OpenAIAssistantIO extends DirectServiceIO {
this.waitingForStreamResponse = false;
if (!this.connectSettings) throw new Error('Request settings have not been set up');
this.rawBody.assistant_id ??= this.config.assistant_id || (await this.createNewAssistant());
this.streamedMessageId = undefined;
// here instead of constructor as messages may be loaded later
if (!this.searchedForThreadId) this.searchPreviousMessagesForThreadId(messages.messages);
const file_ids = files ? await OpenAIAssistantFiles.storeFiles(this, messages, files) : undefined;
const file_ids = files ? await OpenAIAssistantUtils.storeFiles(this, messages, files) : undefined;
this.connectSettings.method = 'POST';
this.callService(messages, pMessages, file_ids);
}
Expand Down Expand Up @@ -172,7 +176,7 @@ export class OpenAIAssistantIO extends DirectServiceIO {
if (!isHistory && this.deepChat.responseInterceptor) {
threadMessages = (await this.deepChat.responseInterceptor?.(threadMessages)) as OpenAIAssistantMessagesResult;
}
return OpenAIAssistantFiles.processAPIMessages(this, threadMessages, isHistory);
return OpenAIAssistantUtils.processAPIMessages(this, threadMessages, isHistory);
}

async extractPollResultData(result: OpenAIRunResult): PollResult {
Expand Down Expand Up @@ -241,9 +245,19 @@ export class OpenAIAssistantIO extends DirectServiceIO {
return {makingAnotherRequest: true};
}

private parseStreamResult(result: OpenAIAssistantInitReqResult) {
private async parseStreamResult(result: OpenAIAssistantInitReqResult) {
if (result.delta?.content) {
return {text: result.delta.content[0].text.value};
if (!this.streamedMessageId) {
this.streamedMessageId = result.id;
} else if (this.streamedMessageId !== result.id) {
this.streamedMessageId = result.id;
this.messageStream?.newMessage();
}
if (result.delta.content.length > 1) {
const messages = await OpenAIAssistantUtils.processSteamMessages(this, result.delta.content);
return {text: messages[0].text, files: messages[1].files};
}
return {text: result.delta.content[0].text?.value};
}
if (!this.sessionId && result.thread_id) {
this.sessionId = result.thread_id;
Expand All @@ -256,6 +270,6 @@ export class OpenAIAssistantIO extends DirectServiceIO {
private async createStreamRun(body: any) {
body.stream = true;
this.waitingForStreamResponse = true;
await Stream.request(this, body, this.messages as Messages, true, true);
this.messageStream = (await Stream.request(this, body, this.messages as Messages, true, true)) as MessageStream;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {ServiceIO} from '../../serviceIO';

type FileDetails = {fileId: string; path?: string; name?: string}[];

export class OpenAIAssistantFiles {
export class OpenAIAssistantUtils {
public static async storeFiles(serviceIO: ServiceIO, messages: Messages, files: File[]) {
const headers = serviceIO.connectSettings.headers;
if (!headers) return;
Expand Down Expand Up @@ -60,7 +60,7 @@ export class OpenAIAssistantFiles {
resolve({
src: (event.target as FileReader).result as string,
name: fileDetails[index].name,
type: OpenAIAssistantFiles.getType(fileDetails, index),
type: OpenAIAssistantUtils.getType(fileDetails, index),
});
};
});
Expand All @@ -82,7 +82,7 @@ export class OpenAIAssistantFiles {
fileDetails.push({
path: annotation.text,
fileId: annotation.file_path.file_id,
name: OpenAIAssistantFiles.getFileName(annotation.text),
name: OpenAIAssistantUtils.getFileName(annotation.text),
});
}
});
Expand All @@ -101,7 +101,7 @@ export class OpenAIAssistantFiles {
role?: string, content?: OpenAIAssistantContent) {
let files: MessageFile[] | undefined;
if (fileDetails.length > 0) {
files = await OpenAIAssistantFiles.getFiles(io, fileDetails);
files = await OpenAIAssistantUtils.getFiles(io, fileDetails);
if (content?.text?.value) {
files.forEach((file, index) => {
if (!file.src) return;
Expand All @@ -116,12 +116,12 @@ export class OpenAIAssistantFiles {
}

public static async getFilesAndText(io: ServiceIO, message: OpenAIAssistantData, content?: OpenAIAssistantContent) {
const fileDetails = OpenAIAssistantFiles.getFileDetails(message, content);
const fileDetails = OpenAIAssistantUtils.getFileDetails(message, content);
// gets files and replaces hyperlinks with base64 file encodings
return await OpenAIAssistantFiles.getFilesAndNewText(io, fileDetails, message.role, content);
return await OpenAIAssistantUtils.getFilesAndNewText(io, fileDetails, message.role, content);
}

private static parseMesages(result: OpenAIAssistantMessagesResult, isHistory: boolean) {
private static parseResult(result: OpenAIAssistantMessagesResult, isHistory: boolean) {
let messages = [];
if (isHistory) {
messages = result.data;
Expand All @@ -140,7 +140,7 @@ export class OpenAIAssistantFiles {

// test this using this prompt and it should give 2 text mesages and a file:
// "give example data for a csv and create a suitable bar chart"
private static parseContent(io: DirectServiceIO, messages: OpenAIAssistantData[]) {
private static parseMessages(io: DirectServiceIO, messages: OpenAIAssistantData[]) {
const parsedContent: Promise<{text?: string; files?: MessageFile[]}>[] = [];
messages.forEach(async (data) => {
data.content
Expand All @@ -151,15 +151,18 @@ export class OpenAIAssistantFiles {
return 0;
})
.forEach(async (content) => {
parsedContent.push(OpenAIAssistantFiles.getFilesAndText(io, data, content));
parsedContent.push(OpenAIAssistantUtils.getFilesAndText(io, data, content));
});
});
return parsedContent;
return Promise.all(parsedContent);
}

public static async processSteamMessages(io: DirectServiceIO, content: OpenAIAssistantContent[]) {
return OpenAIAssistantUtils.parseMessages(io, [{content, role: 'assistant'}]);
}

public static async processAPIMessages(io: DirectServiceIO, result: OpenAIAssistantMessagesResult, isHistory: boolean) {
const messages = OpenAIAssistantFiles.parseMesages(result, isHistory);
const parsedContent = OpenAIAssistantFiles.parseContent(io, messages);
return Promise.all(parsedContent);
const messages = OpenAIAssistantUtils.parseResult(result, isHistory);
return OpenAIAssistantUtils.parseMessages(io, messages);
}
}
2 changes: 1 addition & 1 deletion component/src/types/openAIResult.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export type OpenAIAssistantInitReqResult = OpenAIRunResult & {
error?: {code: string; message: string};
// this is used exclusively for streams
delta?: {
content?: {text: {value: string}}[];
content?: OpenAIAssistantContent[];
step_details?: {
tool_calls?: ToolCalls;
};
Expand Down
3 changes: 2 additions & 1 deletion component/src/utils/HTTP/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ export class Stream {
RequestUtils.displayError(messages, parsedError);
});
});
return stream;
}

public static simulate(messages: Messages, sh: StreamHandlers, result: ResponseI) {
Expand Down Expand Up @@ -136,7 +137,7 @@ export class Stream {
}
if (response?.files) {
messages.addNewMessage({files: response.files});
stream?.markFileAded();
stream?.markFileAdded();
}
}
}
11 changes: 10 additions & 1 deletion component/src/views/chat/messages/stream/messageStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,16 @@ export class MessageStream {
this._hasStreamEnded = true;
}

public markFileAded() {
public markFileAdded() {
this._fileAdded = true;
}

public newMessage() {
this.finaliseStreamedMessage();
this._elements = undefined;
this._streamedContent = '';
this._fileAdded = false;
this._hasStreamEnded = false;
this._activeMessageRole = undefined;
}
}

0 comments on commit d0404db

Please sign in to comment.