Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: begin work on RAG, refactoring, improved docs and tests #258

Merged
merged 24 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
92b175a
feat: started working on RAG (again)
ErikBjare Nov 15, 2024
9961166
fix: fixed bugs in rag
ErikBjare Nov 15, 2024
e485c5a
fix: fixed tests after refactor
ErikBjare Nov 15, 2024
6cf10be
build(deps): updated dependencies
ErikBjare Nov 15, 2024
82e4ce1
build(deps): updated gptme-rag
ErikBjare Nov 15, 2024
43452b4
fix: made rag support optional
ErikBjare Nov 15, 2024
b1d5408
test: fixed tests
ErikBjare Nov 17, 2024
f2bd6a4
docs: fixed docs
ErikBjare Nov 17, 2024
31cfbc2
docs: fixed docs for computer tool
ErikBjare Nov 17, 2024
075c765
docs: fixed docs, made some funcs private to hide them from autodocs
ErikBjare Nov 17, 2024
b2d41b5
test: fixed tests
ErikBjare Nov 17, 2024
19a7d33
Apply suggestions from code review
ErikBjare Nov 17, 2024
b64e805
fix: fixed when running in a directory without gptme.toml
ErikBjare Nov 17, 2024
573514f
fix: fixed lint
ErikBjare Nov 17, 2024
72bc976
test: fixed tests for rag
ErikBjare Nov 17, 2024
aa57b67
build(deps): updated gptme-rag
ErikBjare Nov 17, 2024
2af98f1
fix: add get_project_dir helper function
ErikBjare Nov 17, 2024
3346767
fix: changed rag tool to use functions instead of execute
ErikBjare Nov 17, 2024
5e053f0
Apply suggestions from code review
ErikBjare Nov 17, 2024
26aeced
fix: more refactoring of rag, made typechecking of imports stricter, …
ErikBjare Nov 18, 2024
d403e54
fix: fixed rag tool tests
ErikBjare Nov 19, 2024
75787da
ci: fixed typechecking
ErikBjare Nov 19, 2024
ea0bc6d
fix: hopefully finally fixed rag tests now
ErikBjare Nov 19, 2024
0e1596d
test: made test_chain more reliable for gpt-4o-mini
ErikBjare Nov 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 11 additions & 26 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ The tools can be grouped into the following categories:

- `Chats`_

- Context management

- `RAG`_

Shell
-----

Expand Down Expand Up @@ -112,34 +116,15 @@ Chats
Computer
--------

.. include:: computer-use-warning.rst

.. automodule:: gptme.tools.computer
:members:
:noindex:

The computer tool provides direct interaction with the desktop environment through X11, allowing for:

- Keyboard input simulation
- Mouse control (movement, clicks, dragging)
- Screen capture with automatic scaling
- Cursor position tracking

To use the computer tool, see the instructions for :doc:`server`.

Example usage::

# Type text
computer(action="type", text="Hello, World!")
RAG
---

# Move mouse and click
computer(action="mouse_move", coordinate=(100, 100))
computer(action="left_click")

# Take screenshot
computer(action="screenshot")

# Send keyboard shortcuts
computer(action="key", text="Control_L+c")

The tool automatically handles screen resolution scaling to ensure optimal performance with LLM vision capabilities.

