Skip to content

Phase 3: Full Modular RAG - Advanced Indexing, Fine-Tuning & Auto-Optimization #259

@manavgup

Description

@manavgup

Phase 3: Full Modular RAG - Advanced Indexing, Fine-Tuning & Auto-Optimization

Parent Epic: #256 - RAG Modulo Evolution
Depends On: #258 - Phase 2 (Early Modular RAG)

Timeline: 8-10 weeks
Priority: Medium
Complexity: High

Overview

Phase 3 completes the Modular RAG implementation with advanced capabilities: retriever fine-tuning, multi-index strategies, chunk optimization, and auto-optimization. This phase focuses on maximizing quality and creating a production-ready, self-improving system.

Current State (After Phase 2)

What We Have:

  • ✅ Complete Advanced RAG (Phase 1): Reranking, compression, multi-query, hierarchical chunking, RRF
  • ✅ Early Modular RAG (Phase 2): Semantic routing, scheduling, verification, knowledge graph foundation
  • ✅ All foundational components in place

What's Missing:

  • ❌ Retriever fine-tuning with domain data
  • ❌ Advanced indexing strategies (multi-index, chunk optimization)
  • ❌ Adaptive retrieval with performance feedback
  • ❌ Continuous learning pipeline
  • ❌ Full orchestration with dynamic pipeline assembly
  • ❌ Auto-optimization based on user feedback

Goals and Success Criteria

Goals

  1. Implement retriever fine-tuning with LM supervision
  2. Build multi-index strategy for different content types
  3. Implement chunk optimization (retrieve small, use large)
  4. Create adaptive retrieval with performance monitoring
  5. Build continuous learning and auto-optimization pipeline
  6. Complete production-ready Modular RAG system

Success Criteria

Quantitative:

  • Retrieval quality: MRR@10 > 0.90 (vs 0.80 baseline)
  • Fine-tuned retriever: +15% improvement over baseline
  • Multi-index: +10% improvement for structured content (tables, code)
  • End-to-end latency: <5s for 95th percentile
  • System uptime: >99.9%

Qualitative:

  • Production-ready with monitoring and observability
  • Self-improving through continuous learning
  • Adaptive to different domains and use cases
  • Enterprise-grade reliability and scalability

Implementation Plan

Week 15-17: Retriever Fine-Tuning

Task 1: LM-Supervised Fine-Tuning

New File: backend/rag_solution/retrieval/fine_tuning.py

"""Retriever fine-tuning with LM supervision."""

import logging
from typing import Any
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
from vectordbs.data_types import QueryResult
from core.config import Settings

logger = logging.getLogger(__name__)


class RetrieverFineTuner:
    """Fine-tune retriever with positive/negative examples."""
    
    def __init__(self, base_model: str, settings: Settings):
        self.base_model = base_model
        self.settings = settings
        self.model = None
    
    def generate_training_data(
        self,
        queries: list[str],
        collection_id: str,
        llm_judge: Any
    ) -> list[InputExample]:
        """
        Generate positive/negative pairs using LLM as judge.
        
        For each query:
        1. Retrieve candidates
        2. Use LLM to judge relevance (positive/negative)
        3. Create training pairs
        """
        training_examples = []
        
        for query in queries:
            # Retrieve candidates
            candidates = self._retrieve_candidates(query, collection_id)
            
            # Use LLM to judge relevance
            judgments = llm_judge.judge_relevance(query, candidates)
            
            # Create positive/negative pairs
            for candidate, relevance in zip(candidates, judgments):
                if relevance > 0.7:
                    # Positive pair
                    training_examples.append(
                        InputExample(texts=[query, candidate.text], label=1.0)
                    )
                elif relevance < 0.3:
                    # Negative pair (hard negative)
                    training_examples.append(
                        InputExample(texts=[query, candidate.text], label=0.0)
                    )
        
        logger.info(f"Generated {len(training_examples)} training examples")
        return training_examples
    
    def fine_tune(
        self,
        training_examples: list[InputExample],
        epochs: int = 3,
        batch_size: int = 16
    ) -> SentenceTransformer:
        """Fine-tune retriever model."""
        logger.info(f"Fine-tuning retriever with {len(training_examples)} examples")
        
        # Load base model
        self.model = SentenceTransformer(self.base_model)
        
        # Create DataLoader
        train_dataloader = DataLoader(
            training_examples,
            shuffle=True,
            batch_size=batch_size
        )
        
        # Define loss (contrastive loss for positive/negative pairs)
        train_loss = losses.CosineSimilarityLoss(self.model)
        
        # Train
        self.model.fit(
            train_objectives=[(train_dataloader, train_loss)],
            epochs=epochs,
            warmup_steps=100,
            show_progress_bar=True
        )
        
        logger.info("Fine-tuning complete")
        return self.model
    
    def save_model(self, output_path: str):
        """Save fine-tuned model."""
        if self.model:
            self.model.save(output_path)
            logger.info(f"Model saved to {output_path}")
    
    def _retrieve_candidates(
        self,
        query: str,
        collection_id: str,
        top_k: int = 50
    ) -> list[QueryResult]:
        """Retrieve candidate documents for training."""
        # Use existing retriever to get candidates
        # This should be diverse (not just top results)
        ...


