diff --git a/backend/src/app.ts b/backend/src/app.ts index bc3645611..f2bc64c54 100644 --- a/backend/src/app.ts +++ b/backend/src/app.ts @@ -6,21 +6,26 @@ import session from "express-session"; import { getInitialDefences } from "./defence"; import { setOpenAiApiKey } from "./openai"; import { router } from "./router"; -import { ChatCompletionRequestMessage } from "openai"; +import { ChatHistoryMessage } from "./models/chat"; import { EmailInfo } from "./models/email"; import { DefenceInfo } from "./models/defence"; import { CHAT_MODELS } from "./models/chat"; +import { PHASE_NAMES } from "./models/phase"; import { retrievalQAPrePrompt } from "./promptTemplates"; dotenv.config(); declare module "express-session" { interface Session { - chatHistory: ChatCompletionRequestMessage[]; - defences: DefenceInfo[]; + openAiApiKey: string | null; gptModel: CHAT_MODELS; + phaseState: PhaseState[]; numPhasesCompleted: number; - openAiApiKey: string | null; + } + interface PhaseState { + phase: PHASE_NAMES; + chatHistory: ChatHistoryMessage[]; + defences: DefenceInfo[]; sentEmails: EmailInfo[]; } } @@ -65,12 +70,6 @@ app.use( app.use(async (req, _res, next) => { // initialise session variables - if (!req.session.chatHistory) { - req.session.chatHistory = []; - } - if (!req.session.defences) { - req.session.defences = getInitialDefences(); - } if (!req.session.gptModel) { req.session.gptModel = defaultModel; } @@ -80,8 +79,20 @@ app.use(async (req, _res, next) => { if (!req.session.openAiApiKey) { req.session.openAiApiKey = process.env.OPENAI_API_KEY || null; } - if (!req.session.sentEmails) { - req.session.sentEmails = []; + if (!req.session.phaseState) { + req.session.phaseState = []; + // add empty states for phases 0-3 + Object.values(PHASE_NAMES).forEach((value) => { + if (isNaN(Number(value))) { + req.session.phaseState.push({ + phase: value as PHASE_NAMES, + chatHistory: [], + defences: getInitialDefences(), + sentEmails: [], + }); + } + }); + console.log("Initialised phase state: ", req.session.phaseState); } next(); }); @@ -89,7 +100,7 @@ app.use(async (req, _res, next) => { app.use("/", router); app.listen(port, () => { console.log("Server is running on port: " + port); - + // for dev purposes only - set the API key from the environment variable const envOpenAiKey = process.env.OPENAI_API_KEY; const prePrompt = retrievalQAPrePrompt; diff --git a/backend/src/models/chat.ts b/backend/src/models/chat.ts index 56b73735b..0ce3cd045 100644 --- a/backend/src/models/chat.ts +++ b/backend/src/models/chat.ts @@ -14,9 +14,15 @@ enum CHAT_MODELS { enum CHAT_MESSAGE_TYPE { BOT, + BOT_BLOCKED, INFO, USER, + USER_TRANSFORMED, PHASE_INFO, + DEFENCE_ALERTED, + DEFENCE_TRIGGERED, + SYSTEM, + FUNCTION_CALL, } interface ChatDefenceReport { @@ -50,11 +56,18 @@ interface ChatHttpResponse { wonPhase: boolean; } +interface ChatHistoryMessage { + completion: ChatCompletionRequestMessage | null; + chatMessageType: CHAT_MESSAGE_TYPE; + infoMessage?: string | null; +} + export type { ChatAnswer, ChatDefenceReport, ChatMalicious, ChatResponse, ChatHttpResponse, + ChatHistoryMessage, }; export { CHAT_MODELS, CHAT_MESSAGE_TYPE }; diff --git a/backend/src/models/phase.ts b/backend/src/models/phase.ts index b2196a02d..944e6e881 100644 --- a/backend/src/models/phase.ts +++ b/backend/src/models/phase.ts @@ -1,5 +1,5 @@ enum PHASE_NAMES { - PHASE_0 = 0, + PHASE_0, PHASE_1, PHASE_2, SANDBOX, diff --git a/backend/src/openai.ts b/backend/src/openai.ts index fe811b2df..d5890efda 100644 --- a/backend/src/openai.ts +++ b/backend/src/openai.ts @@ -17,7 +17,12 @@ import { Configuration, OpenAIApi, } from "openai"; -import { CHAT_MODELS, ChatDefenceReport } from "./models/chat"; +import { + CHAT_MESSAGE_TYPE, + CHAT_MODELS, + ChatDefenceReport, + ChatHistoryMessage, +} from "./models/chat"; import { DEFENCE_TYPES, DefenceInfo } from "./models/defence"; import { PHASE_NAMES } from "./models/phase"; @@ -223,7 +228,7 @@ async function chatGptCallFunction( } async function chatGptChatCompletion( - chatHistory: ChatCompletionRequestMessage[], + chatHistory: ChatHistoryMessage[], defences: DefenceInfo[], gptModel: CHAT_MODELS, openai: OpenAIApi, @@ -237,23 +242,29 @@ async function chatGptChatCompletion( isDefenceActive(DEFENCE_TYPES.SYSTEM_ROLE, defences) ) { // check to see if there's already a system role - if (!chatHistory.find((message) => message.role === "system")) { + if (!chatHistory.find((message) => message.completion?.role === "system")) { // add the system role to the start of the chat history chatHistory.unshift({ - role: "system", - content: getSystemRole(defences, currentPhase), + completion: { + role: "system", + content: getSystemRole(defences, currentPhase), + }, + chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM, }); } } else { // remove the system role from the chat history - while (chatHistory.length > 0 && chatHistory[0].role === "system") { + while ( + chatHistory.length > 0 && + chatHistory[0].completion?.role === "system" + ) { chatHistory.shift(); } } const chat_completion = await openai.createChatCompletion({ model: gptModel, - messages: chatHistory, + messages: getChatCompletionsFromHistory(chatHistory), functions: chatGptFunctions, }); @@ -261,11 +272,42 @@ async function chatGptChatCompletion( return chat_completion.data.choices[0].message || null; } +// take only the completions to send to GPT +const getChatCompletionsFromHistory = ( + chatHistory: ChatHistoryMessage[] +): ChatCompletionRequestMessage[] => { + const completions: ChatCompletionRequestMessage[] = + chatHistory.length > 0 + ? chatHistory + .filter((message) => message.completion !== null) + .map((message) => message.completion as ChatCompletionRequestMessage) + : []; + return completions; +}; + +const pushCompletionToHistory = ( + chatHistory: ChatHistoryMessage[], + completion: ChatCompletionRequestMessage, + messageType: CHAT_MESSAGE_TYPE +) => { + if (messageType !== CHAT_MESSAGE_TYPE.BOT_BLOCKED) { + chatHistory.push({ + completion: completion, + chatMessageType: messageType, + }); + } else { + // do not add the bots reply which was subsequently blocked + console.log("Skipping adding blocked message to chat history", completion); + } + return chatHistory; +}; + async function chatGptSendMessage( - chatHistory: ChatCompletionRequestMessage[], + chatHistory: ChatHistoryMessage[], defences: DefenceInfo[], gptModel: CHAT_MODELS, message: string, + messageIsTransformed: boolean, openAiApiKey: string, sentEmails: EmailInfo[], // default to sandbox @@ -283,7 +325,16 @@ async function chatGptSendMessage( let wonPhase: boolean | undefined | null = false; // add user message to chat - chatHistory.push({ role: "user", content: message }); + chatHistory = pushCompletionToHistory( + chatHistory, + { + role: "user", + content: message, + }, + messageIsTransformed + ? CHAT_MESSAGE_TYPE.USER_TRANSFORMED + : CHAT_MESSAGE_TYPE.USER + ); const openai = getOpenAiFromKey(openAiApiKey); let reply = await chatGptChatCompletion( @@ -295,7 +346,11 @@ async function chatGptSendMessage( ); // check if GPT wanted to call a function while (reply && reply.function_call) { - chatHistory.push(reply); + chatHistory = pushCompletionToHistory( + chatHistory, + reply, + CHAT_MESSAGE_TYPE.FUNCTION_CALL + ); // call the function and get a new reply and defence info from const functionCallReply = await chatGptCallFunction( @@ -309,12 +364,15 @@ async function chatGptSendMessage( wonPhase = functionCallReply.wonPhase; // add the function call to the chat history if (functionCallReply.completion !== undefined) { - chatHistory.push(functionCallReply.completion); + chatHistory = pushCompletionToHistory( + chatHistory, + functionCallReply.completion, + CHAT_MESSAGE_TYPE.FUNCTION_CALL + ); } // update the defence info defenceInfo = functionCallReply.defenceInfo; } - // get a new reply from ChatGPT now that the function has been called reply = await chatGptChatCompletion( chatHistory, @@ -326,7 +384,6 @@ async function chatGptSendMessage( } if (reply && reply.content) { - console.debug("GPT reply: " + reply.content); // if output filter defence is active, check for blocked words/phrases if ( currentPhase === PHASE_NAMES.PHASE_2 || @@ -353,9 +410,17 @@ async function chatGptSendMessage( } } // add the ai reply to the chat history - chatHistory.push(reply); + chatHistory = pushCompletionToHistory( + chatHistory, + reply, + defenceInfo.isBlocked + ? CHAT_MESSAGE_TYPE.BOT_BLOCKED + : CHAT_MESSAGE_TYPE.BOT + ); + // log the entire chat history so far console.log(chatHistory); + return { completion: reply, defenceInfo: defenceInfo, diff --git a/backend/src/router.ts b/backend/src/router.ts index 75c97b707..1466f67a9 100644 --- a/backend/src/router.ts +++ b/backend/src/router.ts @@ -10,7 +10,11 @@ import { getInitialDefences, } from "./defence"; import { initQAModel } from "./langchain"; -import { CHAT_MODELS, ChatHttpResponse } from "./models/chat"; +import { + CHAT_MESSAGE_TYPE, + CHAT_MODELS, + ChatHttpResponse, +} from "./models/chat"; import { DEFENCE_TYPES, DefenceConfig } from "./models/defence"; import { chatGptSendMessage, setOpenAiApiKey, setGptModel } from "./openai"; import { retrievalQAPrePrompt } from "./promptTemplates"; @@ -25,10 +29,13 @@ let prevPhase: PHASE_NAMES = PHASE_NAMES.SANDBOX; router.post("/defence/activate", (req, res) => { // id of the defence const defenceId: DEFENCE_TYPES = req.body?.defenceId; - if (defenceId) { + const phase: PHASE_NAMES = req.body?.phase; + if (defenceId && phase) { // activate the defence - req.session.defences = activateDefence(defenceId, req.session.defences); - + req.session.phaseState[phase].defences = activateDefence( + defenceId, + req.session.phaseState[phase].defences + ); // need to re-initialize QA model when turned on if ( defenceId === DEFENCE_TYPES.QA_LLM_INSTRUCTIONS && @@ -39,14 +46,14 @@ router.post("/defence/activate", (req, res) => { ); initQAModel( req.session.openAiApiKey, - getQALLMprePrompt(req.session.defences) + getQALLMprePrompt(req.session.phaseState[phase].defences) ); } res.send("Defence activated"); } else { res.statusCode = 400; - res.send("Missing defenceId"); + res.send("Missing defenceId or phase"); } }); @@ -54,9 +61,13 @@ router.post("/defence/activate", (req, res) => { router.post("/defence/deactivate", (req, res) => { // id of the defence const defenceId: DEFENCE_TYPES = req.body?.defenceId; - if (defenceId) { + const phase: PHASE_NAMES = req.body?.phase; + if (defenceId && phase) { // deactivate the defence - req.session.defences = deactivateDefence(defenceId, req.session.defences); + req.session.phaseState[phase].defences = deactivateDefence( + defenceId, + req.session.phaseState[phase].defences + ); if ( defenceId === DEFENCE_TYPES.QA_LLM_INSTRUCTIONS && @@ -65,13 +76,13 @@ router.post("/defence/deactivate", (req, res) => { console.debug("Resetting QA model with default prompt"); initQAModel( req.session.openAiApiKey, - getQALLMprePrompt(req.session.defences) + getQALLMprePrompt(req.session.phaseState[phase].defences) ); } res.send("Defence deactivated"); } else { res.statusCode = 400; - res.send("Missing defenceId"); + res.send("Missing defenceId or phase"); } }); @@ -80,41 +91,67 @@ router.post("/defence/configure", (req, res) => { // id of the defence const defenceId: DEFENCE_TYPES = req.body?.defenceId; const config: DefenceConfig[] = req.body?.config; - if (defenceId && config) { + const phase: PHASE_NAMES = req.body?.phase; + if (defenceId && config && phase >= 0) { // configure the defence - req.session.defences = configureDefence( + req.session.phaseState[phase].defences = configureDefence( defenceId, - req.session.defences, + req.session.phaseState[phase].defences, config ); res.send("Defence configured"); } else { res.statusCode = 400; - res.send("Missing defenceId or config"); + res.send("Missing defenceId or config or phase"); } }); // reset the active defences router.post("/defence/reset", (req, res) => { - req.session.defences = getInitialDefences(); - console.debug("Defences reset"); - res.send("Defences reset"); + const phase: PHASE_NAMES = req.body?.phase; + if (phase >= 0) { + req.session.phaseState[phase].defences = getInitialDefences(); + console.debug("Defences reset"); + res.send("Defences reset"); + } else { + res.statusCode = 400; + res.send("Missing phase"); + } }); -// Get the status of all defences +// Get the status of all defences /defence/status?phase=1 router.get("/defence/status", (req, res) => { - res.send(req.session.defences); + const phase: number | undefined = req.query?.phase as number | undefined; + if (phase) { + res.send(req.session.phaseState[phase].defences); + } else { + res.statusCode = 400; + res.send("Missing phase"); + } }); -// Get sent emails +// Get sent emails // /email/get?phase=1 router.get("/email/get", (req, res) => { - res.send(req.session.sentEmails); + const phase: number | undefined = req.query?.phase as number | undefined; + if (phase) { + res.send(req.session.phaseState[phase].sentEmails); + } else { + res.statusCode = 400; + res.send("Missing phase"); + } }); // clear emails router.post("/email/clear", (req, res) => { - req.session.sentEmails = []; - res.send("Emails cleared"); + const phase: PHASE_NAMES = req.body?.phase; + if (phase >= 0) { + req.session.phaseState[phase].sentEmails = []; + console.debug("Emails cleared"); + res.send("Emails cleared"); + } else { + res.statusCode = 400; + res.send("Missing phase"); + } }); // Chat to ChatGPT @@ -160,26 +197,44 @@ router.post("/openai/chat", async (req, res) => { ) { chatResponse.defenceInfo = await detectTriggeredDefences( message, - req.session.defences + req.session.phaseState[currentPhase].defences ); + // if message is blocked, add to chat history (not as completion) + if (chatResponse.defenceInfo.isBlocked) { + req.session.phaseState[currentPhase].chatHistory.push({ + completion: null, + chatMessageType: CHAT_MESSAGE_TYPE.USER, + infoMessage: message, + }); + } } // if blocked, send the response if (!chatResponse.defenceInfo.isBlocked) { // transform the message according to active defences chatResponse.transformedMessage = transformMessage( message, - req.session.defences + req.session.phaseState[currentPhase].defences ); - + // if message has been transformed then add the original to chat history and send transformed to chatGPT + const messageIsTransformed = + chatResponse.transformedMessage !== message; + if (messageIsTransformed) { + req.session.phaseState[currentPhase].chatHistory.push({ + completion: null, + chatMessageType: CHAT_MESSAGE_TYPE.USER, + infoMessage: message, + }); + } // get the chatGPT reply try { const openAiReply = await chatGptSendMessage( - req.session.chatHistory, - req.session.defences, + req.session.phaseState[currentPhase].chatHistory, + req.session.phaseState[currentPhase].defences, req.session.gptModel, chatResponse.transformedMessage, + messageIsTransformed, req.session.openAiApiKey, - req.session.sentEmails, + req.session.phaseState[currentPhase].sentEmails, currentPhase ); @@ -216,6 +271,15 @@ router.post("/openai/chat", async (req, res) => { } } + // if the reply was blocked then add it to the chat history + if (chatResponse.defenceInfo.isBlocked) { + req.session.phaseState[currentPhase].chatHistory.push({ + completion: null, + chatMessageType: CHAT_MESSAGE_TYPE.BOT_BLOCKED, + infoMessage: chatResponse.defenceInfo.blockedReason, + }); + } + // enable next phase when user wins current phase if (chatResponse.wonPhase) { console.log("Win conditon met for phase: ", currentPhase); @@ -229,16 +293,51 @@ router.post("/openai/chat", async (req, res) => { console.error(chatResponse.reply); } } - // log and send the reply with defence info console.log(chatResponse); res.send(chatResponse); }); +// get the chat history +router.get("/openai/history", (req, res) => { + const phase: number | undefined = req.query?.phase as number | undefined; + if (phase) { + res.send(req.session.phaseState[phase].chatHistory); + } else { + res.statusCode = 400; + res.send("Missing phase"); + } +}); + +// add an info message to chat history +router.post("/openai/addHistory", (req, res) => { + const message: string = req.body?.message; + const chatMessageType: CHAT_MESSAGE_TYPE = req.body?.chatMessageType; + const phase: PHASE_NAMES = req.body?.phase; + if (message && chatMessageType && phase >= 0) { + req.session.phaseState[phase].chatHistory.push({ + completion: null, + chatMessageType: chatMessageType, + infoMessage: message, + }); + res.send("Message added to chat history"); + } else { + res.statusCode = 400; + res.send("Missing message or message type or phase"); + } +}); + // Clear the ChatGPT messages router.post("/openai/clear", (req, res) => { - req.session.chatHistory = []; - res.send("ChatGPT messages cleared"); + const phase: PHASE_NAMES = req.body?.phase; + if (phase >= 0) { + req.session.phaseState[phase].chatHistory = []; + console.debug("ChatGPT messages cleared"); + res.send("ChatGPT messages cleared"); + } else { + res.statusCode = 400; + res.send("Missing phase"); + } }); // Set API key @@ -252,7 +351,7 @@ router.post("/openai/apiKey", async (req, res) => { await setOpenAiApiKey( openAiApiKey, req.session.gptModel, - getQALLMprePrompt(req.session.defences) + getQALLMprePrompt(req.session.phaseState[3].defences) // use phase 2 as only phase with QA LLM defence ) ) { req.session.openAiApiKey = openAiApiKey; diff --git a/backend/test/integration/openai.test.ts b/backend/test/integration/openai.test.ts index 7f47eb397..371e3f951 100644 --- a/backend/test/integration/openai.test.ts +++ b/backend/test/integration/openai.test.ts @@ -1,5 +1,8 @@ -import { ChatCompletionRequestMessage } from "openai"; -import { CHAT_MODELS } from "../../src/models/chat"; +import { + CHAT_MESSAGE_TYPE, + CHAT_MODELS, + ChatHistoryMessage, +} from "../../src/models/chat"; import { chatGptSendMessage } from "../../src/openai"; import { DEFENCE_TYPES, DefenceInfo } from "../../src/models/defence"; import { EmailInfo } from "../../src/models/email"; @@ -36,7 +39,7 @@ beforeEach(() => { test("GIVEN OpenAI initialised WHEN sending message THEN reply is returned", async () => { const message = "Hello"; - const chatHistory: ChatCompletionRequestMessage[] = []; + const chatHistory: ChatHistoryMessage[] = []; const defences: DefenceInfo[] = []; const sentEmails: EmailInfo[] = []; const gptModel = CHAT_MODELS.GPT_4; @@ -62,6 +65,7 @@ test("GIVEN OpenAI initialised WHEN sending message THEN reply is returned", asy defences, gptModel, message, + true, openAiApiKey, sentEmails ); @@ -71,10 +75,10 @@ test("GIVEN OpenAI initialised WHEN sending message THEN reply is returned", asy expect(reply?.completion.content).toBe("Hi"); // check the chat history has been updated expect(chatHistory.length).toBe(2); - expect(chatHistory[0].role).toBe("user"); - expect(chatHistory[0].content).toBe("Hello"); - expect(chatHistory[1].role).toBe("assistant"); - expect(chatHistory[1].content).toBe("Hi"); + expect(chatHistory[0].completion?.role).toBe("user"); + expect(chatHistory[0].completion?.content).toBe("Hello"); + expect(chatHistory[1].completion?.role).toBe("assistant"); + expect(chatHistory[1].completion?.content).toBe("Hi"); // restore the mock mockCreateChatCompletion.mockRestore(); @@ -85,7 +89,7 @@ test("GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role process.env.SYSTEM_ROLE = "You are a helpful assistant"; const message = "Hello"; - const chatHistory: ChatCompletionRequestMessage[] = []; + const chatHistory: ChatHistoryMessage[] = []; let defences: DefenceInfo[] = getInitialDefences(); const sentEmails: EmailInfo[] = []; const gptModel = CHAT_MODELS.GPT_4; @@ -113,6 +117,7 @@ test("GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role defences, gptModel, message, + true, openAiApiKey, sentEmails ); @@ -122,12 +127,12 @@ test("GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role // check the chat history has been updated expect(chatHistory.length).toBe(3); // system role is added to the start of the chat history - expect(chatHistory[0].role).toBe("system"); - expect(chatHistory[0].content).toBe(process.env.SYSTEM_ROLE); - expect(chatHistory[1].role).toBe("user"); - expect(chatHistory[1].content).toBe("Hello"); - expect(chatHistory[2].role).toBe("assistant"); - expect(chatHistory[2].content).toBe("Hi"); + expect(chatHistory[0].completion?.role).toBe("system"); + expect(chatHistory[0].completion?.content).toBe(process.env.SYSTEM_ROLE); + expect(chatHistory[1].completion?.role).toBe("user"); + expect(chatHistory[1].completion?.content).toBe("Hello"); + expect(chatHistory[2].completion?.role).toBe("assistant"); + expect(chatHistory[2].completion?.content).toBe("Hi"); // restore the mock mockCreateChatCompletion.mockRestore(); @@ -138,14 +143,21 @@ test("GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role process.env.SYSTEM_ROLE = "You are a helpful assistant"; const message = "Hello"; - const chatHistory: ChatCompletionRequestMessage[] = [ + const isOriginalMessage = true; + const chatHistory: ChatHistoryMessage[] = [ { - role: "user", - content: "I'm a user", + completion: { + role: "user", + content: "I'm a user", + }, + chatMessageType: CHAT_MESSAGE_TYPE.USER, }, { - role: "assistant", - content: "I'm an assistant", + completion: { + role: "assistant", + content: "I'm an assistant", + }, + chatMessageType: CHAT_MESSAGE_TYPE.BOT, }, ]; let defences: DefenceInfo[] = getInitialDefences(); @@ -176,6 +188,7 @@ test("GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role defences, gptModel, message, + isOriginalMessage, openAiApiKey, sentEmails ); @@ -185,17 +198,17 @@ test("GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role // check the chat history has been updated expect(chatHistory.length).toBe(5); // system role is added to the start of the chat history - expect(chatHistory[0].role).toBe("system"); - expect(chatHistory[0].content).toBe(process.env.SYSTEM_ROLE); + expect(chatHistory[0].completion?.role).toBe("system"); + expect(chatHistory[0].completion?.content).toBe(process.env.SYSTEM_ROLE); // rest of the chat history is in order - expect(chatHistory[1].role).toBe("user"); - expect(chatHistory[1].content).toBe("I'm a user"); - expect(chatHistory[2].role).toBe("assistant"); - expect(chatHistory[2].content).toBe("I'm an assistant"); - expect(chatHistory[3].role).toBe("user"); - expect(chatHistory[3].content).toBe("Hello"); - expect(chatHistory[4].role).toBe("assistant"); - expect(chatHistory[4].content).toBe("Hi"); + expect(chatHistory[1].completion?.role).toBe("user"); + expect(chatHistory[1].completion?.content).toBe("I'm a user"); + expect(chatHistory[2].completion?.role).toBe("assistant"); + expect(chatHistory[2].completion?.content).toBe("I'm an assistant"); + expect(chatHistory[3].completion?.role).toBe("user"); + expect(chatHistory[3].completion?.content).toBe("Hello"); + expect(chatHistory[4].completion?.role).toBe("assistant"); + expect(chatHistory[4].completion?.content).toBe("Hi"); // restore the mock mockCreateChatCompletion.mockRestore(); @@ -203,18 +216,27 @@ test("GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role test("GIVEN SYSTEM_ROLE defence is inactive WHEN sending message THEN system role is removed from the chat history", async () => { const message = "Hello"; - const chatHistory: ChatCompletionRequestMessage[] = [ + const chatHistory: ChatHistoryMessage[] = [ { - role: "system", - content: "You are a helpful assistant", + completion: { + role: "system", + content: "You are a helpful assistant", + }, + chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM, }, { - role: "user", - content: "I'm a user", + completion: { + role: "user", + content: "I'm a user", + }, + chatMessageType: CHAT_MESSAGE_TYPE.USER, }, { - role: "assistant", - content: "I'm an assistant", + completion: { + role: "assistant", + content: "I'm an assistant", + }, + chatMessageType: CHAT_MESSAGE_TYPE.BOT, }, ]; let defences: DefenceInfo[] = getInitialDefences(); @@ -242,6 +264,7 @@ test("GIVEN SYSTEM_ROLE defence is inactive WHEN sending message THEN system rol defences, gptModel, message, + true, openAiApiKey, sentEmails ); @@ -252,14 +275,14 @@ test("GIVEN SYSTEM_ROLE defence is inactive WHEN sending message THEN system rol expect(chatHistory.length).toBe(4); // system role is removed from the start of the chat history // rest of the chat history is in order - expect(chatHistory[0].role).toBe("user"); - expect(chatHistory[0].content).toBe("I'm a user"); - expect(chatHistory[1].role).toBe("assistant"); - expect(chatHistory[1].content).toBe("I'm an assistant"); - expect(chatHistory[2].role).toBe("user"); - expect(chatHistory[2].content).toBe("Hello"); - expect(chatHistory[3].role).toBe("assistant"); - expect(chatHistory[3].content).toBe("Hi"); + expect(chatHistory[0].completion?.role).toBe("user"); + expect(chatHistory[0].completion?.content).toBe("I'm a user"); + expect(chatHistory[1].completion?.role).toBe("assistant"); + expect(chatHistory[1].completion?.content).toBe("I'm an assistant"); + expect(chatHistory[2].completion?.role).toBe("user"); + expect(chatHistory[2].completion?.content).toBe("Hello"); + expect(chatHistory[3].completion?.role).toBe("assistant"); + expect(chatHistory[3].completion?.content).toBe("Hi"); // restore the mock mockCreateChatCompletion.mockRestore(); @@ -273,18 +296,27 @@ test( process.env.SYSTEM_ROLE = "You are a helpful assistant"; const message = "Hello"; - const chatHistory: ChatCompletionRequestMessage[] = [ + const chatHistory: ChatHistoryMessage[] = [ { - role: "system", - content: "You are a helpful assistant", + completion: { + role: "system", + content: "You are a helpful assistant", + }, + chatMessageType: CHAT_MESSAGE_TYPE.SYSTEM, }, { - role: "user", - content: "I'm a user", + completion: { + role: "user", + content: "I'm a user", + }, + chatMessageType: CHAT_MESSAGE_TYPE.USER, }, { - role: "assistant", - content: "I'm an assistant", + completion: { + role: "assistant", + content: "I'm an assistant", + }, + chatMessageType: CHAT_MESSAGE_TYPE.BOT, }, ]; let defences: DefenceInfo[] = getInitialDefences(); @@ -314,6 +346,7 @@ test( defences, gptModel, message, + true, openAiApiKey, sentEmails ); @@ -323,17 +356,17 @@ test( // check the chat history has been updated expect(chatHistory.length).toBe(5); // system role is added to the start of the chat history - expect(chatHistory[0].role).toBe("system"); - expect(chatHistory[0].content).toBe(process.env.SYSTEM_ROLE); + expect(chatHistory[0].completion?.role).toBe("system"); + expect(chatHistory[0].completion?.content).toBe(process.env.SYSTEM_ROLE); // rest of the chat history is in order - expect(chatHistory[1].role).toBe("user"); - expect(chatHistory[1].content).toBe("I'm a user"); - expect(chatHistory[2].role).toBe("assistant"); - expect(chatHistory[2].content).toBe("I'm an assistant"); - expect(chatHistory[3].role).toBe("user"); - expect(chatHistory[3].content).toBe("Hello"); - expect(chatHistory[4].role).toBe("assistant"); - expect(chatHistory[4].content).toBe("Hi"); + expect(chatHistory[1].completion?.role).toBe("user"); + expect(chatHistory[1].completion?.content).toBe("I'm a user"); + expect(chatHistory[2].completion?.role).toBe("assistant"); + expect(chatHistory[2].completion?.content).toBe("I'm an assistant"); + expect(chatHistory[3].completion?.role).toBe("user"); + expect(chatHistory[3].completion?.content).toBe("Hello"); + expect(chatHistory[4].completion?.role).toBe("assistant"); + expect(chatHistory[4].completion?.content).toBe("Hi"); // restore the mock mockCreateChatCompletion.mockRestore(); @@ -349,7 +382,7 @@ test( process.env.EMAIL_WHITELIST = ""; const message = "Hello"; - const chatHistory: ChatCompletionRequestMessage[] = []; + const chatHistory: ChatHistoryMessage[] = []; const defences: DefenceInfo[] = getInitialDefences(); const sentEmails: EmailInfo[] = []; const gptModel = CHAT_MODELS.GPT_4; @@ -395,6 +428,7 @@ test( defences, gptModel, message, + true, openAiApiKey, sentEmails ); @@ -428,11 +462,12 @@ test( process.env.EMAIL_WHITELIST = ""; const message = "Hello"; - const chatHistory: ChatCompletionRequestMessage[] = []; + const chatHistory: ChatHistoryMessage[] = []; let defences: DefenceInfo[] = getInitialDefences(); const sentEmails: EmailInfo[] = []; const gptModel = CHAT_MODELS.GPT_4; const openAiApiKey = "sk-12345"; + const isOriginalMessage = true; defences = activateDefence(DEFENCE_TYPES.EMAIL_WHITELIST, defences); @@ -476,6 +511,7 @@ test( defences, gptModel, message, + isOriginalMessage, openAiApiKey, sentEmails ); @@ -506,11 +542,12 @@ test( process.env.EMAIL_WHITELIST = "bob@example.com"; const message = "Send an email to bob@example.com saying hi"; - const chatHistory: ChatCompletionRequestMessage[] = []; + const chatHistory: ChatHistoryMessage[] = []; let defences: DefenceInfo[] = getInitialDefences(); const sentEmails: EmailInfo[] = []; const gptModel = CHAT_MODELS.GPT_4; const openAiApiKey = "sk-12345"; + const isOriginalMessage = true; defences = activateDefence(DEFENCE_TYPES.EMAIL_WHITELIST, defences); @@ -554,6 +591,7 @@ test( defences, gptModel, message, + isOriginalMessage, openAiApiKey, sentEmails ); @@ -584,11 +622,12 @@ test( process.env.EMAIL_WHITELIST = "bob@example.com"; const message = "Send an email to bob@example.com saying hi"; - const chatHistory: ChatCompletionRequestMessage[] = []; + const chatHistory: ChatHistoryMessage[] = []; const defences: DefenceInfo[] = getInitialDefences(); const sentEmails: EmailInfo[] = []; const gptModel = CHAT_MODELS.GPT_4; const openAiApiKey = "sk-12345"; + const isOriginalMessage = true; // set email whitelist process.env.EMAIL_WHITELIST = "bob@example.com"; @@ -633,6 +672,7 @@ test( defences, gptModel, message, + isOriginalMessage, openAiApiKey, sentEmails ); @@ -657,11 +697,12 @@ test("GIVEN the output filtering defence is active WHEN the bot responds with a process.env.FILTER_LIST_OUTPUT = "secret project,password"; const message = "What is the secret Project?"; - const chatHistory: ChatCompletionRequestMessage[] = []; + const chatHistory: ChatHistoryMessage[] = []; let defences: DefenceInfo[] = getInitialDefences(); const sentEmails: EmailInfo[] = []; const gptModel = CHAT_MODELS.GPT_4; const openAiApiKey = "sk-12345"; + const isOriginalMessage = true; defences = activateDefence(DEFENCE_TYPES.FILTER_BOT_OUTPUT, defences); @@ -682,6 +723,7 @@ test("GIVEN the output filtering defence is active WHEN the bot responds with a defences, gptModel, message, + isOriginalMessage, openAiApiKey, sentEmails ); @@ -700,11 +742,12 @@ test("GIVEN the output filtering defence is active WHEN the bot responds with a process.env.FILTER_LIST_OUTPUT = "secret project,password"; const message = "What is the secret Project?"; - const chatHistory: ChatCompletionRequestMessage[] = []; + const chatHistory: ChatHistoryMessage[] = []; let defences: DefenceInfo[] = getInitialDefences(); const sentEmails: EmailInfo[] = []; const gptModel = CHAT_MODELS.GPT_4; const openAiApiKey = "sk-12345"; + const isOriginalMessage = true; defences = activateDefence(DEFENCE_TYPES.FILTER_BOT_OUTPUT, defences); @@ -725,6 +768,7 @@ test("GIVEN the output filtering defence is active WHEN the bot responds with a defences, gptModel, message, + isOriginalMessage, openAiApiKey, sentEmails ); @@ -745,11 +789,12 @@ test( process.env.FILTER_LIST_OUTPUT = "secret project,password"; const message = "What is the secret Project?"; - const chatHistory: ChatCompletionRequestMessage[] = []; + const chatHistory: ChatHistoryMessage[] = []; let defences: DefenceInfo[] = getInitialDefences(); const sentEmails: EmailInfo[] = []; const gptModel = CHAT_MODELS.GPT_4; const openAiApiKey = "sk-12345"; + const isOriginalMessage = true; mockCreateChatCompletion.mockResolvedValueOnce({ data: { @@ -768,6 +813,7 @@ test( defences, gptModel, message, + isOriginalMessage, openAiApiKey, sentEmails ); diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 1c8feace7..0105bb8c5 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -4,23 +4,31 @@ import DemoHeader from "./components/MainComponent/DemoHeader"; import DemoBody from "./components/MainComponent/DemoBody"; import { useEffect, useState } from "react"; import { PHASE_NAMES } from "./models/phase"; -import { clearChat } from "./service/chatService"; +import { + getChatHistory, + addMessageToChatHistory, + clearChat, +} from "./service/chatService"; import { EmailInfo } from "./models/email"; -import { clearEmails } from "./service/emailService"; +import { clearEmails, getSentEmails } from "./service/emailService"; import { CHAT_MESSAGE_TYPE, ChatMessage } from "./models/chat"; import { resetActiveDefences } from "./service/defenceService"; -import { PHASES } from "./Phases"; import { DEFENCE_DETAILS_ALL, DEFENCE_DETAILS_PHASE } from "./Defences"; import { DefenceInfo } from "./models/defence"; import { getCompletedPhases } from "./service/phaseService"; +import { PHASES } from "./Phases"; function App() { // start on sandbox mode - const [currentPhase, setCurrentPhase] = useState(PHASE_NAMES.SANDBOX); + const [currentPhase, setCurrentPhase] = useState( + PHASE_NAMES.SANDBOX + ); const [numCompletedPhases, setNumCompletedPhases] = useState(0); - + const [defencesToShow, setDefencesToShow] = useState(DEFENCE_DETAILS_ALL); + + const [triggeredDefences, setTriggeredDefences] = useState([]); const [emails, setEmails] = useState([]); const [messages, setMessages] = useState([]); @@ -31,64 +39,81 @@ function App() { }); setNewPhase(currentPhase); }, []); - - function addChatMessage(message: ChatMessage) { - setMessages((messages: ChatMessage[]) => [...messages, message]); - }; - function addPhasePreambleMessage (message: string) { - addChatMessage({ - message: message, - type: CHAT_MESSAGE_TYPE.PHASE_INFO, - }); + // methods to modify messages + const addChatMessage = (message: ChatMessage) => { + setMessages((messages: ChatMessage[]) => [...messages, message]); }; + // for clearing phase progress function resetPhase() { - setNewPhase(currentPhase); + console.log("resetting phase " + currentPhase); + + clearChat(currentPhase).then(() => { + setMessages([]); + // add preamble to start of chat + addChatMessage({ + message: PHASES[currentPhase].preamble, + type: CHAT_MESSAGE_TYPE.PHASE_INFO, + }); + }); + clearEmails(currentPhase).then(() => { + setEmails([]); + }); + resetActiveDefences(currentPhase).then(() => { + setTriggeredDefences([]); + // choose appropriate defences to display + let defences = + currentPhase === PHASE_NAMES.PHASE_2 + ? DEFENCE_DETAILS_PHASE + : DEFENCE_DETAILS_ALL; + defences = defences.map((defence) => { + defence.isActive = false; + return defence; + }); + setDefencesToShow(defences); + }); } - function setNewPhase(newPhase: PHASE_NAMES) { + // for going switching phase without clearing progress + const setNewPhase = (newPhase: PHASE_NAMES) => { + console.log("changing phase from " + currentPhase + " to " + newPhase); + setMessages([]); setCurrentPhase(newPhase); - // reset remove emails - clearEmails(); - // reset local emails - setEmails([]); - - // reset remove chat - clearChat(); - // reset local chat - setMessages([]); + // get emails for new phase from the backend + getSentEmails(newPhase).then((phaseEmails) => { + setEmails(phaseEmails); + }); - // reset remote defences - resetActiveDefences(); - // choose appropriate defences to display - let defences = newPhase === PHASE_NAMES.PHASE_2 - ? DEFENCE_DETAILS_PHASE - : DEFENCE_DETAILS_ALL; - // make all defences inactive - defences = defences.map((defence) => { - defence.isActive = false; - return defence; + // get chat history for new phase from the backend + getChatHistory(newPhase).then((phaseChatHistory) => { + // add the preamble to the start of the chat history + phaseChatHistory.unshift({ + message: PHASES[newPhase].preamble, + type: CHAT_MESSAGE_TYPE.PHASE_INFO, + }); + setMessages(phaseChatHistory); }); - setDefencesToShow(defences); - // add the preamble to the chat - const preambleMessage = PHASES[newPhase].preamble; - addPhasePreambleMessage(preambleMessage.toLowerCase()); + let defences = + newPhase === PHASE_NAMES.PHASE_2 + ? DEFENCE_DETAILS_PHASE + : DEFENCE_DETAILS_ALL; + setDefencesToShow(defences); }; return (
-
- { // get sent emails - getSentEmails().then((sentEmails) => { + getSentEmails(currentPhase).then((sentEmails) => { setEmails(sentEmails); }); }, [setEmails]); @@ -47,7 +50,9 @@ function ChatBox( } async function sendChatMessage() { - const inputBoxElement = document.getElementById("chat-box-input") as HTMLInputElement; + const inputBoxElement = document.getElementById( + "chat-box-input" + ) as HTMLInputElement; // get the message from the input box const message = inputBoxElement?.value; // clear the input box @@ -77,7 +82,7 @@ function ChatBox( if (response.defenceInfo.isBlocked) { addChatMessage({ type: CHAT_MESSAGE_TYPE.BOT_BLOCKED, - message: response.defenceInfo.blockedReason + message: response.defenceInfo.blockedReason, }); } else { addChatMessage({ @@ -92,10 +97,16 @@ function ChatBox( return defence.id === triggeredDefence; })?.name.toLowerCase(); if (defenceName) { + const alertMsg = `your last message would have triggered the ${defenceName} defence`; addChatMessage({ type: CHAT_MESSAGE_TYPE.DEFENCE_ALERTED, - message: `your last message would have triggered the ${defenceName} defence`, - }) + message: alertMsg, + }); + addMessageToChatHistory( + alertMsg, + CHAT_MESSAGE_TYPE.DEFENCE_ALERTED, + currentPhase + ); } }); // add triggered defences to the chat @@ -105,10 +116,16 @@ function ChatBox( return defence.id === triggeredDefence; })?.name.toLowerCase(); if (defenceName) { + const triggerMsg = `${defenceName} defence triggered`; addChatMessage({ type: CHAT_MESSAGE_TYPE.DEFENCE_TRIGGERED, - message: `${defenceName} defence triggered`, - }) + message: triggerMsg, + }); + addMessageToChatHistory( + triggerMsg, + CHAT_MESSAGE_TYPE.DEFENCE_TRIGGERED, + currentPhase + ); } }); @@ -116,35 +133,40 @@ function ChatBox( setIsSendingMessage(false); // get sent emails - const sentEmails: EmailInfo[] = await getSentEmails(); + const sentEmails: EmailInfo[] = await getSentEmails(currentPhase); // update emails setEmails(sentEmails); if (response.wonPhase) { + const successMessage = + "Congratulations! You have completed this phase. Please click on the next phase to continue."; addChatMessage({ type: CHAT_MESSAGE_TYPE.PHASE_INFO, - message: - "Congratulations! You have completed this phase. Please click on the next phase to continue.", + message: successMessage, }); + addMessageToChatHistory( + successMessage, + CHAT_MESSAGE_TYPE.PHASE_INFO, + currentPhase + ); } } - }; - + } return (