Skip to content

Commit

Permalink
Using DEFENCE_TYPES enum (#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsproston-scottlogic authored Aug 24, 2023
1 parent 35c2f88 commit af9b228
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 76 deletions.
36 changes: 18 additions & 18 deletions backend/src/defence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,22 @@ const getInitialDefences = (): DefenceInfo[] => {
];
};

function activateDefence(id: string, defences: DefenceInfo[]) {
function activateDefence(id: DEFENCE_TYPES, defences: DefenceInfo[]) {
// return the updated list of defences
return defences.map((defence) =>
defence.id === id ? { ...defence, isActive: true } : defence
);
}

function deactivateDefence(id: string, defences: DefenceInfo[]) {
function deactivateDefence(id: DEFENCE_TYPES, defences: DefenceInfo[]) {
// return the updated list of defences
return defences.map((defence) =>
defence.id === id ? { ...defence, isActive: false } : defence
);
}

function configureDefence(
id: string,
id: DEFENCE_TYPES,
defences: DefenceInfo[],
config: DefenceConfig[]
) {
Expand All @@ -72,7 +72,7 @@ function configureDefence(

function getConfigValue(
defences: DefenceInfo[],
defenceId: string,
defenceId: DEFENCE_TYPES,
configId: string,
defaultValue: string
) {
Expand All @@ -85,20 +85,20 @@ function getConfigValue(
function getMaxMessageLength(defences: DefenceInfo[]) {
return getConfigValue(
defences,
"CHARACTER_LIMIT",
DEFENCE_TYPES.CHARACTER_LIMIT,
"maxMessageLength",
String(280)
);
}

function getRandomSequenceEnclosurePrePrompt(defences: DefenceInfo[]) {
return getConfigValue(defences, "RANDOM_SEQUENCE_ENCLOSURE", "prePrompt", retrievalQAPrePromptSecure);
return getConfigValue(defences, DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE, "prePrompt", retrievalQAPrePromptSecure);
}

function getRandomSequenceEnclosureLength(defences: DefenceInfo[]) {
return getConfigValue(
defences,
"RANDOM_SEQUENCE_ENCLOSURE",
DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE,
"length",
String(10)
);
Expand All @@ -117,19 +117,19 @@ function getSystemRole(
case PHASE_NAMES.PHASE_2:
return process.env.SYSTEM_ROLE_PHASE_2 || "";
default:
return getConfigValue(defences, "SYSTEM_ROLE", "systemRole", "");
return getConfigValue(defences, DEFENCE_TYPES.SYSTEM_ROLE, "systemRole", "");
}
}

function getEmailWhitelistVar(defences: DefenceInfo[]) {
return getConfigValue(defences, "EMAIL_WHITELIST", "whitelist", "");
return getConfigValue(defences, DEFENCE_TYPES.EMAIL_WHITELIST, "whitelist", "");
}

function getQALLMprePrompt(defences: DefenceInfo[]) {
return getConfigValue(defences, "QA_LLM_INSTRUCTIONS", "prePrompt", "");
return getConfigValue(defences, DEFENCE_TYPES.QA_LLM_INSTRUCTIONS, "prePrompt", "");
}

function isDefenceActive(id: string, defences: DefenceInfo[]) {
function isDefenceActive(id: DEFENCE_TYPES, defences: DefenceInfo[]) {
return defences.find((defence) => defence.id === id && defence.isActive)
? true
: false;
Expand Down Expand Up @@ -207,13 +207,13 @@ function transformXmlTagging(message: string) {
//apply defence string transformations to original message
function transformMessage(message: string, defences: DefenceInfo[]) {
let transformedMessage: string = message;
if (isDefenceActive("RANDOM_SEQUENCE_ENCLOSURE", defences)) {
if (isDefenceActive(DEFENCE_TYPES.RANDOM_SEQUENCE_ENCLOSURE, defences)) {
transformedMessage = transformRandomSequenceEnclosure(
transformedMessage,
defences
);
}
if (isDefenceActive("XML_TAGGING", defences)) {
if (isDefenceActive(DEFENCE_TYPES.XML_TAGGING, defences)) {
transformedMessage = transformXmlTagging(transformedMessage);
}
if (message == transformedMessage) {
Expand Down Expand Up @@ -242,9 +242,9 @@ async function detectTriggeredDefences(
if (message.length > maxMessageLength) {
console.debug("CHARACTER_LIMIT defence triggered.");
// add the defence to the list of triggered defences
defenceReport.triggeredDefences.push("CHARACTER_LIMIT");
defenceReport.triggeredDefences.push(DEFENCE_TYPES.CHARACTER_LIMIT);
// check if the defence is active
if (isDefenceActive("CHARACTER_LIMIT", defences)) {
if (isDefenceActive(DEFENCE_TYPES.CHARACTER_LIMIT, defences)) {
// block the message
defenceReport.isBlocked = true;
defenceReport.blockedReason = "Message is too long";
Expand All @@ -257,14 +257,14 @@ async function detectTriggeredDefences(
if (detectXMLTags(message)) {
console.debug("XML_TAGGING defence triggered.");
// add the defence to the list of triggered defences
defenceReport.triggeredDefences.push("XML_TAGGING");
defenceReport.triggeredDefences.push(DEFENCE_TYPES.XML_TAGGING);
}

// evaluate the message for prompt injection
const evalPrompt = await queryPromptEvaluationModel(message);
if (evalPrompt.isMalicious) {
defenceReport.triggeredDefences.push("LLM_EVALUATION");
if (isDefenceActive("LLM_EVALUATION", defences)) {
defenceReport.triggeredDefences.push(DEFENCE_TYPES.LLM_EVALUATION);
if (isDefenceActive(DEFENCE_TYPES.LLM_EVALUATION, defences)) {
console.debug("LLM evalutation defence active.");
defenceReport.isBlocked = true;
defenceReport.blockedReason =
Expand Down
4 changes: 2 additions & 2 deletions backend/src/email.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { DefenceInfo } from "./models/defence";
import { DEFENCE_TYPES, DefenceInfo } from "./models/defence";
import { EmailInfo } from "./models/email";
import { getEmailWhitelistVar, isDefenceActive } from "./defence";
import { PHASE_NAMES } from "./models/phase";
Expand All @@ -11,7 +11,7 @@ function getEmailWhitelistValues(defences: DefenceInfo[]) {

// if defense active return the whitelist of emails and domains
function getEmailWhitelist(defences: DefenceInfo[]) {
if (!isDefenceActive("EMAIL_WHITELIST", defences)) {
if (!isDefenceActive(DEFENCE_TYPES.EMAIL_WHITELIST, defences)) {
return "As the email whitelist defence is not active, any email address can be emailed.";
} else {
return (
Expand Down
8 changes: 4 additions & 4 deletions backend/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
OpenAIApi,
} from "openai";
import { CHAT_MODELS, ChatDefenceReport } from "./models/chat";
import { DefenceInfo } from "./models/defence";
import { DEFENCE_TYPES, DefenceInfo } from "./models/defence";
import { PHASE_NAMES } from "./models/phase";

// OpenAI config
Expand Down Expand Up @@ -165,8 +165,8 @@ async function chatGptCallFunction(
isAllowedToSendEmail = true;
} else {
// trigger email defence even if it is not active
defenceInfo.triggeredDefences.push("EMAIL_WHITELIST");
if (isDefenceActive("EMAIL_WHITELIST", defences)) {
defenceInfo.triggeredDefences.push(DEFENCE_TYPES.EMAIL_WHITELIST);
if (isDefenceActive(DEFENCE_TYPES.EMAIL_WHITELIST, defences)) {
// do not send email if defence is on and set to blocked
defenceInfo.isBlocked = true;
defenceInfo.blockedReason =
Expand Down Expand Up @@ -229,7 +229,7 @@ async function chatGptChatCompletion(
// system role is always active on phases
if (
currentPhase !== PHASE_NAMES.SANDBOX ||
isDefenceActive("SYSTEM_ROLE", defences)
isDefenceActive(DEFENCE_TYPES.SYSTEM_ROLE, defences)
) {
// check to see if there's already a system role
if (!chatHistory.find((message) => message.role === "system")) {
Expand Down
12 changes: 6 additions & 6 deletions backend/src/router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
} from "./defence";
import { initQAModel } from "./langchain";
import { CHAT_MODELS, ChatHttpResponse } from "./models/chat";
import { DefenceConfig } from "./models/defence";
import { DEFENCE_TYPES, DefenceConfig } from "./models/defence";
import { chatGptSendMessage, setOpenAiApiKey, setGptModel } from "./openai";
import { retrievalQAPrePrompt } from "./promptTemplates";
import { PHASE_NAMES } from "./models/phase";
Expand All @@ -24,13 +24,13 @@ let prevPhase: PHASE_NAMES = PHASE_NAMES.SANDBOX;
// Activate a defence
router.post("/defence/activate", (req, res) => {
// id of the defence
const defenceId: string = req.body?.defenceId;
const defenceId: DEFENCE_TYPES = req.body?.defenceId;
if (defenceId) {
// activate the defence
req.session.defences = activateDefence(defenceId, req.session.defences);

// need to re-initialize QA model when turned on
if (defenceId === "QA_LLM_INSTRUCTIONS" && req.session.openAiApiKey) {
if (defenceId === DEFENCE_TYPES.QA_LLM_INSTRUCTIONS && req.session.openAiApiKey) {
console.debug(
"Activating qa llm instruction defence - reinitializing qa model"
);
Expand All @@ -47,12 +47,12 @@ router.post("/defence/activate", (req, res) => {
// Deactivate a defence
router.post("/defence/deactivate", (req, res) => {
// id of the defence
const defenceId: string = req.body?.defenceId;
const defenceId: DEFENCE_TYPES = req.body?.defenceId;
if (defenceId) {
// deactivate the defence
req.session.defences = deactivateDefence(defenceId, req.session.defences);

if (defenceId === "QA_LLM_INSTRUCTIONS" && req.session.openAiApiKey) {
if (defenceId === DEFENCE_TYPES.QA_LLM_INSTRUCTIONS && req.session.openAiApiKey) {
console.debug("Resetting QA model with default prompt");
initQAModel(req.session.openAiApiKey, getQALLMprePrompt(req.session.defences));
}
Expand All @@ -66,7 +66,7 @@ router.post("/defence/deactivate", (req, res) => {
// Configure a defence
router.post("/defence/configure", (req, res) => {
// id of the defence
const defenceId: string = req.body?.defenceId;
const defenceId: DEFENCE_TYPES = req.body?.defenceId;
const config: DefenceConfig[] = req.body?.config;
if (defenceId && config) {
// configure the defence
Expand Down
17 changes: 9 additions & 8 deletions backend/test/integration/defences.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { activateDefence, detectTriggeredDefences, getInitialDefences } from "../../src/defence";
import { initPromptEvaluationModel } from "../../src/langchain";
import { DEFENCE_TYPES } from "../../src/models/defence";

// Define a mock implementation for the createChatCompletion method
const mockCall = jest.fn();
Expand Down Expand Up @@ -31,14 +32,14 @@ test("GIVEN LLM_EVALUATION defence is active AND prompt is malicious WHEN detect

let defences = getInitialDefences();
// activate the defence
defences = activateDefence("LLM_EVALUATION", defences);
defences = activateDefence(DEFENCE_TYPES.LLM_EVALUATION, defences);
// create a malicious prompt
const message = "some kind of malicious prompt";
// detect triggered defences
const result = await detectTriggeredDefences(message, defences);
// check that the defence is triggered and the message is blocked
expect(result.isBlocked).toBe(true);
expect(result.triggeredDefences).toContain("LLM_EVALUATION");
expect(result.triggeredDefences).toContain(DEFENCE_TYPES.LLM_EVALUATION);
});

test("GIVEN LLM_EVALUATION defence is active AND prompt only triggers malice detection WHEN detectTriggeredDefences is called THEN defence is triggered AND defence is blocked", async () => {
Expand All @@ -53,14 +54,14 @@ test("GIVEN LLM_EVALUATION defence is active AND prompt only triggers malice det

let defences = getInitialDefences();
// activate the defence
defences = activateDefence("LLM_EVALUATION", defences);
defences = activateDefence(DEFENCE_TYPES.LLM_EVALUATION, defences);
// create a malicious prompt
const message = "some kind of malicious prompt";
// detect triggered defences
const result = await detectTriggeredDefences(message, defences);
// check that the defence is triggered and the message is blocked
expect(result.isBlocked).toBe(true);
expect(result.triggeredDefences).toContain("LLM_EVALUATION");
expect(result.triggeredDefences).toContain(DEFENCE_TYPES.LLM_EVALUATION);
});

test("GIVEN LLM_EVALUATION defence is active AND prompt only triggers prompt injection detection WHEN detectTriggeredDefences is called THEN defence is triggered AND defence is blocked", async () => {
Expand All @@ -75,14 +76,14 @@ test("GIVEN LLM_EVALUATION defence is active AND prompt only triggers prompt inj

let defences = getInitialDefences();
// activate the defence
defences = activateDefence("LLM_EVALUATION", defences);
defences = activateDefence(DEFENCE_TYPES.LLM_EVALUATION, defences);
// create a malicious prompt
const message = "some kind of malicious prompt";
// detect triggered defences
const result = await detectTriggeredDefences(message, defences);
// check that the defence is triggered and the message is blocked
expect(result.isBlocked).toBe(true);
expect(result.triggeredDefences).toContain("LLM_EVALUATION");
expect(result.triggeredDefences).toContain(DEFENCE_TYPES.LLM_EVALUATION);
});

test("GIVEN LLM_EVALUATION defence is active AND prompt not is malicious WHEN detectTriggeredDefences is called THEN defence is not triggered AND defence is not blocked", async () => {
Expand All @@ -97,7 +98,7 @@ test("GIVEN LLM_EVALUATION defence is active AND prompt not is malicious WHEN de

let defences = getInitialDefences();
// activate the defence
defences = activateDefence("LLM_EVALUATION", defences);
defences = activateDefence(DEFENCE_TYPES.LLM_EVALUATION, defences);
// create a malicious prompt
const message = "some kind of malicious prompt";
// detect triggered defences
Expand All @@ -124,5 +125,5 @@ test("GIVEN LLM_EVALUATION defence is not active AND prompt is malicious WHEN de
const result = await detectTriggeredDefences(message, defences);
// check that the defence is triggered and the message is blocked
expect(result.isBlocked).toBe(false);
expect(result.triggeredDefences).toContain("LLM_EVALUATION");
expect(result.triggeredDefences).toContain(DEFENCE_TYPES.LLM_EVALUATION);
});
16 changes: 8 additions & 8 deletions backend/test/integration/openai.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ChatCompletionRequestMessage } from "openai";
import { CHAT_MODELS } from "../../src/models/chat";
import { chatGptSendMessage } from "../../src/openai";
import { DefenceInfo } from "../../src/models/defence";
import { DEFENCE_TYPES, DefenceInfo } from "../../src/models/defence";
import { EmailInfo } from "../../src/models/email";
import { activateDefence, getInitialDefences } from "../../src/defence";

Expand Down Expand Up @@ -91,7 +91,7 @@ test("GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role
const gptModel = CHAT_MODELS.GPT_4;
const openAiApiKey = "sk-12345";

defences = activateDefence("SYSTEM_ROLE", defences);
defences = activateDefence(DEFENCE_TYPES.SYSTEM_ROLE, defences);

// Mock the createChatCompletion function
mockCreateChatCompletion.mockResolvedValueOnce({
Expand Down Expand Up @@ -154,7 +154,7 @@ test("GIVEN SYSTEM_ROLE defence is active WHEN sending message THEN system role
const openAiApiKey = "sk-12345";

// activate the SYSTEM_ROLE defence
defences = activateDefence("SYSTEM_ROLE", defences);
defences = activateDefence(DEFENCE_TYPES.SYSTEM_ROLE, defences);

// Mock the createChatCompletion function
mockCreateChatCompletion.mockResolvedValueOnce({
Expand Down Expand Up @@ -292,7 +292,7 @@ test(
const gptModel = CHAT_MODELS.GPT_4;
const openAiApiKey = "sk-12345";

defences = activateDefence("SYSTEM_ROLE", defences);
defences = activateDefence(DEFENCE_TYPES.SYSTEM_ROLE, defences);

// Mock the createChatCompletion function
mockCreateChatCompletion.mockResolvedValueOnce({
Expand Down Expand Up @@ -410,7 +410,7 @@ test(
expect(reply?.defenceInfo.isBlocked).toBe(false);
// EMAIL_WHITELIST defence is triggered
expect(reply?.defenceInfo.triggeredDefences.length).toBe(1);
expect(reply?.defenceInfo.triggeredDefences[0]).toBe("EMAIL_WHITELIST");
expect(reply?.defenceInfo.triggeredDefences[0]).toBe(DEFENCE_TYPES.EMAIL_WHITELIST);

// restore the mock
mockCreateChatCompletion.mockRestore();
Expand All @@ -432,7 +432,7 @@ test(
const gptModel = CHAT_MODELS.GPT_4;
const openAiApiKey = "sk-12345";

defences = activateDefence("EMAIL_WHITELIST", defences);
defences = activateDefence(DEFENCE_TYPES.EMAIL_WHITELIST, defences);

// Mock the createChatCompletion function
mockCreateChatCompletion
Expand Down Expand Up @@ -486,7 +486,7 @@ test(
expect(reply?.defenceInfo.isBlocked).toBe(true);
// EMAIL_WHITELIST defence is triggered
expect(reply?.defenceInfo.triggeredDefences.length).toBe(1);
expect(reply?.defenceInfo.triggeredDefences[0]).toBe("EMAIL_WHITELIST");
expect(reply?.defenceInfo.triggeredDefences[0]).toBe(DEFENCE_TYPES.EMAIL_WHITELIST);

// restore the mock
mockCreateChatCompletion.mockRestore();
Expand All @@ -508,7 +508,7 @@ test(
const gptModel = CHAT_MODELS.GPT_4;
const openAiApiKey = "sk-12345";

defences = activateDefence("EMAIL_WHITELIST", defences);
defences = activateDefence(DEFENCE_TYPES.EMAIL_WHITELIST, defences);

// Mock the createChatCompletion function
mockCreateChatCompletion
Expand Down
Loading

0 comments on commit af9b228

Please sign in to comment.