diff --git a/chat-server/src/app.test.ts b/chat-server/src/app.test.ts index c76759884..5e4ec9c5b 100644 --- a/chat-server/src/app.test.ts +++ b/chat-server/src/app.test.ts @@ -7,7 +7,7 @@ import { makeOpenAiEmbedFunc, } from "chat-core"; import { errorHandler, makeApp, makeHandleTimeoutMiddleware } from "./app"; -import { ConversationsService } from "./services/conversations"; +import { makeConversationsService } from "./services/conversations"; import { makeDataStreamer } from "./services/dataStreamer"; import { makeOpenAiLlm } from "./services/llm"; import { config } from "./config"; @@ -20,7 +20,7 @@ describe("App", () => { config.mongodb.vectorSearchIndexName ); - const conversations = new ConversationsService( + const conversations = makeConversationsService( mongodb.db, config.llm.systemPrompt ); diff --git a/chat-server/src/app.ts b/chat-server/src/app.ts index 0d4619024..231ee8f4e 100644 --- a/chat-server/src/app.ts +++ b/chat-server/src/app.ts @@ -16,7 +16,7 @@ import { } from "chat-core"; import { DataStreamer } from "./services/dataStreamer"; import { ObjectId } from "mongodb"; -import { ConversationsServiceInterface } from "./services/conversations"; +import { ConversationsService } from "./services/conversations"; import { getRequestId, logRequest, sendErrorResponse } from "./utils"; import { Llm, @@ -91,7 +91,7 @@ export const makeApp = async ({ embed: EmbedFunc; store: EmbeddedContentStore; dataStreamer: DataStreamer; - conversations: ConversationsServiceInterface; + conversations: ConversationsService; llm: Llm; maxRequestTimeoutMs?: number; findNearestNeighborsOptions?: Partial; diff --git a/chat-server/src/index.ts b/chat-server/src/index.ts index 4adf8661e..bb406cc79 100644 --- a/chat-server/src/index.ts +++ b/chat-server/src/index.ts @@ -6,7 +6,7 @@ import { makeDatabaseConnection, makeOpenAiEmbedFunc, } from "chat-core"; -import { ConversationsService } from "./services/conversations"; +import { makeConversationsService } from "./services/conversations"; import { makeDataStreamer } from "./services/dataStreamer"; import { makeOpenAiLlm } from "./services/llm"; import { config } from "./config"; @@ -21,7 +21,7 @@ const startServer = async () => { config.mongodb.vectorSearchIndexName ); - const conversations = new ConversationsService( + const conversations = makeConversationsService( mongodb.db, config.llm.systemPrompt ); @@ -44,7 +44,6 @@ const startServer = async () => { searchBoosters: config.conversations?.searchBoosters, userQueryPreprocessor: config.conversations?.userQueryPreprocessor, maxRequestTimeoutMs: config.maxRequestTimeoutMs, - }); const server = app.listen(PORT, () => { diff --git a/chat-server/src/routes/conversations/addMessageToConversation.test.ts b/chat-server/src/routes/conversations/addMessageToConversation.test.ts index 12bfc84c1..163d021ef 100644 --- a/chat-server/src/routes/conversations/addMessageToConversation.test.ts +++ b/chat-server/src/routes/conversations/addMessageToConversation.test.ts @@ -14,7 +14,7 @@ import { Conversation, ConversationsService, Message, - ConversationsServiceInterface, + makeConversationsService, } from "../../services/conversations"; import express, { Express } from "express"; import { @@ -50,7 +50,7 @@ describe("POST /conversations/:conversationId/messages", () => { let dataStreamer: ReturnType; let findNearestNeighborsOptions: Partial; let store: EmbeddedContentStore; - let conversations: ConversationsServiceInterface; + let conversations: ConversationsService; let app: Express; beforeAll(async () => { @@ -235,7 +235,7 @@ describe("POST /conversations/:conversationId/messages", () => { }); test("Should respond 500 if error with conversation service", async () => { - const mockBrokenConversationsService: ConversationsServiceInterface = { + const mockBrokenConversationsService: ConversationsService = { async create() { throw new Error("mock error"); }, @@ -333,13 +333,13 @@ describe("POST /conversations/:conversationId/messages", () => { }); let conversationId: ObjectId, - conversations: ConversationsServiceInterface, + conversations: ConversationsService, app: Express; let testMongo: MongoDB; beforeEach(async () => { const dbName = `test-${Date.now()}`; testMongo = new MongoDB(MONGODB_CONNECTION_URI, dbName); - conversations = new ConversationsService( + conversations = makeConversationsService( testMongo.db, config.llm.systemPrompt ); diff --git a/chat-server/src/routes/conversations/addMessageToConversation.ts b/chat-server/src/routes/conversations/addMessageToConversation.ts index 754699b5a..ab43bca30 100644 --- a/chat-server/src/routes/conversations/addMessageToConversation.ts +++ b/chat-server/src/routes/conversations/addMessageToConversation.ts @@ -17,7 +17,7 @@ import { } from "chat-core"; import { Conversation, - ConversationsServiceInterface, + ConversationsService, Message, conversationConstants, } from "../../services/conversations"; @@ -68,7 +68,7 @@ export const AddMessageRequest = SomeExpressRequest.merge( export interface AddMessageToConversationRouteParams { store: EmbeddedContentStore; - conversations: ConversationsServiceInterface; + conversations: ConversationsService; embed: EmbedFunc; llm: Llm; dataStreamer: DataStreamer; @@ -387,7 +387,7 @@ export async function sendStaticNonResponse({ dataStreamer, res, }: { - conversations: ConversationsServiceInterface; + conversations: ConversationsService; conversationId: ObjectId; preprocessedUserMessageContent?: string; latestMessageText: string; @@ -455,7 +455,7 @@ interface AddMessagesToDatabaseParams { preprocessedUserMessageContent?: string; assistantMessageContent: string; assistantMessageReferences: References; - conversations: ConversationsServiceInterface; + conversations: ConversationsService; } export async function addMessagesToDatabase({ conversationId, diff --git a/chat-server/src/routes/conversations/createConversation.ts b/chat-server/src/routes/conversations/createConversation.ts index 3bd787980..863cbf9d7 100644 --- a/chat-server/src/routes/conversations/createConversation.ts +++ b/chat-server/src/routes/conversations/createConversation.ts @@ -4,8 +4,12 @@ import { Request as ExpressRequest, } from "express"; import { z } from "zod"; -import { ConversationsServiceInterface } from "../../services/conversations"; -import { ApiConversation, convertConversationFromDbToApi, isValidIp } from "./utils"; +import { ConversationsService } from "../../services/conversations"; +import { + ApiConversation, + convertConversationFromDbToApi, + isValidIp, +} from "./utils"; import { getRequestId, logRequest, sendErrorResponse } from "../../utils"; import { SomeExpressRequest } from "../../middleware/validateRequestSchema"; @@ -22,7 +26,7 @@ export const CreateConversationRequest = SomeExpressRequest.merge( ); export interface CreateConversationRouteParams { - conversations: ConversationsServiceInterface; + conversations: ConversationsService; } export function makeCreateConversationRoute({ diff --git a/chat-server/src/routes/conversations/index.ts b/chat-server/src/routes/conversations/index.ts index 4d6f7ad18..0aada439f 100644 --- a/chat-server/src/routes/conversations/index.ts +++ b/chat-server/src/routes/conversations/index.ts @@ -7,7 +7,7 @@ import { OpenAiStreamingResponse, } from "../../services/llm"; import { DataStreamer } from "../../services/dataStreamer"; -import { ConversationsServiceInterface } from "../../services/conversations"; +import { ConversationsService } from "../../services/conversations"; import { EmbeddedContentStore } from "chat-core"; import { RateMessageRequest, makeRateMessageRoute } from "./rateMessage"; import { @@ -28,7 +28,7 @@ export interface ConversationsRouterParams { embed: EmbedFunc; dataStreamer: DataStreamer; store: EmbeddedContentStore; - conversations: ConversationsServiceInterface; + conversations: ConversationsService; findNearestNeighborsOptions?: Partial; searchBoosters?: SearchBooster[]; userQueryPreprocessor?: QueryPreprocessorFunc; diff --git a/chat-server/src/routes/conversations/rateMessage.test.ts b/chat-server/src/routes/conversations/rateMessage.test.ts index 2235a4d7d..c95950694 100644 --- a/chat-server/src/routes/conversations/rateMessage.test.ts +++ b/chat-server/src/routes/conversations/rateMessage.test.ts @@ -5,7 +5,7 @@ import { MongoDB } from "chat-core"; import { Conversation, Message, - ConversationsServiceInterface, + ConversationsService, } from "../../services/conversations"; import { Express } from "express"; import { ObjectId } from "mongodb"; @@ -19,7 +19,7 @@ describe("POST /conversations/:conversationId/messages/:messageId/rating", () => const endpointUrl = CONVERSATIONS_API_V1_PREFIX + "/:conversationId/messages/:messageId/rating"; let app: Express; - let conversations: ConversationsServiceInterface; + let conversations: ConversationsService; let conversation: Conversation; let testMsg: Message; let testEndpointUrl: string; diff --git a/chat-server/src/routes/conversations/rateMessage.ts b/chat-server/src/routes/conversations/rateMessage.ts index 9fdcd6643..fdc828b08 100644 --- a/chat-server/src/routes/conversations/rateMessage.ts +++ b/chat-server/src/routes/conversations/rateMessage.ts @@ -1,8 +1,7 @@ import { ObjectId } from "mongodb"; -import { strict as assert } from "assert"; import { Conversation, - ConversationsServiceInterface, + ConversationsService, } from "../../services/conversations"; import { Request as ExpressRequest, @@ -33,7 +32,7 @@ export const RateMessageRequest = SomeExpressRequest.merge( ); export interface RateMessageRouteParams { - conversations: ConversationsServiceInterface; + conversations: ConversationsService; } export function makeRateMessageRoute({ diff --git a/chat-server/src/services/conversations.test.ts b/chat-server/src/services/conversations.test.ts index 54d486f25..c0a803b1a 100644 --- a/chat-server/src/services/conversations.test.ts +++ b/chat-server/src/services/conversations.test.ts @@ -1,6 +1,6 @@ import "dotenv/config"; import { MongoDB } from "chat-core"; -import { Conversation, ConversationsService } from "./conversations"; +import { Conversation, makeConversationsService } from "./conversations"; import { BSON } from "mongodb"; import { config } from "../config"; @@ -24,7 +24,7 @@ describe("Conversations Service", () => { await mongodb.close(); }); - const conversationsService = new ConversationsService( + const conversationsService = makeConversationsService( mongodb.db, config.llm.systemPrompt ); diff --git a/chat-server/src/services/conversations.ts b/chat-server/src/services/conversations.ts index 4caa6e244..ac4c3396a 100644 --- a/chat-server/src/services/conversations.ts +++ b/chat-server/src/services/conversations.ts @@ -1,12 +1,5 @@ -import { - ObjectId, - Db, - Collection, - OpenAiChatMessage, - OpenAiMessageRole, -} from "chat-core"; +import { ObjectId, Db, OpenAiChatMessage, OpenAiMessageRole } from "chat-core"; import { References } from "chat-core"; -import { Llm } from "./llm"; import { LlmConfig } from "../AppConfig"; export interface Message { @@ -53,7 +46,7 @@ export interface RateMessageParams { messageId: ObjectId; rating: boolean; } -export interface ConversationsServiceInterface { +export interface ConversationsService { create: ({ ipAddress }: CreateConversationParams) => Promise; addConversationMessage: ({ conversationId, @@ -80,95 +73,91 @@ so I cannot respond to your message. Please try again later. However, here are some links that might provide some helpful information for your message:`, }; -export class ConversationsService implements ConversationsServiceInterface { - private database: Db; - private conversationsCollection: Collection; - private systemPrompt: LlmConfig["systemPrompt"]; - constructor(database: Db, systemPrompt: LlmConfig["systemPrompt"]) { - this.database = database; - this.conversationsCollection = - this.database.collection("conversations"); - this.systemPrompt = systemPrompt; - } - async create({ ipAddress }: CreateConversationParams) { - const newConversation = { - _id: new ObjectId(), - ipAddress, - messages: [this.createMessageFromChatMessage(this.systemPrompt)], - createdAt: new Date(), - }; - const insertResult = await this.conversationsCollection.insertOne( - newConversation - ); +export function makeConversationsService( + database: Db, + systemPrompt: LlmConfig["systemPrompt"] +): ConversationsService { + const conversationsCollection = + database.collection("conversations"); + return { + async create({ ipAddress }: CreateConversationParams) { + const newConversation = { + _id: new ObjectId(), + ipAddress, + messages: [createMessageFromOpenAIChatMessage(systemPrompt)], + createdAt: new Date(), + }; + const insertResult = await conversationsCollection.insertOne( + newConversation + ); - if ( - !insertResult.acknowledged || - insertResult.insertedId !== newConversation._id - ) { - throw new Error("Failed to insert conversation"); - } - return newConversation; - } + if ( + !insertResult.acknowledged || + insertResult.insertedId !== newConversation._id + ) { + throw new Error("Failed to insert conversation"); + } + return newConversation; + }, - async addConversationMessage({ - conversationId, - content, - role, - references, - }: AddConversationMessageParams) { - const newMessage = this.createMessageFromChatMessage({ role, content }); - if (references) { - newMessage.references = references; - } + async addConversationMessage({ + conversationId, + content, + role, + references, + }: AddConversationMessageParams) { + const newMessage = createMessageFromOpenAIChatMessage({ role, content }); + if (references) { + newMessage.references = references; + } - const updateResult = await this.conversationsCollection.updateOne( - { - _id: conversationId, - }, - { - $push: { - messages: newMessage, + const updateResult = await conversationsCollection.updateOne( + { + _id: conversationId, }, + { + $push: { + messages: newMessage, + }, + } + ); + if (!updateResult.acknowledged || updateResult.matchedCount === 0) { + throw new Error("Failed to insert message"); } - ); - if (!updateResult.acknowledged || updateResult.matchedCount === 0) { - throw new Error("Failed to insert message"); - } - return newMessage; - } + return newMessage; + }, - async findById({ _id }: FindByIdParams) { - const conversation = await this.conversationsCollection.findOne({ _id }); - return conversation; - } + async findById({ _id }: FindByIdParams) { + const conversation = await conversationsCollection.findOne({ _id }); + return conversation; + }, - async rateMessage({ conversationId, messageId, rating }: RateMessageParams) { - const updateResult = await this.conversationsCollection.updateOne( - { - _id: conversationId, - }, - { - $set: { - "messages.$[message].rating": rating, + async rateMessage({ + conversationId, + messageId, + rating, + }: RateMessageParams) { + const updateResult = await conversationsCollection.updateOne( + { + _id: conversationId, + }, + { + $set: { + "messages.$[message].rating": rating, + }, }, - }, - { - arrayFilters: [ - { "message.id": messageId, "message.role": "assistant" }, - ], + { + arrayFilters: [ + { "message.id": messageId, "message.role": "assistant" }, + ], + } + ); + if (!updateResult.acknowledged || updateResult.matchedCount === 0) { + throw new Error("Failed to rate message"); } - ); - if (!updateResult.acknowledged || updateResult.matchedCount === 0) { - throw new Error("Failed to rate message"); - } - return true; - } - - private createMessageFromChatMessage( - chatMessage: OpenAiChatMessage - ): Message { - return createMessageFromOpenAIChatMessage(chatMessage); - } + return true; + }, + }; } export function createMessageFromOpenAIChatMessage( diff --git a/chat-server/src/testHelpers.ts b/chat-server/src/testHelpers.ts index dbe261308..f7b860974 100644 --- a/chat-server/src/testHelpers.ts +++ b/chat-server/src/testHelpers.ts @@ -7,7 +7,7 @@ import { import { makeOpenAiLlm } from "./services/llm"; import { makeDataStreamer } from "./services/dataStreamer"; -import { ConversationsService } from "./services/conversations"; +import { makeConversationsService } from "./services/conversations"; import { makeApp } from "./app"; import { config as conf } from "./config"; @@ -32,7 +32,7 @@ export async function makeConversationsRoutesDefaults() { const searchBoosters = conf.conversations!.searchBoosters; const userQueryPreprocessor = conf.conversations!.userQueryPreprocessor; - const conversations = new ConversationsService( + const conversations = makeConversationsService( mongodb.db, conf.llm.systemPrompt );