Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: resolve training visualization and metrics issues #55

Merged
merged 1 commit into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Llama-3.2-1B-Instruct-Complaint/
dataset/
cache/
checkpoints/
training_summary/
visualizations/



Expand Down
40 changes: 37 additions & 3 deletions src/19_llm_fine_tuning_sift.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def main():
n_test_prompts = 100
test_prompts = sampler.sample_test_prompts(n_prompts=n_test_prompts)
training_data = sampler.get_training_subset(size=1000)
validation_data = sampler.get_training_subset(size=100) # Separate validation set

logger.info(
f"Sampled {len(test_prompts)} test prompts and {len(training_data)} training examples"
Expand Down Expand Up @@ -144,9 +145,16 @@ def format_step(step: int, metrics: dict) -> str:
'global_losses': [],
'prompt_losses': [],
'uncertainties': [],
'steps_per_prompt': []
'steps_per_prompt': [],
'validation_loss': [],
'validation_perplexity': []
}

# Create directories for outputs
Path("checkpoints").mkdir(parents=True, exist_ok=True)
Path("visualizations").mkdir(parents=True, exist_ok=True)
Path("training_summary").mkdir(parents=True, exist_ok=True)

for prompt_idx, prompt in enumerate(test_prompts):
try:
# Disable all other loggers
Expand Down Expand Up @@ -202,7 +210,14 @@ def format_step(step: int, metrics: dict) -> str:
if current_loss < global_best_loss:
global_best_loss = current_loss
# Save best model checkpoint
trainer.save_checkpoint(f"checkpoints/best_model_{prompt_idx}")
trainer.save_checkpoint_with_metrics(
f"checkpoints/best_model_{prompt_idx}",
{
"loss": current_loss,
"uncertainty": uncertainty,
"global_best": global_best_loss
}
)

# Log progress
if i % 1 == 0:
Expand All @@ -220,12 +235,29 @@ def format_step(step: int, metrics: dict) -> str:

# Enhanced stopping check with stability verification
if (i >= min_examples and
trainer.should_stop_adaptive(uncertainty, i, alpha=0.1) and
(trainer.should_stop_adaptive(uncertainty, i, alpha=0.1) or
trainer.enhanced_early_stopping(current_loss)) and
len(prompt_stats["losses"]) >= 3 and
np.std(prompt_stats["losses"][-3:]) < 0.1):
logger.info(f"Stopping early at step {i} due to convergence")
break

# Add after line 193 (after current_loss assignment)
trainer.adjust_learning_rate(current_loss)

metrics_tracker['steps_per_prompt'].append(i)
trainer.update_and_visualize_metrics(
current_loss,
uncertainty,
i,
save_path=f"visualizations/step_{prompt_idx}_{i}.png"
)

if i % 10 == 0: # Validate every 10 steps
val_metrics = trainer.compute_validation_metrics(validation_data)
metrics_tracker['validation_loss'].append(val_metrics['val_loss'])
metrics_tracker['validation_perplexity'].append(val_metrics['val_perplexity'])

except Exception as e:
logger.error(f"Error in training step: {str(e)}")
continue
Expand Down Expand Up @@ -312,6 +344,8 @@ def format_step(step: int, metrics: dict) -> str:
save_path="adaptive_stopping.png",
)

trainer.generate_training_summary(save_dir="training_summary")


if __name__ == "__main__":
main()
30 changes: 29 additions & 1 deletion src/sift/sift_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from typing import Dict, List, Any, Optional
import logging
import torch

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -133,4 +134,31 @@ def get_stopping_summary(self) -> Dict[str, List[int]]:
return {
'stopping_points': self.stopping_points,
'uncertainty_history': self.uncertainty_history
}
}

def compute_validation_metrics(self, validation_examples: List[str]) -> Dict[str, float]:
self.model.eval()
total_loss = 0.0
total_perplexity = 0.0

with torch.no_grad():
for example in validation_examples:
inputs = self.tokenizer(
example,
padding=True,
truncation=True,
max_length=self.config.max_length,
return_tensors="pt"
).to(self.device)

outputs = self.model(**inputs, labels=inputs["input_ids"])
total_loss += outputs.loss.item()
total_perplexity += torch.exp(outputs.loss).item()

avg_metrics = {
'val_loss': total_loss / len(validation_examples),
'val_perplexity': total_perplexity / len(validation_examples)
}

self.model.train()
return avg_metrics
190 changes: 157 additions & 33 deletions src/sift/sift_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ def fine_tune_step(self, example: str) -> Optional[Dict[str, Any]]:
labels=labels,
)

# Add gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

return {"loss": outputs.loss.item()}

