diff --git a/.gitignore b/.gitignore
index 905dc80..7bec718 100644
--- a/.gitignore
+++ b/.gitignore
@@ -20,6 +20,7 @@ data/
 Llama-3.2-1B-Instruct-Complaint/
 dataset/
 cache/
+checkpoints/
 
 
 
diff --git a/README.md b/README.md
index d318f02..919b5db 100644
--- a/README.md
+++ b/README.md
@@ -285,7 +285,6 @@ The codebase explores several key areas of model adaptation:
       - Graceful shutdown
       - Comprehensive logging
 
-
 13. **Semantic Router** (`15_semantic_router.py`)
     - Semantic-based query routing system
     - Core components:
@@ -333,7 +332,7 @@ The codebase explores several key areas of model adaptation:
     - Motion data types:
       - Orientation
       - Position
-      - Velocity 
+      - Velocity
       - Acceleration
       - Angular velocity
     - Features:
@@ -352,7 +351,6 @@ The codebase explores several key areas of model adaptation:
       - Missing frame handling
       - Exception tracking
 
-
 **PyTorch Experiments** (`src/poc/`)
     - Neural Network Architecture Studies
       - ResNet Implementation (`resnet.py`, `resnet_02.py`, `resnet_03.py`)
@@ -391,7 +389,7 @@ The codebase explores several key areas of model adaptation:
       - Early stopping
       - Learning rate scheduling
 
-** 19. SIFT Fine-Tuning** (`19_llm_fine_tuning_sift.py`)
+**19. SIFT Fine-Tuning** (`19_llm_fine_tuning_sift.py`)
     - Selective Instance Fine-Tuning (SIFT) Implementation
       - Real-time model adaptation
       - Semantic similarity search
@@ -431,7 +429,6 @@ The codebase explores several key areas of model adaptation:
       - Real-time visualization
       - Training summaries
 
-
 ## Development Standards
 
 - PEP 8 and black formatting
diff --git a/src/19_llm_fine_tuning_sift.py b/src/19_llm_fine_tuning_sift.py
index 5a932f2..2faaebf 100644
--- a/src/19_llm_fine_tuning_sift.py
+++ b/src/19_llm_fine_tuning_sift.py
@@ -9,9 +9,10 @@
 from sift.sift_trainer import SIFTTrainer
 from sift.sift_visualization import SIFTVisualizer
 import sys
+import numpy as np
 
 # Configure logging
-logging.basicConfig(level=logging.ERROR, format="%(message)s")
+logging.basicConfig(level=logging.INFO, format="%(message)s")
 logger = logging.getLogger(__name__)
 
 
@@ -49,12 +50,11 @@ def main():
     metrics = MetricsComputer()
     visualizer = SIFTVisualizer()
 
-    # Initialize trainer
+    # Initialize trainer with enhanced parameters
     trainer = SIFTTrainer(
         llm_name="unsloth/Llama-3.2-1B",
         embedding_model="BAAI/bge-large-en-v1.5",
         index_dir="cache/faiss",
-        max_length=512,
     )
 
     # Sample test prompts and training data
@@ -139,20 +139,31 @@ def format_step(step: int, metrics: dict) -> str:
     # Training loop
     last_prompt_stats = None
 
+    # Add these near the start of the main() function after initializing trainer
+    metrics_tracker = {
+        'global_losses': [],
+        'prompt_losses': [],
+        'uncertainties': [],
+        'steps_per_prompt': []
+    }
+
     for prompt_idx, prompt in enumerate(test_prompts):
         try:
             # Disable all other loggers
             logging.getLogger("sift.sift_trainer").setLevel(logging.WARNING)
             logging.getLogger("tqdm").setLevel(logging.WARNING)
 
-            selected_examples = trainer.select_examples(prompt, training_data)
+            # Use enhanced selection method
+            selected_examples = trainer.select_examples_sift(prompt, training_data)
+            logger.info(f"Selected {len(selected_examples)} examples for fine-tuning")
             if not selected_examples:
+                logger.warning("No examples selected - skipping prompt")
                 continue
 
             # Re-enable logging and reprint header
             logging.getLogger("sift.sift_trainer").setLevel(logging.INFO)
 
-            clear_screen()
+            #clear_screen()
 
             # Reprint header after clear
             logger.info(format_last_prompt_summary(last_prompt_stats))
@@ -166,36 +177,57 @@ def format_step(step: int, metrics: dict) -> str:
                     if step_metrics is None:
                         continue
 
-                    current_loss = min(
-                        step_metrics.get("loss", float("inf")), max_loss_threshold
-                    )
-                    prev_loss = (
-                        prompt_stats["losses"][-1] if prompt_stats["losses"] else None
-                    )
-
+                    # Compute kernel-based uncertainty with stability check
+                    uncertainties = []
+                    for _ in range(3):  # Multiple measurements for stability
+                        uncertainty = trainer.compute_kernel_uncertainty(
+                            prompt, selected_examples[:i + 1]
+                        )
+                        if uncertainty is not None and not np.isnan(uncertainty):
+                            uncertainties.append(uncertainty)
+                            
+                    if not uncertainties:
+                        continue
+                            
+                    uncertainty = np.median(uncertainties)  # Use median for robustness
+                    current_loss = min(step_metrics.get("loss", float("inf")), max_loss_threshold)
+                    
+                    # Update tracking
+                    metrics_tracker['global_losses'].append(current_loss)
+                    metrics_tracker['uncertainties'].append(uncertainty)
+                    
                     prompt_stats["losses"].append(current_loss)
