Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Updating memory and fixing bugs #394

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/code/targets/6_multi_modal_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@

# You can use the following to show the image
if image_location != "content blocked":
image_bytes = await azure_sql_memory._storage_io.read_file(image_location) # type: ignore
image_bytes = await azure_sql_memory.storage_io.read_file(image_location) # type: ignore

image_stream = io.BytesIO(image_bytes)
image = Image.open(image_stream)
Expand Down
2 changes: 1 addition & 1 deletion pyrit/common/display_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def display_response(response_piece: PromptRequestPiece, memory: MemoryInt
and is_in_ipython_session()
):
image_location = response_piece.converted_value
image_bytes = await memory._storage_io.read_file(image_location)
image_bytes = await memory.storage_io.read_file(image_location)

image_stream = io.BytesIO(image_bytes)
image = Image.open(image_stream)
Expand Down
4 changes: 2 additions & 2 deletions pyrit/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData
from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingDataEntry
from pyrit.memory.memory_interface import MemoryInterface

from pyrit.memory.azure_sql_memory import AzureSQLMemory
Expand All @@ -13,7 +13,7 @@
__all__ = [
"AzureSQLMemory",
"DuckDBMemory",
"EmbeddingData",
"EmbeddingDataEntry",
"MemoryInterface",
"MemoryEmbedding",
"MemoryExporter",
Expand Down
44 changes: 26 additions & 18 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from pyrit.common import default_values
from pyrit.common.singleton import Singleton
from pyrit.memory.memory_models import EmbeddingData, Base, PromptMemoryEntry, ScoreEntry
from pyrit.memory.memory_models import EmbeddingDataEntry, Base, PromptMemoryEntry, ScoreEntry
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.models.score import Score
Expand Down Expand Up @@ -49,8 +49,6 @@ def __init__(
sas_token: Optional[str] = None,
verbose: bool = False,
):
super(AzureSQLMemory, self).__init__()

self._connection_string = default_values.get_required_value(
env_var_name=self.AZURE_SQL_DB_CONNECTION_STRING, passed_value=connection_string
)
Expand All @@ -63,8 +61,6 @@ def __init__(
)
except ValueError:
self._sas_token = None # To use delegation SAS
# Handle for Azure Blob Storage when using Azure SQL memory.
self._storage_io = AzureBlobStorageIO(container_url=self._container_url, sas_token=self._sas_token)

self.results_path = self._container_url

Expand All @@ -76,6 +72,12 @@ def __init__(
self.SessionFactory = sessionmaker(bind=self.engine)
self._create_tables_if_not_exist()

super(AzureSQLMemory, self).__init__()

def _init_storage_io(self):
# Handle for Azure Blob Storage when using Azure SQL memory.
self.storage_io = AzureBlobStorageIO(container_url=self._container_url, sas_token=self._sas_token)

def _create_auth_token(self) -> AccessToken:
azure_credentials = DefaultAzureCredential()
return azure_credentials.get_token(self.TOKEN_URL)
Expand Down Expand Up @@ -137,13 +139,13 @@ def _create_tables_if_not_exist(self):
except Exception as e:
logger.error(f"Error during table creation: {e}")

def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None:
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingDataEntry]) -> None:
"""
Inserts embedding data into memory storage
"""
self._insert_entries(entries=embedding_data)

def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[Base]:
def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[PromptRequestPiece]:
"""
Retrieves a list of PromptMemoryEntry Base objects that have the specified orchestrator ID.

Expand All @@ -152,13 +154,15 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[Ba
Can be retrieved by calling orchestrator.get_identifier()["id"]

Returns:
list[Base]: A list of PromptMemoryEntry Base objects matching the specified orchestrator ID.
list[PromptRequestPiece]: A list of PromptMemoryEntry Base objects matching the specified orchestrator ID.
"""
try:
sql_condition = text(
"ISJSON(orchestrator_identifier) = 1 AND JSON_VALUE(orchestrator_identifier, '$.id') = :json_id"
).bindparams(json_id=str(orchestrator_id))
result = self.query_entries(PromptMemoryEntry, conditions=sql_condition) # type: ignore
entries = self.query_entries(PromptMemoryEntry, conditions=sql_condition) # type: ignore
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]