.. include:: computer-use-warning.rst
.. automodule:: gptme.tools.rag
:members:
:noindex:
121 changes: 121 additions & 0 deletions gptme/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Caching system for RAG functionality."""

import logging
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional

from .config import get_project_config

logger = logging.getLogger(__name__)


# Global cache instance
_cache: Optional["RAGCache"] = None


def get_cache() -> "RAGCache":
"""Get the global RAG cache instance."""
global _cache
if _cache is None:
_cache = RAGCache()
return _cache


@dataclass
class CacheEntry:
"""Entry in the cache with metadata."""

data: Any
timestamp: float
ttl: float


class Cache:
"""Simple cache with TTL and size limits."""

def __init__(self, max_size: int = 1000, default_ttl: float = 3600):
self.max_size = max_size
self.default_ttl = default_ttl
self._cache: dict[str, CacheEntry] = {}

def get(self, key: str) -> Any | None:
"""Get a value from the cache."""
if key not in self._cache:
return None

entry = self._cache[key]
if time.time() - entry.timestamp > entry.ttl:
# Entry expired
del self._cache[key]
return None

return entry.data

def set(self, key: str, value: Any, ttl: float | None = None) -> None:
"""Set a value in the cache."""
# Enforce size limit
if len(self._cache) >= self.max_size:
# Remove oldest entry
oldest_key = min(self._cache.items(), key=lambda x: x[1].timestamp)[0]
del self._cache[oldest_key]

self._cache[key] = CacheEntry(
data=value, timestamp=time.time(), ttl=ttl or self.default_ttl
)

def clear(self) -> None:
"""Clear the cache."""
self._cache.clear()


class RAGCache:
"""Cache for RAG functionality."""

def __init__(self):
config = get_project_config(Path.cwd())
assert config
cache_config = config.rag.get("rag", {}).get("cache", {})

# Initialize caches with configured limits
self.embedding_cache = Cache(
max_size=cache_config.get("max_embeddings", 10000),
default_ttl=cache_config.get("embedding_ttl", 86400), # 24 hours
)
self.search_cache = Cache(
max_size=cache_config.get("max_searches", 1000),
default_ttl=cache_config.get("search_ttl", 3600), # 1 hour
)

@staticmethod
def _make_search_key(query: str, n_results: int) -> str:
"""Create a cache key for a search query."""
return f"{query}::{n_results}"

def get_embedding(self, text: str) -> list[float] | None:
"""Get cached embedding for text."""
return self.embedding_cache.get(text)

def set_embedding(self, text: str, embedding: list[float]) -> None:
"""Cache embedding for text."""
self.embedding_cache.set(text, embedding)

def get_search_results(
self, query: str, n_results: int
) -> tuple[list[Any], dict[str, Any]] | None:
"""Get cached search results."""
key = self._make_search_key(query, n_results)
return self.search_cache.get(key)

def set_search_results(
self, query: str, n_results: int, results: tuple[list[Any], dict[str, Any]]
) -> None:
"""Cache search results."""
key = self._make_search_key(query, n_results)
self.search_cache.set(key, results)

def clear(self) -> None:
"""Clear all caches."""
self.embedding_cache.clear()
self.search_cache.clear()
15 changes: 9 additions & 6 deletions gptme/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path

import tomlkit
Expand Down Expand Up @@ -41,6 +42,7 @@ class ProjectConfig:
"""Project-level configuration, such as which files to include in the context by default."""

files: list[str] = field(default_factory=list)
rag: dict = field(default_factory=dict)


ABOUT_ACTIVITYWATCH = """ActivityWatch is a free and open-source automated time-tracker that helps you track how you spend your time on your devices."""
Expand Down Expand Up @@ -72,12 +74,12 @@ class ProjectConfig:
def get_config() -> Config:
global _config
if _config is None:
_config = load_config()
_config = _load_config()
return _config


def load_config() -> Config:
config = _load_config()
def _load_config() -> Config:
config = _load_config_doc()
assert "prompt" in config, "prompt key missing in config"
assert "env" in config, "env key missing in config"
prompt = config.pop("prompt")
Expand All @@ -87,7 +89,7 @@ def load_config() -> Config:
return Config(prompt=prompt, env=env)


def _load_config() -> tomlkit.TOMLDocument:
def _load_config_doc() -> tomlkit.TOMLDocument:
# Check if the config file exists
if not os.path.exists(config_path):
# If not, create it and write some default settings
Expand All @@ -105,7 +107,7 @@ def _load_config() -> tomlkit.TOMLDocument:


def set_config_value(key: str, value: str) -> None: # pragma: no cover
doc: TOMLDocument | Container = _load_config()
doc: TOMLDocument | Container = _load_config_doc()

# Set the value
keypath = key.split(".")
Expand All @@ -120,9 +122,10 @@ def set_config_value(key: str, value: str) -> None: # pragma: no cover

# Reload config
global _config
_config = load_config()
_config = _load_config()


@lru_cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider specifying a maxsize for @lru_cache to prevent unbounded memory usage.

Suggested change
@lru_cache
@lru_cache(maxsize=128)

def get_project_config(workspace: Path) -> ProjectConfig | None:
project_config_paths = [
p
Expand Down
152 changes: 152 additions & 0 deletions gptme/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Context providers for enhancing messages with relevant context."""

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path