-                    prompt_stats["prompt_best"] = min(
-                        prompt_stats["prompt_best"], current_loss
-                    )
+                    prompt_stats["prompt_best"] = min(prompt_stats["prompt_best"], current_loss)
 
                     if current_loss < global_best_loss:
                         global_best_loss = current_loss
+                        # Save best model checkpoint
+                        trainer.save_checkpoint(f"checkpoints/best_model_{prompt_idx}")
 
-                    # Only log every few steps
+                    # Log progress
                     if i % 1 == 0:
                         logger.info(
                             format_step(
                                 i,
                                 {
                                     "loss": current_loss,
-                                    "prev_loss": prev_loss,
+                                    "prev_loss": prompt_stats["losses"][-2] if len(prompt_stats["losses"]) > 1 else None,
                                     "global_best": global_best_loss,
-                                    "uncertainty": step_metrics.get("uncertainty", 0),
+                                    "uncertainty": uncertainty,
                                 },
                             )
                         )
 
+                    # Enhanced stopping check with stability verification
+                    if (i >= min_examples and 
+                        trainer.should_stop_adaptive(uncertainty, i, alpha=0.1) and
+                        len(prompt_stats["losses"]) >= 3 and
+                        np.std(prompt_stats["losses"][-3:]) < 0.1):
+                        logger.info(f"Stopping early at step {i} due to convergence")
+                        break
+
                 except Exception as e:
+                    logger.error(f"Error in training step: {str(e)}")
                     continue
 
             # Store and log summary
@@ -241,6 +273,45 @@ def format_step(step: int, metrics: dict) -> str:
 
     logger.info("=" * 89)
 
+    # After training loop, add visualization
+    if prompt_stats_history:
+        metrics_data = {
+            "loss": [stat["losses"] for stat in prompt_stats_history],
+            "uncertainty": [
+                stat.get("uncertainties", []) for stat in prompt_stats_history
+            ],
+        }
+
+        visualizer.plot_metrics_over_time(metrics_data)
+        visualizer.plot_uncertainty_vs_performance(
+            uncertainty=[
+                stat.get("uncertainties", [])[-1]
+                for stat in prompt_stats_history
+                if stat.get("uncertainties")
+            ],
+            performance=[stat["prompt_best"] for stat in prompt_stats_history],
+            save_path="uncertainty_vs_performance.png",
+        )
+        visualizer.plot_adaptive_stopping(
+            metrics={
+                "uncertainty": [
+                    u
+                    for stat in prompt_stats_history
+                    for u in stat.get("uncertainties", [])
+                ],
+                "compute": list(
+                    range(
+                        sum(
+                            len(stat.get("uncertainties", []))
+                            for stat in prompt_stats_history
+                        )
+                    )
+                ),
+            },
+            alpha=0.1,
+            save_path="adaptive_stopping.png",
+        )
+
 
 if __name__ == "__main__":
     main()
diff --git a/src/sift/sift_data_loading.py b/src/sift/sift_data_loading.py
index 53f55c7..4932720 100644
--- a/src/sift/sift_data_loading.py
+++ b/src/sift/sift_data_loading.py
@@ -16,7 +16,7 @@
         "Please install it with: pip install zstandard"
     )
 
-logging.basicConfig(level=logging.ERROR)
+logging.basicConfig(level=logging.INFO)
 logger = logging.getLogger(__name__)
 
 
diff --git a/src/sift/sift_embedding_cache.py b/src/sift/sift_embedding_cache.py
index 09263ac..493634a 100644
--- a/src/sift/sift_embedding_cache.py
+++ b/src/sift/sift_embedding_cache.py
@@ -4,7 +4,7 @@
 from typing import Tuple
 
 logger = logging.getLogger(__name__)
-logger.setLevel(logging.ERROR)
+logger.setLevel(logging.INFO)
 
 
 def clear_embedding_cache(cache_dir: str = "cache/embeddings"):
diff --git a/src/sift/sift_eval.py b/src/sift/sift_eval.py
index 2b7c183..81a21a2 100644
--- a/src/sift/sift_eval.py
+++ b/src/sift/sift_eval.py
@@ -11,76 +11,3 @@ class EvaluationMetrics:
     bits_per_byte: float
     perplexity: float
     uncertainty: float
