Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #2364: Allow UserMemory to work with custom providers #2365

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/crewai/memory/contextual/contextual_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def build_context_for_task(self, task, context) -> str:
context.append(self._fetch_ltm_context(task.description))
context.append(self._fetch_stm_context(query))
context.append(self._fetch_entity_context(query))
if self.memory_provider == "mem0":
context.append(self._fetch_user_context(query))
context.append(self._fetch_user_context(query))
return "\n".join(filter(None, context))

def _fetch_stm_context(self, query) -> str:
Expand Down Expand Up @@ -97,8 +96,9 @@ def _fetch_user_context(self, query: str) -> str:
user_memories = self.um.search(query)
if not user_memories:
return ""

formatted_memories = "\n".join(
f"- {result['memory']}" for result in user_memories
f"- {result['memory'] if self.um._memory_provider == 'mem0' else result['context']}"
for result in user_memories
)
return f"User memories/preferences:\n{formatted_memories}"
102 changes: 86 additions & 16 deletions src/crewai/memory/user/user_memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from pydantic import PrivateAttr

from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage


class UserMemory(Memory):
Expand All @@ -11,35 +14,102 @@ class UserMemory(Memory):
MemoryItem instances.
"""

def __init__(self, crew=None):
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
_memory_provider: Optional[str] = PrivateAttr()

def __init__(
self,
crew=None,
embedder_config: Optional[Dict[str, Any]] = None,
storage: Optional[Any] = None,
path: Optional[str] = None,
memory_config: Optional[Dict[str, Any]] = None
):
"""
Initialize UserMemory with the specified storage provider.

Args:
crew: Optional crew object that may contain memory configuration
embedder_config: Optional configuration for the embedder
storage: Optional pre-configured storage instance
path: Optional path for storage
memory_config: Optional explicit memory configuration
"""
# Get memory provider from crew or directly from memory_config
memory_provider = None
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_provider = crew.memory_config.get("provider")
elif memory_config is not None:
memory_provider = memory_config.get("provider")

if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
storage = Mem0Storage(type="user", crew=crew)
else:
storage = (
storage
if storage
else RAGStorage(
type="user",
allow_reset=True,
embedder_config=embedder_config,
crew=crew,
path=path,
)
)
storage = Mem0Storage(type="user", crew=crew)
super().__init__(storage)
super().__init__(storage=storage)
self._memory_provider = memory_provider

def save(
self,
value,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
) -> None:
# TODO: Change this function since we want to take care of the case where we save memories for the usr
data = f"Remember the details about the user: {value}"
super().save(data, metadata)
"""
Save user memory data with appropriate formatting based on the storage provider.

Args:
value: The data to save
metadata: Optional metadata to associate with the memory
agent: Optional agent name to associate with the memory
"""
if self._memory_provider == "mem0":
data = f"Remember the details about the user: {value}"
else:
data = value
super().save(data, metadata, agent)

def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
):
results = self.storage.search(
) -> List[Any]:
"""
Search for user memories that match the query.

Args:
query: The search query
limit: Maximum number of results to return
score_threshold: Minimum similarity score for results

Returns:
List of matching memory items
"""
return self.storage.search(
query=query,
limit=limit,
score_threshold=score_threshold,
)
return results

def reset(self) -> None:
"""Reset the user memory storage."""
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the user memory: {e}")
58 changes: 58 additions & 0 deletions tests/memory/user_memory_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from unittest.mock import PropertyMock, patch

import pytest

from crewai.memory.storage.rag_storage import RAGStorage
from crewai.memory.user.user_memory import UserMemory


@patch('crewai.memory.storage.mem0_storage.Mem0Storage')
@patch('crewai.memory.storage.mem0_storage.MemoryClient')
def test_user_memory_provider_selection(mock_memory_client, mock_mem0_storage):
"""Test that UserMemory selects the correct storage provider based on config."""
# Setup - Mock Mem0Storage to avoid API key requirement
mock_mem0_storage.return_value = mock_mem0_storage

# Test with mem0 provider
with patch('crewai.memory.user.user_memory.RAGStorage'):
# Create UserMemory with mem0 provider
memory_config = {"provider": "mem0"}
user_memory = UserMemory(memory_config=memory_config)

# Verify Mem0Storage was used
mock_mem0_storage.assert_called_once()

# Reset mocks
mock_mem0_storage.reset_mock()

# Test with default provider (RAGStorage)
with patch('crewai.memory.user.user_memory.RAGStorage', return_value=mock_mem0_storage) as mock_rag:
# Create UserMemory with no provider specified
user_memory = UserMemory()

# Verify RAGStorage was used
mock_rag.assert_called_once()


@patch('crewai.memory.user.user_memory.UserMemory._memory_provider', new_callable=PropertyMock)
def test_user_memory_save_formatting(mock_memory_provider):
"""Test that UserMemory formats data correctly based on provider."""
# Test with mem0 provider
mock_memory_provider.return_value = "mem0"
with patch('crewai.memory.memory.Memory.save') as mock_save:
user_memory = UserMemory()
user_memory.save("test data")

# Verify data was formatted for mem0
args, _ = mock_save.call_args
assert "Remember the details about the user: test data" in args[0]

# Test with RAG provider
mock_memory_provider.return_value = None
with patch('crewai.memory.memory.Memory.save') as mock_save:
user_memory = UserMemory()
user_memory.save("test data")

# Verify data was not formatted
args, _ = mock_save.call_args
assert args[0] == "test data"