-
-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: begin work on RAG, refactoring, improved docs and tests #258
Merged
Merged
Changes from 16 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
92b175a
feat: started working on RAG (again)
ErikBjare 9961166
fix: fixed bugs in rag
ErikBjare e485c5a
fix: fixed tests after refactor
ErikBjare 6cf10be
build(deps): updated dependencies
ErikBjare 82e4ce1
build(deps): updated gptme-rag
ErikBjare 43452b4
fix: made rag support optional
ErikBjare b1d5408
test: fixed tests
ErikBjare f2bd6a4
docs: fixed docs
ErikBjare 31cfbc2
docs: fixed docs for computer tool
ErikBjare 075c765
docs: fixed docs, made some funcs private to hide them from autodocs
ErikBjare b2d41b5
test: fixed tests
ErikBjare 19a7d33
Apply suggestions from code review
ErikBjare b64e805
fix: fixed when running in a directory without gptme.toml
ErikBjare 573514f
fix: fixed lint
ErikBjare 72bc976
test: fixed tests for rag
ErikBjare aa57b67
build(deps): updated gptme-rag
ErikBjare 2af98f1
fix: add get_project_dir helper function
ErikBjare 3346767
fix: changed rag tool to use functions instead of execute
ErikBjare 5e053f0
Apply suggestions from code review
ErikBjare 26aeced
fix: more refactoring of rag, made typechecking of imports stricter, …
ErikBjare d403e54
fix: fixed rag tool tests
ErikBjare 75787da
ci: fixed typechecking
ErikBjare ea0bc6d
fix: hopefully finally fixed rag tests now
ErikBjare 0e1596d
test: made test_chain more reliable for gpt-4o-mini
ErikBjare File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
"""Caching system for RAG functionality.""" | ||
|
||
import logging | ||
import time | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Any, Optional | ||
|
||
from .config import get_project_config | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# Global cache instance | ||
_cache: Optional["RAGCache"] = None | ||
|
||
|
||
def get_cache() -> "RAGCache": | ||
"""Get the global RAG cache instance.""" | ||
global _cache | ||
if _cache is None: | ||
_cache = RAGCache() | ||
return _cache | ||
|
||
|
||
@dataclass | ||
class CacheEntry: | ||
"""Entry in the cache with metadata.""" | ||
|
||
data: Any | ||
timestamp: float | ||
ttl: float | ||
|
||
|
||
class Cache: | ||
"""Simple cache with TTL and size limits.""" | ||
|
||
def __init__(self, max_size: int = 1000, default_ttl: float = 3600): | ||
self.max_size = max_size | ||
self.default_ttl = default_ttl | ||
self._cache: dict[str, CacheEntry] = {} | ||
|
||
def get(self, key: str) -> Any | None: | ||
"""Get a value from the cache.""" | ||
if key not in self._cache: | ||
return None | ||
|
||
entry = self._cache[key] | ||
if time.time() - entry.timestamp > entry.ttl: | ||
# Entry expired | ||
del self._cache[key] | ||
return None | ||
|
||
return entry.data | ||
|
||
def set(self, key: str, value: Any, ttl: float | None = None) -> None: | ||
"""Set a value in the cache.""" | ||
# Enforce size limit | ||
if len(self._cache) >= self.max_size: | ||
# Remove oldest entry | ||
oldest_key = min(self._cache.items(), key=lambda x: x[1].timestamp)[0] | ||
del self._cache[oldest_key] | ||
|
||
self._cache[key] = CacheEntry( | ||
data=value, timestamp=time.time(), ttl=ttl or self.default_ttl | ||
) | ||
|
||
def clear(self) -> None: | ||
"""Clear the cache.""" | ||
self._cache.clear() | ||
|
||
|
||
class RAGCache: | ||
"""Cache for RAG functionality.""" | ||
|
||
def __init__(self): | ||
config = get_project_config(Path.cwd()) | ||
assert config | ||
cache_config = config.rag.get("rag", {}).get("cache", {}) | ||
|
||
# Initialize caches with configured limits | ||
self.embedding_cache = Cache( | ||
max_size=cache_config.get("max_embeddings", 10000), | ||
default_ttl=cache_config.get("embedding_ttl", 86400), # 24 hours | ||
) | ||
self.search_cache = Cache( | ||
max_size=cache_config.get("max_searches", 1000), | ||
default_ttl=cache_config.get("search_ttl", 3600), # 1 hour | ||
) | ||
|
||
@staticmethod | ||
def _make_search_key(query: str, n_results: int) -> str: | ||
"""Create a cache key for a search query.""" | ||
return f"{query}::{n_results}" | ||
|
||
def get_embedding(self, text: str) -> list[float] | None: | ||
"""Get cached embedding for text.""" | ||
return self.embedding_cache.get(text) | ||
|
||
def set_embedding(self, text: str, embedding: list[float]) -> None: | ||
"""Cache embedding for text.""" | ||
self.embedding_cache.set(text, embedding) | ||
|
||
def get_search_results( | ||
self, query: str, n_results: int | ||
) -> tuple[list[Any], dict[str, Any]] | None: | ||
"""Get cached search results.""" | ||
key = self._make_search_key(query, n_results) | ||
return self.search_cache.get(key) | ||
|
||
def set_search_results( | ||
self, query: str, n_results: int, results: tuple[list[Any], dict[str, Any]] | ||
) -> None: | ||
"""Cache search results.""" | ||
key = self._make_search_key(query, n_results) | ||
self.search_cache.set(key, results) | ||
|
||
def clear(self) -> None: | ||
"""Clear all caches.""" | ||
self.embedding_cache.clear() | ||
self.search_cache.clear() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
"""Context providers for enhancing messages with relevant context.""" | ||
|
||
import logging | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
|
||
import gptme_rag | ||
|
||
from .cache import get_cache | ||
from .config import get_project_config | ||
from .message import Message | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class Context: | ||
"""Context information to be added to messages.""" | ||
|
||
content: str | ||
source: str | ||
relevance: float | ||
|
||
|
||
class ContextProvider(ABC): | ||
"""Base class for context providers.""" | ||
|
||
@abstractmethod | ||
def get_context(self, query: str, max_tokens: int = 1000) -> list[Context]: | ||
"""Get relevant context for a query.""" | ||
pass | ||
|
||
|
||
class RAGContextProvider(ContextProvider): | ||
"""Context provider using RAG.""" | ||
|
||
_has_rag = True # Class attribute for testing | ||
|
||
# TODO: refactor this to share code with rag tool | ||
def __init__(self): | ||
try: | ||
# Check if gptme-rag is installed | ||
import importlib.util | ||
|
||
if importlib.util.find_spec("gptme_rag") is None: | ||
logger.debug( | ||
"gptme-rag not installed, RAG context provider will not be available" | ||
) | ||
self._has_rag = False | ||
return | ||
|
||
# Check if we have a valid config | ||
config = get_project_config(Path.cwd()) | ||
if not config or not hasattr(config, "rag"): | ||
logger.debug("No RAG configuration found in gptme.toml") | ||
self._has_rag = False | ||
return | ||
|
||
self._has_rag = True | ||
|
||
# Storage configuration | ||
self.indexer = gptme_rag.Indexer( | ||
persist_directory=config.rag.get("index_path", "~/.cache/gptme/rag"), | ||
collection_name=config.rag.get("collection", "default"), | ||
) | ||
|
||
# Context enhancement configuration | ||
self.context_assembler = gptme_rag.ContextAssembler( | ||
max_tokens=config.rag.get("max_tokens", 2000) | ||
) | ||
self.auto_context = config.rag.get("auto_context", True) | ||
self.min_relevance = config.rag.get("min_relevance", 0.5) | ||
self.max_results = config.rag.get("max_results", 5) | ||
except Exception as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching all exceptions is too broad and can mask other issues. Consider catching specific exceptions instead. |
||
logger.debug(f"Failed to initialize RAG context provider: {e}") | ||
self._has_rag = False | ||
|
||
def get_context(self, query: str, max_tokens: int = 1000) -> list[Context]: | ||
"""Get relevant context using RAG.""" | ||
if not self._has_rag or not self.auto_context: | ||
return [] | ||
|
||
try: | ||
# Check cache first | ||
cache = get_cache() | ||
cached_results = cache.get_search_results(query, self.max_results) | ||
|
||
if cached_results: | ||
docs, results = cached_results | ||
logger.debug(f"Using cached search results for query: {query}") | ||
else: | ||
# Search with configured limits | ||
docs, results = self.indexer.search(query, n_results=self.max_results) | ||
# Cache the results | ||
cache.set_search_results(query, self.max_results, (docs, results)) | ||
logger.debug(f"Cached search results for query: {query}") | ||
|
||
contexts = [] | ||
for i, doc in enumerate(docs): | ||
# Calculate relevance score (1 - distance) | ||
relevance = 1.0 - results["distances"][0][i] | ||
|
||
# Skip if below minimum relevance | ||
if relevance < self.min_relevance: | ||
continue | ||
|
||
contexts.append( | ||
Context( | ||
content=doc.content, | ||
source=doc.metadata.get("source", "unknown"), | ||
relevance=relevance, | ||
) | ||
) | ||
|
||
# Sort by relevance | ||
contexts.sort(key=lambda x: x.relevance, reverse=True) | ||
|
||
return contexts | ||
except Exception as e: | ||
logger.warning(f"Error getting RAG context: {e}") | ||
return [] | ||
|
||
|
||
def enhance_messages(messages: list[Message]) -> list[Message]: | ||
"""Enhance messages with context from available providers.""" | ||
providers = [RAGContextProvider()] | ||
enhanced_messages = [] | ||
|
||
for msg in messages: | ||
if msg.role == "user": | ||
# Get context from all providers | ||
contexts = [] | ||
for provider in providers: | ||
try: | ||
contexts.extend(provider.get_context(msg.content)) | ||
except Exception as e: | ||
logger.warning(f"Error getting context from provider: {e}") | ||
|
||
# Add context as a system message before the user message | ||
if contexts: | ||
context_msg = "Relevant context:\n\n" | ||
for ctx in contexts: | ||
context_msg += f"### {ctx.source}\n{ctx.content}\n\n" | ||
|
||
enhanced_messages.append( | ||
Message(role="system", content=context_msg, hide=True) | ||
) | ||
|
||
enhanced_messages.append(msg) | ||
|
||
return enhanced_messages |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider specifying a
maxsize
for@lru_cache
to prevent unbounded memory usage.