Skip to content

Commit

Permalink
Improve comments and code formatting (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashleve authored May 17, 2022
1 parent 062f6f3 commit 30bfcf0
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 57 deletions.
39 changes: 24 additions & 15 deletions src/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,22 @@ class MNISTDataModule(LightningDataModule):
"""Example of LightningDataModule for MNIST dataset.
A DataModule implements 5 key methods:
- prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
- setup (things to do on every accelerator in distributed mode)
- train_dataloader (the training dataloader)
- val_dataloader (the validation dataloader(s))
- test_dataloader (the test dataloader(s))
def prepare_data(self):
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
# download data, split, etc...
def setup(self, stage):
# things to do on every process in DDP
# load data, split, set variables, etc...
def train_dataloader(self):
# return train dataloader
def val_dataloader(self):
# return validation dataloader
def test_dataloader(self):
# return test dataloader
def teardown(self):
# called on every process in DDP
# clean up after fit or test
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Expand All @@ -35,6 +46,7 @@ def __init__(
super().__init__()

# this line allows to access init params with 'self.hparams' attribute
# it also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False)

# data transformations
Expand All @@ -51,23 +63,17 @@ def num_classes(self) -> int:
return 10

def prepare_data(self):
"""Download data if needed.
This method is called only from a single GPU.
Do not use it to assign state (self.x = y).
"""
"""Download data if needed."""
MNIST(self.hparams.data_dir, train=True, download=True)
MNIST(self.hparams.data_dir, train=False, download=True)

def setup(self, stage: Optional[str] = None):
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by lightning when doing `trainer.fit()` and `trainer.test()`,
so be careful not to execute the random split twice! The `stage` can be used to
differentiate whether it's called before trainer.fit()` or `trainer.test()`.
This method is called by Lightning with both `trainer.fit()` and `trainer.test()`,
so be careful not to execute the random split twice!
"""

# load datasets only if they're not loaded already
# load and split datasets only if not loaded already
if not self.data_train and not self.data_val and not self.data_test:
trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)
testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms)
Expand Down Expand Up @@ -104,3 +110,6 @@ def test_dataloader(self):
pin_memory=self.hparams.pin_memory,
shuffle=False,
)

def teardown(self):
pass
7 changes: 4 additions & 3 deletions src/models/mnist_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
class MNISTLitModule(LightningModule):
"""Example of LightningModule for MNIST classification.
A LightningModule organizes your PyTorch code into 5 sections:
A LightningModule organizes your PyTorch code into 6 sections:
- Computations (init).
- Train loop (training_step)
- Validation loop (validation_step)
- Test loop (test_step)
- Optimizers (configure_optimizers)
- Prediction Loop (predict_step)
- Optimizers and LR Schedulers (configure_optimizers)
Read the docs:
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
Expand Down Expand Up @@ -118,7 +119,7 @@ def configure_optimizers(self):
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or similar you might have multiple.
See examples here:
Examples:
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
"""
return torch.optim.Adam(
Expand Down
13 changes: 5 additions & 8 deletions src/testing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,26 @@ def test(config: DictConfig) -> None:

assert config.ckpt_path

# Init lightning datamodule
# init lightning datamodule
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)

# Init lightning model
# init lightning model
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(config.model)

# Init lightning loggers
# init lightning loggers
logger: List[LightningLoggerBase] = []
if "logger" in config:
for _, lg_conf in config.logger.items():
if "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))

# Init lightning trainer
# init lightning trainer
log.info(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(config.trainer, logger=logger)

# Log hyperparameters
if trainer.logger:
trainer.logger.log_hyperparams({"ckpt_path": config.ckpt_path})

# test the model
log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule, ckpt_path=config.ckpt_path)
28 changes: 13 additions & 15 deletions src/training_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,37 @@ def train(config: DictConfig) -> Optional[float]:
Optional[float]: Metric score for hyperparameter optimization.
"""

# Init lightning datamodule
# init lightning datamodule
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)

# Init lightning model
# init lightning model
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(config.model)

# Init lightning callbacks
# init lightning callbacks
callbacks: List[Callback] = []
if "callbacks" in config:
for _, cb_conf in config.callbacks.items():
if "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))

# Init lightning loggers
# init lightning loggers
logger: List[LightningLoggerBase] = []
if "logger" in config:
for _, lg_conf in config.logger.items():
if "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))

# Init lightning trainer
# init lightning trainer
log.info(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)

# Send some parameters from config to all lightning loggers
# send hyperparameters to loggers
log.info("Logging hyperparameters!")
utils.log_hyperparameters(
config=config,
Expand All @@ -63,24 +63,22 @@ def train(config: DictConfig) -> Optional[float]:
logger=logger,
)

# Train the model
# train the model
if config.get("train"):
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=config.get("ckpt_path"))

# Get metric score for hyperparameter optimization
# get metric score for hyperparameter optimization
metric_name = config.get("optimized_metric")
score = utils.get_metric_value(metric_name, trainer) if metric_name else None

# Test the model
# test the model
ckpt_path = "best" if config.get("train") and not config.trainer.get("fast_dev_run") else None
if config.get("test"):
ckpt_path = "best"
if not config.get("train") or config.trainer.get("fast_dev_run"):
ckpt_path = None
log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)

# Make sure everything closed properly
# make sure everything closed properly
log.info("Finalizing!")
utils.finish(
config=config,
Expand All @@ -91,9 +89,9 @@ def train(config: DictConfig) -> Optional[float]:
logger=logger,
)

# Print path to best checkpoint
# print path to best checkpoint
if not config.trainer.get("fast_dev_run") and config.get("train"):
log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}")

# Return metric score for hyperparameter optimization
# return metric score for hyperparameter optimization
return score
15 changes: 5 additions & 10 deletions src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,8 @@ def get_logger(name=__name__) -> logging.Logger:

# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in (
"debug",
"info",
"warning",
"error",
"exception",
"fatal",
"critical",
):
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
for level in logging_levels:
setattr(logger, level, rank_zero_only(getattr(logger, level)))

return logger
Expand Down Expand Up @@ -94,7 +87,7 @@ def print_config(

for field in print_order:
queue.append(field) if field in config else log.info(
f"Field '{field}' not found in config."
f"Field '{field}' not found in config. Skipping '{field}' config printing..."
)

for field in config:
Expand Down Expand Up @@ -157,6 +150,8 @@ def log_hyperparameters(
hparams["seed"] = config["seed"]
if "callbacks" in config:
hparams["callbacks"] = config["callbacks"]
if "ckpt_path" in config:
hparams["ckpt_path"] = config.ckpt_path

# send hparams to all loggers
trainer.logger.log_hyperparams(hparams)
Expand Down
6 changes: 3 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
@hydra.main(config_path="configs/", config_name="test.yaml")
def main(config: DictConfig):

# Imports can be nested inside @hydra.main to optimize tab completion
# imports can be nested inside @hydra.main to optimize tab completion
# https://github.com/facebookresearch/hydra/issues/934
from src import utils
from src.testing_pipeline import test

# Applies optional utilities
# applies optional utilities
utils.extras(config)

# Evaluate model
# evaluate model
return test(config)


Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
@hydra.main(config_path="configs/", config_name="train.yaml")
def main(config: DictConfig):

# Imports can be nested inside @hydra.main to optimize tab completion
# imports can be nested inside @hydra.main to optimize tab completion
# https://github.com/facebookresearch/hydra/issues/934
from src import utils
from src.training_pipeline import train

# Applies optional utilities
# applies optional utilities
utils.extras(config)

# Train model
# train model
return train(config)


Expand Down

0 comments on commit 30bfcf0

Please sign in to comment.