class LLMJudge:
    """Use LLM to judge retrieval relevance for training data."""
    
    def __init__(self, llm_provider: Any):
        self.llm = llm_provider
    
    def judge_relevance(
        self,
        query: str,
        candidates: list[QueryResult]
    ) -> list[float]:
        """
        Judge relevance of candidates to query.
        
        Returns relevance scores (0.0-1.0) for each candidate.
        """
        scores = []
        
        for candidate in candidates:
            prompt = f"""Rate the relevance of the following passage to the query on a scale of 0-10.

Query: {query}

Passage: {candidate.text}

Relevance score (0-10):"""
            
            response = self.llm.generate(prompt, max_tokens=5)
            
            try:
                # Parse score
                score = float(response.strip()) / 10.0
                scores.append(max(0.0, min(1.0, score)))
            except ValueError:
                # Default to neutral if parsing fails
                scores.append(0.5)
        
        return scores

Integration: Create fine-tuning pipeline

# backend/rag_solution/cli/commands/fine_tune.py

import click
from rag_solution.retrieval.fine_tuning import RetrieverFineTuner, LLMJudge


@click.command()
@click.option('--collection-id', required=True, help='Collection to train on')
@click.option('--queries-file', required=True, help='File with training queries')
@click.option('--output-model', required=True, help='Output path for fine-tuned model')
@click.option('--epochs', default=3, help='Training epochs')
def fine_tune_retriever(collection_id, queries_file, output_model, epochs):
    """Fine-tune retriever for specific collection."""
    
    # Load queries
    with open(queries_file) as f:
        queries = [line.strip() for line in f]
    
    # Initialize fine-tuner
    fine_tuner = RetrieverFineTuner(
        base_model="sentence-transformers/all-MiniLM-L6-v2",
        settings=get_settings()
    )
    
    # Generate training data
    llm_judge = LLMJudge(get_llm_provider())
    training_data = fine_tuner.generate_training_data(
        queries, collection_id, llm_judge
    )
    
    # Fine-tune
    model = fine_tuner.fine_tune(training_data, epochs=epochs)
    
    # Save
    fine_tuner.save_model(output_model)
    
    click.echo(f"Fine-tuned model saved to {output_model}")

Week 18-20: Advanced Indexing

Task 2: Multi-Index Strategy

New File: backend/rag_solution/data_ingestion/multi_index.py

"""Multi-index strategy for different content types."""

from enum import Enum
from typing import Any
from vectordbs.data_types import Document, DocumentChunk, QueryResult
from vectordbs.vector_store import VectorStore


class ContentType(str, Enum):
    """Content types for multi-index."""
    TEXT = "text"
    TABLE = "table"
    CODE = "code"
    IMAGE = "image"