except Exception as e:
Expand Down Expand Up @@ -506,16 +509,19 @@ def should_stop_adaptive(
if step < self.min_steps:
return False

# Use adaptive window size
adaptive_window = self.compute_adaptive_window_size(step)

# Add uncertainty to buffer
self._uncertainty_buffer.append(uncertainty)

# Keep buffer size manageable
if len(self._uncertainty_buffer) > self.window_size:
if len(self._uncertainty_buffer) > adaptive_window:
self._uncertainty_buffer.pop(0)

# Get recent uncertainties with outlier removal
recent_uncertainties = np.array(self._uncertainty_buffer[-self.window_size:])
if len(recent_uncertainties) < self.window_size:
recent_uncertainties = np.array(self._uncertainty_buffer[-adaptive_window:])
if len(recent_uncertainties) < adaptive_window:
return False

# Remove outliers using IQR method
Expand Down Expand Up @@ -588,36 +594,33 @@ def update_and_visualize_metrics(
save_path: Optional[str] = None
) -> Dict[str, Any]:
"""Update metrics and generate visualizations."""
# Update EMA loss
ema_loss = self.update_ema_loss(current_loss)

# Compute metrics
metrics = self.metrics_computer.compute_metrics(
{'loss': current_loss},
uncertainty=uncertainty
)

# Check stopping condition using both metrics
should_stop = self.adaptive_stopping.should_stop(
uncertainty=uncertainty,
step=step,
loss_history=self.metrics_computer.metrics_history['loss']
)

# Generate visualizations periodically
if step % 10 == 0:
self.visualizer.plot_adaptive_stopping(
metrics=self.get_training_summary(),
alpha=self.adaptive_stopping.alpha,
title=f"Training Progress - Step {step}",
save_path=save_path
)

return {
'metrics': metrics,
'should_stop': should_stop,
'ema_loss': ema_loss
}
try:
# Update metrics history
self.metrics_computer.update({
'loss': current_loss,
'uncertainty': uncertainty
})

# Get complete metrics history
metrics_history = self.metrics_computer.get_metrics_summary()

# Generate visualization if save path is provided
if save_path:
self.visualizer.plot_training_summary(
losses=metrics_history['loss'],
uncertainties=metrics_history['uncertainty'],
save_path=save_path
)

return {
'loss': current_loss,
'uncertainty': uncertainty,
'step': step
}

except Exception as e:
logger.error(f"Error in metrics update: {e}")
return None

def generate_training_summary(self, save_dir: str = "training_summary"):
"""Generate comprehensive training summary and visualizations."""
Expand Down Expand Up @@ -654,6 +657,127 @@ def _ensure_tensor_dtype(self, tensor: torch.Tensor, is_index: bool = False) ->
tensor = tensor.to(dtype=self.dtype)
return tensor.to(self.device)

def adjust_learning_rate(self, current_loss: float, window_size: int = 5) -> None:
if len(self._uncertainty_buffer) >= window_size:
recent_losses = self._uncertainty_buffer[-window_size:]
loss_trend = np.mean(np.diff(recent_losses))

# Increase learning rate if loss is stagnating
if abs(loss_trend) < 0.01:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: The loss trend threshold should be configurable and scale-aware

A fixed threshold of 0.01 may not be appropriate for all loss scales. Consider making this relative to the loss magnitude or configurable.

self.optimizer.param_groups[0]['lr'] *= 1.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Learning rate adjustment factors should be configurable parameters

The 1.2 and 0.8 multipliers are magic numbers that should be configurable to allow tuning for different optimization scenarios.

                self.optimizer.param_groups[0]['lr'] *= self.lr_increase_factor

# Decrease learning rate if loss is unstable
elif loss_trend > 0:
self.optimizer.param_groups[0]['lr'] *= 0.8

def enhanced_early_stopping(self,
current_loss: float,
patience: int = 10,
min_delta: float = 0.01
) -> bool:
if not hasattr(self, '_best_loss'):
self._best_loss = float('inf')
self._patience_counter = 0

if current_loss < (self._best_loss - min_delta):
self._best_loss = current_loss
self._patience_counter = 0
return False

self._patience_counter += 1
return self._patience_counter >= patience

def save_checkpoint_with_metrics(self, path: str, metrics: Dict[str, float]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider extracting checkpoint management into a dedicated class to improve code organization.

The checkpoint management logic should be extracted into a dedicated class to reduce complexity while maintaining functionality. Here's an example:

class CheckpointManager:
    def __init__(self, checkpoint_dir: str, keep_top_n: int = 3):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.keep_top_n = keep_top_n

    def save(self, model, optimizer, scheduler, metrics: Dict[str, float], **extra_state):
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics,
            **extra_state,
            'timestamp': datetime.datetime.now().isoformat()
        }

        filename = f"checkpoint_loss_{metrics['loss']:.4f}.pt"
        torch.save(checkpoint, self.checkpoint_dir / filename)
        self._cleanup_old_checkpoints()

    def _cleanup_old_checkpoints(self):
        checkpoints = sorted(
            self.checkpoint_dir.glob("checkpoint_loss_*.pt"),
            key=lambda f: float(str(f).split("loss_")[1].split(".pt")[0])
        )
        for f in checkpoints[self.keep_top_n:]:
            f.unlink()

Then in SIFTTrainer:

class SIFTTrainer:
    def __init__(self, ...):
        self.checkpoint_manager = CheckpointManager("checkpoints")

    def save_checkpoint_with_metrics(self, metrics: Dict[str, float]):
        self.checkpoint_manager.save(
            self.model, 
            self.optimizer,
            self.scheduler,
            metrics,
            uncertainty_buffer=self._uncertainty_buffer,
            current_step=self.current_accumulation_step
        )

This refactoring:

  1. Separates checkpoint logic into a focused class
  2. Simplifies filename handling and cleanup
  3. Makes checkpoint behavior easier to modify and test
  4. Reduces the main trainer class complexity

Similar extractions should be done for validation metrics and training strategy logic.

checkpoint_dir = Path(path)
checkpoint_dir.mkdir(parents=True, exist_ok=True)

checkpoint = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'metrics': metrics,
'uncertainty_buffer': self._uncertainty_buffer,
'current_step': self.current_accumulation_step,
'timestamp': datetime.datetime.now().isoformat()
}

