Skip to content

Commit

Permalink
FIX: Updating memory and fixing bugs (#394)
Browse files Browse the repository at this point in the history
  • Loading branch information
rlundeen2 authored Sep 24, 2024
1 parent a069c88 commit 8d13072
Show file tree
Hide file tree
Showing 13 changed files with 83 additions and 60 deletions.
2 changes: 1 addition & 1 deletion doc/code/targets/6_multi_modal_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@

# You can use the following to show the image
if image_location != "content blocked":
image_bytes = await azure_sql_memory._storage_io.read_file(image_location) # type: ignore
image_bytes = await azure_sql_memory.storage_io.read_file(image_location) # type: ignore

image_stream = io.BytesIO(image_bytes)
image = Image.open(image_stream)
Expand Down
2 changes: 1 addition & 1 deletion pyrit/common/display_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def display_response(response_piece: PromptRequestPiece, memory: MemoryInt
and is_in_ipython_session()
):
image_location = response_piece.converted_value
image_bytes = await memory._storage_io.read_file(image_location)
image_bytes = await memory.storage_io.read_file(image_location)

image_stream = io.BytesIO(image_bytes)
image = Image.open(image_stream)
Expand Down
4 changes: 2 additions & 2 deletions pyrit/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData
from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingDataEntry
from pyrit.memory.memory_interface import MemoryInterface

from pyrit.memory.azure_sql_memory import AzureSQLMemory
Expand All @@ -13,7 +13,7 @@
__all__ = [
"AzureSQLMemory",
"DuckDBMemory",
"EmbeddingData",
"EmbeddingDataEntry",
"MemoryInterface",
"MemoryEmbedding",
"MemoryExporter",
Expand Down
44 changes: 26 additions & 18 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from pyrit.common import default_values
from pyrit.common.singleton import Singleton
from pyrit.memory.memory_models import EmbeddingData, Base, PromptMemoryEntry, ScoreEntry
from pyrit.memory.memory_models import EmbeddingDataEntry, Base, PromptMemoryEntry, ScoreEntry
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.models.score import Score
Expand Down Expand Up @@ -49,8 +49,6 @@ def __init__(
sas_token: Optional[str] = None,
verbose: bool = False,
):
super(AzureSQLMemory, self).__init__()

self._connection_string = default_values.get_required_value(
env_var_name=self.AZURE_SQL_DB_CONNECTION_STRING, passed_value=connection_string
)
Expand All @@ -63,8 +61,6 @@ def __init__(
)
except ValueError:
self._sas_token = None # To use delegation SAS
# Handle for Azure Blob Storage when using Azure SQL memory.
self._storage_io = AzureBlobStorageIO(container_url=self._container_url, sas_token=self._sas_token)

self.results_path = self._container_url

Expand All @@ -76,6 +72,12 @@ def __init__(
self.SessionFactory = sessionmaker(bind=self.engine)
self._create_tables_if_not_exist()

super(AzureSQLMemory, self).__init__()

def _init_storage_io(self):
# Handle for Azure Blob Storage when using Azure SQL memory.
self.storage_io = AzureBlobStorageIO(container_url=self._container_url, sas_token=self._sas_token)

def _create_auth_token(self) -> AccessToken:
azure_credentials = DefaultAzureCredential()
return azure_credentials.get_token(self.TOKEN_URL)
Expand Down Expand Up @@ -137,13 +139,13 @@ def _create_tables_if_not_exist(self):
except Exception as e:
logger.error(f"Error during table creation: {e}")

def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None:
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingDataEntry]) -> None:
"""
Inserts embedding data into memory storage
"""
self._insert_entries(entries=embedding_data)

def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[Base]:
def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[PromptRequestPiece]:
"""
Retrieves a list of PromptMemoryEntry Base objects that have the specified orchestrator ID.
Expand All @@ -152,13 +154,15 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[Ba
Can be retrieved by calling orchestrator.get_identifier()["id"]
Returns:
list[Base]: A list of PromptMemoryEntry Base objects matching the specified orchestrator ID.
list[PromptRequestPiece]: A list of PromptMemoryEntry Base objects matching the specified orchestrator ID.
"""
try:
sql_condition = text(
"ISJSON(orchestrator_identifier) = 1 AND JSON_VALUE(orchestrator_identifier, '$.id') = :json_id"
).bindparams(json_id=str(orchestrator_id))
result = self.query_entries(PromptMemoryEntry, conditions=sql_condition) # type: ignore
entries = self.query_entries(PromptMemoryEntry, conditions=sql_condition) # type: ignore
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]

return result
except Exception as e:
logger.exception(
Expand All @@ -177,10 +181,14 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li
list[PromptRequestPiece]: A list of PromptRequestPieces with the specified conversation ID.
"""
try:
return self.query_entries(
entries = self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.conversation_id == conversation_id,
) # type: ignore

result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result

except Exception as e:
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
return []
Expand All @@ -206,11 +214,11 @@ def dispose_engine(self):
self.engine.dispose()
logger.info("Engine disposed successfully.")

def get_all_embeddings(self) -> list[EmbeddingData]:
def get_all_embeddings(self) -> list[EmbeddingDataEntry]:
"""
Fetches all entries from the specified table and returns them as model instances.
"""
result = self.query_entries(EmbeddingData)
result = self.query_entries(EmbeddingDataEntry)
return result

def get_all_prompt_pieces(self) -> list[PromptRequestPiece]:
Expand All @@ -231,10 +239,12 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[Prom
list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID.
"""
try:
return self.query_entries(
entries = self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.id.in_(prompt_ids),
) # type: ignore
)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
f"Unexpected error: Failed to retrieve ConversationData with orchestrator {prompt_ids}. {e}"
Expand Down Expand Up @@ -266,10 +276,8 @@ def get_prompt_request_piece_by_memory_labels(
# for safe parameter passing, preventing SQL injection
sql_condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()})

result: list[PromptRequestPiece] = self.query_entries(
PromptMemoryEntry, conditions=sql_condition
) # type: ignore

entries = self.query_entries(PromptMemoryEntry, conditions=sql_condition)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
Expand Down
34 changes: 20 additions & 14 deletions pyrit/memory/duckdb_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlalchemy.engine.base import Engine
from contextlib import closing

from pyrit.memory.memory_models import EmbeddingData, PromptMemoryEntry, Base, ScoreEntry
from pyrit.memory.memory_models import EmbeddingDataEntry, PromptMemoryEntry, Base, ScoreEntry
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.common.path import RESULTS_PATH
from pyrit.common.singleton import Singleton
Expand Down Expand Up @@ -46,12 +46,15 @@ def __init__(
else:
self.db_path = Path(db_path or Path(RESULTS_PATH, self.DEFAULT_DB_FILE_NAME)).resolve()
self.results_path = str(RESULTS_PATH)
# Handles disk-based storage for DuckDB local memory.
self._storage_io = DiskStorageIO()

self.engine = self._create_engine(has_echo=verbose)
self.SessionFactory = sessionmaker(bind=self.engine)
self._create_tables_if_not_exist()

def _init_storage_io(self):
# Handles disk-based storage for DuckDB local memory.
self.storage_io = DiskStorageIO()

def _create_engine(self, *, has_echo: bool) -> Engine:
"""Creates the SQLAlchemy engine for DuckDB.
Expand Down Expand Up @@ -91,11 +94,11 @@ def get_all_prompt_pieces(self) -> list[PromptRequestPiece]:
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result

def get_all_embeddings(self) -> list[EmbeddingData]:
def get_all_embeddings(self) -> list[EmbeddingDataEntry]:
"""
Fetches all entries from the specified table and returns them as model instances.
"""
result: list[EmbeddingData] = self.query_entries(EmbeddingData)
result: list[EmbeddingDataEntry] = self.query_entries(EmbeddingDataEntry)
return result

def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> list[PromptRequestPiece]:
Expand All @@ -109,9 +112,10 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li
list[PromptRequestPiece]: A list of PromptRequestPieces with the specified conversation ID.
"""
try:
result: list[PromptRequestPiece] = self.query_entries(
entries = self.query_entries(
PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id
) # type: ignore
)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
Expand All @@ -128,10 +132,12 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[Prom
list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID.
"""
try:
return self.query_entries(
entries = self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.id.in_(prompt_ids),
) # type: ignore
)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
f"Unexpected error: Failed to retrieve ConversationData with orchestrator {prompt_ids}. {e}"
Expand All @@ -156,9 +162,8 @@ def get_prompt_request_piece_by_memory_labels(
try:
conditions = [PromptMemoryEntry.labels.op("->>")(key) == value for key, value in memory_labels.items()]
query_condition = and_(*conditions)
result: list[PromptRequestPiece] = self.query_entries(
PromptMemoryEntry, conditions=query_condition
) # type: ignore
entries = self.query_entries(PromptMemoryEntry, conditions=query_condition)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
Expand All @@ -178,10 +183,11 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[Pr
list[PromptRequestPiece]: A list of PromptRequestPiece objects matching the specified orchestrator ID.
"""
try:
result: list[PromptRequestPiece] = self.query_entries(
entries = self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.orchestrator_identifier.op("->>")("id") == orchestrator_id,
) # type: ignore
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
Expand All @@ -196,7 +202,7 @@ def add_request_pieces_to_memory(self, *, request_pieces: Sequence[PromptRequest
"""
self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in request_pieces])

def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None:
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingDataEntry]) -> None:
"""
Inserts embedding data into memory storage
"""
Expand Down
6 changes: 3 additions & 3 deletions pyrit/memory/memory_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional
from pyrit.embedding import AzureTextEmbedding
from pyrit.models import PromptRequestPiece, EmbeddingSupport
from pyrit.memory.memory_models import EmbeddingData
from pyrit.memory.memory_models import EmbeddingDataEntry


class MemoryEmbedding:
Expand All @@ -21,7 +21,7 @@ def __init__(self, *, embedding_model: Optional[EmbeddingSupport]):
raise ValueError("embedding_model must be set.")
self.embedding_model = embedding_model

def generate_embedding_memory_data(self, *, prompt_request_piece: PromptRequestPiece) -> EmbeddingData:
def generate_embedding_memory_data(self, *, prompt_request_piece: PromptRequestPiece) -> EmbeddingDataEntry:
"""
Generates metadata for a chat memory entry.
Expand All @@ -32,7 +32,7 @@ def generate_embedding_memory_data(self, *, prompt_request_piece: PromptRequestP
ConversationMemoryEntryMetadata: The generated metadata.
"""
if prompt_request_piece.converted_value_data_type == "text":
embedding_data = EmbeddingData(
embedding_data = EmbeddingDataEntry(
embedding=self.embedding_model.generate_text_embedding(text=prompt_request_piece.converted_value)
.data[0]
.embedding,
Expand Down
15 changes: 11 additions & 4 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
group_conversation_request_pieces_by_sequence,
)

from pyrit.memory.memory_models import EmbeddingData
from pyrit.memory.memory_models import EmbeddingDataEntry
from pyrit.memory.memory_embedding import default_memory_embedding_factory, MemoryEmbedding
from pyrit.memory.memory_exporter import MemoryExporter
from pyrit.models.storage_io import StorageIO
Expand All @@ -33,13 +33,14 @@ class MemoryInterface(abc.ABC):
"""

memory_embedding: MemoryEmbedding = None
_storage_io: StorageIO = None
storage_io: StorageIO = None
results_path: str = None

def __init__(self, embedding_model=None):
self.memory_embedding = embedding_model
# Initialize the MemoryExporter instance
self.exporter = MemoryExporter()
self._init_storage_io()

def enable_embedding(self, embedding_model=None):
self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model)
Expand All @@ -54,11 +55,17 @@ def get_all_prompt_pieces(self) -> Sequence[PromptRequestPiece]:
"""

@abc.abstractmethod
def get_all_embeddings(self) -> Sequence[EmbeddingData]:
def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]:
"""
Loads all EmbeddingData from the memory storage handler.
"""

@abc.abstractmethod
def _init_storage_io(self):
"""
Initialize the storage IO handler storage_io.
"""

@abc.abstractmethod
def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> MutableSequence[PromptRequestPiece]:
"""
Expand Down Expand Up @@ -91,7 +98,7 @@ def add_request_pieces_to_memory(self, *, request_pieces: Sequence[PromptRequest
"""

@abc.abstractmethod
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None:
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingDataEntry]) -> None:
"""
Inserts embedding data into memory storage
"""
Expand Down
2 changes: 1 addition & 1 deletion pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __str__(self):
return f": {self.role}: {self.converted_value}"


class EmbeddingData(Base): # type: ignore
class EmbeddingDataEntry(Base): # type: ignore
"""
Represents the embedding data associated with conversation entries in the database.
Each embedding is linked to a specific conversation entry via an id
Expand Down
Loading

0 comments on commit 8d13072

Please sign in to comment.