-
-
-class MetricsComputer:
-    """Compute and track evaluation metrics."""
-
-    def __init__(self):
-        self.metrics_history = []
-
-    def compute_bits_per_byte(
-        self, logits: torch.Tensor, labels: torch.Tensor
-    ) -> float:
-        """Compute bits per byte metric."""
-        loss = torch.nn.functional.cross_entropy(
-            logits.view(-1, logits.size(-1)), labels.view(-1), reduction="mean"
-        )
-        return (loss / np.log(2)).item()
-
-    def compute_perplexity(self, logits: torch.Tensor, labels: torch.Tensor) -> float:
-        """Compute perplexity."""
-        loss = torch.nn.functional.cross_entropy(
-            logits.view(-1, logits.size(-1)), labels.view(-1), reduction="mean"
-        )
-        return torch.exp(loss).item()
-
-    def compute_metrics(
-        self, 
-        outputs: Dict[str, Any], 
-        labels: Optional[torch.Tensor] = None, 
-        uncertainty: Optional[float] = None
-    ) -> Dict[str, float]:
-        """Compute training metrics."""
-        metrics = {
-            'loss': outputs['loss'],
-            'uncertainty': uncertainty if uncertainty is not None else 0.0
-        }
-        
-        if outputs['logits'] is not None and labels is not None:
-            # Add any additional metrics computation here
-            pass
-            
-        return metrics
-
-    def get_metrics_summary(self) -> Dict[str, List[float]]:
-        """Get summary of tracked metrics."""
-        return {
-            "bits_per_byte": [m.bits_per_byte for m in self.metrics_history],
-            "perplexity": [m.perplexity for m in self.metrics_history],
-            "uncertainty": [m.uncertainty for m in self.metrics_history],
-        }
-
-
-class AdaptiveStoppingMetrics:
-    """Track metrics for adaptive stopping."""
-
-    def __init__(self, alpha: float = 0.1):
-        self.alpha = alpha
-        self.uncertainty_history = []
-        self.compute_history = []
-
-    def should_stop(self, uncertainty: float, step: int) -> bool:
-        """Determine if training should stop based on uncertainty."""
-        if step < 5:  # Minimum number of steps
-            return False
-        
-        # Add your stopping criterion here
-        return uncertainty < self.alpha
-
-    def get_stopping_summary(self) -> Dict[str, list]:
-        """Get summary of stopping metrics."""
-        return {
-            "uncertainty": self.uncertainty_history,
-            "compute": self.compute_history,
-        }
diff --git a/src/sift/sift_metrics.py b/src/sift/sift_metrics.py
index 7275289..1f4a54d 100644
--- a/src/sift/sift_metrics.py
+++ b/src/sift/sift_metrics.py
@@ -3,6 +3,7 @@
 import logging
 
 logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
 
 class MetricsComputer:
     def __init__(self):
@@ -94,24 +95,32 @@ def __init__(self, alpha: float = 0.1, window_size: int = 5):
         self.stopping_points = []
         self.uncertainty_history = []
     
-    def should_stop(self, uncertainty: float, step: int) -> bool:
-        """Determine if training should stop based on uncertainty."""
+    def should_stop(self, uncertainty: float, step: int, loss_history: List[float]) -> bool:
+        """Enhanced stopping criterion using both uncertainty and loss trends."""
         try:
             self.uncertainty_history.append(uncertainty)
             
             if step < self.window_size:
                 return False
             
-            # Get recent uncertainties
+            # Get recent metrics
             recent_uncertainties = self.uncertainty_history[-self.window_size:]
-            avg_uncertainty = sum(recent_uncertainties) / len(recent_uncertainties)
+            recent_losses = loss_history[-self.window_size:]
             
-            # Stop if average uncertainty is below threshold
-            should_stop = avg_uncertainty < self.alpha
+            # Compute trends
+            uncertainty_trend = np.polyfit(range(len(recent_uncertainties)), recent_uncertainties, 1)[0]
+            loss_trend = np.polyfit(range(len(recent_losses)), recent_losses, 1)[0]
+            
+            # Stop if both metrics are stable or improving
+            should_stop = (
+                uncertainty_trend <= 0 and  # Uncertainty is decreasing
+                loss_trend <= 0 and        # Loss is decreasing
+                np.mean(recent_uncertainties) < self.alpha
+            )
             
             if should_stop:
                 self.stopping_points.append(step)
-                logger.info(f"Stopping at step {step} with uncertainty {avg_uncertainty:.4f}")
+                logger.info(f"Stopping at step {step} with uncertainty {np.mean(recent_uncertainties):.4f}")
             
             return should_stop
             
diff --git a/src/sift/sift_trainer.py b/src/sift/sift_trainer.py
index 978f32f..1cd2edc 100644
--- a/src/sift/sift_trainer.py
+++ b/src/sift/sift_trainer.py
@@ -18,8 +18,12 @@
 import torch.cuda
 import faiss
 from sentence_transformers import SentenceTransformer
+from functools import wraps
+import signal
+from .sift_metrics import MetricsComputer, AdaptiveStoppingMetrics
+from .sift_visualization import SIFTVisualizer
 
-logging.basicConfig(level=logging.ERROR)
+logging.basicConfig(level=logging.INFO)
 logger = logging.getLogger(__name__)
 
 dotenv.load_dotenv()
@@ -29,77 +33,129 @@
 @dataclass
 class SIFTConfig:
     """Configuration for SIFT and test-time fine-tuning"""
-
     lambda_param: float = 0.1
     num_candidates: int = 5
     batch_size: int = 1
     learning_rate: float = 5e-5
     max_length: int = 512
