Skip to content

Commit

Permalink
Merge pull request #55 from leonvanbokhorst/sift
Browse files Browse the repository at this point in the history
fix: resolve training visualization and metrics issues
  • Loading branch information
leonvanbokhorst authored Nov 17, 2024
2 parents a62d0d9 + cfe5ce3 commit 3555d36
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 37 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Llama-3.2-1B-Instruct-Complaint/
dataset/
cache/
checkpoints/
training_summary/
visualizations/



Expand Down
40 changes: 37 additions & 3 deletions src/19_llm_fine_tuning_sift.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def main():
n_test_prompts = 100
test_prompts = sampler.sample_test_prompts(n_prompts=n_test_prompts)
training_data = sampler.get_training_subset(size=1000)
validation_data = sampler.get_training_subset(size=100) # Separate validation set

logger.info(
f"Sampled {len(test_prompts)} test prompts and {len(training_data)} training examples"
Expand Down Expand Up @@ -144,9 +145,16 @@ def format_step(step: int, metrics: dict) -> str:
'global_losses': [],
'prompt_losses': [],
'uncertainties': [],
'steps_per_prompt': []
'steps_per_prompt': [],
'validation_loss': [],
'validation_perplexity': []
}

# Create directories for outputs
Path("checkpoints").mkdir(parents=True, exist_ok=True)
Path("visualizations").mkdir(parents=True, exist_ok=True)
Path("training_summary").mkdir(parents=True, exist_ok=True)

for prompt_idx, prompt in enumerate(test_prompts):
try:
# Disable all other loggers
Expand Down Expand Up @@ -202,7 +210,14 @@ def format_step(step: int, metrics: dict) -> str:
if current_loss < global_best_loss:
global_best_loss = current_loss
# Save best model checkpoint
trainer.save_checkpoint(f"checkpoints/best_model_{prompt_idx}")
trainer.save_checkpoint_with_metrics(
f"checkpoints/best_model_{prompt_idx}",
{
"loss": current_loss,
"uncertainty": uncertainty,
"global_best": global_best_loss
}
)

# Log progress
if i % 1 == 0:
Expand All @@ -220,12 +235,29 @@ def format_step(step: int, metrics: dict) -> str:

# Enhanced stopping check with stability verification
if (i >= min_examples and
trainer.should_stop_adaptive(uncertainty, i, alpha=0.1) and
(trainer.should_stop_adaptive(uncertainty, i, alpha=0.1) or
trainer.enhanced_early_stopping(current_loss)) and
len(prompt_stats["losses"]) >= 3 and
np.std(prompt_stats["losses"][-3:]) < 0.1):
logger.info(f"Stopping early at step {i} due to convergence")
break

# Add after line 193 (after current_loss assignment)
trainer.adjust_learning_rate(current_loss)

metrics_tracker['steps_per_prompt'].append(i)
trainer.update_and_visualize_metrics(
current_loss,
uncertainty,
i,
save_path=f"visualizations/step_{prompt_idx}_{i}.png"
)

if i % 10 == 0: # Validate every 10 steps
val_metrics = trainer.compute_validation_metrics(validation_data)
metrics_tracker['validation_loss'].append(val_metrics['val_loss'])
metrics_tracker['validation_perplexity'].append(val_metrics['val_perplexity'])

except Exception as e:
logger.error(f"Error in training step: {str(e)}")
continue
Expand Down Expand Up @@ -312,6 +344,8 @@ def format_step(step: int, metrics: dict) -> str:
save_path="adaptive_stopping.png",
)

trainer.generate_training_summary(save_dir="training_summary")


if __name__ == "__main__":
main()
30 changes: 29 additions & 1 deletion src/sift/sift_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from typing import Dict, List, Any, Optional
import logging
import torch

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -133,4 +134,31 @@ def get_stopping_summary(self) -> Dict[str, List[int]]:
return {
'stopping_points': self.stopping_points,
'uncertainty_history': self.uncertainty_history
}
}

