From 3411341905cd6bc2d6fec19649c9025698198518 Mon Sep 17 00:00:00 2001 From: Leon van Bokhorst Date: Sun, 17 Nov 2024 16:10:31 +0100 Subject: [PATCH] feat(trainer): Enhance SIFT trainer initialization and metrics handling - Add proper device/dtype handling for Mac M3 compatibility - Initialize optimizer (AdamW) and scheduler (OneCycleLR) - Add robust checkpoint saving with optimizer state - Implement EMA loss tracking and visualization - Add comprehensive metrics computation and tracking - Fix tensor dtype handling for embedding indices - Add adaptive stopping criterion with uncertainty tracking Key changes: - SIFTTrainer initialization (lines 65-158) - Metrics computation and visualization (lines 583-620) - Tensor dtype handling (lines 646-655) - Checkpoint management (lines 552-573) This commit improves training stability and adds better monitoring capabilities while ensuring compatibility across different hardware configurations. --- .gitignore | 1 + README.md | 7 +- src/19_llm_fine_tuning_sift.py | 107 +++++- src/sift/sift_data_loading.py | 2 +- src/sift/sift_embedding_cache.py | 2 +- src/sift/sift_eval.py | 73 ---- src/sift/sift_metrics.py | 23 +- src/sift/sift_trainer.py | 599 ++++++++++++++++++++++--------- src/sift/sift_visualization.py | 36 ++ 9 files changed, 573 insertions(+), 277 deletions(-) diff --git a/.gitignore b/.gitignore index 905dc80..7bec718 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ data/ Llama-3.2-1B-Instruct-Complaint/ dataset/ cache/ +checkpoints/ diff --git a/README.md b/README.md index d318f02..919b5db 100644 --- a/README.md +++ b/README.md @@ -285,7 +285,6 @@ The codebase explores several key areas of model adaptation: - Graceful shutdown - Comprehensive logging - 13. **Semantic Router** (`15_semantic_router.py`) - Semantic-based query routing system - Core components: @@ -333,7 +332,7 @@ The codebase explores several key areas of model adaptation: - Motion data types: - Orientation - Position - - Velocity + - Velocity - Acceleration - Angular velocity - Features: @@ -352,7 +351,6 @@ The codebase explores several key areas of model adaptation: - Missing frame handling - Exception tracking - **PyTorch Experiments** (`src/poc/`) - Neural Network Architecture Studies - ResNet Implementation (`resnet.py`, `resnet_02.py`, `resnet_03.py`) @@ -391,7 +389,7 @@ The codebase explores several key areas of model adaptation: - Early stopping - Learning rate scheduling -** 19. SIFT Fine-Tuning** (`19_llm_fine_tuning_sift.py`) +**19. SIFT Fine-Tuning** (`19_llm_fine_tuning_sift.py`) - Selective Instance Fine-Tuning (SIFT) Implementation - Real-time model adaptation - Semantic similarity search @@ -431,7 +429,6 @@ The codebase explores several key areas of model adaptation: - Real-time visualization - Training summaries - ## Development Standards - PEP 8 and black formatting diff --git a/src/19_llm_fine_tuning_sift.py b/src/19_llm_fine_tuning_sift.py index 5a932f2..2faaebf 100644 --- a/src/19_llm_fine_tuning_sift.py +++ b/src/19_llm_fine_tuning_sift.py @@ -9,9 +9,10 @@ from sift.sift_trainer import SIFTTrainer from sift.sift_visualization import SIFTVisualizer import sys +import numpy as np # Configure logging -logging.basicConfig(level=logging.ERROR, format="%(message)s") +logging.basicConfig(level=logging.INFO, format="%(message)s") logger = logging.getLogger(__name__) @@ -49,12 +50,11 @@ def main(): metrics = MetricsComputer() visualizer = SIFTVisualizer() - # Initialize trainer + # Initialize trainer with enhanced parameters trainer = SIFTTrainer( llm_name="unsloth/Llama-3.2-1B", embedding_model="BAAI/bge-large-en-v1.5", index_dir="cache/faiss", - max_length=512, ) # Sample test prompts and training data @@ -139,20 +139,31 @@ def format_step(step: int, metrics: dict) -> str: # Training loop last_prompt_stats = None + # Add these near the start of the main() function after initializing trainer + metrics_tracker = { + 'global_losses': [], + 'prompt_losses': [], + 'uncertainties': [], + 'steps_per_prompt': [] + } + for prompt_idx, prompt in enumerate(test_prompts): try: # Disable all other loggers logging.getLogger("sift.sift_trainer").setLevel(logging.WARNING) logging.getLogger("tqdm").setLevel(logging.WARNING) - selected_examples = trainer.select_examples(prompt, training_data) + # Use enhanced selection method + selected_examples = trainer.select_examples_sift(prompt, training_data) + logger.info(f"Selected {len(selected_examples)} examples for fine-tuning") if not selected_examples: + logger.warning("No examples selected - skipping prompt") continue # Re-enable logging and reprint header logging.getLogger("sift.sift_trainer").setLevel(logging.INFO) - clear_screen() + #clear_screen() # Reprint header after clear logger.info(format_last_prompt_summary(last_prompt_stats)) @@ -166,36 +177,57 @@ def format_step(step: int, metrics: dict) -> str: if step_metrics is None: continue - current_loss = min( - step_metrics.get("loss", float("inf")), max_loss_threshold - ) - prev_loss = ( - prompt_stats["losses"][-1] if prompt_stats["losses"] else None - ) - + # Compute kernel-based uncertainty with stability check + uncertainties = [] + for _ in range(3): # Multiple measurements for stability + uncertainty = trainer.compute_kernel_uncertainty( + prompt, selected_examples[:i + 1] + ) + if uncertainty is not None and not np.isnan(uncertainty): + uncertainties.append(uncertainty) + + if not uncertainties: + continue + + uncertainty = np.median(uncertainties) # Use median for robustness + current_loss = min(step_metrics.get("loss", float("inf")), max_loss_threshold) + + # Update tracking + metrics_tracker['global_losses'].append(current_loss) + metrics_tracker['uncertainties'].append(uncertainty) + prompt_stats["losses"].append(current_loss) - prompt_stats["prompt_best"] = min( - prompt_stats["prompt_best"], current_loss - ) + prompt_stats["prompt_best"] = min(prompt_stats["prompt_best"], current_loss) if current_loss < global_best_loss: global_best_loss = current_loss + # Save best model checkpoint + trainer.save_checkpoint(f"checkpoints/best_model_{prompt_idx}") - # Only log every few steps + # Log progress if i % 1 == 0: logger.info( format_step( i, { "loss": current_loss, - "prev_loss": prev_loss, + "prev_loss": prompt_stats["losses"][-2] if len(prompt_stats["losses"]) > 1 else None, "global_best": global_best_loss, - "uncertainty": step_metrics.get("uncertainty", 0), + "uncertainty": uncertainty, }, ) ) + # Enhanced stopping check with stability verification + if (i >= min_examples and + trainer.should_stop_adaptive(uncertainty, i, alpha=0.1) and + len(prompt_stats["losses"]) >= 3 and + np.std(prompt_stats["losses"][-3:]) < 0.1): + logger.info(f"Stopping early at step {i} due to convergence") + break + except Exception as e: + logger.error(f"Error in training step: {str(e)}") continue # Store and log summary @@ -241,6 +273,45 @@ def format_step(step: int, metrics: dict) -> str: logger.info("=" * 89) + # After training loop, add visualization + if prompt_stats_history: + metrics_data = { + "loss": [stat["losses"] for stat in prompt_stats_history], + "uncertainty": [ + stat.get("uncertainties", []) for stat in prompt_stats_history + ], + } + + visualizer.plot_metrics_over_time(metrics_data) + visualizer.plot_uncertainty_vs_performance( + uncertainty=[ + stat.get("uncertainties", [])[-1] + for stat in prompt_stats_history + if stat.get("uncertainties") + ], + performance=[stat["prompt_best"] for stat in prompt_stats_history], + save_path="uncertainty_vs_performance.png", + ) + visualizer.plot_adaptive_stopping( + metrics={ + "uncertainty": [ + u + for stat in prompt_stats_history + for u in stat.get("uncertainties", []) + ], + "compute": list( + range( + sum( + len(stat.get("uncertainties", [])) + for stat in prompt_stats_history + ) + ) + ), + }, + alpha=0.1, + save_path="adaptive_stopping.png", + ) + if __name__ == "__main__": main() diff --git a/src/sift/sift_data_loading.py b/src/sift/sift_data_loading.py index 53f55c7..4932720 100644 --- a/src/sift/sift_data_loading.py +++ b/src/sift/sift_data_loading.py @@ -16,7 +16,7 @@ "Please install it with: pip install zstandard" ) -logging.basicConfig(level=logging.ERROR) +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) diff --git a/src/sift/sift_embedding_cache.py b/src/sift/sift_embedding_cache.py index 09263ac..493634a 100644 --- a/src/sift/sift_embedding_cache.py +++ b/src/sift/sift_embedding_cache.py @@ -4,7 +4,7 @@ from typing import Tuple logger = logging.getLogger(__name__) -logger.setLevel(logging.ERROR) +logger.setLevel(logging.INFO) def clear_embedding_cache(cache_dir: str = "cache/embeddings"): diff --git a/src/sift/sift_eval.py b/src/sift/sift_eval.py index 2b7c183..81a21a2 100644 --- a/src/sift/sift_eval.py +++ b/src/sift/sift_eval.py @@ -11,76 +11,3 @@ class EvaluationMetrics: bits_per_byte: float perplexity: float uncertainty: float - - -class MetricsComputer: - """Compute and track evaluation metrics.""" - - def __init__(self): - self.metrics_history = [] - - def compute_bits_per_byte( - self, logits: torch.Tensor, labels: torch.Tensor - ) -> float: - """Compute bits per byte metric.""" - loss = torch.nn.functional.cross_entropy( - logits.view(-1, logits.size(-1)), labels.view(-1), reduction="mean" - ) - return (loss / np.log(2)).item() - - def compute_perplexity(self, logits: torch.Tensor, labels: torch.Tensor) -> float: - """Compute perplexity.""" - loss = torch.nn.functional.cross_entropy( - logits.view(-1, logits.size(-1)), labels.view(-1), reduction="mean" - ) - return torch.exp(loss).item() - - def compute_metrics( - self, - outputs: Dict[str, Any], - labels: Optional[torch.Tensor] = None, - uncertainty: Optional[float] = None - ) -> Dict[str, float]: - """Compute training metrics.""" - metrics = { - 'loss': outputs['loss'], - 'uncertainty': uncertainty if uncertainty is not None else 0.0 - } - - if outputs['logits'] is not None and labels is not None: - # Add any additional metrics computation here - pass - - return metrics - - def get_metrics_summary(self) -> Dict[str, List[float]]: - """Get summary of tracked metrics.""" - return { - "bits_per_byte": [m.bits_per_byte for m in self.metrics_history], - "perplexity": [m.perplexity for m in self.metrics_history], - "uncertainty": [m.uncertainty for m in self.metrics_history], - } - - -class AdaptiveStoppingMetrics: - """Track metrics for adaptive stopping.""" - - def __init__(self, alpha: float = 0.1): - self.alpha = alpha - self.uncertainty_history = [] - self.compute_history = [] - - def should_stop(self, uncertainty: float, step: int) -> bool: - """Determine if training should stop based on uncertainty.""" - if step < 5: # Minimum number of steps - return False - - # Add your stopping criterion here - return uncertainty < self.alpha - - def get_stopping_summary(self) -> Dict[str, list]: - """Get summary of stopping metrics.""" - return { - "uncertainty": self.uncertainty_history, - "compute": self.compute_history, - } diff --git a/src/sift/sift_metrics.py b/src/sift/sift_metrics.py index 7275289..1f4a54d 100644 --- a/src/sift/sift_metrics.py +++ b/src/sift/sift_metrics.py @@ -3,6 +3,7 @@ import logging logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) class MetricsComputer: def __init__(self): @@ -94,24 +95,32 @@ def __init__(self, alpha: float = 0.1, window_size: int = 5): self.stopping_points = [] self.uncertainty_history = [] - def should_stop(self, uncertainty: float, step: int) -> bool: - """Determine if training should stop based on uncertainty.""" + def should_stop(self, uncertainty: float, step: int, loss_history: List[float]) -> bool: + """Enhanced stopping criterion using both uncertainty and loss trends.""" try: self.uncertainty_history.append(uncertainty) if step < self.window_size: return False - # Get recent uncertainties + # Get recent metrics recent_uncertainties = self.uncertainty_history[-self.window_size:] - avg_uncertainty = sum(recent_uncertainties) / len(recent_uncertainties) + recent_losses = loss_history[-self.window_size:] - # Stop if average uncertainty is below threshold - should_stop = avg_uncertainty < self.alpha + # Compute trends + uncertainty_trend = np.polyfit(range(len(recent_uncertainties)), recent_uncertainties, 1)[0] + loss_trend = np.polyfit(range(len(recent_losses)), recent_losses, 1)[0] + + # Stop if both metrics are stable or improving + should_stop = ( + uncertainty_trend <= 0 and # Uncertainty is decreasing + loss_trend <= 0 and # Loss is decreasing + np.mean(recent_uncertainties) < self.alpha + ) if should_stop: self.stopping_points.append(step) - logger.info(f"Stopping at step {step} with uncertainty {avg_uncertainty:.4f}") + logger.info(f"Stopping at step {step} with uncertainty {np.mean(recent_uncertainties):.4f}") return should_stop diff --git a/src/sift/sift_trainer.py b/src/sift/sift_trainer.py index 978f32f..1cd2edc 100644 --- a/src/sift/sift_trainer.py +++ b/src/sift/sift_trainer.py @@ -18,8 +18,12 @@ import torch.cuda import faiss from sentence_transformers import SentenceTransformer +from functools import wraps +import signal +from .sift_metrics import MetricsComputer, AdaptiveStoppingMetrics +from .sift_visualization import SIFTVisualizer -logging.basicConfig(level=logging.ERROR) +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) dotenv.load_dotenv() @@ -29,77 +33,129 @@ @dataclass class SIFTConfig: """Configuration for SIFT and test-time fine-tuning""" - lambda_param: float = 0.1 num_candidates: int = 5 batch_size: int = 1 learning_rate: float = 5e-5 max_length: int = 512 - device: str = ( - "cuda" - if torch.cuda.is_available() - else "mps" if torch.backends.mps.is_available() else "cpu" - ) + device: str = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + + +def timeout(seconds): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + def handler(signum, frame): + raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds") + + # Set the timeout handler + signal.signal(signal.SIGALRM, handler) + signal.alarm(seconds) + + try: + result = func(*args, **kwargs) + finally: + # Disable the alarm + signal.alarm(0) + return result + return wrapper + return decorator class SIFTTrainer: def __init__( self, - llm_name: str = "unsloth/Llama-3.2-1B", - embedding_model: str = "BAAI/bge-large-en-v1.5", + llm_name: str, + embedding_model: str, index_dir: str = "cache/faiss", - max_length: int = 512, + cache_dir: str = "cache/embeddings", + window_size: int = 5, + min_steps: int = 3, + lambda_param: float = 0.1, + max_length: Optional[int] = None, embedding_dim: int = 1024, - uncertainty_buffer_size: int = 100, - gradient_accumulation_steps: int = 8, - max_grad_norm: float = 1.0, + config: Optional[SIFTConfig] = None, ): - """Initialize trainer with all required attributes.""" - self.device = torch.device("cpu") - self.embedding_dim = embedding_dim # Store embedding dimension - self.max_length = max_length - - # Initialize paths + """Initialize SIFT trainer with models and configuration.""" + # Initialize config first + self.config = config or SIFTConfig( + lambda_param=lambda_param, + max_length=max_length if max_length is not None else 512 + ) + + # Set device and dtype + if torch.backends.mps.is_available(): + self.device = torch.device("mps") + self.dtype = torch.float32 # MPS requires float32 + torch.mps.empty_cache() + else: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.dtype = torch.float32 # Use float32 consistently + + # Initialize models and tokenizer with consistent dtype + logger.info(f"Loading models from {llm_name} to {self.device} with dtype {self.dtype}") + self.model = AutoModelForCausalLM.from_pretrained( + llm_name, + torch_dtype=self.dtype, + device_map={"": self.device}, + use_flash_attention_2=False # Disable flash attention to avoid dtype issues + ) + + # Convert model parameters to consistent dtype + self.model = self.model.to(dtype=self.dtype) + + # Initialize tokenizer and embedding model + self.tokenizer = AutoTokenizer.from_pretrained(llm_name) + self.embedding_model = SentenceTransformer(embedding_model) + self.embedding_model = self.embedding_model.to(self.device) + + # Basic parameters + self.lambda_param = self.config.lambda_param + self.window_size = window_size + self.min_steps = min_steps + self.embedding_dim = embedding_dim + + # Initialize tracking components + self.metrics_computer = MetricsComputer() + self.adaptive_stopping = AdaptiveStoppingMetrics( + alpha=0.1, + window_size=self.window_size + ) + self.visualizer = SIFTVisualizer() + + # Initialize buffers + self._uncertainty_buffer = [] + self._last_loss = None + + # Setup directories + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.index_dir = Path(index_dir) self.index_dir.mkdir(parents=True, exist_ok=True) self.index_path = self.index_dir / "embeddings.faiss" self.metadata_path = self.index_dir / "metadata.pkl" - - # Initialize FAISS index and mappings + + # Load or create FAISS index self.index, self.text_to_id, self.id_to_text = self._load_or_create_index() - - # Load embedding model - self.embedding_model = SentenceTransformer(embedding_model, device=self.device) - - # Initialize tokenizer and model - self.tokenizer = AutoTokenizer.from_pretrained(llm_name) - self.model = AutoModelForCausalLM.from_pretrained( - llm_name, torch_dtype=torch.float32, low_cpu_mem_usage=True - ) - - self.model.eval() - - # Training settings - self.gradient_accumulation_steps = gradient_accumulation_steps - self.current_accumulation_step = 0 - - # Initialize optimizer + + # Initialize optimizer and scheduler self.optimizer = torch.optim.AdamW( - self.model.parameters(), lr=5e-5, weight_decay=0.01 + self.model.parameters(), + lr=self.config.learning_rate, + weight_decay=0.01 ) - self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6 + + # Initialize scheduler (linear warmup and decay) + self.scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, + max_lr=self.config.learning_rate, + total_steps=1000, # Adjust based on expected steps + pct_start=0.1 ) - - self.max_grad_norm = max_grad_norm - - # Initialize uncertainty tracking - self._uncertainty_buffer = [] - self._uncertainty_buffer_size = uncertainty_buffer_size - self._last_loss = None - - logger.info(f"Using device: {self.device}") - logger.info(f"Model loaded with dtype: {self.model.dtype}") + + # Initialize step counter + self.current_accumulation_step = 0 def _load_or_create_index( self, @@ -146,67 +202,160 @@ def save_index(self): def compute_embedding(self, text: str) -> Optional[np.ndarray]: """Compute embedding using dedicated embedding model.""" try: - # Check if text is already in index - if text in self.text_to_id: - vector_id = self.text_to_id[text] - reconstructed = np.zeros((1, self.embedding_dim), dtype=np.float32) - self.index.reconstruct(vector_id, reconstructed[0]) - return reconstructed - + # Try cache first + cache_key = hashlib.md5(text.encode()).hexdigest() + cache_path = self.cache_dir / f"{cache_key}.npy" + + if cache_path.exists(): + embedding = np.load(cache_path) + # Ensure 2D shape + if len(embedding.shape) == 1: + embedding = embedding.reshape(1, -1) + return embedding + # Compute new embedding embedding = self.embedding_model.encode( text, - normalize_embeddings=True, # Important for cosine similarity + normalize_embeddings=True, show_progress_bar=False, ) - # Reshape for FAISS + # Ensure 2D shape embedding_np = np.array(embedding).reshape(1, -1) - # Add to index - vector_id = self.index.ntotal - self.index.add(embedding_np) - self.text_to_id[text] = vector_id - self.id_to_text[vector_id] = text - + # Cache the result + np.save(cache_path, embedding_np) return embedding_np except Exception as e: logger.error(f"Error computing embedding: {e}") return None - def select_examples( + @timeout(10) # 10 second timeout for kernel computation + def compute_kernel_uncertainty(self, x_star: str, selected_points: List[str]) -> float: + try: + # Get embeddings + x_star_emb = self.compute_embedding(x_star) + if x_star_emb is None or len(selected_points) == 0: + return float('inf') + + # Ensure x_star_emb is 2D + if len(x_star_emb.shape) == 1: + x_star_emb = x_star_emb.reshape(1, -1) + + # Compute embeddings with timeout protection + selected_embs = [] + for point in selected_points: + emb = self.compute_embedding(point) + if emb is None: + return float('inf') + # Ensure each embedding is 2D + if len(emb.shape) == 1: + emb = emb.reshape(1, -1) + selected_embs.append(emb) + + # Stack embeddings and ensure correct shape + selected_embs = np.vstack([emb.reshape(1, -1) for emb in selected_embs]) + + # Add small epsilon for numerical stability + epsilon = 1e-8 + + # Compute kernel values with correct shapes + k_xx = np.dot(x_star_emb, x_star_emb.T) + epsilon + K_X = self.compute_kernel_matrix_batch(selected_embs) + epsilon * np.eye(len(selected_points)) + k_X = np.dot(x_star_emb, selected_embs.T) + + # Add regularization + K_X_reg = K_X + self.lambda_param * np.eye(len(selected_points)) + + try: + # Ensure shapes match for solve operation + if k_X.shape[1] != K_X_reg.shape[0]: + logger.error(f"Shape mismatch: k_X: {k_X.shape}, K_X_reg: {K_X_reg.shape}") + return float('inf') + + # Use more stable solver with condition number check + if np.linalg.cond(K_X_reg) > 1e10: + logger.warning("Poorly conditioned matrix, adding more regularization") + K_X_reg += 0.1 * np.eye(len(selected_points)) + + # Reshape for solve operation + k_X = k_X.reshape(-1, 1) + uncertainty = float(k_xx - k_X.T @ np.linalg.solve(K_X_reg, k_X)) + + # Validate output + if np.isnan(uncertainty) or np.isinf(uncertainty): + return float('inf') + + return float(np.clip(uncertainty, 0, 10.0)) # Clip to reasonable range + + except np.linalg.LinAlgError as e: + logger.warning(f"Matrix inversion failed: {e}") + return float('inf') + + except Exception as e: + logger.error(f"Error in kernel uncertainty computation: {e}") + return float('inf') + + def select_examples_sift( self, prompt: str, candidates: List[str], n_examples: int = 5 ) -> List[str]: - """Select examples using FAISS index.""" + """Select examples using SIFT with improved robustness and logging.""" try: - # Compute prompt embedding - prompt_embedding = self.compute_embedding(prompt) - if prompt_embedding is None: - raise ValueError("Failed to compute prompt embedding") - - # Add candidates to index if not already present - for candidate in tqdm(candidates, desc="Processing candidates"): - if candidate not in self.text_to_id: - self.compute_embedding(candidate) - - # Search for nearest neighbors - D, I = self.index.search(prompt_embedding, n_examples) - - # Get corresponding texts - selected = [self.id_to_text[int(i)] for i in I[0]] - logger.info(f"Selected {len(selected)} examples using FAISS") - - # Save updated index - self.save_index() - + selected = [] + prompt_emb = self.compute_embedding(prompt) + + if prompt_emb is None: + logger.error("Failed to compute prompt embedding") + return selected + + # Add progress tracking + from tqdm import tqdm + + for i in range(n_examples): + min_uncertainty = float("inf") + best_candidate = None + + # Track candidates processing + for candidate in tqdm(candidates, desc=f"Processing candidates for example {i+1}/{n_examples}", leave=False): + if candidate in selected: + continue + + try: + # Compute uncertainty with timeout protection + test_selected = selected + [candidate] + uncertainty = self.compute_kernel_uncertainty(prompt, test_selected) + + # Skip invalid uncertainties + if uncertainty is None or np.isnan(uncertainty) or np.isinf(uncertainty): + continue + + if uncertainty < min_uncertainty: + min_uncertainty = uncertainty + best_candidate = candidate + + except Exception as e: + logger.warning(f"Error processing candidate: {str(e)}") + continue + + # Check if we found a valid candidate + if best_candidate: + selected.append(best_candidate) + logger.info(f"Selected example {i+1}/{n_examples} with uncertainty: {min_uncertainty:.4f}") + else: + logger.warning(f"No valid candidate found for example {i+1}") + break + + # Clear memory periodically + if i % 2 == 0: + self.clear_memory() + + logger.info(f"Selected {len(selected)}/{n_examples} examples") return selected - + except Exception as e: logger.error(f"Error in example selection: {str(e)}") - logger.error("Falling back to random selection") - indices = np.random.choice(len(candidates), n_examples, replace=False) - return [candidates[idx] for idx in indices] + return selected def compute_metrics(self, outputs) -> Dict[str, float]: """Compute metrics from model outputs.""" @@ -229,21 +378,30 @@ def compute_metrics(self, outputs) -> Dict[str, float]: logger.error(f"Error computing metrics: {e}") return None + @timeout(30) def fine_tune_step(self, example: str) -> Optional[Dict[str, Any]]: - """Perform fine-tuning step with proper metrics.""" try: gc.collect() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() # Tokenize inputs = self.tokenizer( example, padding=True, truncation=True, - max_length=128, + max_length=self.config.max_length, return_tensors="pt", ) - - labels = inputs["input_ids"].clone() + + # Move to device and ensure correct dtype + inputs = { + "input_ids": self._ensure_tensor_dtype(inputs["input_ids"], is_index=True), + "attention_mask": self._ensure_tensor_dtype(inputs["attention_mask"], is_index=True) + } + + # Clone labels and ensure they're long type + labels = self._ensure_tensor_dtype(inputs["input_ids"].clone(), is_index=True) # Forward pass self.model.train() @@ -253,53 +411,7 @@ def fine_tune_step(self, example: str) -> Optional[Dict[str, Any]]: labels=labels, ) - # Compute loss and backward - loss = outputs.loss / self.gradient_accumulation_steps - loss.backward() - - # Add gradient clipping - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) - - # Update weights if needed - self.current_accumulation_step += 1 - if self.current_accumulation_step >= self.gradient_accumulation_steps: - self.optimizer.step() - - # Update learning rate based on loss - self.scheduler.step(loss.item()) - - self.optimizer.zero_grad() - self.current_accumulation_step = 0 - - self.model.eval() - - # Update uncertainty - loss_value = loss.item() * self.gradient_accumulation_steps - self._last_loss = loss_value - self._uncertainty_buffer.append(loss_value) - if len(self._uncertainty_buffer) > self._uncertainty_buffer_size: - self._uncertainty_buffer.pop(0) - - # Compute metrics - metrics = self.compute_metrics(outputs) - if metrics is None: - return None - - # Add training info - metrics.update( - { - "labels": labels.detach(), - "logits": ( - outputs.logits.detach() if hasattr(outputs, "logits") else None - ), - } - ) - - # Clear memory - del outputs - gc.collect() - - return metrics + return {"loss": outputs.loss.item()} except Exception as e: logger.error(f"Error in fine-tuning step: {str(e)}") @@ -365,39 +477,182 @@ def clear_cache(self): except Exception as e: logger.error(f"Failed to clear cache: {e}") - -class MetricsComputer: - def __init__(self): - self.metrics_history = { - "bits_per_byte": [], - "perplexity": [], - "uncertainty": [], + def compute_kernel_matrix_batch(self, embeddings: np.ndarray) -> np.ndarray: + try: + return np.dot(embeddings, embeddings.T) + except Exception as e: + logger.error(f"Error in kernel computation: {e}") + return np.zeros((embeddings.shape[0], embeddings.shape[0])) + + def get_training_summary(self) -> Dict[str, List[float]]: + """Get complete training summary for visualization.""" + return { + "loss": self._uncertainty_buffer, + "uncertainty": [self.compute_uncertainty(None) for _ in self._uncertainty_buffer], + "compute": list(range(len(self._uncertainty_buffer))) } - def compute_metrics( - self, outputs: Dict[str, Any], uncertainty: float = None - ) -> Dict[str, float]: - """Compute and store metrics.""" - if outputs is None: - return None + def clear_memory(self): + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() - metrics = { - "bits_per_byte": outputs.get("bits_per_byte", 0.0), - "perplexity": outputs.get("perplexity", 0.0), - "uncertainty": outputs.get( - "uncertainty", uncertainty if uncertainty is not None else 0.0 - ), - } - - # Store metrics - for key, value in metrics.items(): - self.metrics_history[key].append(value) + def should_stop_adaptive( + self, uncertainty: float, step: int, alpha: float = 0.1 + ) -> bool: + """Enhanced adaptive stopping criterion with stability checks.""" + try: + if step < self.min_steps: + return False + + # Add uncertainty to buffer + self._uncertainty_buffer.append(uncertainty) + + # Keep buffer size manageable + if len(self._uncertainty_buffer) > self.window_size: + self._uncertainty_buffer.pop(0) + + # Get recent uncertainties with outlier removal + recent_uncertainties = np.array(self._uncertainty_buffer[-self.window_size:]) + if len(recent_uncertainties) < self.window_size: + return False + + # Remove outliers using IQR method + q1, q3 = np.percentile(recent_uncertainties, [25, 75]) + iqr = q3 - q1 + mask = (recent_uncertainties >= q1 - 1.5 * iqr) & (recent_uncertainties <= q3 + 1.5 * iqr) + filtered_uncertainties = recent_uncertainties[mask] + + if len(filtered_uncertainties) < 3: # Require minimum number of valid points + return False + + # Compute robust statistics + avg_uncertainty = np.median(filtered_uncertainties) + uncertainty_std = np.std(filtered_uncertainties) + + # Dynamic threshold based on training progress + threshold = alpha * (1 + 1/np.sqrt(1 + step)) + + # Check both absolute and relative stability + is_stable = uncertainty_std < 0.1 * avg_uncertainty + is_low = avg_uncertainty < threshold + + should_stop = is_stable and is_low + + if should_stop: + logger.info(f"Stopping at step {step} with uncertainty {avg_uncertainty:.4f} (threshold: {threshold:.4f})") + + return should_stop + + except Exception as e: + logger.error(f"Error in stopping criterion: {e}") + return False - return metrics + def save_checkpoint(self, path: str): + """Save a checkpoint of the current model state.""" + try: + checkpoint_dir = Path(path) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + checkpoint = { + 'model_state_dict': self.model.state_dict(), + 'uncertainty_buffer': self._uncertainty_buffer, + 'current_accumulation_step': getattr(self, 'current_accumulation_step', 0), + } + + # Add optimizer and scheduler if they exist + if hasattr(self, 'optimizer'): + checkpoint['optimizer_state_dict'] = self.optimizer.state_dict() + if hasattr(self, 'scheduler'): + checkpoint['scheduler_state_dict'] = self.scheduler.state_dict() + + torch.save(checkpoint, checkpoint_dir / 'checkpoint.pt') + logger.info(f"Saved checkpoint to {path}") + except Exception as e: + logger.error(f"Failed to save checkpoint: {e}") + + def update_ema_loss(self, current_loss: float, alpha: float = 0.1) -> float: + """Update exponential moving average of loss.""" + if not hasattr(self, '_ema_loss'): + self._ema_loss = current_loss + else: + self._ema_loss = alpha * current_loss + (1 - alpha) * self._ema_loss + return self._ema_loss + + def update_and_visualize_metrics( + self, + current_loss: float, + uncertainty: float, + step: int, + save_path: Optional[str] = None + ) -> Dict[str, Any]: + """Update metrics and generate visualizations.""" + # Update EMA loss + ema_loss = self.update_ema_loss(current_loss) + + # Compute metrics + metrics = self.metrics_computer.compute_metrics( + {'loss': current_loss}, + uncertainty=uncertainty + ) + + # Check stopping condition using both metrics + should_stop = self.adaptive_stopping.should_stop( + uncertainty=uncertainty, + step=step, + loss_history=self.metrics_computer.metrics_history['loss'] + ) + + # Generate visualizations periodically + if step % 10 == 0: + self.visualizer.plot_adaptive_stopping( + metrics=self.get_training_summary(), + alpha=self.adaptive_stopping.alpha, + title=f"Training Progress - Step {step}", + save_path=save_path + ) + + return { + 'metrics': metrics, + 'should_stop': should_stop, + 'ema_loss': ema_loss + } - def get_metrics_summary(self) -> Dict[str, List[float]]: - """Get the history of all metrics.""" - return self.metrics_history + def generate_training_summary(self, save_dir: str = "training_summary"): + """Generate comprehensive training summary and visualizations.""" + Path(save_dir).mkdir(parents=True, exist_ok=True) + + # Get complete training history + summary = self.get_training_summary() + + # Save metrics to JSON + with open(f"{save_dir}/metrics.json", "w") as f: + json.dump(summary, f, indent=2) + + # Generate final visualizations + self.visualizer.plot_adaptive_stopping( + metrics=summary, + alpha=self.adaptive_stopping.alpha, + title="Final Training Summary", + save_path=f"{save_dir}/final_stopping_analysis.png" + ) + + # Save stopping points + stopping_summary = self.adaptive_stopping.get_stopping_summary() + with open(f"{save_dir}/stopping_points.json", "w") as f: + json.dump(stopping_summary, f, indent=2) + + def _ensure_tensor_dtype(self, tensor: torch.Tensor, is_index: bool = False) -> torch.Tensor: + """Ensure tensor has correct dtype.""" + if is_index: + # For indices (input_ids, attention masks), use long + tensor = tensor.to(dtype=torch.long) + else: + # For other tensors, use configured dtype + if tensor.dtype != self.dtype: + tensor = tensor.to(dtype=self.dtype) + return tensor.to(self.device) def main(): @@ -419,7 +674,7 @@ def main(): ] logger.info("Testing example selection...") - selected = trainer.select_examples(prompt, candidates) + selected = trainer.select_examples_sift(prompt, candidates) logger.info("Selected examples:") for i, example in enumerate(selected, 1): logger.info(f"{i}. {example}") diff --git a/src/sift/sift_visualization.py b/src/sift/sift_visualization.py index 67cab8c..bbe3515 100644 --- a/src/sift/sift_visualization.py +++ b/src/sift/sift_visualization.py @@ -87,3 +87,39 @@ def plot_adaptive_stopping( if save_path: plt.savefig(save_path) plt.show() + + def plot_training_summary( + self, + losses: List[float], + uncertainties: List[float], + save_path: Optional[str] = None + ): + """Plot comprehensive training summary.""" + fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) + + # Loss over time + ax1.plot(losses) + ax1.set_title("Loss Progression") + ax1.set_xlabel("Step") + ax1.set_ylabel("Loss") + + # Uncertainty over time + ax2.plot(uncertainties) + ax2.set_title("Uncertainty Progression") + ax2.set_xlabel("Step") + ax2.set_ylabel("Uncertainty") + + # Loss distribution + ax3.hist(losses, bins=30) + ax3.set_title("Loss Distribution") + + # Loss vs Uncertainty + ax4.scatter(uncertainties, losses, alpha=0.5) + ax4.set_title("Loss vs Uncertainty") + ax4.set_xlabel("Uncertainty") + ax4.set_ylabel("Loss") + + plt.tight_layout() + if save_path: + plt.savefig(save_path) + plt.close()