class MultiIndexManager:
    """Manage separate indexes for different content types."""
    
    def __init__(self, vector_store: VectorStore):
        self.vector_store = vector_store
        self.indexes = {
            ContentType.TEXT: "collection_text",
            ContentType.TABLE: "collection_tables",
            ContentType.CODE: "collection_code",
            ContentType.IMAGE: "collection_images"
        }
    
    def index_document(self, document: Document, collection_base_name: str):
        """Index document chunks into appropriate indexes."""
        
        # Classify chunks by content type
        classified_chunks = self._classify_chunks(document.chunks)
        
        # Index each type separately
        for content_type, chunks in classified_chunks.items():
            index_name = f"{collection_base_name}_{content_type.value}"
            self.vector_store.add_documents(index_name, chunks)
    
    def search_multi_index(
        self,
        query: str,
        collection_base_name: str,
        content_types: list[ContentType] | None = None,
        top_k_per_index: int = 10
    ) -> list[QueryResult]:
        """Search across multiple indexes and merge results."""
        
        if content_types is None:
            content_types = list(ContentType)
        
        # Search each index
        all_results = []
        for content_type in content_types:
            index_name = f"{collection_base_name}_{content_type.value}"
            results = self.vector_store.retrieve_documents(
                query, index_name, top_k_per_index
            )
            
            # Tag results with content type
            for result in results:
                result.metadata["content_type"] = content_type.value
            
            all_results.extend(results)
        
        # Merge and rerank
        merged = self._merge_results(all_results)
        
        return merged
    
    def _classify_chunks(
        self,
        chunks: list[DocumentChunk]
    ) -> dict[ContentType, list[DocumentChunk]]:
        """Classify chunks by content type."""
        classified = {ct: [] for ct in ContentType}
        
        for chunk in chunks:
            content_type = self._detect_content_type(chunk)
            classified[content_type].append(chunk)
        
        return classified
    
    def _detect_content_type(self, chunk: DocumentChunk) -> ContentType:
        """Detect content type of chunk."""
        metadata = chunk.metadata
        
        # Check metadata flags
        if metadata.get("table_index", 0) > 0:
            return ContentType.TABLE
        
        if metadata.get("image_index", 0) > 0:
            return ContentType.IMAGE
        
        # Detect code by patterns
        if self._is_code(chunk.text):
            return ContentType.CODE
        
        return ContentType.TEXT
    
    def _is_code(self, text: str) -> bool:
        """Detect if text is code."""
        code_indicators = [
            "def ", "class ", "function ", "import ", "return ",
            "for ", "while ", "if ", "{", "}", "[]", "()", "//", "/*"
        ]
        indicator_count = sum(1 for ind in code_indicators if ind in text)
        return indicator_count >= 3
    
    def _merge_results(
        self,
        results: list[QueryResult]
    ) -> list[QueryResult]:
        """Merge results from multiple indexes."""
        # Sort by score
        results.sort(key=lambda x: x.score, reverse=True)
        
        # Deduplicate while preserving order
        seen = set()
        merged = []
        for result in results:
            key = (result.document_id, result.text)
            if key not in seen:
                seen.add(key)
                merged.append(result)
        
        return merged

Integration: Update SearchService

from rag_solution.data_ingestion.multi_index import MultiIndexManager, ContentType

class SearchService:
    def __init__(self, db: Session, settings: Settings) -> None:
        # ... existing init ...
        self._multi_index: MultiIndexManager | None = None
    
    @property
    def multi_index(self) -> MultiIndexManager | None:
        if self.settings.enable_multi_index and self._multi_index is None:
            self._multi_index = MultiIndexManager(self.pipeline_service.vector_store)
        return self._multi_index
    
    async def search(self, search_input: SearchInput) -> SearchOutput:
        # ... existing routing/scheduling ...
        
        # MULTI-INDEX RETRIEVAL (NEW)
        if self.multi_index:
            # Detect which content types are relevant
            relevant_types = self._detect_relevant_content_types(search_input.question)
            
            logger.info(f"Searching content types: {relevant_types}")
            query_results = self.multi_index.search_multi_index(
                search_input.question,
                collection_base_name,
                content_types=relevant_types
            )
        else:
            # Standard single-index retrieval
            query_results = await self.pipeline_service.retrieve(...)
        
        # ... continue with reranking, generation ...
    
    def _detect_relevant_content_types(self, query: str) -> list[ContentType]:
        """Detect which content types are relevant to query."""
        types = [ContentType.TEXT]  # Always include text
        
        if any(kw in query.lower() for kw in ["table", "data", "statistics", "numbers"]):
            types.append(ContentType.TABLE)
        
        if any(kw in query.lower() for kw in ["code", "function", "implementation", "example"]):
            types.append(ContentType.CODE)
        
        if any(kw in query.lower() for kw in ["image", "figure", "diagram", "picture"]):
            types.append(ContentType.IMAGE)
        
        return types

