Skip to content

Commit

Permalink
fix: changed rag tool to use functions instead of execute
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Nov 17, 2024
1 parent 2af98f1 commit 5b41de3
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 121 deletions.
92 changes: 39 additions & 53 deletions gptme/tools/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@
from pathlib import Path

from ..config import get_project_config
from ..message import Message
from ..util import get_project_dir
from .base import ConfirmFunc, ToolSpec, ToolUse
from .base import ToolSpec, ToolUse

logger = logging.getLogger(__name__)

Expand All @@ -80,59 +79,50 @@

instructions = """
Use RAG to index and search project documentation.
Commands:
- index [paths...] - Index documents in specified paths
- search <query> - Search indexed documents
- status - Show index status
"""

examples = f"""
User: Index the current directory
Assistant: Let me index the current directory with RAG.
{ToolUse("rag", ["index"], "").to_output()}
{ToolUse("ipython", [], "rag_index()").to_output()}
System: Indexed 1 paths
User: Search for documentation about functions
Assistant: I'll search for function-related documentation.
{ToolUse("rag", ["search", "function", "documentation"], "").to_output()}
{ToolUse("ipython", [], 'rag_search("function documentation")').to_output()}
System: ### docs/api.md
Functions are documented using docstrings...
User: Show index status
Assistant: I'll check the current status of the RAG index.
{ToolUse("rag", ["status"], "").to_output()}
{ToolUse("ipython", [], "get_status()").to_output()}
System: Index contains 42 documents
"""


def execute_rag(code: str, args: list[str], confirm: ConfirmFunc) -> Message:
"""Execute RAG commands."""
def rag_index(*paths: str, glob: str | None = None) -> str:
"""Index documents in specified paths."""
assert indexer is not None, "RAG indexer not initialized"
paths = paths or (".",)
kwargs = {"glob_pattern": glob} if glob else {}
for path in paths:
indexer.index_directory(Path(path), **kwargs)
return f"Indexed {len(paths)} paths"


def rag_search(query: str) -> str:
"""Search indexed documents."""
assert indexer is not None, "RAG indexer not initialized"
docs, _ = indexer.search(query)
return "\n\n".join(
f"### {doc.metadata['source']}\n{doc.content[:200]}..." for doc in docs
)


def rag_status() -> str:
"""Show index status."""
assert indexer is not None, "RAG indexer not initialized"
command = args[0] if args else "help"

if command == "help":
return Message("system", "Available commands: index, search, status")
elif command == "index":
paths = args[1:] or ["."]
for path in paths:
indexer.index_directory(Path(path))
return Message("system", f"Indexed {len(paths)} paths")
elif command == "search":
query = " ".join(args[1:])
docs, _ = indexer.search(query)
return Message(
"system",
"\n\n".join(
f"### {doc.metadata['source']}\n{doc.content[:200]}..." for doc in docs
),
)
elif command == "status":
return Message(
"system", f"Index contains {indexer.collection.count()} documents"
)
else:
return Message("system", f"Unknown command: {command}")
return f"Index contains {indexer.collection.count()} documents"


def init() -> ToolSpec:
Expand All @@ -141,22 +131,19 @@ def init() -> ToolSpec:
return tool

project_dir = get_project_dir()
if not project_dir:
return tool
index_path = Path("~/.cache/gptme/rag").expanduser()
collection = "default"
if project_dir and (config := get_project_config(project_dir)):
index_path = Path(config.rag.get("index_path", index_path)).expanduser()
collection = config.rag.get("collection", project_dir.name)

import gptme_rag # fmt: skip

config = get_project_config(project_dir)
if config:
# Initialize RAG with configuration
global indexer
import gptme_rag # fmt: skip

indexer = gptme_rag.Indexer(
persist_directory=Path(
config.rag.get("index_path", "~/.cache/gptme/rag")
).expanduser(),
# TODO: use a better default collection name? (e.g. project name)
collection_name=config.rag.get("collection", "gptme_docs"),
)
global indexer
indexer = gptme_rag.Indexer(
persist_directory=index_path,
collection_name=collection,
)
return tool


Expand All @@ -165,8 +152,7 @@ def init() -> ToolSpec:
desc="RAG (Retrieval-Augmented Generation) for context-aware assistance",
instructions=instructions,
examples=examples,
block_types=["rag"],
execute=execute_rag,
functions=[rag_index, rag_search, rag_status],
available=_HAS_RAG,
init=init,
)
Expand Down
16 changes: 8 additions & 8 deletions gptme/util/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ def context():


@context.command("generate")
@click.argument("path", type=click.Path(exists=True))
def context_generate(_path: str):
"""Generate context from a directory."""
pass
# from ..context import generate_context # fmt: skip

# ctx = generate_context(path)
# print(ctx)
@click.argument("query")
def context_generate(query: str):
"""Retrieve context for a given query."""
from ..context import RAGContextProvider # fmt: skip

provider = RAGContextProvider()
ctx = provider.get_context(query)
print(ctx)


@main.group()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ include = ["gptme/server/static/**/*", "media/logo.png"]
gptme = "gptme.cli:main"
gptme-server = "gptme.server.cli:main"
gptme-eval = "gptme.eval.main:main"
gptme-util = "gptme.util.cli:main"
gptme-nc = "gptme.ncurses:main"