def compute_validation_metrics(self, validation_examples: List[str]) -> Dict[str, float]:
self.model.eval()
total_loss = 0.0
total_perplexity = 0.0

with torch.no_grad():
for example in validation_examples:
inputs = self.tokenizer(
example,
padding=True,
truncation=True,
max_length=self.config.max_length,
return_tensors="pt"
).to(self.device)

outputs = self.model(**inputs, labels=inputs["input_ids"])
total_loss += outputs.loss.item()
total_perplexity += torch.exp(outputs.loss).item()

avg_metrics = {
'val_loss': total_loss / len(validation_examples),
'val_perplexity': total_perplexity / len(validation_examples)
}

self.model.train()
return avg_metrics
190 changes: 157 additions & 33 deletions src/sift/sift_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ def fine_tune_step(self, example: str) -> Optional[Dict[str, Any]]:
labels=labels,
)

# Add gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

return {"loss": outputs.loss.item()}

except Exception as e:
Expand Down Expand Up @@ -506,16 +509,19 @@ def should_stop_adaptive(
if step < self.min_steps:
return False

# Use adaptive window size
adaptive_window = self.compute_adaptive_window_size(step)

# Add uncertainty to buffer
self._uncertainty_buffer.append(uncertainty)

# Keep buffer size manageable
if len(self._uncertainty_buffer) > self.window_size:
if len(self._uncertainty_buffer) > adaptive_window:
self._uncertainty_buffer.pop(0)

# Get recent uncertainties with outlier removal
recent_uncertainties = np.array(self._uncertainty_buffer[-self.window_size:])
if len(recent_uncertainties) < self.window_size:
recent_uncertainties = np.array(self._uncertainty_buffer[-adaptive_window:])
if len(recent_uncertainties) < adaptive_window:
return False

# Remove outliers using IQR method
Expand Down Expand Up @@ -588,36 +594,33 @@ def update_and_visualize_metrics(
save_path: Optional[str] = None
) -> Dict[str, Any]:
"""Update metrics and generate visualizations."""
# Update EMA loss
ema_loss = self.update_ema_loss(current_loss)

# Compute metrics
metrics = self.metrics_computer.compute_metrics(
{'loss': current_loss},
uncertainty=uncertainty
)

# Check stopping condition using both metrics
should_stop = self.adaptive_stopping.should_stop(
uncertainty=uncertainty,
step=step,
loss_history=self.metrics_computer.metrics_history['loss']
)

# Generate visualizations periodically
if step % 10 == 0:
self.visualizer.plot_adaptive_stopping(
metrics=self.get_training_summary(),
alpha=self.adaptive_stopping.alpha,
title=f"Training Progress - Step {step}",
save_path=save_path
)

return {
'metrics': metrics,
'should_stop': should_stop,
'ema_loss': ema_loss
}
try:
# Update metrics history
self.metrics_computer.update({
'loss': current_loss,
'uncertainty': uncertainty
})

# Get complete metrics history
metrics_history = self.metrics_computer.get_metrics_summary()

# Generate visualization if save path is provided
if save_path:
self.visualizer.plot_training_summary(
losses=metrics_history['loss'],
uncertainties=metrics_history['uncertainty'],
save_path=save_path
)

return {
'loss': current_loss,
'uncertainty': uncertainty,
'step': step
}

except Exception as e:
logger.error(f"Error in metrics update: {e}")
return None

def generate_training_summary(self, save_dir: str = "training_summary"):
"""Generate comprehensive training summary and visualizations."""
Expand Down Expand Up @@ -654,6 +657,127 @@ def _ensure_tensor_dtype(self, tensor: torch.Tensor, is_index: bool = False) ->
tensor = tensor.to(dtype=self.dtype)
return tensor.to(self.device)

def adjust_learning_rate(self, current_loss: float, window_size: int = 5) -> None:
if len(self._uncertainty_buffer) >= window_size:
recent_losses = self._uncertainty_buffer[-window_size:]
loss_trend = np.mean(np.diff(recent_losses))

# Increase learning rate if loss is stagnating
if abs(loss_trend) < 0.01:
self.optimizer.param_groups[0]['lr'] *= 1.2
# Decrease learning rate if loss is unstable
elif loss_trend > 0:
self.optimizer.param_groups[0]['lr'] *= 0.8

def enhanced_early_stopping(self,
current_loss: float,
patience: int = 10,
min_delta: float = 0.01
) -> bool:
if not hasattr(self, '_best_loss'):
self._best_loss = float('inf')
self._patience_counter = 0

if current_loss < (self._best_loss - min_delta):
self._best_loss = current_loss
self._patience_counter = 0
return False

self._patience_counter += 1
return self._patience_counter >= patience

def save_checkpoint_with_metrics(self, path: str, metrics: Dict[str, float]):
checkpoint_dir = Path(path)
checkpoint_dir.mkdir(parents=True, exist_ok=True)

checkpoint = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'metrics': metrics,
'uncertainty_buffer': self._uncertainty_buffer,
'current_step': self.current_accumulation_step,
'timestamp': datetime.datetime.now().isoformat()
}

