From cfe5ce3c1983f7d1e107b07976c0b5875a13815a Mon Sep 17 00:00:00 2001 From: Leon van Bokhorst Date: Sun, 17 Nov 2024 16:31:29 +0100 Subject: [PATCH] fix: resolve training visualization and metrics issues - Add missing _cleanup_old_checkpoints method to SIFTTrainer - Create required directories for checkpoints and visualizations - Fix compute_validation_metrics implementation - Update .gitignore to properly handle output directories The changes ensure proper handling of model checkpoints and training visualizations while maintaining a clean project structure. Directory creation is now handled proactively to prevent file write errors. Reference: - SIFTTrainer cleanup implementation (lines 717-742) - Validation metrics computation (lines 744-779) - Directory structure in .gitignore (lines 23-25) --- .gitignore | 2 + src/19_llm_fine_tuning_sift.py | 40 ++++++- src/sift/sift_metrics.py | 30 +++++- src/sift/sift_trainer.py | 190 +++++++++++++++++++++++++++------ 4 files changed, 225 insertions(+), 37 deletions(-) diff --git a/.gitignore b/.gitignore index 7bec718..52376fc 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,8 @@ Llama-3.2-1B-Instruct-Complaint/ dataset/ cache/ checkpoints/ +training_summary/ +visualizations/ diff --git a/src/19_llm_fine_tuning_sift.py b/src/19_llm_fine_tuning_sift.py index 2faaebf..c873d0b 100644 --- a/src/19_llm_fine_tuning_sift.py +++ b/src/19_llm_fine_tuning_sift.py @@ -61,6 +61,7 @@ def main(): n_test_prompts = 100 test_prompts = sampler.sample_test_prompts(n_prompts=n_test_prompts) training_data = sampler.get_training_subset(size=1000) + validation_data = sampler.get_training_subset(size=100) # Separate validation set logger.info( f"Sampled {len(test_prompts)} test prompts and {len(training_data)} training examples" @@ -144,9 +145,16 @@ def format_step(step: int, metrics: dict) -> str: 'global_losses': [], 'prompt_losses': [], 'uncertainties': [], - 'steps_per_prompt': [] + 'steps_per_prompt': [], + 'validation_loss': [], + 'validation_perplexity': [] } + # Create directories for outputs + Path("checkpoints").mkdir(parents=True, exist_ok=True) + Path("visualizations").mkdir(parents=True, exist_ok=True) + Path("training_summary").mkdir(parents=True, exist_ok=True) + for prompt_idx, prompt in enumerate(test_prompts): try: # Disable all other loggers @@ -202,7 +210,14 @@ def format_step(step: int, metrics: dict) -> str: if current_loss < global_best_loss: global_best_loss = current_loss # Save best model checkpoint - trainer.save_checkpoint(f"checkpoints/best_model_{prompt_idx}") + trainer.save_checkpoint_with_metrics( + f"checkpoints/best_model_{prompt_idx}", + { + "loss": current_loss, + "uncertainty": uncertainty, + "global_best": global_best_loss + } + ) # Log progress if i % 1 == 0: @@ -220,12 +235,29 @@ def format_step(step: int, metrics: dict) -> str: # Enhanced stopping check with stability verification if (i >= min_examples and - trainer.should_stop_adaptive(uncertainty, i, alpha=0.1) and + (trainer.should_stop_adaptive(uncertainty, i, alpha=0.1) or + trainer.enhanced_early_stopping(current_loss)) and len(prompt_stats["losses"]) >= 3 and np.std(prompt_stats["losses"][-3:]) < 0.1): logger.info(f"Stopping early at step {i} due to convergence") break + # Add after line 193 (after current_loss assignment) + trainer.adjust_learning_rate(current_loss) + + metrics_tracker['steps_per_prompt'].append(i) + trainer.update_and_visualize_metrics( + current_loss, + uncertainty, + i, + save_path=f"visualizations/step_{prompt_idx}_{i}.png" + ) + + if i % 10 == 0: # Validate every 10 steps + val_metrics = trainer.compute_validation_metrics(validation_data) + metrics_tracker['validation_loss'].append(val_metrics['val_loss']) + metrics_tracker['validation_perplexity'].append(val_metrics['val_perplexity']) + except Exception as e: logger.error(f"Error in training step: {str(e)}") continue @@ -312,6 +344,8 @@ def format_step(step: int, metrics: dict) -> str: save_path="adaptive_stopping.png", ) + trainer.generate_training_summary(save_dir="training_summary") + if __name__ == "__main__": main() diff --git a/src/sift/sift_metrics.py b/src/sift/sift_metrics.py index 1f4a54d..400980b 100644 --- a/src/sift/sift_metrics.py +++ b/src/sift/sift_metrics.py @@ -1,6 +1,7 @@ import numpy as np from typing import Dict, List, Any, Optional import logging +import torch logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -133,4 +134,31 @@ def get_stopping_summary(self) -> Dict[str, List[int]]: return { 'stopping_points': self.stopping_points, 'uncertainty_history': self.uncertainty_history - } \ No newline at end of file + } + +def compute_validation_metrics(self, validation_examples: List[str]) -> Dict[str, float]: + self.model.eval() + total_loss = 0.0 + total_perplexity = 0.0 + + with torch.no_grad(): + for example in validation_examples: + inputs = self.tokenizer( + example, + padding=True, + truncation=True, + max_length=self.config.max_length, + return_tensors="pt" + ).to(self.device) + + outputs = self.model(**inputs, labels=inputs["input_ids"]) + total_loss += outputs.loss.item() + total_perplexity += torch.exp(outputs.loss).item() + + avg_metrics = { + 'val_loss': total_loss / len(validation_examples), + 'val_perplexity': total_perplexity / len(validation_examples) + } + + self.model.train() + return avg_metrics \ No newline at end of file diff --git a/src/sift/sift_trainer.py b/src/sift/sift_trainer.py index 1cd2edc..ecbc043 100644 --- a/src/sift/sift_trainer.py +++ b/src/sift/sift_trainer.py @@ -411,6 +411,9 @@ def fine_tune_step(self, example: str) -> Optional[Dict[str, Any]]: labels=labels, ) + # Add gradient clipping + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + return {"loss": outputs.loss.item()} except Exception as e: @@ -506,16 +509,19 @@ def should_stop_adaptive( if step < self.min_steps: return False + # Use adaptive window size + adaptive_window = self.compute_adaptive_window_size(step) + # Add uncertainty to buffer self._uncertainty_buffer.append(uncertainty) # Keep buffer size manageable - if len(self._uncertainty_buffer) > self.window_size: + if len(self._uncertainty_buffer) > adaptive_window: self._uncertainty_buffer.pop(0) # Get recent uncertainties with outlier removal - recent_uncertainties = np.array(self._uncertainty_buffer[-self.window_size:]) - if len(recent_uncertainties) < self.window_size: + recent_uncertainties = np.array(self._uncertainty_buffer[-adaptive_window:]) + if len(recent_uncertainties) < adaptive_window: return False # Remove outliers using IQR method @@ -588,36 +594,33 @@ def update_and_visualize_metrics( save_path: Optional[str] = None ) -> Dict[str, Any]: """Update metrics and generate visualizations.""" - # Update EMA loss - ema_loss = self.update_ema_loss(current_loss) - - # Compute metrics - metrics = self.metrics_computer.compute_metrics( - {'loss': current_loss}, - uncertainty=uncertainty - ) - - # Check stopping condition using both metrics - should_stop = self.adaptive_stopping.should_stop( - uncertainty=uncertainty, - step=step, - loss_history=self.metrics_computer.metrics_history['loss'] - ) - - # Generate visualizations periodically - if step % 10 == 0: - self.visualizer.plot_adaptive_stopping( - metrics=self.get_training_summary(), - alpha=self.adaptive_stopping.alpha, - title=f"Training Progress - Step {step}", - save_path=save_path - ) - - return { - 'metrics': metrics, - 'should_stop': should_stop, - 'ema_loss': ema_loss - } + try: + # Update metrics history + self.metrics_computer.update({ + 'loss': current_loss, + 'uncertainty': uncertainty + }) + + # Get complete metrics history + metrics_history = self.metrics_computer.get_metrics_summary() + + # Generate visualization if save path is provided + if save_path: + self.visualizer.plot_training_summary( + losses=metrics_history['loss'], + uncertainties=metrics_history['uncertainty'], + save_path=save_path + ) + + return { + 'loss': current_loss, + 'uncertainty': uncertainty, + 'step': step + } + + except Exception as e: + logger.error(f"Error in metrics update: {e}") + return None def generate_training_summary(self, save_dir: str = "training_summary"): """Generate comprehensive training summary and visualizations.""" @@ -654,6 +657,127 @@ def _ensure_tensor_dtype(self, tensor: torch.Tensor, is_index: bool = False) -> tensor = tensor.to(dtype=self.dtype) return tensor.to(self.device) + def adjust_learning_rate(self, current_loss: float, window_size: int = 5) -> None: + if len(self._uncertainty_buffer) >= window_size: + recent_losses = self._uncertainty_buffer[-window_size:] + loss_trend = np.mean(np.diff(recent_losses)) + + # Increase learning rate if loss is stagnating + if abs(loss_trend) < 0.01: + self.optimizer.param_groups[0]['lr'] *= 1.2 + # Decrease learning rate if loss is unstable + elif loss_trend > 0: + self.optimizer.param_groups[0]['lr'] *= 0.8 + + def enhanced_early_stopping(self, + current_loss: float, + patience: int = 10, + min_delta: float = 0.01 + ) -> bool: + if not hasattr(self, '_best_loss'): + self._best_loss = float('inf') + self._patience_counter = 0 + + if current_loss < (self._best_loss - min_delta): + self._best_loss = current_loss + self._patience_counter = 0 + return False + + self._patience_counter += 1 + return self._patience_counter >= patience + + def save_checkpoint_with_metrics(self, path: str, metrics: Dict[str, float]): + checkpoint_dir = Path(path) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + checkpoint = { + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict(), + 'metrics': metrics, + 'uncertainty_buffer': self._uncertainty_buffer, + 'current_step': self.current_accumulation_step, + 'timestamp': datetime.datetime.now().isoformat() + } + + # Save with metrics in filename + filename = f"checkpoint_loss_{metrics['loss']:.4f}_step_{self.current_accumulation_step}.pt" + torch.save(checkpoint, checkpoint_dir / filename) + + # Keep only top N checkpoints + self._cleanup_old_checkpoints(checkpoint_dir, keep_top_n=3) + + def compute_adaptive_window_size(self, step: int, min_window: int = 3, max_window: int = 10) -> int: + """Compute adaptive window size based on training progress.""" + # Start small and increase window size as training progresses + progress_factor = min(1.0, step / 1000) # Normalize steps to [0,1] + window_size = min_window + int((max_window - min_window) * progress_factor) + return window_size + + def _cleanup_old_checkpoints(self, checkpoint_dir: Path, keep_top_n: int = 3): + """Keep only the top N checkpoints based on loss.""" + try: + # Get all checkpoint files + checkpoint_files = list(checkpoint_dir.glob("checkpoint_loss_*.pt")) + + if len(checkpoint_files) <= keep_top_n: + return + + # Extract loss values from filenames + def get_loss(filepath): + try: + return float(str(filepath).split("loss_")[1].split("_")[0]) + except: + return float('inf') + + # Sort by loss and keep only top N + sorted_files = sorted(checkpoint_files, key=get_loss) + files_to_remove = sorted_files[keep_top_n:] + + # Remove excess checkpoints + for file in files_to_remove: + file.unlink() + + except Exception as e: + logger.error(f"Error cleaning up checkpoints: {e}") + + def compute_validation_metrics(self, validation_examples: List[str]) -> Dict[str, float]: + """Compute validation metrics on a subset of examples.""" + self.model.eval() + total_loss = 0.0 + total_perplexity = 0.0 + + with torch.no_grad(): + for example in validation_examples[:10]: # Limit to 10 examples for speed + try: + inputs = self.tokenizer( + example, + padding=True, + truncation=True, + max_length=self.config.max_length, + return_tensors="pt" + ) + + inputs = {k: self._ensure_tensor_dtype(v, is_index=True) + for k, v in inputs.items()} + + outputs = self.model(**inputs, labels=inputs["input_ids"]) + total_loss += outputs.loss.item() + total_perplexity += torch.exp(outputs.loss).item() + + except Exception as e: + logger.error(f"Error in validation: {e}") + continue + + n_examples = min(len(validation_examples), 10) + metrics = { + 'val_loss': total_loss / max(n_examples, 1), + 'val_perplexity': total_perplexity / max(n_examples, 1) + } + + self.model.train() + return metrics + def main(): """Test the SIFT trainer"""