[tool.poetry.dependencies]
Expand Down
107 changes: 52 additions & 55 deletions tests/test_tools_rag.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Tests for the RAG tool."""

from collections.abc import Generator
from dataclasses import replace
from unittest.mock import patch

import pytest
from gptme import Message
from gptme.tools.base import ToolSpec
from gptme.tools.rag import _HAS_RAG
from gptme.tools.rag import init as init_rag
from gptme.tools.rag import rag_index, rag_search, rag_status


@pytest.fixture
Expand All @@ -35,7 +34,6 @@ def test_rag_tool_init():

def test_rag_tool_init_without_gptme_rag():
"""Test RAG tool initialization when gptme-rag is not available."""

tool = init_rag()
with (
patch("gptme.tools.rag._HAS_RAG", False),
Expand All @@ -47,73 +45,72 @@ def test_rag_tool_init_without_gptme_rag():
assert tool.available is False


def _m2str(tool_execute: Generator[Message, None, None] | Message) -> str:
"""Convert a execute() call to a string."""
if isinstance(tool_execute, Generator):
return tool_execute.send(None).content
elif isinstance(tool_execute, Message):
return tool_execute.content


def noconfirm(*args, **kwargs):
return True


@pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed")
def test_rag_index_command(temp_docs, tmp_path):
"""Test the index command."""
def test_rag_index_function(temp_docs, tmp_path):
"""Test the index function."""
with patch("gptme.tools.rag.get_project_config") as mock_config:
mock_config.return_value.rag = {
"index_path": str(tmp_path),
"collection": "test",
}

tool = init_rag()
assert tool.execute
result = _m2str(tool.execute("", ["index", str(temp_docs)], noconfirm))
assert "Indexed" in result
# Initialize RAG
init_rag()

# Check status after indexing
result = _m2str(tool.execute("", ["status"], noconfirm))
assert "Index contains" in result
assert "2" in result # Should have indexed 2 documents
# Test indexing with specific path
result = rag_index(str(temp_docs))
assert "Indexed 1 paths" in result

# Test indexing with default path
# FIXME: this is really slow in the gptme directory,
# since it contains a lot of files (which are in gitignore, but not respected)
result = rag_index(glob="**/*.py")
assert "Indexed 1 paths" in result


@pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed")
def test_rag_search_command(temp_docs):
"""Test the search command."""
tool = init_rag()
assert tool.execute
# Index first
_m2str(tool.execute("", ["index", str(temp_docs)], noconfirm))
def test_rag_search_function(temp_docs, tmp_path):
"""Test the search function."""
with patch("gptme.tools.rag.get_project_config") as mock_config:
mock_config.return_value.rag = {
"index_path": str(tmp_path),
"collection": "test",
}

# Initialize RAG and index documents
init_rag()
rag_index(str(temp_docs))

# Search for Python
result = _m2str(tool.execute("", ["search", "Python"], noconfirm))
assert "doc1.md" in result
assert "Python functions" in result
# Search for Python
result = rag_search("Python")
assert "doc1.md" in result
assert "Python functions" in result

# Search for testing
result = _m2str(tool.execute("", ["search", "testing"], noconfirm))
assert "doc2.md" in result
assert "testing practices" in result
# Search for testing
result = rag_search("testing")
assert "doc2.md" in result
assert "testing practices" in result


@pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed")
def test_rag_help_command():
"""Test the help command."""
tool = init_rag()
assert tool.execute
result = _m2str(tool.execute("", ["help"], noconfirm))
assert "Available commands" in result
assert "index" in result
assert "search" in result
assert "status" in result
def test_rag_status_function(temp_docs, tmp_path):
"""Test the status function."""
with patch("gptme.tools.rag.get_project_config") as mock_config:
mock_config.return_value.rag = {
"index_path": str(tmp_path),
"collection": "test",
}

# Initialize RAG
init_rag()

@pytest.mark.skipif(not _HAS_RAG, reason="gptme-rag not installed")
def test_rag_invalid_command():
"""Test invalid command handling."""
tool = init_rag()
assert tool.execute
result = _m2str(tool.execute("", ["invalid"], noconfirm))
assert "Unknown command" in result
# Check initial status
result = rag_status()
assert "Index contains" in result
assert "0" in result

# Index documents and check status again
rag_index(str(temp_docs))
result = rag_status()
assert "Index contains" in result
assert "2" in result # Should have indexed 2 documents
8 changes: 3 additions & 5 deletions tests/test_util_cli.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Tests for the gptme-util CLI."""

import time
from pathlib import Path
from click.testing import CliRunner
import pytest

from click.testing import CliRunner
from gptme.logmanager import ConversationMeta
from gptme.util.cli import main


Expand Down Expand Up @@ -64,8 +65,6 @@ def test_chats_list(tmp_path, mocker):
)

# Create ConversationMeta objects for our test conversations
from gptme.logmanager import ConversationMeta
import time

conv1 = ConversationMeta(
name="2024-01-01-chat-one",
Expand Down Expand Up @@ -96,7 +95,6 @@ def test_chats_list(tmp_path, mocker):
assert "Messages: 2" in result.output # Second chat has 2 messages


@pytest.mark.skip("Waiting for context module PR")
def test_context_generate(tmp_path):
"""Test the context generate command."""
# Create a test file
Expand Down

0 comments on commit 5b41de3

Please sign in to comment.