# Save with metrics in filename
filename = f"checkpoint_loss_{metrics['loss']:.4f}_step_{self.current_accumulation_step}.pt"
torch.save(checkpoint, checkpoint_dir / filename)

# Keep only top N checkpoints
self._cleanup_old_checkpoints(checkpoint_dir, keep_top_n=3)

def compute_adaptive_window_size(self, step: int, min_window: int = 3, max_window: int = 10) -> int:
"""Compute adaptive window size based on training progress."""
# Start small and increase window size as training progresses
progress_factor = min(1.0, step / 1000) # Normalize steps to [0,1]
window_size = min_window + int((max_window - min_window) * progress_factor)
return window_size
Comment on lines +714 to +715
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Inline variable that is immediately returned (inline-immediately-returned-variable)

Suggested change
window_size = min_window + int((max_window - min_window) * progress_factor)
return window_size
return min_window + int((max_window - min_window) * progress_factor)


def _cleanup_old_checkpoints(self, checkpoint_dir: Path, keep_top_n: int = 3):
"""Keep only the top N checkpoints based on loss."""
try:
# Get all checkpoint files
checkpoint_files = list(checkpoint_dir.glob("checkpoint_loss_*.pt"))

if len(checkpoint_files) <= keep_top_n:
return

# Extract loss values from filenames
def get_loss(filepath):
try:
return float(str(filepath).split("loss_")[1].split("_")[0])
except:
return float('inf')
Comment on lines +728 to +731
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): Use except Exception: rather than bare except: (do-not-use-bare-except)


# Sort by loss and keep only top N
sorted_files = sorted(checkpoint_files, key=get_loss)
files_to_remove = sorted_files[keep_top_n:]

# Remove excess checkpoints
for file in files_to_remove:
file.unlink()

except Exception as e:
logger.error(f"Error cleaning up checkpoints: {e}")

def compute_validation_metrics(self, validation_examples: List[str]) -> Dict[str, float]:
"""Compute validation metrics on a subset of examples."""
self.model.eval()
total_loss = 0.0
total_perplexity = 0.0

with torch.no_grad():
for example in validation_examples[:10]: # Limit to 10 examples for speed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: Hardcoding validation set size to 10 examples could lead to unstable metrics

Consider making the validation sample size configurable and ensuring it's large enough for statistically significant validation metrics.

try:
inputs = self.tokenizer(
example,
padding=True,
truncation=True,
max_length=self.config.max_length,
return_tensors="pt"
)

inputs = {k: self._ensure_tensor_dtype(v, is_index=True)
for k, v in inputs.items()}

outputs = self.model(**inputs, labels=inputs["input_ids"])
total_loss += outputs.loss.item()
total_perplexity += torch.exp(outputs.loss).item()

except Exception as e:
logger.error(f"Error in validation: {e}")
continue

n_examples = min(len(validation_examples), 10)
metrics = {
'val_loss': total_loss / max(n_examples, 1),
'val_perplexity': total_perplexity / max(n_examples, 1)
}

self.model.train()
return metrics


def main():
"""Test the SIFT trainer"""
Expand Down