Skip to content

Commit

Permalink
feat(endpoints): Add conv ID to headers passed to TGI (#1511)
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarrazin authored Oct 10, 2024
1 parent 81c4393 commit 6451bee
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/lib/server/endpoints/endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import endpointLangserve, {
} from "./langserve/endpointLangserve";

import type { Tool, ToolCall, ToolResult } from "$lib/types/Tool";
import type { ObjectId } from "mongodb";

export type EndpointMessage = Omit<Message, "id">;

Expand All @@ -41,6 +42,7 @@ export interface EndpointParameters {
tools?: Tool[];
toolResults?: ToolResult[];
isMultimodal?: boolean;
conversationId?: ObjectId;
}

interface CommonEndpoint {
Expand Down
17 changes: 15 additions & 2 deletions src/lib/server/endpoints/openai/endpointOai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ export async function endpointOai(
"Tools are not supported for 'completions' mode, switch to 'chat_completions' instead"
);
}
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
return async ({ messages, preprompt, continueMessage, generateSettings, conversationId }) => {
const prompt = await buildPrompt({
messages,
continueMessage,
Expand All @@ -171,12 +171,22 @@ export async function endpointOai(

const openAICompletion = await openai.completions.create(body, {
body: { ...body, ...extraBody },
headers: {
"ChatUI-Conversation-ID": conversationId?.toString() ?? "",
},
});

return openAICompletionToTextGenerationStream(openAICompletion);
};
} else if (completion === "chat_completions") {
return async ({ messages, preprompt, generateSettings, tools, toolResults }) => {
return async ({
messages,
preprompt,
generateSettings,
tools,
toolResults,
conversationId,
}) => {
let messagesOpenAI: OpenAI.Chat.Completions.ChatCompletionMessageParam[] =
await prepareMessages(messages, imageProcessor, !model.tools && model.multimodal);

Expand Down Expand Up @@ -240,6 +250,9 @@ export async function endpointOai(

const openChatAICompletion = await openai.chat.completions.create(body, {
body: { ...body, ...extraBody },
headers: {
"ChatUI-Conversation-ID": conversationId?.toString() ?? "",
},
});

return openAIChatToTextGenerationStream(openChatAICompletion);
Expand Down
2 changes: 2 additions & 0 deletions src/lib/server/endpoints/tgi/endpointTgi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
tools,
toolResults,
isMultimodal,
conversationId,
}) => {
const messagesWithResizedFiles = await Promise.all(
messages.map((message) => prepareMessage(Boolean(isMultimodal), message, imageProcessor))
Expand Down Expand Up @@ -72,6 +73,7 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
info.headers = {
...info.headers,
Authorization: authorization,
"ChatUI-Conversation-ID": conversationId?.toString() ?? "",
};
}
return fetch(endpointUrl, info);
Expand Down
1 change: 1 addition & 0 deletions src/lib/server/textGeneration/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export async function* generate(
generateSettings: assistant?.generateSettings,
toolResults,
isMultimodal: model.multimodal,
conversationId: conv._id,
})) {
// text generation completed
if (output.generated_text) {
Expand Down
1 change: 1 addition & 0 deletions src/lib/server/textGeneration/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ export async function* runTools(
type: input.type === "file" ? "str" : input.type,
})),
})),
conversationId: conv._id,
})) {
// model natively supports tool calls
if (output.token.toolCalls) {
Expand Down

0 comments on commit 6451bee

Please sign in to comment.