import gptme_rag

from .cache import get_cache
from .config import get_project_config
from .message import Message

logger = logging.getLogger(__name__)


@dataclass
class Context:
"""Context information to be added to messages."""

content: str
source: str
relevance: float


class ContextProvider(ABC):
"""Base class for context providers."""

@abstractmethod
def get_context(self, query: str, max_tokens: int = 1000) -> list[Context]:
"""Get relevant context for a query."""
pass


class RAGContextProvider(ContextProvider):
"""Context provider using RAG."""

_has_rag = True # Class attribute for testing

# TODO: refactor this to share code with rag tool
def __init__(self):
try:
# Check if gptme-rag is installed
import importlib.util

if importlib.util.find_spec("gptme_rag") is None:
logger.debug(
"gptme-rag not installed, RAG context provider will not be available"
)
self._has_rag = False
return

# Check if we have a valid config
config = get_project_config(Path.cwd())
if not config or not hasattr(config, "rag"):
logger.debug("No RAG configuration found in gptme.toml")
self._has_rag = False
return

self._has_rag = True

# Storage configuration
self.indexer = gptme_rag.Indexer(
persist_directory=config.rag.get("index_path", "~/.cache/gptme/rag"),
collection_name=config.rag.get("collection", "default"),
)

# Context enhancement configuration
self.context_assembler = gptme_rag.ContextAssembler(
max_tokens=config.rag.get("max_tokens", 2000)
)
self.auto_context = config.rag.get("auto_context", True)
self.min_relevance = config.rag.get("min_relevance", 0.5)
self.max_results = config.rag.get("max_results", 5)
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catching all exceptions is too broad and can mask other issues. Consider catching specific exceptions instead.

logger.debug(f"Failed to initialize RAG context provider: {e}")
self._has_rag = False

def get_context(self, query: str, max_tokens: int = 1000) -> list[Context]:
"""Get relevant context using RAG."""
if not self._has_rag or not self.auto_context:
return []

try:
# Check cache first
cache = get_cache()
cached_results = cache.get_search_results(query, self.max_results)

if cached_results:
docs, results = cached_results
logger.debug(f"Using cached search results for query: {query}")
else:
# Search with configured limits
docs, results = self.indexer.search(query, n_results=self.max_results)
# Cache the results
cache.set_search_results(query, self.max_results, (docs, results))
logger.debug(f"Cached search results for query: {query}")

contexts = []
for i, doc in enumerate(docs):
# Calculate relevance score (1 - distance)
relevance = 1.0 - results["distances"][0][i]

# Skip if below minimum relevance
if relevance < self.min_relevance:
continue

contexts.append(
Context(
content=doc.content,
source=doc.metadata.get("source", "unknown"),
relevance=relevance,
)
)

# Sort by relevance
contexts.sort(key=lambda x: x.relevance, reverse=True)

return contexts
except Exception as e:
logger.warning(f"Error getting RAG context: {e}")
return []


def enhance_messages(messages: list[Message]) -> list[Message]:
"""Enhance messages with context from available providers."""
providers = [RAGContextProvider()]
enhanced_messages = []

for msg in messages:
if msg.role == "user":
# Get context from all providers
contexts = []
for provider in providers:
try:
contexts.extend(provider.get_context(msg.content))
except Exception as e:
logger.warning(f"Error getting context from provider: {e}")

# Add context as a system message before the user message
if contexts:
context_msg = "Relevant context:\n\n"
for ctx in contexts:
context_msg += f"### {ctx.source}\n{ctx.content}\n\n"

enhanced_messages.append(
Message(role="system", content=context_msg, hide=True)
)

enhanced_messages.append(msg)

return enhanced_messages
Loading
Loading