Skip to content

Commit

Permalink
more backend tests (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
heatherlogan-scottlogic authored Aug 24, 2023
1 parent 2b4e4da commit 35c2f88
Show file tree
Hide file tree
Showing 10 changed files with 770 additions and 75 deletions.
113 changes: 91 additions & 22 deletions backend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 4 additions & 5 deletions backend/src/defence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function getMaxMessageLength(defences: DefenceInfo[]) {
}

function getRandomSequenceEnclosurePrePrompt(defences: DefenceInfo[]) {
return getConfigValue(defences, "RANDOM_SEQUENCE_ENCLOSURE", "prePrompt", "");
return getConfigValue(defences, "RANDOM_SEQUENCE_ENCLOSURE", "prePrompt", retrievalQAPrePromptSecure);
}

function getRandomSequenceEnclosureLength(defences: DefenceInfo[]) {
Expand Down Expand Up @@ -179,9 +179,8 @@ function escapeXml(unsafe: string) {
case "'":
return "'";
case '"':
return """;
default:
return c;
return """;
}
});
}
Expand Down Expand Up @@ -281,11 +280,11 @@ export {
activateDefence,
configureDefence,
deactivateDefence,
detectTriggeredDefences,
getEmailWhitelistVar,
getInitialDefences,
getQALLMprePrompt,
getSystemRole,
isDefenceActive,
transformMessage,
detectTriggeredDefences,
getEmailWhitelistVar,
};
46 changes: 34 additions & 12 deletions backend/src/langchain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ let qaChain: RetrievalQAChain | null = null;
// chain we use in prompt evaluation request
let promptEvaluationChain: SequentialChain | null = null;

function setQAChain(chain: RetrievalQAChain | null) {
console.debug("Setting QA chain.");
qaChain = chain;
}

function setPromptEvaluationChain(chain: SequentialChain | null) {
console.debug("Setting evaluation chain.");
promptEvaluationChain = chain;
}

function getFilepath(currentPhase: PHASE_NAMES = PHASE_NAMES.SANDBOX) {
let filePath = "resources/documents/";
switch (currentPhase) {
Expand Down Expand Up @@ -66,7 +76,9 @@ function getQAPromptTemplate(prePrompt: string) {
console.debug("Using default retrieval QA pre-prompt");
prePrompt = retrievalQAPrePrompt;
}
return PromptTemplate.fromTemplate(prePrompt + qAcontextTemplate);
const fullPrompt = prePrompt + qAcontextTemplate;
const template: PromptTemplate = PromptTemplate.fromTemplate(fullPrompt);
return template;
}

// QA Chain - ask the chat model a question about the documents
Expand Down Expand Up @@ -100,18 +112,22 @@ async function initQAModel(
const qaPrompt = getQAPromptTemplate(prePrompt);

// set chain to retrieval QA chain
qaChain = RetrievalQAChain.fromLLM(model, vectorStore.asRetriever(), {
prompt: qaPrompt,
});
setQAChain(
RetrievalQAChain.fromLLM(model, vectorStore.asRetriever(), {
prompt: qaPrompt,
})
);
console.debug("QA chain initialised.");
}

// initialise the prompt evaluation model
function initPromptEvaluationModel(openAiApiKey: string) {
if (!openAiApiKey) {
console.debug("No OpenAI API key set to initialise prompt evaluation model");
console.debug(
"No OpenAI API key set to initialise prompt evaluation model"
);
return;
}

// create chain to detect prompt injection
const promptInjectionPrompt = PromptTemplate.fromTemplate(
promptInjectionEvalTemplate
Expand Down Expand Up @@ -141,12 +157,14 @@ function initPromptEvaluationModel(openAiApiKey: string) {
outputKey: "maliciousInputEval",
});

promptEvaluationChain = new SequentialChain({
const sequentialChain = new SequentialChain({
chains: [promptInjectionChain, maliciousInputChain],
inputVariables: ["prompt"],
outputVariables: ["promptInjectionEval", "maliciousInputEval"],
});
console.debug("Prompt evaluation chain initialised");
setPromptEvaluationChain(sequentialChain);

console.debug("Prompt evaluation chain initialised.");
}

// ask the question and return models answer
Expand All @@ -172,19 +190,17 @@ async function queryPromptEvaluationModel(input: string) {
console.debug("Prompt evaluation chain not initialised.");
return { isMalicious: false, reason: "" };
}

console.log(`Checking '${input}' for malicious prompts`);

const response = await promptEvaluationChain.call({
prompt: input,
});

const promptInjectionEval = formatEvaluationOutput(
response.promptInjectionEval
);
const maliciousInputEval = formatEvaluationOutput(
response.maliciousInputEval
);

console.debug(
"Prompt injection eval: " + JSON.stringify(promptInjectionEval)
);
Expand All @@ -210,7 +226,7 @@ function formatEvaluationOutput(response: string) {
// split response on first full stop or comma
const splitResponse = response.split(/\.|,/);
const answer = splitResponse[0]?.replace(/\W/g, "").toLowerCase();
const reason = splitResponse[1];
const reason = splitResponse[1]?.trim();
return {
isMalicious: answer === "yes",
reason: reason,
Expand All @@ -228,7 +244,13 @@ function formatEvaluationOutput(response: string) {

export {
initQAModel,
getFilepath,
getQAPromptTemplate,
getDocuments,
initPromptEvaluationModel,
queryDocuments,
queryPromptEvaluationModel,
formatEvaluationOutput,
setQAChain,
setPromptEvaluationChain,
};
Loading

0 comments on commit 35c2f88

Please sign in to comment.