diff --git a/python/docs/src/user-guide/agentchat-user-guide/memory.ipynb b/python/docs/src/user-guide/agentchat-user-guide/memory.ipynb index dbe5eb847c0a..35b14939e82b 100644 --- a/python/docs/src/user-guide/agentchat-user-guide/memory.ipynb +++ b/python/docs/src/user-guide/agentchat-user-guide/memory.ipynb @@ -225,7 +225,7 @@ "\n", "- `autogen_ext.memory.chromadb.SentenceTransformerEmbeddingFunctionConfig`: A configuration class for the SentenceTransformer embedding function used by the `ChromaDBVectorMemory` store. Note that other embedding functions such as `autogen_ext.memory.openai.OpenAIEmbeddingFunctionConfig` can also be used with the `ChromaDBVectorMemory` store.\n", "\n", - "- `autogen_ext.memory.redis_memory.RedisMemory`: A memory store that uses a Redis vector database to store and retrieve information.\n" + "- `autogen_ext.memory.redis.RedisMemory`: A memory store that uses a Redis vector database to store and retrieve information.\n" ] }, { @@ -377,7 +377,7 @@ "from autogen_agentchat.agents import AssistantAgent\n", "from autogen_agentchat.ui import Console\n", "from autogen_core.memory import MemoryContent, MemoryMimeType\n", - "from autogen_ext.memory.redis_memory import RedisMemory, RedisMemoryConfig\n", + "from autogen_ext.memory.redis import RedisMemory, RedisMemoryConfig\n", "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", "\n", "logger = getLogger()\n", diff --git a/python/packages/autogen-ext/src/autogen_ext/memory/redis/_redis_memory.py b/python/packages/autogen-ext/src/autogen_ext/memory/redis/_redis_memory.py index f828e4e2d7e1..98cb19e18837 100644 --- a/python/packages/autogen-ext/src/autogen_ext/memory/redis/_redis_memory.py +++ b/python/packages/autogen-ext/src/autogen_ext/memory/redis/_redis_memory.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Literal +from typing import Any, List, Literal from autogen_core import CancellationToken, Component from autogen_core.memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult @@ -217,20 +217,31 @@ async def add(self, content: MemoryContent, cancellation_token: CancellationToke .. note:: To perform semantic search over stored memories RedisMemory creates a vector embedding - from the content field of a MemoryContent object. This content is assumed to be text, and - is passed to the vector embedding model specified in RedisMemoryConfig. + from the content field of a MemoryContent object. This content is assumed to be text, + JSON, or Markdown, and is passed to the vector embedding model specified in + RedisMemoryConfig. Args: content (MemoryContent): The memory content to store within Redis. cancellation_token (CancellationToken): Token passed to cease operation. Not used. """ - if content.mime_type != MemoryMimeType.TEXT: + if content.mime_type == MemoryMimeType.TEXT: + memory_content = content.content + mime_type = "text/plain" + elif content.mime_type == MemoryMimeType.JSON: + memory_content = serialize(content.content) + mime_type = "application/json" + elif content.mime_type == MemoryMimeType.MARKDOWN: + memory_content = content.content + mime_type = "text/markdown" + else: raise NotImplementedError( - f"Error: {content.mime_type} is not supported. Only MemoryMimeType.TEXT is currently supported." + f"Error: {content.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported." ) - + metadata = {"mime_type": mime_type} + metadata.update(content.metadata if content.metadata else {}) self.message_history.add_message( - {"role": "user", "content": content.content, "tool_call_id": serialize(content.metadata)} # type: ignore[reportArgumentType] + {"role": "user", "content": memory_content, "tool_call_id": serialize(metadata)} # type: ignore[reportArgumentType] ) async def query( @@ -260,14 +271,19 @@ async def query( memoryQueryResult: Object containing memories relevant to the provided query. """ # get the query string, or raise an error for unsupported MemoryContent types - if isinstance(query, MemoryContent): - if query.mime_type != MemoryMimeType.TEXT: + if isinstance(query, str): + prompt = query + elif isinstance(query, MemoryContent): + if query.mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN): + prompt = str(query.content) + elif query.mime_type == MemoryMimeType.JSON: + prompt = serialize(query.content) + else: raise NotImplementedError( - f"Error: {query.mime_type} is not supported. Only MemoryMimeType.TEXT is currently supported." + f"Error: {query.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, MemoryMimeType.MARKDOWN are currently supported." ) - prompt = query.content else: - prompt = query + raise TypeError("'query' must be either a string or MemoryContent") top_k = kwargs.pop("top_k", self.config.top_k) distance_threshold = kwargs.pop("distance_threshold", self.config.distance_threshold) @@ -279,12 +295,22 @@ async def query( raw=False, ) - memories = [] + memories: List[MemoryContent] = [] for result in results: + metadata = deserialize(result["tool_call_id"]) # type: ignore[reportArgumentType] + mime_type = MemoryMimeType(metadata.pop("mime_type")) + if mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN): + memory_content = result["content"] # type: ignore[reportArgumentType] + elif mime_type == MemoryMimeType.JSON: + memory_content = deserialize(result["content"]) # type: ignore[reportArgumentType] + else: + raise NotImplementedError( + f"Error: {mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, and MemoryMimeType.MARKDOWN are currently supported." + ) memory = MemoryContent( - content=result["content"], # type: ignore[reportArgumentType] - mime_type=MemoryMimeType.TEXT, - metadata=deserialize(result["tool_call_id"]), # type: ignore[reportArgumentType] + content=memory_content, # type: ignore[reportArgumentType] + mime_type=mime_type, + metadata=metadata, ) memories.append(memory) # type: ignore[reportUknownMemberType] diff --git a/python/packages/autogen-ext/tests/memory/test_redis_memory.py b/python/packages/autogen-ext/tests/memory/test_redis_memory.py index 83d1b6781f20..2ab25f55c8f3 100644 --- a/python/packages/autogen-ext/tests/memory/test_redis_memory.py +++ b/python/packages/autogen-ext/tests/memory/test_redis_memory.py @@ -35,7 +35,9 @@ async def test_redis_memory_query_with_mock() -> None: config = RedisMemoryConfig() memory = RedisMemory(config=config) - mock_history.get_relevant.return_value = [{"content": "test content", "tool_call_id": '{"foo": "bar"}'}] + mock_history.get_relevant.return_value = [ + {"content": "test content", "tool_call_id": '{"foo": "bar", "mime_type": "text/plain"}'} + ] result = await memory.query("test") assert len(result.results) == 1 assert result.results[0].content == "test content" @@ -304,8 +306,7 @@ async def test_basic_workflow(semantic_config: RedisMemoryConfig) -> None: @pytest.mark.asyncio @pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally") -async def test_content_types(semantic_memory: RedisMemory) -> None: - """Test different content types with semantic memory.""" +async def test_text_memory_type(semantic_memory: RedisMemory) -> None: await semantic_memory.clear() # Test text content @@ -317,8 +318,104 @@ async def test_content_types(semantic_memory: RedisMemory) -> None: assert len(results.results) > 0 assert any("Simple text content" in str(r.content) for r in results.results) - # Test JSON content - json_data = {"key": "value", "number": 42} - json_content = MemoryContent(content=json_data, mime_type=MemoryMimeType.JSON) - with pytest.raises(NotImplementedError): - await semantic_memory.add(json_content) + +@pytest.mark.asyncio +@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally") +async def test_json_memory_type(semantic_memory: RedisMemory) -> None: + await semantic_memory.clear() + + json_data = {"title": "Hitchhiker's Guide to the Galaxy", "The answer to life, the universe and everything.": 42} + await semantic_memory.add( + MemoryContent(content=json_data, mime_type=MemoryMimeType.JSON, metadata={"author": "Douglas Adams"}) + ) + + results = await semantic_memory.query("what is the ultimate question of the universe?") + assert results.results[0].content == json_data + + # meta data should not be searched + results = await semantic_memory.query("who is Douglas Adams?") + assert len(results.results) == 0 + + # test we can't query with JSON also + with pytest.raises(TypeError): + results = await semantic_memory.query({"question": "what is the ultimate question of the universe?"}) # type: ignore[arg-type] + + # but we can if the JSON is within a MemoryContent container + results = await semantic_memory.query( + MemoryContent( + content={"question": "what is the ultimate question of the universe?"}, mime_type=MemoryMimeType.JSON + ) + ) + assert results.results[0].content == json_data + + +@pytest.mark.asyncio +@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally") +async def test_markdown_memory_type(semantic_memory: RedisMemory) -> None: + await semantic_memory.clear() + + markdown_data = """ + This is an H1 header + ============ + + Paragraphs are separated by a blank line. + + *Italics are within asteriks*, **bold text is within two asterisks**, + while `monospace is within back tics`. + + Itemized lists are made with indented asterisks: + + * this one + * that one + * the next one + + > Block quotes are make with arrows + > like this. + > + > They can span multiple paragraphs, + > if you like. + + Unicode is supported. ☺ + """ + + await semantic_memory.add( + MemoryContent(content=markdown_data, mime_type=MemoryMimeType.MARKDOWN, metadata={"type": "markdown example"}) + ) + + results = await semantic_memory.query("how can I make itemized lists, or italicize text with asterisks?") + assert results.results[0].content == markdown_data + + # test we can query with markdown interpreted as a text string also + results = await semantic_memory.query("") + + # we can also if the markdown is within a MemoryContent container + results = await semantic_memory.query( + MemoryContent( + content="**bold text is within 2 asterisks**, and *italics are within 1 asterisk*", + mime_type=MemoryMimeType.MARKDOWN, + ) + ) + assert results.results[0].content == markdown_data + + +@pytest.mark.asyncio +@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally") +async def test_query_arguments(semantic_memory: RedisMemory) -> None: + # test that we can utilize the optional query arguments top_k and distance_threshold + await semantic_memory.clear() + + await semantic_memory.add(MemoryContent(content="my favorite fruit are apples", mime_type=MemoryMimeType.TEXT)) + await semantic_memory.add(MemoryContent(content="I also like cherries", mime_type=MemoryMimeType.TEXT)) + await semantic_memory.add(MemoryContent(content="I like plums as well", mime_type=MemoryMimeType.TEXT)) + + # default search + results = await semantic_memory.query("what fruits do I like?") + assert len(results.results) == 3 + + # limit search to 2 results + results = await semantic_memory.query("what fruits do I like?", top_k=2) + assert len(results.results) == 2 + + # limit search to only close matches + results = await semantic_memory.query("my favorite fruit are what?", distance_threshold=0.2) + assert len(results.results) == 1