Task 3: Chunk Optimization (Retrieve Small, Use Large)

Update File: backend/rag_solution/retrieval/chunk_optimizer.py

"""Chunk optimization for retrieval."""

from vectordbs.data_types import DocumentChunk, QueryResult


class ChunkOptimizer:
    """Optimize chunks for retrieval and generation."""
    
    def __init__(self, settings):
        self.settings = settings
    
    def optimize_for_retrieval(
        self,
        chunks: list[DocumentChunk]
    ) -> list[DocumentChunk]:
        """Create smaller chunks optimized for retrieval."""
        optimized = []
        
        for chunk in chunks:
            # Split into smaller chunks for better precision
            small_chunks = self._split_to_sentences(chunk.text)
            
            for i, small_text in enumerate(small_chunks):
                optimized_chunk = DocumentChunk(
                    chunk_id=f"{chunk.chunk_id}_small_{i}",
                    text=small_text,
                    embeddings=chunk.embeddings,  # Will be recomputed
                    metadata={
                        **chunk.metadata,
                        "parent_chunk_id": chunk.chunk_id,
                        "optimization": "small_for_retrieval"
                    },
                    document_id=chunk.document_id
                )
                optimized.append(optimized_chunk)
        
        return optimized
    
    def expand_for_generation(
        self,
        retrieved_results: list[QueryResult]
    ) -> list[QueryResult]:
        """Expand small chunks to include surrounding context."""
        expanded = []
        
        for result in retrieved_results:
            # Get parent chunk or surrounding context
            parent_chunk_id = result.metadata.get("parent_chunk_id")
            
            if parent_chunk_id:
                # Retrieve full parent chunk
                parent_text = self._get_parent_chunk_text(parent_chunk_id)
                
                # Create expanded result
                expanded_result = QueryResult(
                    text=parent_text,
                    score=result.score,
                    document_id=result.document_id,
                    metadata={
                        **result.metadata,
                        "original_text": result.text,
                        "expanded": True
                    }
                )
                expanded.append(expanded_result)
            else:
                expanded.append(result)
        
        return expanded
    
    def _split_to_sentences(self, text: str) -> list[str]:
        """Split text into sentences."""
        import re
        sentences = re.split(r'[.!?]+', text)
        return [s.strip() for s in sentences if len(s.strip()) > 20]
    
    def _get_parent_chunk_text(self, parent_chunk_id: str) -> str:
        """Retrieve parent chunk text from storage."""
        # Implementation depends on storage backend
        ...

Week 21-24: Auto-Optimization & Production Polish

Task 4: Adaptive Retrieval

New File: backend/rag_solution/orchestration/adaptive_retrieval.py

"""Adaptive retrieval with performance feedback."""

from collections import defaultdict
from typing import Any
from datetime import datetime, timedelta
import numpy as np


class AdaptiveRetriever:
    """Adapt retrieval strategy based on performance feedback."""
    
    def __init__(self):
        self.performance_history = defaultdict(list)
        self.strategy_scores = defaultdict(float)
    
    def record_feedback(
        self,
        query: str,
        strategy: str,
        success: bool,
        latency: float,
        user_rating: float | None = None
    ):
        """Record performance feedback for a query."""
        self.performance_history[strategy].append({
            "timestamp": datetime.utcnow(),
            "query": query,
            "success": success,
            "latency": latency,
            "user_rating": user_rating
        })
        
        # Update strategy score
        self._update_strategy_score(strategy)
    
    def select_strategy(self, query: str) -> str:
        """Select best strategy based on historical performance."""
        # Get strategy scores
        scores = {
            strategy: self._get_recent_score(strategy)
            for strategy in ["simple", "hybrid", "cot", "multi_query"]
        }
        
        # Add exploration bonus (epsilon-greedy)
        epsilon = 0.1
        if np.random.random() < epsilon:
            # Explore: random strategy
            return np.random.choice(list(scores.keys()))
        else:
            # Exploit: best strategy
            return max(scores.items(), key=lambda x: x[1])[0]
    
    def _update_strategy_score(self, strategy: str):
        """Update rolling average score for strategy."""
        recent_results = self._get_recent_results(strategy, days=7)
        
        if not recent_results:
            return
        
        # Compute average score
        success_rate = sum(r["success"] for r in recent_results) / len(recent_results)
        avg_latency = np.mean([r["latency"] for r in recent_results])
        avg_rating = np.mean([r["user_rating"] for r in recent_results if r["user_rating"]])
        
        # Combine metrics
        score = (
            0.5 * success_rate +
            0.2 * (1 / (1 + avg_latency)) +  # Lower latency = higher score
            0.3 * (avg_rating / 5.0) if avg_rating else 0
        )
        
        self.strategy_scores[strategy] = score
    
    def _get_recent_score(self, strategy: str) -> float:
        """Get recent performance score."""
        return self.strategy_scores.get(strategy, 0.5)  # Default to neutral
    
    def _get_recent_results(self, strategy: str, days: int = 7) -> list[dict]:
        """Get results from last N days."""
        cutoff = datetime.utcnow() - timedelta(days=days)
        return [
            r for r in self.performance_history[strategy]
            if r["timestamp"] > cutoff
        ]

Task 5: Continuous Learning

New File: backend/rag_solution/orchestration/continuous_learning.py

"""Continuous learning and auto-optimization."""

import logging
from typing import Any
from datetime import datetime, timedelta

logger = logging.getLogger(__name__)


class ContinuousLearner:
    """Continuously improve system based on usage patterns."""
    
    def __init__(self, db, settings):
        self.db = db
        self.settings = settings
        self.learning_interval = timedelta(days=7)
    
    async def run_learning_cycle(self):
        """Run one learning cycle."""
        logger.info("Starting continuous learning cycle")
        
        # 1. Analyze recent queries
        query_patterns = await self._analyze_query_patterns()
        
        # 2. Identify underperforming areas
        problem_areas = await self._identify_problem_areas()
        
        # 3. Generate training data for fine-tuning
        if problem_areas:
            training_data = await self._generate_training_data(problem_areas)
            await self._trigger_fine_tuning(training_data)
        
        # 4. Update retrieval strategies
        await self._update_retrieval_strategies(query_patterns)
        
        # 5. Optimize indexes
        await self._optimize_indexes()
        
        logger.info("Continuous learning cycle complete")
    
    async def _analyze_query_patterns(self) -> dict[str, Any]:
        """Analyze recent query patterns."""
        # Get queries from last week
        cutoff = datetime.utcnow() - self.learning_interval
        
        queries = await self._get_recent_queries(since=cutoff)
        
        # Analyze patterns
        patterns = {
            "common_topics": self._extract_topics(queries),
            "avg_complexity": self._compute_avg_complexity(queries),
            "content_type_distribution": self._analyze_content_types(queries)
        }
        
        return patterns
    
    async def _identify_problem_areas(self) -> list[dict]:
        """Identify areas with poor performance."""
        # Get low-rated queries
        problem_queries = await self._get_low_rated_queries()
        
        # Analyze failure modes
        problem_areas = []
        for query in problem_queries:
            issue = self._diagnose_issue(query)
            problem_areas.append({
                "query": query["question"],
                "issue": issue,
                "rating": query["rating"]
            })
        
        return problem_areas
    
    async def _generate_training_data(
        self,
        problem_areas: list[dict]
    ) -> list[Any]:
        """Generate training data for problematic queries."""
        training_data = []
        
        for problem in problem_areas:
            # Use LLM to generate better examples
            examples = await self._generate_better_examples(problem)
            training_data.extend(examples)
        
        return training_data
    
    async def _trigger_fine_tuning(self, training_data: list[Any]):
        """Trigger retriever fine-tuning with new data."""
        logger.info(f"Triggering fine-tuning with {len(training_data)} examples")
        # Schedule async fine-tuning job
        ...
    
    async def _update_retrieval_strategies(self, patterns: dict):
        """Update retrieval strategy selection based on patterns."""
        # Adjust routing thresholds based on observed patterns
        ...
    
    async def _optimize_indexes(self):
        """Optimize vector indexes."""
        logger.info("Optimizing indexes")
        # Trigger index optimization (rebuild if needed)
        ...

