-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: resolve training visualization and metrics issues #55
Conversation
- Add missing _cleanup_old_checkpoints method to SIFTTrainer - Create required directories for checkpoints and visualizations - Fix compute_validation_metrics implementation - Update .gitignore to properly handle output directories The changes ensure proper handling of model checkpoints and training visualizations while maintaining a clean project structure. Directory creation is now handled proactively to prevent file write errors. Reference: - SIFTTrainer cleanup implementation (lines 717-742) - Validation metrics computation (lines 744-779) - Directory structure in .gitignore (lines 23-25)
Reviewer's Guide by SourceryThis PR implements several improvements to the training infrastructure, focusing on checkpoint management, metrics computation, and training visualization. The changes include the addition of adaptive learning rate adjustment, enhanced early stopping mechanisms, proper checkpoint cleanup, and validation metrics computation. The implementation also ensures proper directory structure creation and adds gradient clipping for training stability. Sequence diagram for training step with validationsequenceDiagram
actor User
participant Trainer as SIFTTrainer
participant Metrics as MetricsComputer
participant Visualizer
participant Checkpoint
User->>Trainer: Start training step
Trainer->>Metrics: Update metrics
alt Save path provided
Trainer->>Visualizer: Plot training summary
end
alt Validation step
Trainer->>Metrics: Compute validation metrics
end
Trainer->>Checkpoint: Save checkpoint with metrics
Note right of Trainer: Adjust learning rate
Note right of Trainer: Enhanced early stopping check
User->>Trainer: End training step
ER diagram for checkpoint and metrics managementerDiagram
SIFTTrainer {
string model_state_dict
string optimizer_state_dict
string scheduler_state_dict
float loss
float uncertainty
int current_step
datetime timestamp
}
SIFTTrainer ||--o{ Checkpoint : manages
SIFTTrainer ||--o{ Metrics : computes
Checkpoint {
string path
float loss
int step
}
Metrics {
float val_loss
float val_perplexity
}
Updated class diagram for SIFTTrainerclassDiagram
class SIFTTrainer {
+fine_tune_step(example: str) Optional<Dict[str, Any]]
+should_stop_adaptive(step: int) bool
+update_and_visualize_metrics(current_loss: float, uncertainty: float, step: int, save_path: Optional[str]) Dict[str, Any]
+generate_training_summary(save_dir: str)
+adjust_learning_rate(current_loss: float, window_size: int)
+enhanced_early_stopping(current_loss: float, patience: int, min_delta: float) bool
+save_checkpoint_with_metrics(path: str, metrics: Dict[str, float])
+compute_adaptive_window_size(step: int, min_window: int, max_window: int) int
+_cleanup_old_checkpoints(checkpoint_dir: Path, keep_top_n: int)
+compute_validation_metrics(validation_examples: List[str]) Dict[str, float]
}
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @leonvanbokhorst - I've reviewed your changes - here's some feedback:
Overall Comments:
- The learning rate adjustment factors (1.2x increase, 0.8x decrease) in adjust_learning_rate() may be too aggressive and could lead to training instability. Consider using more conservative values like 1.05x and 0.95x.
Here's what I looked at during the review
- 🟡 General issues: 3 issues found
- 🟢 Security: all looks good
- 🟢 Testing: all looks good
- 🟡 Complexity: 1 issue found
- 🟢 Documentation: all looks good
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
loss_trend = np.mean(np.diff(recent_losses)) | ||
|
||
# Increase learning rate if loss is stagnating | ||
if abs(loss_trend) < 0.01: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue: The loss trend threshold should be configurable and scale-aware
A fixed threshold of 0.01 may not be appropriate for all loss scales. Consider making this relative to the loss magnitude or configurable.
total_perplexity = 0.0 | ||
|
||
with torch.no_grad(): | ||
for example in validation_examples[:10]: # Limit to 10 examples for speed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue: Hardcoding validation set size to 10 examples could lead to unstable metrics
Consider making the validation sample size configurable and ensuring it's large enough for statistically significant validation metrics.
|
||
# Increase learning rate if loss is stagnating | ||
if abs(loss_trend) < 0.01: | ||
self.optimizer.param_groups[0]['lr'] *= 1.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Learning rate adjustment factors should be configurable parameters
The 1.2 and 0.8 multipliers are magic numbers that should be configurable to allow tuning for different optimization scenarios.
self.optimizer.param_groups[0]['lr'] *= self.lr_increase_factor
self._patience_counter += 1 | ||
return self._patience_counter >= patience | ||
|
||
def save_checkpoint_with_metrics(self, path: str, metrics: Dict[str, float]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (complexity): Consider extracting checkpoint management into a dedicated class to improve code organization.
The checkpoint management logic should be extracted into a dedicated class to reduce complexity while maintaining functionality. Here's an example:
class CheckpointManager:
def __init__(self, checkpoint_dir: str, keep_top_n: int = 3):
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.keep_top_n = keep_top_n
def save(self, model, optimizer, scheduler, metrics: Dict[str, float], **extra_state):
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'metrics': metrics,
**extra_state,
'timestamp': datetime.datetime.now().isoformat()
}
filename = f"checkpoint_loss_{metrics['loss']:.4f}.pt"
torch.save(checkpoint, self.checkpoint_dir / filename)
self._cleanup_old_checkpoints()
def _cleanup_old_checkpoints(self):
checkpoints = sorted(
self.checkpoint_dir.glob("checkpoint_loss_*.pt"),
key=lambda f: float(str(f).split("loss_")[1].split(".pt")[0])
)
for f in checkpoints[self.keep_top_n:]:
f.unlink()
Then in SIFTTrainer:
class SIFTTrainer:
def __init__(self, ...):
self.checkpoint_manager = CheckpointManager("checkpoints")
def save_checkpoint_with_metrics(self, metrics: Dict[str, float]):
self.checkpoint_manager.save(
self.model,
self.optimizer,
self.scheduler,
metrics,
uncertainty_buffer=self._uncertainty_buffer,
current_step=self.current_accumulation_step
)
This refactoring:
- Separates checkpoint logic into a focused class
- Simplifies filename handling and cleanup
- Makes checkpoint behavior easier to modify and test
- Reduces the main trainer class complexity
Similar extractions should be done for validation metrics and training strategy logic.
window_size = min_window + int((max_window - min_window) * progress_factor) | ||
return window_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (code-quality): Inline variable that is immediately returned (inline-immediately-returned-variable
)
window_size = min_window + int((max_window - min_window) * progress_factor) | |
return window_size | |
return min_window + int((max_window - min_window) * progress_factor) |
try: | ||
return float(str(filepath).split("loss_")[1].split("_")[0]) | ||
except: | ||
return float('inf') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (code-quality): Use except Exception:
rather than bare except:
(do-not-use-bare-except
)
The changes ensure proper handling of model checkpoints and training
visualizations while maintaining a clean project structure. Directory
creation is now handled proactively to prevent file write errors.
Reference:
Summary by Sourcery
Resolve issues with training visualization and metrics by adding missing methods, creating necessary directories, and fixing validation metrics computation. Enhance training stability with gradient clipping, adaptive window size, and improved early stopping. Update .gitignore to manage output directories.
Bug Fixes:
Enhancements:
Build: