Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Crescendo backtrack with same orchestrator ID and handling responses with markdown syntax #340

Merged
merged 2 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 171 additions & 154 deletions doc/code/orchestrators/6_crescendo_orchestrator.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion doc/code/orchestrators/6_crescendo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
) as orchestrator:

# For five turns this can take a couple minutes depending on LLM latency
score = await orchestrator.apply_crescendo_attack_async(max_rounds=5, max_backtracks=5) # type: ignore
score = await orchestrator.apply_crescendo_attack_async(max_rounds=10, max_backtracks=5) # type: ignore
orchestrator.print_conversation()

print(f"{score} {score.score_metadata} {score.score_rationale} ")
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ template: |
last_response_summary: Summarized last response, if there is no last response return empty strings,
rationale_behind_jailbreak: Detail the reasoning behind the development of the potential jailbreaks
}}
Do not put the JSON inside markdown code tags.

###Inputs###
Inputs for crafting the new jailbreak:
Expand Down
14 changes: 8 additions & 6 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import abc
import copy
from pathlib import Path
from typing import MutableSequence, Sequence
from typing import MutableSequence, Optional, Sequence
import uuid

from pyrit.common.path import RESULTS_PATH
Expand Down Expand Up @@ -213,16 +213,19 @@ def duplicate_conversation_for_new_orchestrator(self, *, new_orchestrator_id: st
self.add_request_pieces_to_memory(request_pieces=prompt_pieces)
return new_conversation_id

def duplicate_conversation_excluding_last_turn(self, *, new_orchestrator_id: str, conversation_id: str) -> str:
def duplicate_conversation_excluding_last_turn(
self, *, conversation_id: str, new_orchestrator_id: Optional[str] = None
) -> str:
"""
Duplicate a conversation, excluding the last turn. In this case, last turn is defined as before the last
user request (e.g. if there is half a turn, it just removes that half).

This can be useful when an attack strategy requires back tracking the last prompt/response pair.

Args:
new_orchestrator_id (str): The new orchestrator ID to assign to the duplicated conversations.
conversation_id (str): The conversation ID with existing conversations.
new_orchestrator_id (str, optional): The new orchestrator ID to assign to the duplicated conversations.
If no new orchestrator ID is provided, the orchestrator ID will remain the same. Defaults to None.
Returns:
The uuid for the new conversation.
"""
Expand Down Expand Up @@ -252,9 +255,8 @@ def duplicate_conversation_excluding_last_turn(self, *, new_orchestrator_id: str

for piece in prompt_pieces:
piece.id = uuid.uuid4()
if piece.orchestrator_identifier["id"] == new_orchestrator_id:
raise ValueError("The new orchestrator ID must be different from the existing orchestrator ID.")
piece.orchestrator_identifier["id"] = new_orchestrator_id
if new_orchestrator_id:
piece.orchestrator_identifier["id"] = new_orchestrator_id
piece.conversation_id = new_conversation_id

self.add_request_pieces_to_memory(request_pieces=prompt_pieces)
Expand Down
2 changes: 2 additions & 0 deletions pyrit/orchestrator/crescendo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pyrit.exceptions.exception_classes import (
InvalidJsonException,
pyrit_json_retry,
remove_markdown_json,
)
from pyrit.models import PromptTemplate
from pyrit.models import Score
Expand Down Expand Up @@ -262,6 +263,7 @@ async def _get_attack_prompt(
.request_pieces[0]
.converted_value
)
response_text = remove_markdown_json(response_text)

expected_output = ["generated_question", "rationale_behind_jailbreak", "last_response_summary"]
try:
Expand Down
50 changes: 50 additions & 0 deletions tests/memory/test_memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,56 @@ def test_duplicate_conversation_excluding_last_turn(memory: MemoryInterface):
assert piece.sequence < 2


def test_duplicate_conversation_excluding_last_turn_same_orchestrator(memory: MemoryInterface):
orchestrator1 = Orchestrator()
conversation_id_1 = "11111"
pieces = [
PromptRequestPiece(
role="user",
original_value="original prompt text",
conversation_id=conversation_id_1,
sequence=0,
orchestrator_identifier=orchestrator1.get_identifier(),
),
PromptRequestPiece(
role="assistant",
original_value="original prompt text",
conversation_id=conversation_id_1,
sequence=1,
orchestrator_identifier=orchestrator1.get_identifier(),
),
PromptRequestPiece(
role="user",
original_value="original prompt text",
conversation_id=conversation_id_1,
sequence=2,
orchestrator_identifier=orchestrator1.get_identifier(),
),
PromptRequestPiece(
role="assistant",
original_value="original prompt text",
conversation_id=conversation_id_1,
sequence=3,
orchestrator_identifier=orchestrator1.get_identifier(),
),
]
memory.add_request_pieces_to_memory(request_pieces=pieces)
assert len(memory.get_all_prompt_pieces()) == 4

new_conversation_id1 = memory.duplicate_conversation_excluding_last_turn(
conversation_id=conversation_id_1,
)

all_memory = memory.get_all_prompt_pieces()
assert len(all_memory) == 6

duplicate_conversation = memory._get_prompt_pieces_with_conversation_id(conversation_id=new_conversation_id1)
assert len(duplicate_conversation) == 2

for piece in duplicate_conversation:
assert piece.sequence < 2


def test_duplicate_memory_orchestrator_id_collision(memory: MemoryInterface):
orchestrator1 = Orchestrator()
conversation_id = "11111"
Expand Down
Loading