-    device: str = (
-        "cuda"
-        if torch.cuda.is_available()
-        else "mps" if torch.backends.mps.is_available() else "cpu"
-    )
+    device: str = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
+
+
+def timeout(seconds):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            def handler(signum, frame):
+                raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds")
+            
+            # Set the timeout handler
+            signal.signal(signal.SIGALRM, handler)
+            signal.alarm(seconds)
+            
+            try:
+                result = func(*args, **kwargs)
+            finally:
+                # Disable the alarm
+                signal.alarm(0)
+            return result
+        return wrapper
+    return decorator
 
 
 class SIFTTrainer:
     def __init__(
         self,
-        llm_name: str = "unsloth/Llama-3.2-1B",
-        embedding_model: str = "BAAI/bge-large-en-v1.5",
+        llm_name: str,
+        embedding_model: str,
         index_dir: str = "cache/faiss",
-        max_length: int = 512,
+        cache_dir: str = "cache/embeddings",
+        window_size: int = 5,
+        min_steps: int = 3,
+        lambda_param: float = 0.1,
+        max_length: Optional[int] = None,
         embedding_dim: int = 1024,
-        uncertainty_buffer_size: int = 100,
-        gradient_accumulation_steps: int = 8,
-        max_grad_norm: float = 1.0,
+        config: Optional[SIFTConfig] = None,
     ):
-        """Initialize trainer with all required attributes."""
-        self.device = torch.device("cpu")
-        self.embedding_dim = embedding_dim  # Store embedding dimension
-        self.max_length = max_length
-
-        # Initialize paths
+        """Initialize SIFT trainer with models and configuration."""
+        # Initialize config first
+        self.config = config or SIFTConfig(
+            lambda_param=lambda_param,
+            max_length=max_length if max_length is not None else 512
+        )
+        
+        # Set device and dtype
+        if torch.backends.mps.is_available():
+            self.device = torch.device("mps")
+            self.dtype = torch.float32  # MPS requires float32
+            torch.mps.empty_cache()
+        else:
+            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+            self.dtype = torch.float32  # Use float32 consistently
+        
+        # Initialize models and tokenizer with consistent dtype
+        logger.info(f"Loading models from {llm_name} to {self.device} with dtype {self.dtype}")
+        self.model = AutoModelForCausalLM.from_pretrained(
+            llm_name,
+            torch_dtype=self.dtype,
+            device_map={"": self.device},
+            use_flash_attention_2=False  # Disable flash attention to avoid dtype issues
+        )
+        
+        # Convert model parameters to consistent dtype
+        self.model = self.model.to(dtype=self.dtype)
+        
+        # Initialize tokenizer and embedding model
+        self.tokenizer = AutoTokenizer.from_pretrained(llm_name)
+        self.embedding_model = SentenceTransformer(embedding_model)
+        self.embedding_model = self.embedding_model.to(self.device)
+        
+        # Basic parameters
+        self.lambda_param = self.config.lambda_param
+        self.window_size = window_size
+        self.min_steps = min_steps
+        self.embedding_dim = embedding_dim
+        
+        # Initialize tracking components
+        self.metrics_computer = MetricsComputer()
+        self.adaptive_stopping = AdaptiveStoppingMetrics(
+            alpha=0.1, 
+            window_size=self.window_size
+        )
+        self.visualizer = SIFTVisualizer()
+        
+        # Initialize buffers
+        self._uncertainty_buffer = []
+        self._last_loss = None
+        
+        # Setup directories
+        self.cache_dir = Path(cache_dir)
+        self.cache_dir.mkdir(parents=True, exist_ok=True)
+        
         self.index_dir = Path(index_dir)
         self.index_dir.mkdir(parents=True, exist_ok=True)
         self.index_path = self.index_dir / "embeddings.faiss"
         self.metadata_path = self.index_dir / "metadata.pkl"
-
-        # Initialize FAISS index and mappings
+        
+        # Load or create FAISS index
         self.index, self.text_to_id, self.id_to_text = self._load_or_create_index()
-
-        # Load embedding model
-        self.embedding_model = SentenceTransformer(embedding_model, device=self.device)
-
-        # Initialize tokenizer and model
-        self.tokenizer = AutoTokenizer.from_pretrained(llm_name)
-        self.model = AutoModelForCausalLM.from_pretrained(
-            llm_name, torch_dtype=torch.float32, low_cpu_mem_usage=True
-        )
-
-        self.model.eval()
-
-        # Training settings
-        self.gradient_accumulation_steps = gradient_accumulation_steps
-        self.current_accumulation_step = 0
-
-        # Initialize optimizer
+        
+        # Initialize optimizer and scheduler
         self.optimizer = torch.optim.AdamW(
-            self.model.parameters(), lr=5e-5, weight_decay=0.01
+            self.model.parameters(),
+            lr=self.config.learning_rate,
+            weight_decay=0.01
         )
-        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
-            self.optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6
+        
+        # Initialize scheduler (linear warmup and decay)
+        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
+            self.optimizer,
+            max_lr=self.config.learning_rate,
+            total_steps=1000,  # Adjust based on expected steps
+            pct_start=0.1
         )