return result
except Exception as e:
logger.exception(
Expand All @@ -177,10 +181,14 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li
list[PromptRequestPiece]: A list of PromptRequestPieces with the specified conversation ID.
"""
try:
return self.query_entries(
entries = self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.conversation_id == conversation_id,
) # type: ignore

result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result

except Exception as e:
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
return []
Expand All @@ -206,11 +214,11 @@ def dispose_engine(self):
self.engine.dispose()
logger.info("Engine disposed successfully.")

def get_all_embeddings(self) -> list[EmbeddingData]:
def get_all_embeddings(self) -> list[EmbeddingDataEntry]:
"""
Fetches all entries from the specified table and returns them as model instances.
"""
result = self.query_entries(EmbeddingData)
result = self.query_entries(EmbeddingDataEntry)
return result

def get_all_prompt_pieces(self) -> list[PromptRequestPiece]:
Expand All @@ -231,10 +239,12 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[Prom
list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID.
"""
try:
return self.query_entries(
entries = self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.id.in_(prompt_ids),
) # type: ignore
)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
f"Unexpected error: Failed to retrieve ConversationData with orchestrator {prompt_ids}. {e}"
Expand Down Expand Up @@ -266,10 +276,8 @@ def get_prompt_request_piece_by_memory_labels(
# for safe parameter passing, preventing SQL injection
sql_condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()})

result: list[PromptRequestPiece] = self.query_entries(
PromptMemoryEntry, conditions=sql_condition
) # type: ignore

entries = self.query_entries(PromptMemoryEntry, conditions=sql_condition)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
Expand Down
34 changes: 20 additions & 14 deletions pyrit/memory/duckdb_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlalchemy.engine.base import Engine
from contextlib import closing

