Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: resolve training visualization and metrics issues #55

Merged
merged 1 commit into from
Nov 17, 2024
Merged

Conversation

leonvanbokhorst
Copy link
Owner

@leonvanbokhorst leonvanbokhorst commented Nov 17, 2024

  • Add missing _cleanup_old_checkpoints method to SIFTTrainer
  • Create required directories for checkpoints and visualizations
  • Fix compute_validation_metrics implementation
  • Update .gitignore to properly handle output directories

The changes ensure proper handling of model checkpoints and training
visualizations while maintaining a clean project structure. Directory
creation is now handled proactively to prevent file write errors.

Reference:

  • SIFTTrainer cleanup implementation (lines 717-742)
  • Validation metrics computation (lines 744-779)
  • Directory structure in .gitignore (lines 23-25)

Summary by Sourcery

Resolve issues with training visualization and metrics by adding missing methods, creating necessary directories, and fixing validation metrics computation. Enhance training stability with gradient clipping, adaptive window size, and improved early stopping. Update .gitignore to manage output directories.

Bug Fixes:

  • Fix the implementation of compute_validation_metrics to correctly compute and return validation metrics.

Enhancements:

  • Add gradient clipping to the fine-tuning step to improve training stability.
  • Implement adaptive window size for uncertainty buffer management in the should_stop_adaptive method.
  • Enhance early stopping mechanism with stability verification and learning rate adjustment based on loss trends.
  • Improve checkpoint management by adding a method to clean up old checkpoints and save checkpoints with metrics in the filename.

Build:

  • Update .gitignore to handle output directories for checkpoints, visualizations, and training summaries.

- Add missing _cleanup_old_checkpoints method to SIFTTrainer
- Create required directories for checkpoints and visualizations
- Fix compute_validation_metrics implementation
- Update .gitignore to properly handle output directories

The changes ensure proper handling of model checkpoints and training
visualizations while maintaining a clean project structure. Directory
creation is now handled proactively to prevent file write errors.

Reference:
- SIFTTrainer cleanup implementation (lines 717-742)
- Validation metrics computation (lines 744-779)
- Directory structure in .gitignore (lines 23-25)
Copy link
Contributor

sourcery-ai bot commented Nov 17, 2024

Reviewer's Guide by Sourcery

This PR implements several improvements to the training infrastructure, focusing on checkpoint management, metrics computation, and training visualization. The changes include the addition of adaptive learning rate adjustment, enhanced early stopping mechanisms, proper checkpoint cleanup, and validation metrics computation. The implementation also ensures proper directory structure creation and adds gradient clipping for training stability.

Sequence diagram for training step with validation

sequenceDiagram
    actor User
    participant Trainer as SIFTTrainer
    participant Metrics as MetricsComputer
    participant Visualizer
    participant Checkpoint
    User->>Trainer: Start training step
    Trainer->>Metrics: Update metrics
    alt Save path provided
        Trainer->>Visualizer: Plot training summary
    end
    alt Validation step
        Trainer->>Metrics: Compute validation metrics
    end
    Trainer->>Checkpoint: Save checkpoint with metrics
    Note right of Trainer: Adjust learning rate
    Note right of Trainer: Enhanced early stopping check
    User->>Trainer: End training step
Loading

ER diagram for checkpoint and metrics management

erDiagram
    SIFTTrainer {
        string model_state_dict
        string optimizer_state_dict
        string scheduler_state_dict
        float loss
        float uncertainty
        int current_step
        datetime timestamp
    }
    SIFTTrainer ||--o{ Checkpoint : manages
    SIFTTrainer ||--o{ Metrics : computes
    Checkpoint {
        string path
        float loss
        int step
    }
    Metrics {
        float val_loss
        float val_perplexity
    }
Loading

Updated class diagram for SIFTTrainer

classDiagram
    class SIFTTrainer {
        +fine_tune_step(example: str) Optional<Dict[str, Any]]
        +should_stop_adaptive(step: int) bool
        +update_and_visualize_metrics(current_loss: float, uncertainty: float, step: int, save_path: Optional[str]) Dict[str, Any]
        +generate_training_summary(save_dir: str)
        +adjust_learning_rate(current_loss: float, window_size: int)
        +enhanced_early_stopping(current_loss: float, patience: int, min_delta: float) bool
        +save_checkpoint_with_metrics(path: str, metrics: Dict[str, float])
        +compute_adaptive_window_size(step: int, min_window: int, max_window: int) int
        +_cleanup_old_checkpoints(checkpoint_dir: Path, keep_top_n: int)
        +compute_validation_metrics(validation_examples: List[str]) Dict[str, float]
    }
Loading

File-Level Changes

Change Details Files
Implemented enhanced model checkpoint management system
  • Added checkpoint cleanup mechanism to retain only top N checkpoints based on loss
  • Implemented checkpoint saving with metrics in filename
  • Added timestamp and comprehensive state information to checkpoints
src/sift/sift_trainer.py
Added training optimization and stability improvements
  • Implemented gradient clipping with max_norm=1.0
  • Added adaptive learning rate adjustment based on loss trends
  • Implemented enhanced early stopping with patience and minimum delta parameters
  • Added adaptive window size computation for uncertainty calculations
src/sift/sift_trainer.py
Enhanced metrics computation and validation
  • Implemented validation metrics computation with loss and perplexity tracking
  • Added proper error handling in metrics computation
  • Updated metrics tracking to include validation metrics
  • Improved metrics visualization and summary generation
src/sift/sift_trainer.py
src/sift/sift_metrics.py
src/19_llm_fine_tuning_sift.py
Improved project structure and organization
  • Added automatic creation of required directories for outputs
  • Updated directory structure handling in training script
  • Improved error handling and logging throughout the codebase
src/19_llm_fine_tuning_sift.py
.gitignore

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time. You can also use
    this command to specify where the summary should be inserted.

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@leonvanbokhorst leonvanbokhorst self-assigned this Nov 17, 2024
@leonvanbokhorst leonvanbokhorst added documentation Improvements or additions to documentation enhancement New feature or request labels Nov 17, 2024
@leonvanbokhorst leonvanbokhorst merged commit 3555d36 into main Nov 17, 2024
1 check passed
Copy link
Contributor

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @leonvanbokhorst - I've reviewed your changes - here's some feedback:

Overall Comments:

  • The learning rate adjustment factors (1.2x increase, 0.8x decrease) in adjust_learning_rate() may be too aggressive and could lead to training instability. Consider using more conservative values like 1.05x and 0.95x.
Here's what I looked at during the review
  • 🟡 General issues: 3 issues found
  • 🟢 Security: all looks good
  • 🟢 Testing: all looks good
  • 🟡 Complexity: 1 issue found
  • 🟢 Documentation: all looks good

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

loss_trend = np.mean(np.diff(recent_losses))

# Increase learning rate if loss is stagnating
if abs(loss_trend) < 0.01:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: The loss trend threshold should be configurable and scale-aware

A fixed threshold of 0.01 may not be appropriate for all loss scales. Consider making this relative to the loss magnitude or configurable.

total_perplexity = 0.0

with torch.no_grad():
for example in validation_examples[:10]: # Limit to 10 examples for speed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: Hardcoding validation set size to 10 examples could lead to unstable metrics

Consider making the validation sample size configurable and ensuring it's large enough for statistically significant validation metrics.


# Increase learning rate if loss is stagnating
if abs(loss_trend) < 0.01:
self.optimizer.param_groups[0]['lr'] *= 1.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Learning rate adjustment factors should be configurable parameters

The 1.2 and 0.8 multipliers are magic numbers that should be configurable to allow tuning for different optimization scenarios.

                self.optimizer.param_groups[0]['lr'] *= self.lr_increase_factor

self._patience_counter += 1
return self._patience_counter >= patience

def save_checkpoint_with_metrics(self, path: str, metrics: Dict[str, float]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider extracting checkpoint management into a dedicated class to improve code organization.

The checkpoint management logic should be extracted into a dedicated class to reduce complexity while maintaining functionality. Here's an example:

class CheckpointManager:
    def __init__(self, checkpoint_dir: str, keep_top_n: int = 3):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.keep_top_n = keep_top_n

    def save(self, model, optimizer, scheduler, metrics: Dict[str, float], **extra_state):
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics,
            **extra_state,
            'timestamp': datetime.datetime.now().isoformat()
        }

        filename = f"checkpoint_loss_{metrics['loss']:.4f}.pt"
        torch.save(checkpoint, self.checkpoint_dir / filename)
        self._cleanup_old_checkpoints()

    def _cleanup_old_checkpoints(self):
        checkpoints = sorted(
            self.checkpoint_dir.glob("checkpoint_loss_*.pt"),
            key=lambda f: float(str(f).split("loss_")[1].split(".pt")[0])
        )
        for f in checkpoints[self.keep_top_n:]:
            f.unlink()

Then in SIFTTrainer:

class SIFTTrainer:
    def __init__(self, ...):
        self.checkpoint_manager = CheckpointManager("checkpoints")

    def save_checkpoint_with_metrics(self, metrics: Dict[str, float]):
        self.checkpoint_manager.save(
            self.model, 
            self.optimizer,
            self.scheduler,
            metrics,
            uncertainty_buffer=self._uncertainty_buffer,
            current_step=self.current_accumulation_step
        )

This refactoring:

  1. Separates checkpoint logic into a focused class
  2. Simplifies filename handling and cleanup
  3. Makes checkpoint behavior easier to modify and test
  4. Reduces the main trainer class complexity

Similar extractions should be done for validation metrics and training strategy logic.

Comment on lines +714 to +715
window_size = min_window + int((max_window - min_window) * progress_factor)
return window_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Inline variable that is immediately returned (inline-immediately-returned-variable)

Suggested change
window_size = min_window + int((max_window - min_window) * progress_factor)
return window_size
return min_window + int((max_window - min_window) * progress_factor)

Comment on lines +728 to +731
try:
return float(str(filepath).split("loss_")[1].split("_")[0])
except:
return float('inf')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): Use except Exception: rather than bare except: (do-not-use-bare-except)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant