diff --git a/backend/src/defence.ts b/backend/src/defence.ts index a8e2597cc..9359f72f6 100644 --- a/backend/src/defence.ts +++ b/backend/src/defence.ts @@ -45,14 +45,14 @@ const getInitialDefences = (): DefenceInfo[] => { ]; }; -function activateDefence(id: string, defences: DefenceInfo[]) { +function activateDefence(id: DEFENCE_TYPES, defences: DefenceInfo[]) { // return the updated list of defences return defences.map((defence) => defence.id === id ? { ...defence, isActive: true } : defence ); } -function deactivateDefence(id: string, defences: DefenceInfo[]) { +function deactivateDefence(id: DEFENCE_TYPES, defences: DefenceInfo[]) { // return the updated list of defences return defences.map((defence) => defence.id === id ? { ...defence, isActive: false } : defence @@ -60,7 +60,7 @@ function deactivateDefence(id: string, defences: DefenceInfo[]) { } function configureDefence( - id: string, + id: DEFENCE_TYPES, defences: DefenceInfo[], config: DefenceConfig[] ) { @@ -72,7 +72,7 @@ function configureDefence( function getConfigValue( defences: DefenceInfo[], - defenceId: string, + defenceId: DEFENCE_TYPES, configId: string, defaultValue: string ) { @@ -85,20 +85,20 @@ function getConfigValue( function getMaxMessageLength(defences: DefenceInfo[]) { return getConfigValue( defences, - "CHARACTER_LIMIT", + DEFENCE_TYPES.CHARACTER_LIMIT, "maxMessageLength", String(280) ); } function getRandomSequenceEnclosurePrePrompt(defences: DefenceInfo[]) { - return getConfigValue(defences, "RANDOM_SEQUENCE_ENCLOSURE", "prePrompt", retrievalQAPrePromptSecure); + return getConfigValue(defences, DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE, "prePrompt", retrievalQAPrePromptSecure); } function getRandomSequenceEnclosureLength(defences: DefenceInfo[]) { return getConfigValue( defences, - "RANDOM_SEQUENCE_ENCLOSURE", + DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE, "length", String(10) ); @@ -117,19 +117,19 @@ function getSystemRole( case PHASE_NAMES.PHASE_2: return process.env.SYSTEM_ROLE_PHASE_2 || ""; default: - return getConfigValue(defences, "SYSTEM_ROLE", "systemRole", ""); + return getConfigValue(defences, DEFENCE_TYPES.SYSTEM_ROLE, "systemRole", ""); } } function getEmailWhitelistVar(defences: DefenceInfo[]) { - return getConfigValue(defences, "EMAIL_WHITELIST", "whitelist", ""); + return getConfigValue(defences, DEFENCE_TYPES.EMAIL_WHITELIST, "whitelist", ""); } function getQALLMprePrompt(defences: DefenceInfo[]) { - return getConfigValue(defences, "QA_LLM_INSTRUCTIONS", "prePrompt", ""); + return getConfigValue(defences, DEFENCE_TYPES.QA_LLM_INSTRUCTIONS, "prePrompt", ""); } -function isDefenceActive(id: string, defences: DefenceInfo[]) { +function isDefenceActive(id: DEFENCE_TYPES, defences: DefenceInfo[]) { return defences.find((defence) => defence.id === id && defence.isActive) ? true : false; @@ -207,13 +207,13 @@ function transformXmlTagging(message: string) { //apply defence string transformations to original message function transformMessage(message: string, defences: DefenceInfo[]) { let transformedMessage: string = message; - if (isDefenceActive("RANDOM_SEQUENCE_ENCLOSURE", defences)) { + if (isDefenceActive(DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE, defences)) { transformedMessage = transformRandomSequenceEnclosure( transformedMessage, defences ); } - if (isDefenceActive("XML_TAGGING", defences)) { + if (isDefenceActive(DEFENCE_TYPES.XML_TAGGING, defences)) { transformedMessage = transformXmlTagging(transformedMessage); } if (message == transformedMessage) { @@ -242,9 +242,9 @@ async function detectTriggeredDefences( if (message.length > maxMessageLength) { console.debug("CHARACTER_LIMIT defence triggered."); // add the defence to the list of triggered defences - defenceReport.triggeredDefences.push("CHARACTER_LIMIT"); + defenceReport.triggeredDefences.push(DEFENCE_TYPES.CHARACTER_LIMIT); // check if the defence is active - if (isDefenceActive("CHARACTER_LIMIT", defences)) { + if (isDefenceActive(DEFENCE_TYPES.CHARACTER_LIMIT, defences)) { // block the message defenceReport.isBlocked = true; defenceReport.blockedReason = "Message is too long"; @@ -257,14 +257,14 @@ async function detectTriggeredDefences( if (detectXMLTags(message)) { console.debug("XML_TAGGING defence triggered."); // add the defence to the list of triggered defences - defenceReport.triggeredDefences.push("XML_TAGGING"); + defenceReport.triggeredDefences.push(DEFENCE_TYPES.XML_TAGGING); } // evaluate the message for prompt injection const evalPrompt = await queryPromptEvaluationModel(message); if (evalPrompt.isMalicious) { - defenceReport.triggeredDefences.push("LLM_EVALUATION"); - if (isDefenceActive("LLM_EVALUATION", defences)) { + defenceReport.triggeredDefences.push(DEFENCE_TYPES.LLM_EVALUATION); + if (isDefenceActive(DEFENCE_TYPES.LLM_EVALUATION, defences)) { console.debug("LLM evalutation defence active."); defenceReport.isBlocked = true; defenceReport.blockedReason = diff --git a/backend/src/email.ts b/backend/src/email.ts index e60f4473c..7cad34ef5 100644 --- a/backend/src/email.ts +++ b/backend/src/email.ts @@ -1,4 +1,4 @@ -import { DefenceInfo } from "./models/defence"; +import { DEFENCE_TYPES, DefenceInfo } from "./models/defence"; import { EmailInfo } from "./models/email"; import { getEmailWhitelistVar, isDefenceActive } from "./defence"; import { PHASE_NAMES } from "./models/phase"; @@ -11,7 +11,7 @@ function getEmailWhitelistValues(defences: DefenceInfo[]) { // if defense active return the whitelist of emails and domains function getEmailWhitelist(defences: DefenceInfo[]) { - if (!isDefenceActive("EMAIL_WHITELIST", defences)) { + if (!isDefenceActive(DEFENCE_TYPES.EMAIL_WHITELIST, defences)) { return "As the email whitelist defence is not active, any email address can be emailed."; } else { return ( diff --git a/backend/src/openai.ts b/backend/src/openai.ts index bb0875d20..3da8401e3 100644 --- a/backend/src/openai.ts +++ b/backend/src/openai.ts @@ -13,7 +13,7 @@ import { OpenAIApi, } from "openai"; import { CHAT_MODELS, ChatDefenceReport } from "./models/chat"; -import { DefenceInfo } from "./models/defence"; +import { DEFENCE_TYPES, DefenceInfo } from "./models/defence"; import { PHASE_NAMES } from "./models/phase"; // OpenAI config @@ -165,8 +165,8 @@ async function chatGptCallFunction( isAllowedToSendEmail = true; } else { // trigger email defence even if it is not active - defenceInfo.triggeredDefences.push("EMAIL_WHITELIST"); - if (isDefenceActive("EMAIL_WHITELIST", defences)) { + defenceInfo.triggeredDefences.push(DEFENCE_TYPES.EMAIL_WHITELIST); + if (isDefenceActive(DEFENCE_TYPES.EMAIL_WHITELIST, defences)) { // do not send email if defence is on and set to blocked defenceInfo.isBlocked = true; defenceInfo.blockedReason = @@ -229,7 +229,7 @@ async function chatGptChatCompletion( // system role is always active on phases if ( currentPhase !== PHASE_NAMES.SANDBOX || - isDefenceActive("SYSTEM_ROLE", defences) + isDefenceActive(DEFENCE_TYPES.SYSTEM_ROLE, defences) ) { // check to see if there's already a system role if (!chatHistory.find((message) => message.role === "system")) { diff --git a/backend/src/router.ts b/backend/src/router.ts index ca641b149..70e139bdc 100644 --- a/backend/src/router.ts +++ b/backend/src/router.ts @@ -11,7 +11,7 @@ import { } from "./defence"; import { initQAModel } from "./langchain"; import { CHAT_MODELS, ChatHttpResponse } from "./models/chat"; -import { DefenceConfig } from "./models/defence"; +import { DEFENCE_TYPES, DefenceConfig } from "./models/defence"; import { chatGptSendMessage, setOpenAiApiKey, setGptModel } from "./openai"; import { retrievalQAPrePrompt } from "./promptTemplates"; import { PHASE_NAMES } from "./models/phase"; @@ -24,13 +24,13 @@ let prevPhase: PHASE_NAMES = PHASE_NAMES.SANDBOX; // Activate a defence router.post("/defence/activate", (req, res) => { // id of the defence - const defenceId: string = req.body?.defenceId; + const defenceId: DEFENCE_TYPES = req.body?.defenceId; if (defenceId) { // activate the defence req.session.defences = activateDefence(defenceId, req.session.defences); // need to re-initialize QA model when turned on - if (defenceId === "QA_LLM_INSTRUCTIONS" && req.session.openAiApiKey) { + if (defenceId === DEFENCE_TYPES.QA_LLM_INSTRUCTIONS && req.session.openAiApiKey) { console.debug( "Activating qa llm instruction defence - reinitializing qa model" ); @@ -47,12 +47,12 @@ router.post("/defence/activate", (req, res) => { // Deactivate a defence router.post("/defence/deactivate", (req, res) => { // id of the defence - const defenceId: string = req.body?.defenceId; + const defenceId: DEFENCE_TYPES = req.body?.defenceId; if (defenceId) { // deactivate the defence req.session.defences = deactivateDefence(defenceId, req.session.defences); - if (defenceId === "QA_LLM_INSTRUCTIONS" && req.session.openAiApiKey) { + if (defenceId === DEFENCE_TYPES.QA_LLM_INSTRUCTIONS && req.session.openAiApiKey) { console.debug("Resetting QA model with default prompt"); initQAModel(req.session.openAiApiKey, getQALLMprePrompt(req.session.defences)); } @@ -66,7 +66,7 @@ router.post("/defence/deactivate", (req, res) => { // Configure a defence router.post("/defence/configure", (req, res) => { // id of the defence - const defenceId: string = req.body?.defenceId; + const defenceId: DEFENCE_TYPES = req.body?.defenceId; const config: DefenceConfig[] = req.body?.config; if (defenceId && config) { // configure the defence diff --git a/backend/test/integration/defences.test.ts b/backend/test/integration/defences.test.ts index 350e4e943..910cdc363 100644 --- a/backend/test/integration/defences.test.ts +++ b/backend/test/integration/defences.test.ts @@ -1,5 +1,6 @@ import { activateDefence, detectTriggeredDefences, getInitialDefences } from "../../src/defence"; import { initPromptEvaluationModel } from "../../src/langchain"; +import { DEFENCE_TYPES } from "../../src/models/defence"; // Define a mock implementation for the createChatCompletion method const mockCall = jest.fn(); @@ -31,14 +32,14 @@ test("GIVEN LLM_EVALUATION defence is active AND prompt is malicious WHEN detect let defences = getInitialDefences(); // activate the defence - defences = activateDefence("LLM_EVALUATION", defences); + defences = activateDefence(DEFENCE_TYPES.LLM_EVALUATION, defences); // create a malicious prompt const message = "some kind of malicious prompt"; // detect triggered defences const result = await detectTriggeredDefences(message, defences); // check that the defence is triggered and the message is blocked expect(result.isBlocked).toBe(true); - expect(result.triggeredDefences).toContain("LLM_EVALUATION"); + expect(result.triggeredDefences).toContain(DEFENCE_TYPES.LLM_EVALUATION); }); test("GIVEN LLM_EVALUATION defence is active AND prompt only triggers malice detection WHEN detectTriggeredDefences is called THEN defence is triggered AND defence is blocked", async () => { @@ -53,14 +54,14 @@ test("GIVEN LLM_EVALUATION defence is active AND prompt only triggers malice det let defences = getInitialDefences(); // activate the defence - defences = activateDefence("LLM_EVALUATION", defences); + defences = activateDefence(DEFENCE_TYPES.LLM_EVALUATION, defences); // create a malicious prompt const message = "some kind of malicious prompt"; // detect triggered defences const result = await detectTriggeredDefences(message, defences); // check that the defence is triggered and the message is blocked expect(result.isBlocked).toBe(true); - expect(result.triggeredDefences).toContain("LLM_EVALUATION"); + expect(result.triggeredDefences).toContain(DEFENCE_TYPES.LLM_EVALUATION); }); test("GIVEN LLM_EVALUATION defence is active AND prompt only triggers prompt injection detection WHEN detectTriggeredDefences is called THEN defence is triggered AND defence is blocked", async () => { @@ -75,14 +76,14 @@ test("GIVEN LLM_EVALUATION defence is active AND prompt only triggers prompt inj let defences = getInitialDefences(); // activate the defence - defences = activateDefence("LLM_EVALUATION", defences); + defences = activateDefence(DEFENCE_TYPES.LLM_EVALUATION, defences); // create a malicious prompt const message = "some kind of malicious prompt"; // detect triggered defences const result = await detectTriggeredDefences(message, defences); // check that the defence is triggered and the message is blocked expect(result.isBlocked).toBe(true); - expect(result.triggeredDefences).toContain("LLM_EVALUATION"); + expect(result.triggeredDefences).toContain(DEFENCE_TYPES.LLM_EVALUATION); }); test("GIVEN LLM_EVALUATION defence is active AND prompt not is malicious WHEN detectTriggeredDefences is called THEN defence is not triggered AND defence is not blocked", async () => { @@ -97,7 +98,7 @@ test("GIVEN LLM_EVALUATION defence is active AND prompt not is malicious WHEN de let defences = getInitialDefences(); // activate the defence - defences = activateDefence("LLM_EVALUATION", defences); + defences = activateDefence(DEFENCE_TYPES.LLM_EVALUATION, defences); // create a malicious prompt const message = "some kind of malicious prompt"; // detect triggered defences @@ -124,5 +125,5 @@ test("GIVEN LLM_EVALUATION defence is not active AND prompt is malicious WHEN de const result = await detectTriggeredDefences(message, defences); // check that the defence is triggered and the message is blocked expect(result.isBlocked).toBe(false); - expect(result.triggeredDefences).toContain("LLM_EVALUATION"); + expect(result.triggeredDefences).toContain(DEFENCE_TYPES.LLM_EVALUATION); }); \ No newline at end of file diff --git a/backend/test/integration/openai.test.ts b/backend/test/integration/openai.test.ts index 096c193a7..62b3518b5 100644 --- a/backend/test/integration/openai.test.ts +++ b/backend/test/integration/openai.test.ts @@ -1,7 +1,7 @@ import { ChatCompletionRequestMessage } from "openai"; import { CHAT_MODELS } from "../../src/models/chat"; import { chatGptSendMessage } from "../../src/openai"; -import { DefenceInfo } from "../../src/models/defence"; +import { DEFENCE_TYPES, DefenceInfo } from "../../src/models/defence"; import { EmailInfo } from "../../src/models/email"; import { activateDefence, getInitialDefences } from "../../src/defence"; @@ -91,7 +91,7 @@ test("GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role const gptModel = CHAT_MODELS.GPT_4; const openAiApiKey = "sk-12345"; - defences = activateDefence("SYSTEM_ROLE", defences); + defences = activateDefence(DEFENCE_TYPES.SYSTEM_ROLE, defences); // Mock the createChatCompletion function mockCreateChatCompletion.mockResolvedValueOnce({ @@ -154,7 +154,7 @@ test("GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role const openAiApiKey = "sk-12345"; // activate the SYSTEM_ROLE defence - defences = activateDefence("SYSTEM_ROLE", defences); + defences = activateDefence(DEFENCE_TYPES.SYSTEM_ROLE, defences); // Mock the createChatCompletion function mockCreateChatCompletion.mockResolvedValueOnce({ @@ -292,7 +292,7 @@ test( const gptModel = CHAT_MODELS.GPT_4; const openAiApiKey = "sk-12345"; - defences = activateDefence("SYSTEM_ROLE", defences); + defences = activateDefence(DEFENCE_TYPES.SYSTEM_ROLE, defences); // Mock the createChatCompletion function mockCreateChatCompletion.mockResolvedValueOnce({ @@ -410,7 +410,7 @@ test( expect(reply?.defenceInfo.isBlocked).toBe(false); // EMAIL_WHITELIST defence is triggered expect(reply?.defenceInfo.triggeredDefences.length).toBe(1); - expect(reply?.defenceInfo.triggeredDefences[0]).toBe("EMAIL_WHITELIST"); + expect(reply?.defenceInfo.triggeredDefences[0]).toBe(DEFENCE_TYPES.EMAIL_WHITELIST); // restore the mock mockCreateChatCompletion.mockRestore(); @@ -432,7 +432,7 @@ test( const gptModel = CHAT_MODELS.GPT_4; const openAiApiKey = "sk-12345"; - defences = activateDefence("EMAIL_WHITELIST", defences); + defences = activateDefence(DEFENCE_TYPES.EMAIL_WHITELIST, defences); // Mock the createChatCompletion function mockCreateChatCompletion @@ -486,7 +486,7 @@ test( expect(reply?.defenceInfo.isBlocked).toBe(true); // EMAIL_WHITELIST defence is triggered expect(reply?.defenceInfo.triggeredDefences.length).toBe(1); - expect(reply?.defenceInfo.triggeredDefences[0]).toBe("EMAIL_WHITELIST"); + expect(reply?.defenceInfo.triggeredDefences[0]).toBe(DEFENCE_TYPES.EMAIL_WHITELIST); // restore the mock mockCreateChatCompletion.mockRestore(); @@ -508,7 +508,7 @@ test( const gptModel = CHAT_MODELS.GPT_4; const openAiApiKey = "sk-12345"; - defences = activateDefence("EMAIL_WHITELIST", defences); + defences = activateDefence(DEFENCE_TYPES.EMAIL_WHITELIST, defences); // Mock the createChatCompletion function mockCreateChatCompletion diff --git a/backend/test/unit/defence.test.ts b/backend/test/unit/defence.test.ts index 23378508d..125234ee8 100644 --- a/backend/test/unit/defence.test.ts +++ b/backend/test/unit/defence.test.ts @@ -9,6 +9,7 @@ import { isDefenceActive, transformMessage, } from "../../src/defence"; +import { DEFENCE_TYPES } from "../../src/models/defence"; import { PHASE_NAMES } from "../../src/models/phase"; import { retrievalQAPrePromptSecure } from "../../src/promptTemplates"; @@ -18,21 +19,21 @@ beforeEach(() => { }); test("GIVEN defence is not active WHEN activating defence THEN defence is active", () => { - const defence = "SYSTEM_ROLE"; + const defence = DEFENCE_TYPES.SYSTEM_ROLE; const defences = getInitialDefences(); const updatedActiveDefences = activateDefence(defence, defences); expect(isDefenceActive(defence, updatedActiveDefences)).toBe(true); }); test("GIVEN defence is active WHEN activating defence THEN defence is active", () => { - const defence = "SYSTEM_ROLE"; + const defence = DEFENCE_TYPES.SYSTEM_ROLE; const defences = getInitialDefences(); const updatedActiveDefences = activateDefence(defence, defences); expect(isDefenceActive(defence, updatedActiveDefences)).toBe(true); }); test("GIVEN defence is active WHEN deactivating defence THEN defence is not active", () => { - const defence = "SYSTEM_ROLE"; + const defence = DEFENCE_TYPES.SYSTEM_ROLE; const defences = getInitialDefences(); activateDefence(defence, defences); const updatedActiveDefences = deactivateDefence(defence, defences); @@ -40,14 +41,14 @@ test("GIVEN defence is active WHEN deactivating defence THEN defence is not acti }); test("GIVEN defence is not active WHEN deactivating defence THEN defence is not active", () => { - const defence = "SYSTEM_ROLE"; + const defence = DEFENCE_TYPES.SYSTEM_ROLE; const defences = getInitialDefences(); const updatedActiveDefences = deactivateDefence(defence, defences); expect(updatedActiveDefences).not.toContain(defence); }); test("GIVEN defence is active WHEN checking if defence is active THEN return true", () => { - const defence = "SYSTEM_ROLE"; + const defence = DEFENCE_TYPES.SYSTEM_ROLE; const defences = getInitialDefences(); const updatedDefences = activateDefence(defence, defences); const isActive = isDefenceActive(defence, updatedDefences); @@ -55,7 +56,7 @@ test("GIVEN defence is active WHEN checking if defence is active THEN return tru }); test("GIVEN defence is not active WHEN checking if defence is active THEN return false", () => { - const defence = "SYSTEM_ROLE"; + const defence = DEFENCE_TYPES.SYSTEM_ROLE; const defences = getInitialDefences(); const isActive = isDefenceActive(defence, defences); expect(isActive).toBe(false); @@ -77,7 +78,7 @@ test("GIVEN RANDOM_SEQUENCE_ENCLOSURE defence is active WHEN transforming messag const message = "Hello"; let defences = getInitialDefences(); // activate RSE defence - defences = activateDefence("RANDOM_SEQUENCE_ENCLOSURE", defences); + defences = activateDefence(DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE, defences); // regex to match the transformed message with const regex = new RegExp( @@ -110,7 +111,7 @@ test("GIVEN XML_TAGGING defence is active WHEN transforming message THEN message const message = "Hello"; const defences = getInitialDefences(); // activate XML_TAGGING defence - const updatedDefences = activateDefence("XML_TAGGING", defences); + const updatedDefences = activateDefence(DEFENCE_TYPES.XML_TAGGING, defences); const transformedMessage = transformMessage(message, updatedDefences); // expect the message to be surrounded by XML tags expect(transformedMessage).toBe("" + message + ""); @@ -121,7 +122,7 @@ test("GIVEN XML_TAGGING defence is active AND message contains XML tags WHEN tra const escapedMessage = "<>&'""; const defences = getInitialDefences(); // activate XML_TAGGING defence - const updatedDefences = activateDefence("XML_TAGGING", defences); + const updatedDefences = activateDefence(DEFENCE_TYPES.XML_TAGGING, defences); const transformedMessage = transformMessage(message, updatedDefences); // expect the message to be surrounded by XML tags expect(transformedMessage).toBe( @@ -146,9 +147,9 @@ test( const message = "Hello"; let defences = getInitialDefences(); // activate CHARACTER_LIMIT defence - defences = activateDefence("CHARACTER_LIMIT", defences); + defences = activateDefence(DEFENCE_TYPES.CHARACTER_LIMIT, defences); // configure CHARACTER_LIMIT defence - defences = configureDefence("CHARACTER_LIMIT", defences, [ + defences = configureDefence(DEFENCE_TYPES.CHARACTER_LIMIT, defences, [ { id: "maxMessageLength", value: String(3), @@ -157,7 +158,7 @@ test( const defenceReport = await detectTriggeredDefences(message, defences); expect(defenceReport.blockedReason).toBe("Message is too long"); expect(defenceReport.isBlocked).toBe(true); - expect(defenceReport.triggeredDefences).toContain("CHARACTER_LIMIT"); + expect(defenceReport.triggeredDefences).toContain(DEFENCE_TYPES.CHARACTER_LIMIT); } ); @@ -169,9 +170,9 @@ test( const message = "Hello"; const defences = getInitialDefences(); // activate CHARACTER_LIMIT defence - activateDefence("CHARACTER_LIMIT", defences); + activateDefence(DEFENCE_TYPES.CHARACTER_LIMIT, defences); // configure CHARACTER_LIMIT defence - configureDefence("CHARACTER_LIMIT", defences, [ + configureDefence(DEFENCE_TYPES.CHARACTER_LIMIT, defences, [ { id: "maxMessageLength", value: String(280), @@ -192,7 +193,7 @@ test( const message = "Hello"; let defences = getInitialDefences(); // configure CHARACTER_LIMIT defence - defences = configureDefence("CHARACTER_LIMIT", defences, [ + defences = configureDefence(DEFENCE_TYPES.CHARACTER_LIMIT, defences, [ { id: "maxMessageLength", value: String(3), @@ -201,7 +202,7 @@ test( const defenceReport = await detectTriggeredDefences(message, defences); expect(defenceReport.blockedReason).toBe(null); expect(defenceReport.isBlocked).toBe(false); - expect(defenceReport.triggeredDefences).toContain("CHARACTER_LIMIT"); + expect(defenceReport.triggeredDefences).toContain(DEFENCE_TYPES.CHARACTER_LIMIT); } ); @@ -211,11 +212,11 @@ test("GIVEN message contains XML tags WHEN detecting triggered defences THEN XML const defenceReport = await detectTriggeredDefences(message, defences); expect(defenceReport.blockedReason).toBe(null); expect(defenceReport.isBlocked).toBe(false); - expect(defenceReport.triggeredDefences).toContain("XML_TAGGING"); + expect(defenceReport.triggeredDefences).toContain(DEFENCE_TYPES.XML_TAGGING); }); test("GIVEN setting max message length WHEN configuring defence THEN defence is configured", () => { - const defence = "CHARACTER_LIMIT"; + const defence = DEFENCE_TYPES.CHARACTER_LIMIT; let defences = getInitialDefences(); // configure CHARACTER_LIMIT defence defences = configureDefence(defence, defences, [ @@ -257,7 +258,7 @@ test("GIVEN a new system role has been set WHEN getting system role THEN return let defences = getInitialDefences(); let systemRole = getSystemRole(defences); expect(systemRole).toBe("system role"); - defences = configureDefence("SYSTEM_ROLE", defences, [ + defences = configureDefence(DEFENCE_TYPES.SYSTEM_ROLE, defences, [ { id: "systemRole", value: "new system role" }, ]); systemRole = getSystemRole(defences); @@ -296,7 +297,7 @@ test("GIVEN QA LLM instructions have not been configured WHEN getting QA LLM ins test("GIVEN QA LLM instructions have been configured WHEN getting QA LLM instructions THEN return configured prompt", () => { const newQaLlmInstructions = "new QA LLM instructions"; let defences = getInitialDefences(); - defences = configureDefence("QA_LLM_INSTRUCTIONS", defences, [ + defences = configureDefence(DEFENCE_TYPES.QA_LLM_INSTRUCTIONS, defences, [ { id: "prePrompt", value: newQaLlmInstructions }, ]); const qaLlmInstructions = getQALLMprePrompt(defences); @@ -304,7 +305,7 @@ test("GIVEN QA LLM instructions have been configured WHEN getting QA LLM instruc }); test("GIVEN setting email whitelist WHEN configuring defence THEN defence is configured", () => { - const defence = "EMAIL_WHITELIST"; + const defence = DEFENCE_TYPES.EMAIL_WHITELIST; const defences = getInitialDefences(); // configure EMAIL_WHITELIST defence configureDefence(defence, defences, [ @@ -331,7 +332,7 @@ test("GIVEN setting email whitelist WHEN configuring defence THEN defence is con }); test("GIVEN user configures random sequence enclosure WHEN configuring defence THEN defence is configured", () => { - const defence = "RANDOM_SEQUENCE_ENCLOSURE"; + const defence = DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE; let defences = getInitialDefences(); const newPrePrompt = "new pre prompt"; const newLength = String(100); diff --git a/backend/test/unit/email.test.ts b/backend/test/unit/email.test.ts index 4598bb746..b04a686b2 100644 --- a/backend/test/unit/email.test.ts +++ b/backend/test/unit/email.test.ts @@ -4,6 +4,7 @@ import { isEmailInWhitelist, sendEmail, } from "../../src/email"; +import { DEFENCE_TYPES } from "../../src/models/defence"; import { PHASE_NAMES } from "../../src/models/phase"; beforeEach(() => { @@ -153,7 +154,7 @@ test("GIVEN EMAIL_WHITELIST envionrment variable is set WHEN getting whitelist A process.env.EMAIL_WHITELIST = "bob@example.com,kate@example.com"; let defences = getInitialDefences(); // activate email whitelist defence - defences = activateDefence("EMAIL_WHITELIST", defences); + defences = activateDefence(DEFENCE_TYPES.EMAIL_WHITELIST, defences); const whitelist = getEmailWhitelist(defences); expect(whitelist).toBe( "The whitelisted emails and domains are: " + process.env.EMAIL_WHITELIST @@ -173,7 +174,7 @@ test("GIVEN email is not in whitelist WHEN checking whitelist THEN false is retu process.env.EMAIL_WHITELIST = "bob@example.com,kate@example.com"; let defences = getInitialDefences(); // activate email whitelist defence - defences = activateDefence("EMAIL_WHITELIST", defences); + defences = activateDefence(DEFENCE_TYPES.EMAIL_WHITELIST, defences); const address = "malicious@user.com"; const isWhitelisted = isEmailInWhitelist(address, defences); expect(isWhitelisted).toBe(false); @@ -183,7 +184,7 @@ test("GIVEN email is in whitelist WHEN checking whitelist THEN true is returned" process.env.EMAIL_WHITELIST = "bob@example.com,kate@example.com"; let defences = getInitialDefences(); // activate email whitelist defence - defences = activateDefence("EMAIL_WHITELIST", defences); + defences = activateDefence(DEFENCE_TYPES.EMAIL_WHITELIST, defences); const address = "bob@example.com"; const isWhitelisted = isEmailInWhitelist(address, defences); expect(isWhitelisted).toBe(true); @@ -193,7 +194,7 @@ test("GIVEN email domain is in whitelist WHEN checking whitelist THEN true is re process.env.EMAIL_WHITELIST = "@example.com"; let defences = getInitialDefences(); // activate email whitelist defence - defences = activateDefence("EMAIL_WHITELIST", defences); + defences = activateDefence(DEFENCE_TYPES.EMAIL_WHITELIST, defences); const address = "bob@example.com"; const isWhitelisted = isEmailInWhitelist(address, defences); expect(isWhitelisted).toBe(true); diff --git a/frontend/src/models/defence.ts b/frontend/src/models/defence.ts index e9eda0e76..c72e224b4 100644 --- a/frontend/src/models/defence.ts +++ b/frontend/src/models/defence.ts @@ -2,10 +2,10 @@ enum DEFENCE_TYPES { CHARACTER_LIMIT = "CHARACTER_LIMIT", EMAIL_WHITELIST = "EMAIL_WHITELIST", LLM_EVALUATION = "LLM_EVALUATION", + QA_LLM_INSTRUCTIONS = "QA_LLM_INSTRUCTIONS", RANDOM_SEQUENCE_ENCLOSURE = "RANDOM_SEQUENCE_ENCLOSURE", SYSTEM_ROLE = "SYSTEM_ROLE", XML_TAGGING = "XML_TAGGING", - QA_LLM_INSTRUCTIONS = "QA_LLM_INSTRUCTIONS", } class DefenceConfig { diff --git a/frontend/src/service/defenceService.ts b/frontend/src/service/defenceService.ts index 90a42fa43..0cd83be0c 100644 --- a/frontend/src/service/defenceService.ts +++ b/frontend/src/service/defenceService.ts @@ -1,4 +1,4 @@ -import { DefenceConfig, DefenceInfo } from "../models/defence"; +import { DEFENCE_TYPES, DefenceConfig, DefenceInfo } from "../models/defence"; import { sendRequest } from "./backendService"; const PATH = "defence/"; @@ -60,9 +60,9 @@ function validateStringConfig(config: string) { function validateDefence(id: string, configName: string, config: string) { switch (id) { - case "CHARACTER_LIMIT": + case DEFENCE_TYPES.CHARACTER_LIMIT: return validateNumberConfig(config); - case "RANDOM_SEQUENCE_ENCLOSURE": + case DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE: return configName === "length" ? validateNumberConfig(config) : validateStringConfig(config);