-
-        self.max_grad_norm = max_grad_norm
-
-        # Initialize uncertainty tracking
-        self._uncertainty_buffer = []
-        self._uncertainty_buffer_size = uncertainty_buffer_size
-        self._last_loss = None
-
-        logger.info(f"Using device: {self.device}")
-        logger.info(f"Model loaded with dtype: {self.model.dtype}")
+        
+        # Initialize step counter
+        self.current_accumulation_step = 0
 
     def _load_or_create_index(
         self,
@@ -146,67 +202,160 @@ def save_index(self):
     def compute_embedding(self, text: str) -> Optional[np.ndarray]:
         """Compute embedding using dedicated embedding model."""
         try:
-            # Check if text is already in index
-            if text in self.text_to_id:
-                vector_id = self.text_to_id[text]
-                reconstructed = np.zeros((1, self.embedding_dim), dtype=np.float32)
-                self.index.reconstruct(vector_id, reconstructed[0])
-                return reconstructed
-
+            # Try cache first
+            cache_key = hashlib.md5(text.encode()).hexdigest()
+            cache_path = self.cache_dir / f"{cache_key}.npy"
+            
+            if cache_path.exists():
+                embedding = np.load(cache_path)
+                # Ensure 2D shape
+                if len(embedding.shape) == 1:
+                    embedding = embedding.reshape(1, -1)
+                return embedding
+            
             # Compute new embedding
             embedding = self.embedding_model.encode(
                 text,
-                normalize_embeddings=True,  # Important for cosine similarity
+                normalize_embeddings=True,
                 show_progress_bar=False,
             )
 
-            # Reshape for FAISS
+            # Ensure 2D shape
             embedding_np = np.array(embedding).reshape(1, -1)
 
-            # Add to index
-            vector_id = self.index.ntotal
-            self.index.add(embedding_np)
-            self.text_to_id[text] = vector_id
-            self.id_to_text[vector_id] = text
-
+            # Cache the result
+            np.save(cache_path, embedding_np)
             return embedding_np
 
         except Exception as e:
             logger.error(f"Error computing embedding: {e}")
             return None
 
-    def select_examples(
+    @timeout(10)  # 10 second timeout for kernel computation
+    def compute_kernel_uncertainty(self, x_star: str, selected_points: List[str]) -> float:
+        try:
+            # Get embeddings
+            x_star_emb = self.compute_embedding(x_star)
+            if x_star_emb is None or len(selected_points) == 0:
+                return float('inf')
+            
+            # Ensure x_star_emb is 2D
+            if len(x_star_emb.shape) == 1:
+                x_star_emb = x_star_emb.reshape(1, -1)
+            
+            # Compute embeddings with timeout protection
+            selected_embs = []
+            for point in selected_points:
+                emb = self.compute_embedding(point)
+                if emb is None:
+                    return float('inf')
+                # Ensure each embedding is 2D
+                if len(emb.shape) == 1:
+                    emb = emb.reshape(1, -1)
+                selected_embs.append(emb)
+            
+            # Stack embeddings and ensure correct shape
+            selected_embs = np.vstack([emb.reshape(1, -1) for emb in selected_embs])
+            
+            # Add small epsilon for numerical stability
+            epsilon = 1e-8
+            
+            # Compute kernel values with correct shapes
+            k_xx = np.dot(x_star_emb, x_star_emb.T) + epsilon
+            K_X = self.compute_kernel_matrix_batch(selected_embs) + epsilon * np.eye(len(selected_points))
+            k_X = np.dot(x_star_emb, selected_embs.T)
+            
+            # Add regularization
+            K_X_reg = K_X + self.lambda_param * np.eye(len(selected_points))
+            
+            try:
+                # Ensure shapes match for solve operation
+                if k_X.shape[1] != K_X_reg.shape[0]:
+                    logger.error(f"Shape mismatch: k_X: {k_X.shape}, K_X_reg: {K_X_reg.shape}")
+                    return float('inf')
+                    
+                # Use more stable solver with condition number check
+                if np.linalg.cond(K_X_reg) > 1e10:
+                    logger.warning("Poorly conditioned matrix, adding more regularization")
+                    K_X_reg += 0.1 * np.eye(len(selected_points))
+                    
+                # Reshape for solve operation
+                k_X = k_X.reshape(-1, 1)
+                uncertainty = float(k_xx - k_X.T @ np.linalg.solve(K_X_reg, k_X))
+                
+                # Validate output
+                if np.isnan(uncertainty) or np.isinf(uncertainty):
+                    return float('inf')
+                    
+                return float(np.clip(uncertainty, 0, 10.0))  # Clip to reasonable range
+                
+            except np.linalg.LinAlgError as e:
+                logger.warning(f"Matrix inversion failed: {e}")
+                return float('inf')
+                
+        except Exception as e:
+            logger.error(f"Error in kernel uncertainty computation: {e}")
+            return float('inf')
+
+    def select_examples_sift(
         self, prompt: str, candidates: List[str], n_examples: int = 5
     ) -> List[str]:
-        """Select examples using FAISS index."""
+        """Select examples using SIFT with improved robustness and logging."""
         try:
-            # Compute prompt embedding
-            prompt_embedding = self.compute_embedding(prompt)
-            if prompt_embedding is None:
-                raise ValueError("Failed to compute prompt embedding")
-
-            # Add candidates to index if not already present
-            for candidate in tqdm(candidates, desc="Processing candidates"):
-                if candidate not in self.text_to_id:
-                    self.compute_embedding(candidate)
-
-            # Search for nearest neighbors
-            D, I = self.index.search(prompt_embedding, n_examples)
-
-            # Get corresponding texts
-            selected = [self.id_to_text[int(i)] for i in I[0]]
-            logger.info(f"Selected {len(selected)} examples using FAISS")
-
-            # Save updated index
-            self.save_index()
-
+            selected = []
+            prompt_emb = self.compute_embedding(prompt)
+            
+            if prompt_emb is None:
+                logger.error("Failed to compute prompt embedding")
+                return selected
+            
+            # Add progress tracking
+            from tqdm import tqdm
+            
+            for i in range(n_examples):
+                min_uncertainty = float("inf")
+                best_candidate = None
+                
+                # Track candidates processing
+                for candidate in tqdm(candidates, desc=f"Processing candidates for example {i+1}/{n_examples}", leave=False):
+                    if candidate in selected:
+                        continue
+                        
+                    try:
+                        # Compute uncertainty with timeout protection
+                        test_selected = selected + [candidate]
+                        uncertainty = self.compute_kernel_uncertainty(prompt, test_selected)
+                        
+                        # Skip invalid uncertainties
+                        if uncertainty is None or np.isnan(uncertainty) or np.isinf(uncertainty):
+                            continue
+                            
+                        if uncertainty < min_uncertainty:
+                            min_uncertainty = uncertainty
+                            best_candidate = candidate
+                            
+                    except Exception as e:
+                        logger.warning(f"Error processing candidate: {str(e)}")
+                        continue
+                
+                # Check if we found a valid candidate
+                if best_candidate:
+                    selected.append(best_candidate)
+                    logger.info(f"Selected example {i+1}/{n_examples} with uncertainty: {min_uncertainty:.4f}")
+                else:
+                    logger.warning(f"No valid candidate found for example {i+1}")
+                    break
+                    
+                # Clear memory periodically
+                if i % 2 == 0:
+                    self.clear_memory()
+            
+            logger.info(f"Selected {len(selected)}/{n_examples} examples")
             return selected
-
+            
         except Exception as e:
             logger.error(f"Error in example selection: {str(e)}")
-            logger.error("Falling back to random selection")
-            indices = np.random.choice(len(candidates), n_examples, replace=False)
-            return [candidates[idx] for idx in indices]
+            return selected
 
     def compute_metrics(self, outputs) -> Dict[str, float]:
         """Compute metrics from model outputs."""
@@ -229,21 +378,30 @@ def compute_metrics(self, outputs) -> Dict[str, float]:
             logger.error(f"Error computing metrics: {e}")
             return None
 
+    @timeout(30)
     def fine_tune_step(self, example: str) -> Optional[Dict[str, Any]]:
-        """Perform fine-tuning step with proper metrics."""
         try:
             gc.collect()
+            if torch.backends.mps.is_available():
+                torch.mps.empty_cache()
 
             # Tokenize
             inputs = self.tokenizer(
                 example,
                 padding=True,
                 truncation=True,
-                max_length=128,
+                max_length=self.config.max_length,
                 return_tensors="pt",
             )
-
-            labels = inputs["input_ids"].clone()
+            
+            # Move to device and ensure correct dtype
+            inputs = {
+                "input_ids": self._ensure_tensor_dtype(inputs["input_ids"], is_index=True),
+                "attention_mask": self._ensure_tensor_dtype(inputs["attention_mask"], is_index=True)
+            }
+            
+            # Clone labels and ensure they're long type
+            labels = self._ensure_tensor_dtype(inputs["input_ids"].clone(), is_index=True)
 
             # Forward pass
             self.model.train()
@@ -253,53 +411,7 @@ def fine_tune_step(self, example: str) -> Optional[Dict[str, Any]]:
                 labels=labels,
             )
 
-            # Compute loss and backward
-            loss = outputs.loss / self.gradient_accumulation_steps
-            loss.backward()
-
-            # Add gradient clipping
-            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
-
-            # Update weights if needed
-            self.current_accumulation_step += 1
-            if self.current_accumulation_step >= self.gradient_accumulation_steps:
-                self.optimizer.step()
-
-                # Update learning rate based on loss
-                self.scheduler.step(loss.item())
-
-                self.optimizer.zero_grad()
-                self.current_accumulation_step = 0
-
-            self.model.eval()
-
-            # Update uncertainty
-            loss_value = loss.item() * self.gradient_accumulation_steps
-            self._last_loss = loss_value
-            self._uncertainty_buffer.append(loss_value)
-            if len(self._uncertainty_buffer) > self._uncertainty_buffer_size:
-                self._uncertainty_buffer.pop(0)
-
-            # Compute metrics
-            metrics = self.compute_metrics(outputs)
-            if metrics is None:
-                return None
-
-            # Add training info
-            metrics.update(
-                {
-                    "labels": labels.detach(),
-                    "logits": (
-                        outputs.logits.detach() if hasattr(outputs, "logits") else None
-                    ),
-                }
-            )
-
-            # Clear memory
-            del outputs
-            gc.collect()
-
-            return metrics
+            return {"loss": outputs.loss.item()}
 
         except Exception as e:
             logger.error(f"Error in fine-tuning step: {str(e)}")
@@ -365,39 +477,182 @@ def clear_cache(self):
         except Exception as e:
             logger.error(f"Failed to clear cache: {e}")
 
-
-class MetricsComputer:
-    def __init__(self):
-        self.metrics_history = {
-            "bits_per_byte": [],
-            "perplexity": [],
-            "uncertainty": [],
+    def compute_kernel_matrix_batch(self, embeddings: np.ndarray) -> np.ndarray:
+        try:
+            return np.dot(embeddings, embeddings.T)
+        except Exception as e:
+            logger.error(f"Error in kernel computation: {e}")
+            return np.zeros((embeddings.shape[0], embeddings.shape[0]))
+
+    def get_training_summary(self) -> Dict[str, List[float]]:
+        """Get complete training summary for visualization."""
+        return {
+            "loss": self._uncertainty_buffer,
+            "uncertainty": [self.compute_uncertainty(None) for _ in self._uncertainty_buffer],
+            "compute": list(range(len(self._uncertainty_buffer)))
         }
 
-    def compute_metrics(
-        self, outputs: Dict[str, Any], uncertainty: float = None
-    ) -> Dict[str, float]:
-        """Compute and store metrics."""
-        if outputs is None:
-            return None
+    def clear_memory(self):
+        import gc
+        gc.collect()
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
 
-        metrics = {
-            "bits_per_byte": outputs.get("bits_per_byte", 0.0),
-            "perplexity": outputs.get("perplexity", 0.0),
-            "uncertainty": outputs.get(
-                "uncertainty", uncertainty if uncertainty is not None else 0.0
-            ),
-        }
-
-        # Store metrics
-        for key, value in metrics.items():
-            self.metrics_history[key].append(value)
+    def should_stop_adaptive(
+        self, uncertainty: float, step: int, alpha: float = 0.1
+    ) -> bool:
+        """Enhanced adaptive stopping criterion with stability checks."""
+        try:
+            if step < self.min_steps:
+                return False
+            
+            # Add uncertainty to buffer
+            self._uncertainty_buffer.append(uncertainty)
+            
+            # Keep buffer size manageable
+            if len(self._uncertainty_buffer) > self.window_size:
+                self._uncertainty_buffer.pop(0)
+            
+            # Get recent uncertainties with outlier removal
+            recent_uncertainties = np.array(self._uncertainty_buffer[-self.window_size:])
+            if len(recent_uncertainties) < self.window_size:
+                return False
+            
+            # Remove outliers using IQR method
+            q1, q3 = np.percentile(recent_uncertainties, [25, 75])
+            iqr = q3 - q1
+            mask = (recent_uncertainties >= q1 - 1.5 * iqr) & (recent_uncertainties <= q3 + 1.5 * iqr)
+            filtered_uncertainties = recent_uncertainties[mask]
+            
+            if len(filtered_uncertainties) < 3:  # Require minimum number of valid points
+                return False
+            
+            # Compute robust statistics
+            avg_uncertainty = np.median(filtered_uncertainties)
+            uncertainty_std = np.std(filtered_uncertainties)
+            
+            # Dynamic threshold based on training progress
+            threshold = alpha * (1 + 1/np.sqrt(1 + step))
+            
+            # Check both absolute and relative stability
+            is_stable = uncertainty_std < 0.1 * avg_uncertainty
+            is_low = avg_uncertainty < threshold
+            
+            should_stop = is_stable and is_low
+            
+            if should_stop:
+                logger.info(f"Stopping at step {step} with uncertainty {avg_uncertainty:.4f} (threshold: {threshold:.4f})")
+            
+            return should_stop
+            
+        except Exception as e:
+            logger.error(f"Error in stopping criterion: {e}")
+            return False
 
-        return metrics
+    def save_checkpoint(self, path: str):
+        """Save a checkpoint of the current model state."""
+        try:
+            checkpoint_dir = Path(path)
+            checkpoint_dir.mkdir(parents=True, exist_ok=True)
+            
+            checkpoint = {
+                'model_state_dict': self.model.state_dict(),
+                'uncertainty_buffer': self._uncertainty_buffer,
+                'current_accumulation_step': getattr(self, 'current_accumulation_step', 0),
+            }
+            
+            # Add optimizer and scheduler if they exist
+            if hasattr(self, 'optimizer'):
+                checkpoint['optimizer_state_dict'] = self.optimizer.state_dict()
+            if hasattr(self, 'scheduler'):
+                checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
+            
+            torch.save(checkpoint, checkpoint_dir / 'checkpoint.pt')
+            logger.info(f"Saved checkpoint to {path}")
+        except Exception as e:
+            logger.error(f"Failed to save checkpoint: {e}")
+
+    def update_ema_loss(self, current_loss: float, alpha: float = 0.1) -> float:
+        """Update exponential moving average of loss."""
+        if not hasattr(self, '_ema_loss'):
+            self._ema_loss = current_loss
+        else:
+            self._ema_loss = alpha * current_loss + (1 - alpha) * self._ema_loss
+        return self._ema_loss
+
+    def update_and_visualize_metrics(
+        self, 
+        current_loss: float, 
+        uncertainty: float, 
+        step: int,
+        save_path: Optional[str] = None
+    ) -> Dict[str, Any]:
+        """Update metrics and generate visualizations."""
+        # Update EMA loss
+        ema_loss = self.update_ema_loss(current_loss)
+        
+        # Compute metrics
+        metrics = self.metrics_computer.compute_metrics(
+            {'loss': current_loss}, 
+            uncertainty=uncertainty
+        )
+        
+        # Check stopping condition using both metrics
+        should_stop = self.adaptive_stopping.should_stop(
+            uncertainty=uncertainty,
+            step=step,
+            loss_history=self.metrics_computer.metrics_history['loss']
+        )
+        
+        # Generate visualizations periodically
+        if step % 10 == 0:
+            self.visualizer.plot_adaptive_stopping(
+                metrics=self.get_training_summary(),
+                alpha=self.adaptive_stopping.alpha,
+                title=f"Training Progress - Step {step}",
+                save_path=save_path
+            )
+        
+        return {
+            'metrics': metrics,
+            'should_stop': should_stop,
+            'ema_loss': ema_loss
+        }
 
-    def get_metrics_summary(self) -> Dict[str, List[float]]:
-        """Get the history of all metrics."""
-        return self.metrics_history
+    def generate_training_summary(self, save_dir: str = "training_summary"):
+        """Generate comprehensive training summary and visualizations."""
+        Path(save_dir).mkdir(parents=True, exist_ok=True)
+        
+        # Get complete training history
+        summary = self.get_training_summary()
+        
+        # Save metrics to JSON
+        with open(f"{save_dir}/metrics.json", "w") as f:
+            json.dump(summary, f, indent=2)
+        
+        # Generate final visualizations
+        self.visualizer.plot_adaptive_stopping(
+            metrics=summary,
+            alpha=self.adaptive_stopping.alpha,
+            title="Final Training Summary",
+            save_path=f"{save_dir}/final_stopping_analysis.png"
+        )
+        
+        # Save stopping points
+        stopping_summary = self.adaptive_stopping.get_stopping_summary()
+        with open(f"{save_dir}/stopping_points.json", "w") as f:
+            json.dump(stopping_summary, f, indent=2)
+
+    def _ensure_tensor_dtype(self, tensor: torch.Tensor, is_index: bool = False) -> torch.Tensor:
+        """Ensure tensor has correct dtype."""
+        if is_index:
+            # For indices (input_ids, attention masks), use long
+            tensor = tensor.to(dtype=torch.long)
+        else:
+            # For other tensors, use configured dtype
+            if tensor.dtype != self.dtype:
+                tensor = tensor.to(dtype=self.dtype)
+        return tensor.to(self.device)
 
 
 def main():
@@ -419,7 +674,7 @@ def main():
     ]
 
     logger.info("Testing example selection...")
-    selected = trainer.select_examples(prompt, candidates)
+    selected = trainer.select_examples_sift(prompt, candidates)
     logger.info("Selected examples:")
     for i, example in enumerate(selected, 1):
         logger.info(f"{i}. {example}")
diff --git a/src/sift/sift_visualization.py b/src/sift/sift_visualization.py
index 67cab8c..bbe3515 100644
--- a/src/sift/sift_visualization.py
+++ b/src/sift/sift_visualization.py
@@ -87,3 +87,39 @@ def plot_adaptive_stopping(
         if save_path:
             plt.savefig(save_path)
         plt.show()
+
+    def plot_training_summary(
+        self,
+        losses: List[float],
+        uncertainties: List[float],
+        save_path: Optional[str] = None
+    ):
+        """Plot comprehensive training summary."""
+        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
+        
+        # Loss over time
+        ax1.plot(losses)
+        ax1.set_title("Loss Progression")
+        ax1.set_xlabel("Step")
+        ax1.set_ylabel("Loss")
+        
+        # Uncertainty over time
+        ax2.plot(uncertainties)
+        ax2.set_title("Uncertainty Progression")
+        ax2.set_xlabel("Step")
+        ax2.set_ylabel("Uncertainty")
+        
+        # Loss distribution
+        ax3.hist(losses, bins=30)
+        ax3.set_title("Loss Distribution")
+        
+        # Loss vs Uncertainty
+        ax4.scatter(uncertainties, losses, alpha=0.5)
+        ax4.set_title("Loss vs Uncertainty")
+        ax4.set_xlabel("Uncertainty")
+        ax4.set_ylabel("Loss")
+        
+        plt.tight_layout()
+        if save_path:
+            plt.savefig(save_path)
+        plt.close()