-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: resolve training visualization and metrics issues #55
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,8 @@ Llama-3.2-1B-Instruct-Complaint/ | |
dataset/ | ||
cache/ | ||
checkpoints/ | ||
training_summary/ | ||
visualizations/ | ||
|
||
|
||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -411,6 +411,9 @@ def fine_tune_step(self, example: str) -> Optional[Dict[str, Any]]: | |||||||
labels=labels, | ||||||||
) | ||||||||
|
||||||||
# Add gradient clipping | ||||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | ||||||||
|
||||||||
return {"loss": outputs.loss.item()} | ||||||||
|
||||||||
except Exception as e: | ||||||||
|
@@ -506,16 +509,19 @@ def should_stop_adaptive( | |||||||
if step < self.min_steps: | ||||||||
return False | ||||||||
|
||||||||
# Use adaptive window size | ||||||||
adaptive_window = self.compute_adaptive_window_size(step) | ||||||||
|
||||||||
# Add uncertainty to buffer | ||||||||
self._uncertainty_buffer.append(uncertainty) | ||||||||
|
||||||||
# Keep buffer size manageable | ||||||||
if len(self._uncertainty_buffer) > self.window_size: | ||||||||
if len(self._uncertainty_buffer) > adaptive_window: | ||||||||
self._uncertainty_buffer.pop(0) | ||||||||
|
||||||||
# Get recent uncertainties with outlier removal | ||||||||
recent_uncertainties = np.array(self._uncertainty_buffer[-self.window_size:]) | ||||||||
if len(recent_uncertainties) < self.window_size: | ||||||||
recent_uncertainties = np.array(self._uncertainty_buffer[-adaptive_window:]) | ||||||||
if len(recent_uncertainties) < adaptive_window: | ||||||||
return False | ||||||||
|
||||||||
# Remove outliers using IQR method | ||||||||
|
@@ -588,36 +594,33 @@ def update_and_visualize_metrics( | |||||||
save_path: Optional[str] = None | ||||||||
) -> Dict[str, Any]: | ||||||||
"""Update metrics and generate visualizations.""" | ||||||||
# Update EMA loss | ||||||||
ema_loss = self.update_ema_loss(current_loss) | ||||||||
|
||||||||
# Compute metrics | ||||||||
metrics = self.metrics_computer.compute_metrics( | ||||||||
{'loss': current_loss}, | ||||||||
uncertainty=uncertainty | ||||||||
) | ||||||||
|
||||||||
# Check stopping condition using both metrics | ||||||||
should_stop = self.adaptive_stopping.should_stop( | ||||||||
uncertainty=uncertainty, | ||||||||
step=step, | ||||||||
loss_history=self.metrics_computer.metrics_history['loss'] | ||||||||
) | ||||||||
|
||||||||
# Generate visualizations periodically | ||||||||
if step % 10 == 0: | ||||||||
self.visualizer.plot_adaptive_stopping( | ||||||||
metrics=self.get_training_summary(), | ||||||||
alpha=self.adaptive_stopping.alpha, | ||||||||
title=f"Training Progress - Step {step}", | ||||||||
save_path=save_path | ||||||||
) | ||||||||
|
||||||||
return { | ||||||||
'metrics': metrics, | ||||||||
'should_stop': should_stop, | ||||||||
'ema_loss': ema_loss | ||||||||
} | ||||||||
try: | ||||||||
# Update metrics history | ||||||||
self.metrics_computer.update({ | ||||||||
'loss': current_loss, | ||||||||
'uncertainty': uncertainty | ||||||||
}) | ||||||||
|
||||||||
# Get complete metrics history | ||||||||
metrics_history = self.metrics_computer.get_metrics_summary() | ||||||||
|
||||||||
# Generate visualization if save path is provided | ||||||||
if save_path: | ||||||||
self.visualizer.plot_training_summary( | ||||||||
losses=metrics_history['loss'], | ||||||||
uncertainties=metrics_history['uncertainty'], | ||||||||
save_path=save_path | ||||||||
) | ||||||||
|
||||||||
return { | ||||||||
'loss': current_loss, | ||||||||
'uncertainty': uncertainty, | ||||||||
'step': step | ||||||||
} | ||||||||
|
||||||||
except Exception as e: | ||||||||
logger.error(f"Error in metrics update: {e}") | ||||||||
return None | ||||||||
|
||||||||
def generate_training_summary(self, save_dir: str = "training_summary"): | ||||||||
"""Generate comprehensive training summary and visualizations.""" | ||||||||
|
@@ -654,6 +657,127 @@ def _ensure_tensor_dtype(self, tensor: torch.Tensor, is_index: bool = False) -> | |||||||
tensor = tensor.to(dtype=self.dtype) | ||||||||
return tensor.to(self.device) | ||||||||
|
||||||||
def adjust_learning_rate(self, current_loss: float, window_size: int = 5) -> None: | ||||||||
if len(self._uncertainty_buffer) >= window_size: | ||||||||
recent_losses = self._uncertainty_buffer[-window_size:] | ||||||||
loss_trend = np.mean(np.diff(recent_losses)) | ||||||||
|
||||||||
# Increase learning rate if loss is stagnating | ||||||||
if abs(loss_trend) < 0.01: | ||||||||
self.optimizer.param_groups[0]['lr'] *= 1.2 | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Learning rate adjustment factors should be configurable parameters The 1.2 and 0.8 multipliers are magic numbers that should be configurable to allow tuning for different optimization scenarios. self.optimizer.param_groups[0]['lr'] *= self.lr_increase_factor |
||||||||
# Decrease learning rate if loss is unstable | ||||||||
elif loss_trend > 0: | ||||||||
self.optimizer.param_groups[0]['lr'] *= 0.8 | ||||||||
|
||||||||
def enhanced_early_stopping(self, | ||||||||
current_loss: float, | ||||||||
patience: int = 10, | ||||||||
min_delta: float = 0.01 | ||||||||
) -> bool: | ||||||||
if not hasattr(self, '_best_loss'): | ||||||||
self._best_loss = float('inf') | ||||||||
self._patience_counter = 0 | ||||||||
|
||||||||
if current_loss < (self._best_loss - min_delta): | ||||||||
self._best_loss = current_loss | ||||||||
self._patience_counter = 0 | ||||||||
return False | ||||||||
|
||||||||
self._patience_counter += 1 | ||||||||
return self._patience_counter >= patience | ||||||||
|
||||||||
def save_checkpoint_with_metrics(self, path: str, metrics: Dict[str, float]): | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (complexity): Consider extracting checkpoint management into a dedicated class to improve code organization. The checkpoint management logic should be extracted into a dedicated class to reduce complexity while maintaining functionality. Here's an example: class CheckpointManager:
def __init__(self, checkpoint_dir: str, keep_top_n: int = 3):
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.keep_top_n = keep_top_n
def save(self, model, optimizer, scheduler, metrics: Dict[str, float], **extra_state):
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'metrics': metrics,
**extra_state,
'timestamp': datetime.datetime.now().isoformat()
}
filename = f"checkpoint_loss_{metrics['loss']:.4f}.pt"
torch.save(checkpoint, self.checkpoint_dir / filename)
self._cleanup_old_checkpoints()
def _cleanup_old_checkpoints(self):
checkpoints = sorted(
self.checkpoint_dir.glob("checkpoint_loss_*.pt"),
key=lambda f: float(str(f).split("loss_")[1].split(".pt")[0])
)
for f in checkpoints[self.keep_top_n:]:
f.unlink() Then in SIFTTrainer: class SIFTTrainer:
def __init__(self, ...):
self.checkpoint_manager = CheckpointManager("checkpoints")
def save_checkpoint_with_metrics(self, metrics: Dict[str, float]):
self.checkpoint_manager.save(
self.model,
self.optimizer,
self.scheduler,
metrics,
uncertainty_buffer=self._uncertainty_buffer,
current_step=self.current_accumulation_step
) This refactoring:
Similar extractions should be done for validation metrics and training strategy logic. |
||||||||
checkpoint_dir = Path(path) | ||||||||
checkpoint_dir.mkdir(parents=True, exist_ok=True) | ||||||||
|
||||||||
checkpoint = { | ||||||||
'model_state_dict': self.model.state_dict(), | ||||||||
'optimizer_state_dict': self.optimizer.state_dict(), | ||||||||
'scheduler_state_dict': self.scheduler.state_dict(), | ||||||||
'metrics': metrics, | ||||||||
'uncertainty_buffer': self._uncertainty_buffer, | ||||||||
'current_step': self.current_accumulation_step, | ||||||||
'timestamp': datetime.datetime.now().isoformat() | ||||||||
} | ||||||||
|
||||||||
# Save with metrics in filename | ||||||||
filename = f"checkpoint_loss_{metrics['loss']:.4f}_step_{self.current_accumulation_step}.pt" | ||||||||
torch.save(checkpoint, checkpoint_dir / filename) | ||||||||
|
||||||||
# Keep only top N checkpoints | ||||||||
self._cleanup_old_checkpoints(checkpoint_dir, keep_top_n=3) | ||||||||
|
||||||||
def compute_adaptive_window_size(self, step: int, min_window: int = 3, max_window: int = 10) -> int: | ||||||||
"""Compute adaptive window size based on training progress.""" | ||||||||
# Start small and increase window size as training progresses | ||||||||
progress_factor = min(1.0, step / 1000) # Normalize steps to [0,1] | ||||||||
window_size = min_window + int((max_window - min_window) * progress_factor) | ||||||||
return window_size | ||||||||
Comment on lines
+714
to
+715
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (code-quality): Inline variable that is immediately returned (
Suggested change
|
||||||||
|
||||||||
def _cleanup_old_checkpoints(self, checkpoint_dir: Path, keep_top_n: int = 3): | ||||||||
"""Keep only the top N checkpoints based on loss.""" | ||||||||
try: | ||||||||
# Get all checkpoint files | ||||||||
checkpoint_files = list(checkpoint_dir.glob("checkpoint_loss_*.pt")) | ||||||||
|
||||||||
if len(checkpoint_files) <= keep_top_n: | ||||||||
return | ||||||||
|
||||||||
# Extract loss values from filenames | ||||||||
def get_loss(filepath): | ||||||||
try: | ||||||||
return float(str(filepath).split("loss_")[1].split("_")[0]) | ||||||||
except: | ||||||||
return float('inf') | ||||||||
Comment on lines
+728
to
+731
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (code-quality): Use |
||||||||
|
||||||||
# Sort by loss and keep only top N | ||||||||
sorted_files = sorted(checkpoint_files, key=get_loss) | ||||||||
files_to_remove = sorted_files[keep_top_n:] | ||||||||
|
||||||||
# Remove excess checkpoints | ||||||||
for file in files_to_remove: | ||||||||
file.unlink() | ||||||||
|
||||||||
except Exception as e: | ||||||||
logger.error(f"Error cleaning up checkpoints: {e}") | ||||||||
|
||||||||
def compute_validation_metrics(self, validation_examples: List[str]) -> Dict[str, float]: | ||||||||
"""Compute validation metrics on a subset of examples.""" | ||||||||
self.model.eval() | ||||||||
total_loss = 0.0 | ||||||||
total_perplexity = 0.0 | ||||||||
|
||||||||
with torch.no_grad(): | ||||||||
for example in validation_examples[:10]: # Limit to 10 examples for speed | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue: Hardcoding validation set size to 10 examples could lead to unstable metrics Consider making the validation sample size configurable and ensuring it's large enough for statistically significant validation metrics. |
||||||||
try: | ||||||||
inputs = self.tokenizer( | ||||||||
example, | ||||||||
padding=True, | ||||||||
truncation=True, | ||||||||
max_length=self.config.max_length, | ||||||||
return_tensors="pt" | ||||||||
) | ||||||||
|
||||||||
inputs = {k: self._ensure_tensor_dtype(v, is_index=True) | ||||||||
for k, v in inputs.items()} | ||||||||
|
||||||||
outputs = self.model(**inputs, labels=inputs["input_ids"]) | ||||||||
total_loss += outputs.loss.item() | ||||||||
total_perplexity += torch.exp(outputs.loss).item() | ||||||||
|
||||||||
except Exception as e: | ||||||||
logger.error(f"Error in validation: {e}") | ||||||||
continue | ||||||||
|
||||||||
n_examples = min(len(validation_examples), 10) | ||||||||
metrics = { | ||||||||
'val_loss': total_loss / max(n_examples, 1), | ||||||||
'val_perplexity': total_perplexity / max(n_examples, 1) | ||||||||
} | ||||||||
|
||||||||
self.model.train() | ||||||||
return metrics | ||||||||
|
||||||||
|
||||||||
def main(): | ||||||||
"""Test the SIFT trainer""" | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue: The loss trend threshold should be configurable and scale-aware
A fixed threshold of 0.01 may not be appropriate for all loss scales. Consider making this relative to the loss magnitude or configurable.