from pyrit.memory.memory_models import EmbeddingData, PromptMemoryEntry, Base, ScoreEntry
from pyrit.memory.memory_models import EmbeddingDataEntry, PromptMemoryEntry, Base, ScoreEntry
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.common.path import RESULTS_PATH
from pyrit.common.singleton import Singleton
Expand Down Expand Up @@ -46,12 +46,15 @@ def __init__(
else:
self.db_path = Path(db_path or Path(RESULTS_PATH, self.DEFAULT_DB_FILE_NAME)).resolve()
self.results_path = str(RESULTS_PATH)
# Handles disk-based storage for DuckDB local memory.
self._storage_io = DiskStorageIO()

self.engine = self._create_engine(has_echo=verbose)
self.SessionFactory = sessionmaker(bind=self.engine)
self._create_tables_if_not_exist()

def _init_storage_io(self):
# Handles disk-based storage for DuckDB local memory.
self.storage_io = DiskStorageIO()

def _create_engine(self, *, has_echo: bool) -> Engine:
"""Creates the SQLAlchemy engine for DuckDB.

Expand Down Expand Up @@ -91,11 +94,11 @@ def get_all_prompt_pieces(self) -> list[PromptRequestPiece]:
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result

def get_all_embeddings(self) -> list[EmbeddingData]:
def get_all_embeddings(self) -> list[EmbeddingDataEntry]:
"""
Fetches all entries from the specified table and returns them as model instances.
"""
result: list[EmbeddingData] = self.query_entries(EmbeddingData)
result: list[EmbeddingDataEntry] = self.query_entries(EmbeddingDataEntry)
return result

def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> list[PromptRequestPiece]:
Expand All @@ -109,9 +112,10 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li
list[PromptRequestPiece]: A list of PromptRequestPieces with the specified conversation ID.
"""
try:
result: list[PromptRequestPiece] = self.query_entries(
entries = self.query_entries(
PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id
) # type: ignore
)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
Expand All @@ -128,10 +132,12 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[Prom
list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID.
"""
try:
return self.query_entries(
entries = self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.id.in_(prompt_ids),
) # type: ignore
)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
f"Unexpected error: Failed to retrieve ConversationData with orchestrator {prompt_ids}. {e}"
Expand All @@ -156,9 +162,8 @@ def get_prompt_request_piece_by_memory_labels(
try:
conditions = [PromptMemoryEntry.labels.op("->>")(key) == value for key, value in memory_labels.items()]
query_condition = and_(*conditions)
result: list[PromptRequestPiece] = self.query_entries(
PromptMemoryEntry, conditions=query_condition
) # type: ignore
entries = self.query_entries(PromptMemoryEntry, conditions=query_condition)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
Expand All @@ -178,10 +183,11 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[Pr
list[PromptRequestPiece]: A list of PromptRequestPiece objects matching the specified orchestrator ID.
"""
try:
result: list[PromptRequestPiece] = self.query_entries(
entries = self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.orchestrator_identifier.op("->>")("id") == orchestrator_id,
) # type: ignore
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
Expand All @@ -196,7 +202,7 @@ def add_request_pieces_to_memory(self, *, request_pieces: Sequence[PromptRequest
"""
self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in request_pieces])

def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None:
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingDataEntry]) -> None:
"""
Inserts embedding data into memory storage
"""
Expand Down
6 changes: 3 additions & 3 deletions pyrit/memory/memory_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional
from pyrit.embedding import AzureTextEmbedding
from pyrit.models import PromptRequestPiece, EmbeddingSupport
from pyrit.memory.memory_models import EmbeddingData
from pyrit.memory.memory_models import EmbeddingDataEntry


class MemoryEmbedding:
Expand All @@ -21,7 +21,7 @@ def __init__(self, *, embedding_model: Optional[EmbeddingSupport]):
raise ValueError("embedding_model must be set.")
self.embedding_model = embedding_model

def generate_embedding_memory_data(self, *, prompt_request_piece: PromptRequestPiece) -> EmbeddingData:
def generate_embedding_memory_data(self, *, prompt_request_piece: PromptRequestPiece) -> EmbeddingDataEntry:
"""
Generates metadata for a chat memory entry.

Expand All @@ -32,7 +32,7 @@ def generate_embedding_memory_data(self, *, prompt_request_piece: PromptRequestP
ConversationMemoryEntryMetadata: The generated metadata.
"""
if prompt_request_piece.converted_value_data_type == "text":
embedding_data = EmbeddingData(
embedding_data = EmbeddingDataEntry(
embedding=self.embedding_model.generate_text_embedding(text=prompt_request_piece.converted_value)
.data[0]
.embedding,
Expand Down
15 changes: 11 additions & 4 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
group_conversation_request_pieces_by_sequence,
)

from pyrit.memory.memory_models import EmbeddingData
from pyrit.memory.memory_models import EmbeddingDataEntry
from pyrit.memory.memory_embedding import default_memory_embedding_factory, MemoryEmbedding
from pyrit.memory.memory_exporter import MemoryExporter
from pyrit.models.storage_io import StorageIO
Expand All @@ -33,13 +33,14 @@ class MemoryInterface(abc.ABC):
"""

memory_embedding: MemoryEmbedding = None
_storage_io: StorageIO = None
storage_io: StorageIO = None
results_path: str = None

def __init__(self, embedding_model=None):
self.memory_embedding = embedding_model
# Initialize the MemoryExporter instance
self.exporter = MemoryExporter()
self._init_storage_io()

def enable_embedding(self, embedding_model=None):
self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model)
Expand All @@ -54,11 +55,17 @@ def get_all_prompt_pieces(self) -> Sequence[PromptRequestPiece]:
"""

@abc.abstractmethod
def get_all_embeddings(self) -> Sequence[EmbeddingData]:
def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]:
"""
Loads all EmbeddingData from the memory storage handler.
"""

@abc.abstractmethod
def _init_storage_io(self):
"""
Initialize the storage IO handler storage_io.
"""

@abc.abstractmethod
def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> MutableSequence[PromptRequestPiece]:
"""
Expand Down Expand Up @@ -91,7 +98,7 @@ def add_request_pieces_to_memory(self, *, request_pieces: Sequence[PromptRequest
"""

@abc.abstractmethod
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None:
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingDataEntry]) -> None:
"""
Inserts embedding data into memory storage
"""
Expand Down
2 changes: 1 addition & 1 deletion pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __str__(self):
return f": {self.role}: {self.converted_value}"


class EmbeddingData(Base): # type: ignore
class EmbeddingDataEntry(Base): # type: ignore
"""
Represents the embedding data associated with conversation entries in the database.
Each embedding is linked to a specific conversation entry via an id
Expand Down
Loading
Loading