Files to Create/Modify

New Files (Phase 3)

  1. backend/rag_solution/retrieval/fine_tuning.py (~400 lines)
  2. backend/rag_solution/data_ingestion/multi_index.py (~300 lines)
  3. backend/rag_solution/retrieval/chunk_optimizer.py (~200 lines)
  4. backend/rag_solution/orchestration/adaptive_retrieval.py (~250 lines)
  5. backend/rag_solution/orchestration/continuous_learning.py (~400 lines)
  6. backend/rag_solution/cli/commands/fine_tune.py (~150 lines)
  7. Test files (~1200 lines total):
    • backend/tests/unit/test_fine_tuning.py
    • backend/tests/unit/test_multi_index.py
    • backend/tests/unit/test_chunk_optimizer.py
    • backend/tests/unit/test_adaptive_retrieval.py
    • backend/tests/integration/test_phase3_pipeline.py
    • backend/tests/performance/test_phase3_benchmarks.py

Modified Files (Phase 3)

  1. backend/core/config.py - Add Phase 3 settings
  2. backend/rag_solution/services/search_service.py - Integrate Phase 3 components
  3. backend/rag_solution/data_ingestion/ingestion.py - Support multi-index
  4. backend/pyproject.toml - Add ML dependencies (torch, transformers)

Total Estimate: ~3,000 lines new code, ~500 lines modifications

Testing Strategy

Unit Tests

  • Fine-tuning: Test training data generation, model training
  • Multi-index: Test classification, merging
  • Chunk optimizer: Test splitting and expansion
  • Adaptive retrieval: Test feedback recording and strategy selection

Integration Tests

  • Full pipeline with all Phase 3 features
  • End-to-end fine-tuning workflow
  • Multi-index retrieval quality

Performance Benchmarks

  • Retrieval quality: MRR@10 > 0.90
  • Fine-tuned model: +15% improvement
  • Multi-index: +10% for structured content
  • Latency: <5s for p95

Production Readiness

Monitoring

  • Performance metrics dashboard
  • Retrieval quality tracking
  • Error rate monitoring
  • Latency percentiles

Observability

  • Distributed tracing
  • Structured logging
  • Metrics collection (Prometheus)
  • Alerting (critical failures)

Scalability

  • Horizontal scaling for inference
  • Distributed indexing
  • Caching strategies
  • Load balancing

Acceptance Criteria

  • All unit tests passing (>90% coverage)
  • Retrieval quality benchmarks met
  • Fine-tuning pipeline working
  • Multi-index implementation complete
  • Continuous learning operational
  • Production monitoring in place
  • Documentation complete
  • Performance targets met

Rollout Plan

  1. Week 15-17: Fine-tuning foundation
  2. Week 18-20: Advanced indexing
  3. Week 21-22: Auto-optimization
  4. Week 23: Production hardening
  5. Week 24: Release and monitoring

Related Issues

Completion

After Phase 3, RAG Modulo will be a production-ready, full Modular RAG system with:

  • Complete Advanced RAG capabilities
  • Intelligent routing and orchestration
  • Answer verification and quality assurance
  • Self-improving through continuous learning
  • Enterprise-grade reliability and scalability

This completes the RAG Modulo Evolution roadmap!

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestinfrastructureInfrastructure and deployment

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions