diff --git a/doc/code/targets/6_multi_modal_targets.py b/doc/code/targets/6_multi_modal_targets.py index 1cc741f61..8e47333fe 100644 --- a/doc/code/targets/6_multi_modal_targets.py +++ b/doc/code/targets/6_multi_modal_targets.py @@ -92,7 +92,7 @@ # You can use the following to show the image if image_location != "content blocked": - image_bytes = await azure_sql_memory._storage_io.read_file(image_location) # type: ignore + image_bytes = await azure_sql_memory.storage_io.read_file(image_location) # type: ignore image_stream = io.BytesIO(image_bytes) image = Image.open(image_stream) diff --git a/pyrit/common/display_response.py b/pyrit/common/display_response.py index 9d43ce617..93e8daf7b 100644 --- a/pyrit/common/display_response.py +++ b/pyrit/common/display_response.py @@ -25,7 +25,7 @@ async def display_response(response_piece: PromptRequestPiece, memory: MemoryInt and is_in_ipython_session() ): image_location = response_piece.converted_value - image_bytes = await memory._storage_io.read_file(image_location) + image_bytes = await memory.storage_io.read_file(image_location) image_stream = io.BytesIO(image_bytes) image = Image.open(image_stream) diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index 178898c1e..c3c82664c 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingDataEntry from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.azure_sql_memory import AzureSQLMemory @@ -13,7 +13,7 @@ __all__ = [ "AzureSQLMemory", "DuckDBMemory", - "EmbeddingData", + "EmbeddingDataEntry", "MemoryInterface", "MemoryEmbedding", "MemoryExporter", diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 7989aa8d1..28a74f944 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -17,7 +17,7 @@ from pyrit.common import default_values from pyrit.common.singleton import Singleton -from pyrit.memory.memory_models import EmbeddingData, Base, PromptMemoryEntry, ScoreEntry +from pyrit.memory.memory_models import EmbeddingDataEntry, Base, PromptMemoryEntry, ScoreEntry from pyrit.memory.memory_interface import MemoryInterface from pyrit.models.prompt_request_piece import PromptRequestPiece from pyrit.models.score import Score @@ -49,8 +49,6 @@ def __init__( sas_token: Optional[str] = None, verbose: bool = False, ): - super(AzureSQLMemory, self).__init__() - self._connection_string = default_values.get_required_value( env_var_name=self.AZURE_SQL_DB_CONNECTION_STRING, passed_value=connection_string ) @@ -63,8 +61,6 @@ def __init__( ) except ValueError: self._sas_token = None # To use delegation SAS - # Handle for Azure Blob Storage when using Azure SQL memory. - self._storage_io = AzureBlobStorageIO(container_url=self._container_url, sas_token=self._sas_token) self.results_path = self._container_url @@ -76,6 +72,12 @@ def __init__( self.SessionFactory = sessionmaker(bind=self.engine) self._create_tables_if_not_exist() + super(AzureSQLMemory, self).__init__() + + def _init_storage_io(self): + # Handle for Azure Blob Storage when using Azure SQL memory. + self.storage_io = AzureBlobStorageIO(container_url=self._container_url, sas_token=self._sas_token) + def _create_auth_token(self) -> AccessToken: azure_credentials = DefaultAzureCredential() return azure_credentials.get_token(self.TOKEN_URL) @@ -137,13 +139,13 @@ def _create_tables_if_not_exist(self): except Exception as e: logger.error(f"Error during table creation: {e}") - def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None: + def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingDataEntry]) -> None: """ Inserts embedding data into memory storage """ self._insert_entries(entries=embedding_data) - def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[Base]: + def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[PromptRequestPiece]: """ Retrieves a list of PromptMemoryEntry Base objects that have the specified orchestrator ID. @@ -152,13 +154,15 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[Ba Can be retrieved by calling orchestrator.get_identifier()["id"] Returns: - list[Base]: A list of PromptMemoryEntry Base objects matching the specified orchestrator ID. + list[PromptRequestPiece]: A list of PromptMemoryEntry Base objects matching the specified orchestrator ID. """ try: sql_condition = text( "ISJSON(orchestrator_identifier) = 1 AND JSON_VALUE(orchestrator_identifier, '$.id') = :json_id" ).bindparams(json_id=str(orchestrator_id)) - result = self.query_entries(PromptMemoryEntry, conditions=sql_condition) # type: ignore + entries = self.query_entries(PromptMemoryEntry, conditions=sql_condition) # type: ignore + result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries] + return result except Exception as e: logger.exception( @@ -177,10 +181,14 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li list[PromptRequestPiece]: A list of PromptRequestPieces with the specified conversation ID. """ try: - return self.query_entries( + entries = self.query_entries( PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id, ) # type: ignore + + result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries] + return result + except Exception as e: logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}") return [] @@ -206,11 +214,11 @@ def dispose_engine(self): self.engine.dispose() logger.info("Engine disposed successfully.") - def get_all_embeddings(self) -> list[EmbeddingData]: + def get_all_embeddings(self) -> list[EmbeddingDataEntry]: """ Fetches all entries from the specified table and returns them as model instances. """ - result = self.query_entries(EmbeddingData) + result = self.query_entries(EmbeddingDataEntry) return result def get_all_prompt_pieces(self) -> list[PromptRequestPiece]: @@ -231,10 +239,12 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[Prom list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID. """ try: - return self.query_entries( + entries = self.query_entries( PromptMemoryEntry, conditions=PromptMemoryEntry.id.in_(prompt_ids), - ) # type: ignore + ) + result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries] + return result except Exception as e: logger.exception( f"Unexpected error: Failed to retrieve ConversationData with orchestrator {prompt_ids}. {e}" @@ -266,10 +276,8 @@ def get_prompt_request_piece_by_memory_labels( # for safe parameter passing, preventing SQL injection sql_condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()}) - result: list[PromptRequestPiece] = self.query_entries( - PromptMemoryEntry, conditions=sql_condition - ) # type: ignore - + entries = self.query_entries(PromptMemoryEntry, conditions=sql_condition) + result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries] return result except Exception as e: logger.exception( diff --git a/pyrit/memory/duckdb_memory.py b/pyrit/memory/duckdb_memory.py index 2b56ce157..2221491e7 100644 --- a/pyrit/memory/duckdb_memory.py +++ b/pyrit/memory/duckdb_memory.py @@ -12,7 +12,7 @@ from sqlalchemy.engine.base import Engine from contextlib import closing -from pyrit.memory.memory_models import EmbeddingData, PromptMemoryEntry, Base, ScoreEntry +from pyrit.memory.memory_models import EmbeddingDataEntry, PromptMemoryEntry, Base, ScoreEntry from pyrit.memory.memory_interface import MemoryInterface from pyrit.common.path import RESULTS_PATH from pyrit.common.singleton import Singleton @@ -46,12 +46,15 @@ def __init__( else: self.db_path = Path(db_path or Path(RESULTS_PATH, self.DEFAULT_DB_FILE_NAME)).resolve() self.results_path = str(RESULTS_PATH) - # Handles disk-based storage for DuckDB local memory. - self._storage_io = DiskStorageIO() + self.engine = self._create_engine(has_echo=verbose) self.SessionFactory = sessionmaker(bind=self.engine) self._create_tables_if_not_exist() + def _init_storage_io(self): + # Handles disk-based storage for DuckDB local memory. + self.storage_io = DiskStorageIO() + def _create_engine(self, *, has_echo: bool) -> Engine: """Creates the SQLAlchemy engine for DuckDB. @@ -91,11 +94,11 @@ def get_all_prompt_pieces(self) -> list[PromptRequestPiece]: result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries] return result - def get_all_embeddings(self) -> list[EmbeddingData]: + def get_all_embeddings(self) -> list[EmbeddingDataEntry]: """ Fetches all entries from the specified table and returns them as model instances. """ - result: list[EmbeddingData] = self.query_entries(EmbeddingData) + result: list[EmbeddingDataEntry] = self.query_entries(EmbeddingDataEntry) return result def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> list[PromptRequestPiece]: @@ -109,9 +112,10 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li list[PromptRequestPiece]: A list of PromptRequestPieces with the specified conversation ID. """ try: - result: list[PromptRequestPiece] = self.query_entries( + entries = self.query_entries( PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id - ) # type: ignore + ) + result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries] return result except Exception as e: logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}") @@ -128,10 +132,12 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[Prom list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID. """ try: - return self.query_entries( + entries = self.query_entries( PromptMemoryEntry, conditions=PromptMemoryEntry.id.in_(prompt_ids), - ) # type: ignore + ) + result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries] + return result except Exception as e: logger.exception( f"Unexpected error: Failed to retrieve ConversationData with orchestrator {prompt_ids}. {e}" @@ -156,9 +162,8 @@ def get_prompt_request_piece_by_memory_labels( try: conditions = [PromptMemoryEntry.labels.op("->>")(key) == value for key, value in memory_labels.items()] query_condition = and_(*conditions) - result: list[PromptRequestPiece] = self.query_entries( - PromptMemoryEntry, conditions=query_condition - ) # type: ignore + entries = self.query_entries(PromptMemoryEntry, conditions=query_condition) + result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries] return result except Exception as e: logger.exception( @@ -178,10 +183,11 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[Pr list[PromptRequestPiece]: A list of PromptRequestPiece objects matching the specified orchestrator ID. """ try: - result: list[PromptRequestPiece] = self.query_entries( + entries = self.query_entries( PromptMemoryEntry, conditions=PromptMemoryEntry.orchestrator_identifier.op("->>")("id") == orchestrator_id, ) # type: ignore + result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries] return result except Exception as e: logger.exception( @@ -196,7 +202,7 @@ def add_request_pieces_to_memory(self, *, request_pieces: Sequence[PromptRequest """ self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in request_pieces]) - def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None: + def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingDataEntry]) -> None: """ Inserts embedding data into memory storage """ diff --git a/pyrit/memory/memory_embedding.py b/pyrit/memory/memory_embedding.py index ad27c3d21..e3103bd3c 100644 --- a/pyrit/memory/memory_embedding.py +++ b/pyrit/memory/memory_embedding.py @@ -5,7 +5,7 @@ from typing import Optional from pyrit.embedding import AzureTextEmbedding from pyrit.models import PromptRequestPiece, EmbeddingSupport -from pyrit.memory.memory_models import EmbeddingData +from pyrit.memory.memory_models import EmbeddingDataEntry class MemoryEmbedding: @@ -21,7 +21,7 @@ def __init__(self, *, embedding_model: Optional[EmbeddingSupport]): raise ValueError("embedding_model must be set.") self.embedding_model = embedding_model - def generate_embedding_memory_data(self, *, prompt_request_piece: PromptRequestPiece) -> EmbeddingData: + def generate_embedding_memory_data(self, *, prompt_request_piece: PromptRequestPiece) -> EmbeddingDataEntry: """ Generates metadata for a chat memory entry. @@ -32,7 +32,7 @@ def generate_embedding_memory_data(self, *, prompt_request_piece: PromptRequestP ConversationMemoryEntryMetadata: The generated metadata. """ if prompt_request_piece.converted_value_data_type == "text": - embedding_data = EmbeddingData( + embedding_data = EmbeddingDataEntry( embedding=self.embedding_model.generate_text_embedding(text=prompt_request_piece.converted_value) .data[0] .embedding, diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 9eaa8947c..38251549b 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -16,7 +16,7 @@ group_conversation_request_pieces_by_sequence, ) -from pyrit.memory.memory_models import EmbeddingData +from pyrit.memory.memory_models import EmbeddingDataEntry from pyrit.memory.memory_embedding import default_memory_embedding_factory, MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter from pyrit.models.storage_io import StorageIO @@ -33,13 +33,14 @@ class MemoryInterface(abc.ABC): """ memory_embedding: MemoryEmbedding = None - _storage_io: StorageIO = None + storage_io: StorageIO = None results_path: str = None def __init__(self, embedding_model=None): self.memory_embedding = embedding_model # Initialize the MemoryExporter instance self.exporter = MemoryExporter() + self._init_storage_io() def enable_embedding(self, embedding_model=None): self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model) @@ -54,11 +55,17 @@ def get_all_prompt_pieces(self) -> Sequence[PromptRequestPiece]: """ @abc.abstractmethod - def get_all_embeddings(self) -> Sequence[EmbeddingData]: + def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]: """ Loads all EmbeddingData from the memory storage handler. """ + @abc.abstractmethod + def _init_storage_io(self): + """ + Initialize the storage IO handler storage_io. + """ + @abc.abstractmethod def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> MutableSequence[PromptRequestPiece]: """ @@ -91,7 +98,7 @@ def add_request_pieces_to_memory(self, *, request_pieces: Sequence[PromptRequest """ @abc.abstractmethod - def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None: + def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingDataEntry]) -> None: """ Inserts embedding data into memory storage """ diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 84ff22b5e..3a1df1198 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -126,7 +126,7 @@ def __str__(self): return f": {self.role}: {self.converted_value}" -class EmbeddingData(Base): # type: ignore +class EmbeddingDataEntry(Base): # type: ignore """ Represents the embedding data associated with conversation entries in the database. Each embedding is linked to a specific conversation entry via an id diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index 9560786e1..2efdcfa99 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -79,7 +79,7 @@ async def save_data(self, data: bytes) -> None: Saves the data to storage. """ self.value = str(await self.get_data_filename()) - await self._memory._storage_io.write_file(self.value, data) + await self._memory.storage_io.write_file(self.value, data) async def save_b64_image(self, data: str, output_filename: str = None) -> None: """ @@ -93,7 +93,7 @@ async def save_b64_image(self, data: str, output_filename: str = None) -> None: else: self.value = str(await self.get_data_filename()) image_bytes = base64.b64decode(data) - await self._memory._storage_io.write_file(self.value, image_bytes) + await self._memory.storage_io.write_file(self.value, image_bytes) async def read_data(self) -> bytes: """ @@ -105,11 +105,11 @@ async def read_data(self) -> bytes: if not self.value: raise RuntimeError("Prompt text not set") # Check if path exists - file_exists = await self._memory._storage_io.path_exists(path=self.value) + file_exists = await self._memory.storage_io.path_exists(path=self.value) if not file_exists: raise FileNotFoundError(f"File not found: {self.value}") # Read the contents from the path - return await self._memory._storage_io.read_file(self.value) + return await self._memory.storage_io.read_file(self.value) async def read_data_base64(self) -> str: """ @@ -122,7 +122,7 @@ async def get_sha256(self) -> str: input_bytes: bytes = None if self.data_on_disk(): - input_bytes = await self._memory._storage_io.read_file(self.value) + input_bytes = await self._memory.storage_io.read_file(self.value) else: input_bytes = self.value.encode("utf-8") @@ -142,7 +142,7 @@ async def get_data_filename(self) -> Union[Path, str]: if not self.data_directory: raise RuntimeError("Data directory not set") - await self._memory._storage_io.create_directory_if_not_exists(self.data_directory) + await self._memory.storage_io.create_directory_if_not_exists(self.data_directory) ticks = int(time.time() * 1_000_000) if self.is_url(str(self.data_directory)): diff --git a/tests/analytics/test_conversation_analytics.py b/tests/analytics/test_conversation_analytics.py index a0b4b5afc..10907df41 100644 --- a/tests/analytics/test_conversation_analytics.py +++ b/tests/analytics/test_conversation_analytics.py @@ -6,7 +6,7 @@ from pyrit.memory.memory_interface import MemoryInterface from pyrit.analytics.conversation_analytics import ConversationAnalytics -from pyrit.memory.memory_models import EmbeddingData +from pyrit.memory.memory_models import EmbeddingDataEntry from tests.mocks import get_sample_conversation_entries @@ -52,8 +52,10 @@ def test_get_similar_chat_messages_by_embedding(mock_memory_interface, sample_co different_embedding = [0.9, 0.8, 0.7] mock_embeddings = [ - EmbeddingData(id=sample_conversations_entries[0].id, embedding=similar_embedding, embedding_type_name="model1"), - EmbeddingData( + EmbeddingDataEntry( + id=sample_conversations_entries[0].id, embedding=similar_embedding, embedding_type_name="model1" + ), + EmbeddingDataEntry( id=sample_conversations_entries[1].id, embedding=different_embedding, embedding_type_name="model2" ), ] diff --git a/tests/memory/test_azure_sql_memory.py b/tests/memory/test_azure_sql_memory.py index ab409abd8..3ce087a0d 100644 --- a/tests/memory/test_azure_sql_memory.py +++ b/tests/memory/test_azure_sql_memory.py @@ -13,7 +13,7 @@ from sqlalchemy import text from pyrit.memory import AzureSQLMemory -from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingDataEntry from pyrit.models import PromptRequestPiece, Score from pyrit.orchestrator.orchestrator_class import Orchestrator from pyrit.prompt_converter.base64_converter import Base64Converter @@ -100,12 +100,12 @@ def test_insert_embedding_entry(memory_interface: AzureSQLMemory): uuid = reattached_conversation_entry.id # Now that we have the uuid, we can create and insert the EmbeddingData entry - embedding_entry = EmbeddingData(id=uuid, embedding=[1, 2, 3], embedding_type_name="test_type") + embedding_entry = EmbeddingDataEntry(id=uuid, embedding=[1, 2, 3], embedding_type_name="test_type") memory_interface._insert_entry(embedding_entry) # Verify the EmbeddingData entry was inserted correctly with memory_interface.get_session() as session: # type: ignore - persisted_embedding_entry = session.query(EmbeddingData).filter_by(id=uuid).first() + persisted_embedding_entry = session.query(EmbeddingDataEntry).filter_by(id=uuid).first() assert persisted_embedding_entry is not None assert persisted_embedding_entry.embedding == [1, 2, 3] assert persisted_embedding_entry.embedding_type_name == "test_type" diff --git a/tests/memory/test_duckdb_memory.py b/tests/memory/test_duckdb_memory.py index 0609a8449..942963217 100644 --- a/tests/memory/test_duckdb_memory.py +++ b/tests/memory/test_duckdb_memory.py @@ -14,7 +14,7 @@ from sqlalchemy.sql.sqltypes import NullType from pyrit.memory.duckdb_memory import DuckDBMemory -from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingDataEntry from pyrit.models import PromptRequestPiece, Score from pyrit.orchestrator.orchestrator_class import Orchestrator from pyrit.prompt_converter.base64_converter import Base64Converter @@ -251,12 +251,12 @@ def test_insert_embedding_entry(memory_interface): uuid = reattached_conversation_entry.id # Now that we have the uuid, we can create and insert the EmbeddingData entry - embedding_entry = EmbeddingData(id=uuid, embedding=[1, 2, 3], embedding_type_name="test_type") + embedding_entry = EmbeddingDataEntry(id=uuid, embedding=[1, 2, 3], embedding_type_name="test_type") memory_interface._insert_entry(embedding_entry) # Verify the EmbeddingData entry was inserted correctly with memory_interface.get_session() as session: - persisted_embedding_entry = session.query(EmbeddingData).filter_by(id=uuid).first() + persisted_embedding_entry = session.query(EmbeddingDataEntry).filter_by(id=uuid).first() assert persisted_embedding_entry is not None assert persisted_embedding_entry.embedding == [1, 2, 3] assert persisted_embedding_entry.embedding_type_name == "test_type" diff --git a/tests/test_prompt_normalizer.py b/tests/test_prompt_normalizer.py index 92ef5d6d2..4e7321c80 100644 --- a/tests/test_prompt_normalizer.py +++ b/tests/test_prompt_normalizer.py @@ -233,7 +233,7 @@ async def test_send_prompt_async_image_converter(): normalizer = PromptNormalizer(memory=MagicMock()) # Mock the async read_file method - normalizer._memory._storage_io.read_file = AsyncMock(return_value=b"mocked data") + normalizer._memory.storage_io.read_file = AsyncMock(return_value=b"mocked data") await normalizer.send_prompt_async(normalizer_request=NormalizerRequest([prompt]), target=prompt_target)