# Save with metrics in filename
filename = f"checkpoint_loss_{metrics['loss']:.4f}_step_{self.current_accumulation_step}.pt"
torch.save(checkpoint, checkpoint_dir / filename)

# Keep only top N checkpoints
self._cleanup_old_checkpoints(checkpoint_dir, keep_top_n=3)

def compute_adaptive_window_size(self, step: int, min_window: int = 3, max_window: int = 10) -> int:
"""Compute adaptive window size based on training progress."""
# Start small and increase window size as training progresses
progress_factor = min(1.0, step / 1000) # Normalize steps to [0,1]
window_size = min_window + int((max_window - min_window) * progress_factor)
return window_size

def _cleanup_old_checkpoints(self, checkpoint_dir: Path, keep_top_n: int = 3):
"""Keep only the top N checkpoints based on loss."""
try:
# Get all checkpoint files
checkpoint_files = list(checkpoint_dir.glob("checkpoint_loss_*.pt"))

if len(checkpoint_files) <= keep_top_n:
return

# Extract loss values from filenames
def get_loss(filepath):
try:
return float(str(filepath).split("loss_")[1].split("_")[0])
except:
return float('inf')

# Sort by loss and keep only top N
sorted_files = sorted(checkpoint_files, key=get_loss)
files_to_remove = sorted_files[keep_top_n:]

# Remove excess checkpoints
for file in files_to_remove:
file.unlink()

except Exception as e:
logger.error(f"Error cleaning up checkpoints: {e}")

def compute_validation_metrics(self, validation_examples: List[str]) -> Dict[str, float]:
"""Compute validation metrics on a subset of examples."""
self.model.eval()
total_loss = 0.0
total_perplexity = 0.0

with torch.no_grad():
for example in validation_examples[:10]: # Limit to 10 examples for speed
try:
inputs = self.tokenizer(
example,
padding=True,
truncation=True,
max_length=self.config.max_length,
return_tensors="pt"
)

inputs = {k: self._ensure_tensor_dtype(v, is_index=True)
for k, v in inputs.items()}

outputs = self.model(**inputs, labels=inputs["input_ids"])
total_loss += outputs.loss.item()
total_perplexity += torch.exp(outputs.loss).item()

except Exception as e:
logger.error(f"Error in validation: {e}")
continue

n_examples = min(len(validation_examples), 10)
metrics = {
'val_loss': total_loss / max(n_examples, 1),
'val_perplexity': total_perplexity / max(n_examples, 1)
}

self.model.train()
return metrics


def main():
"""Test the SIFT trainer"""
Expand Down

0 comments on commit 3555d36

Please sign in to comment.