Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelTrainer Module #27

Closed
gitttt-1234 opened this issue Nov 29, 2023 · 1 comment
Closed

ModelTrainer Module #27

gitttt-1234 opened this issue Nov 29, 2023 · 1 comment

Comments

@gitttt-1234
Copy link
Collaborator

gitttt-1234 commented Nov 29, 2023

ModelTrainer Module is required to facilitate the training of different types of model using PyTorch Lightning with optional logging in wandb. We define the ModelTrainer Module as follows:

ModelTrainer Module: This class is used to train a PyTorch model using Lightning and save the trained model. The inputs to this class are config files for Data pipeline, model backbone and head configs and training configs. We have a child class of pl.LightningModule for each type of model - SingleInstance, TopDownCenteredInstance, etc, which is trained using a Trainer Class.

The ModelTrainer Module performs the following functions:

  1. Creates DataLoaders from the data configs
  2. Trains a Lightning Model with Trainer class
  3. Optional Logging with wandb to track the training process
  4. Saves Checkpoints throughout the training
from omegaconf import OmegaConf
from sleap_nn.models import ModelTrainer

config= OmegaConfig.create() # data config, model config and trainer config

m = ModelTrainer(config)
m.train() # calls the trainer.fit() from Lightning
# the checkpoints are saved in the specified directory
@gitttt-1234
Copy link
Collaborator Author

PR #29

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant