Skip to content

Commit

Permalink
Merge 4db10a9 into 7b3b992
Browse files Browse the repository at this point in the history
  • Loading branch information
drewoldag authored Nov 27, 2024
2 parents 7b3b992 + 4db10a9 commit b252cde
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 24 deletions.
20 changes: 19 additions & 1 deletion docs/notebooks/train_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,24 @@
">> fibad train --runtime-config ./results/<timestamped_directory>/runtime_config.toml\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%tensorboard --logdir ./results"
]
}
],
"metadata": {
Expand All @@ -89,7 +107,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
"version": "3.12.4"
}
},
"nbformat": 4,
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ dependencies = [
"toml", # Used to load configuration files as dictionaries
"torch", # Used for CNN model and in train.py
"torchvision", # Used in hsc data loader, example autoencoder, and CNN model data set
"tensorboardX", # Used to log training metrics
"tensorboard", # Used to log training metrics
"GPUtil", # Used to gather GPU usage information
"schwimmbad", # Used to speedup hsc data loader file scans
]

Expand Down
27 changes: 25 additions & 2 deletions src/fibad/data_sets/example_cifar_data_set.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# ruff: noqa: D101, D102
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.datasets import CIFAR10

from .data_set_registry import fibad_data_set
Expand All @@ -16,12 +18,33 @@ def __init__(self, config, split: str):
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

if split not in ["train", "test"]:
if split not in ["train", "validate", "test"]:
RuntimeError("CIFAR10 dataset only supports 'train' and 'test' splits.")

train = split == "train"
train = split != "test"

super().__init__(root=config["general"]["data_dir"], train=train, download=True, transform=transform)

if train:
num_train = len(self)
indices = list(range(num_train))
split = int(np.floor(config["data_set"]["validate_size"] * num_train))

random_seed = None
if config["data_set"]["seed"]:
random_seed = config["data_set"]["seed"]
np.random.seed(random_seed)
np.random.shuffle(indices)

train_idx, valid_idx = indices[split:], indices[:split]

#! These two "samplers" are used by PyTorch's DataLoader to split the
#! dataset into training and validation sets. Using Samplers is mutually
#! exclusive with using "shuffle" in the DataLoader.
#! If a user doesn't define a Sampler, the default behavior of pytorch-ignite
#! is to shuffle the data unless `shuffle = False` in the config.
self.train_sampler = SubsetRandomSampler(train_idx)
self.validation_sampler = SubsetRandomSampler(valid_idx)

def shape(self):
return (3, 32, 32)
8 changes: 4 additions & 4 deletions src/fibad/data_sets/hsc_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ def __init__(self, config, split: Union[str, None]):
self._set_split(split)

def _create_splits(self, config):
seed = config["prepare"]["seed"] if config["prepare"]["seed"] else None
seed = config["data_set"]["seed"] if config["data_set"]["seed"] else None

# Init the splits based on config values
train_size = config["prepare"]["train_size"] if config["prepare"]["train_size"] else None
test_size = config["prepare"]["test_size"] if config["prepare"]["test_size"] else None
validate_size = config["prepare"]["validate_size"] if config["prepare"]["validate_size"] else None
train_size = config["data_set"]["train_size"] if config["data_set"]["train_size"] else None
test_size = config["data_set"]["test_size"] if config["data_set"]["test_size"] else None
validate_size = config["data_set"]["validate_size"] if config["data_set"]["validate_size"] else None

# Convert all values specified as counts into ratios of the underlying container
if isinstance(train_size, int):
Expand Down
16 changes: 9 additions & 7 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,6 @@ filters = false
# Implementation is dataset class dependent. Default is false meaning now filtering.
filter_catalog = false

[data_loader]
# Default PyTorch DataLoader parameters
batch_size = 32
shuffle = true
num_workers = 2

[prepare]
# How to split the data between training and eval sets.
# The semantics are borrowed from scikit-learn's train-test-split, and HF Dataset's train-test-split function
# It is an error for these values to add to more than 1.0 as ratios or the size of the dataset if expressed
Expand Down Expand Up @@ -134,6 +127,15 @@ test_size = 0.6
# a system source at runtime.
seed = false

[data_loader]
# Default PyTorch DataLoader parameters
batch_size = 32

# We could remove this potentially - pytorch-ignite will default to shuffle=True
# If the user wanted to explicitly require no shuffling, they could set this to false.
shuffle = false
num_workers = 2

[predict]
model_weights_file = false
batch_size = 32
Expand Down
35 changes: 35 additions & 0 deletions src/fibad/gpu_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import time
from threading import Thread

import GPUtil


class GpuMonitor(Thread):
"""General GPU monitor that runs in a separate thread and logs GPU metrics
to Tensorboard.
"""

def __init__(self, tensorboard_logger, interval_seconds=1):
super().__init__()
self.stopped = False
self.delay = interval_seconds # Seconds between calls to GPUtil
self.start_time = time.time()
self.tensorboard_logger = tensorboard_logger
self.start()

def run(self):
"""Run loop that logs GPU metrics every `self.delay` seconds."""
while not self.stopped:
gpus = GPUtil.getGPUs()
step = time.time() - self.start_time
for gpu in gpus:
gpu_name = f"GPU_{gpu.id}"
self.tensorboard_logger.add_scalar(f"{gpu_name}/load", gpu.load * 100, step)
self.tensorboard_logger.add_scalar(
f"{gpu_name}/memory_utilization", gpu.memoryUtil * 100, step
)
time.sleep(self.delay)

def stop(self):
"""Stop the monitoring thread."""
self.stopped = True
104 changes: 100 additions & 4 deletions src/fibad/pytorch_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import torch
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine
from tensorboardX import SummaryWriter
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset

from fibad.config_utils import ConfigDict
from fibad.data_sets.data_set_registry import fetch_data_set_class
Expand Down Expand Up @@ -46,7 +47,7 @@ def setup_model_and_dataset(config: ConfigDict, split: str) -> tuple:
return model, data_set


def dist_data_loader(data_set: Dataset, config: ConfigDict):
def dist_data_loader(data_set: Dataset, config: ConfigDict, split: str):
"""Create a Pytorch Ignite distributed data loader
Parameters
Expand All @@ -55,13 +56,26 @@ def dist_data_loader(data_set: Dataset, config: ConfigDict):
A Pytorch Dataset object
config : ConfigDict
Fibad runtime configuration
split : str
The name of the split we want to use from the data set.
Returns
-------
Dataloader (or an ignite-wrapped equivalent)
This is the distributed dataloader, formed by calling ignite.distributed.auto_dataloader
"""
return idist.auto_dataloader(data_set, **config["data_loader"])

#! Here we are allowing `train_sampler` and `validation_sampler` to become magic words.
#! It would be nice to find a way to do this that doesn't require external libraries to
#! know about these.
if hasattr(data_set, "train_sampler") and split == "train":
sampler = data_set.train_sampler
elif hasattr(data_set, "validation_sampler") and split == "validate":
sampler = data_set.validation_sampler
else:
sampler = None

return idist.auto_dataloader(data_set, sampler=sampler, **config["data_loader"])


def create_engine(funcname: str, device: torch.device, model: torch.nn.Module):
Expand Down Expand Up @@ -160,7 +174,79 @@ def log_total_time(evaluator):
return evaluator


def create_trainer(model: torch.nn.Module, config: ConfigDict, results_directory: Path) -> Engine:
#! There will likely be a significant amount of code duplication between the
#! `create_trainer` and `create_validator` functions. We should find a way to
#! refactor this code to reduce duplication.
def create_validator(
model: torch.nn.Module,
config: ConfigDict,
results_directory: Path,
tensorboardx_logger: SummaryWriter,
validation_data_loader: DataLoader,
trainer: Engine,
) -> Engine:
"""This function creates a Pytorch Ignite engine object that will be used to
validate the model.
Parameters
----------
model : torch.nn.Module
The model to train
config : ConfigDict
Fibad runtime configuration
results_directory : Path
The directory where training results will be saved
tensorboardx_logger : SummaryWriter
The tensorboard logger object
validation_data_loader : DataLoader
The data loader for the validation data
trainer : Engine
The engine object that will be used to train the model. We will use specific
hooks in the trainer to determine when to run the validation engine.
Returns
-------
pytorch-ignite.Engine
Engine object that will be used to train the model.
"""

device = idist.device()
model = idist.auto_model(model)

#! Need to figure out the appropriate way to switch the model between .train()
#! and .eval() mode. We aren't doing that here - so the model is being trained
#! during validation!
validator = create_engine("train_step", device, model)

@validator.on(Events.STARTED)
def set_model_to_eval_mode():
model.eval()

@validator.on(Events.COMPLETED)
def set_model_to_train_mode():
model.train()

@validator.on(Events.EPOCH_COMPLETED)
def log_training_loss():
logger.info(f"Validation run time: {validator.state.times['EPOCH_COMPLETED']:.2f}[s]")
logger.info(f"Validation metrics: {validator.state.output}")

@trainer.on(Events.EPOCH_COMPLETED)
def run_validation():
validator.run(validation_data_loader)

def log_validation_loss(validator, trainer):
step = trainer.state.get_event_attrib_value(Events.EPOCH_COMPLETED)
tensorboardx_logger.add_scalar("training/validation/loss", validator.state.output["loss"], step)

validator.add_event_handler(Events.EPOCH_COMPLETED, log_validation_loss, trainer)

return validator


def create_trainer(
model: torch.nn.Module, config: ConfigDict, results_directory: Path, tensorboardx_logger: SummaryWriter
) -> Engine:
"""This function is originally copied from here:
https://github.com/pytorch-ignite/examples/blob/main/tutorials/intermediate/cifar10-distributed.py#L164
Expand All @@ -174,6 +260,8 @@ def create_trainer(model: torch.nn.Module, config: ConfigDict, results_directory
Fibad runtime configuration
results_directory : Path
The directory where training results will be saved
tensorboardx_logger : SummaryWriter
The tensorboard logger object
Returns
-------
Expand All @@ -193,6 +281,9 @@ def create_trainer(model: torch.nn.Module, config: ConfigDict, results_directory
"trainer": trainer,
}

#! We may want to move the checkpointing logic over to the `validator`.
#! It was created here initially because this was the only place where the
#! model training was happening.
latest_checkpoint = Checkpoint(
to_save,
DiskSaver(results_directory, require_empty=False),
Expand Down Expand Up @@ -227,6 +318,11 @@ def log_training_start(trainer):
def log_epoch_start(trainer):
logger.debug(f"Starting epoch {trainer.state.epoch}")

@trainer.on(Events.ITERATION_COMPLETED(every=10))
def log_training_loss_tensorboard(trainer):
step = trainer.state.get_event_attrib_value(Events.ITERATION_COMPLETED)
tensorboardx_logger.add_scalar("training/training/loss", trainer.state.output["loss"], step)

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(trainer):
logger.info(f"Epoch {trainer.state.epoch} run time: {trainer.state.times['EPOCH_COMPLETED']:.2f}[s]")
Expand Down
27 changes: 23 additions & 4 deletions src/fibad/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging

from tensorboardX import SummaryWriter

from fibad.config_utils import create_results_dir, log_runtime_config
from fibad.pytorch_ignite import create_trainer, dist_data_loader, setup_model_and_dataset
from fibad.gpu_monitor import GpuMonitor
from fibad.pytorch_ignite import create_trainer, create_validator, dist_data_loader, setup_model_and_dataset

logger = logging.getLogger(__name__)

Expand All @@ -19,16 +22,32 @@ def run(config):
results_dir = create_results_dir(config, "train")
log_runtime_config(config, results_dir)

# Create a tensorboardX logger
tensorboardx_logger = SummaryWriter(log_dir=results_dir)

# Instantiate the model and dataset
model, data_set = setup_model_and_dataset(config, split=config["train"]["split"])
data_loader = dist_data_loader(data_set, config)

# Create a data loader for the training set
train_data_loader = dist_data_loader(data_set, config, "train")

# Create validation_data_loader if a validation split is defined in data_set
validation_data_loader = dist_data_loader(data_set, config, "validate")

# Create trainer, a pytorch-ignite `Engine` object
trainer = create_trainer(model, config, results_dir)
trainer = create_trainer(model, config, results_dir, tensorboardx_logger)

# Create a validator if a validation data loader is available
if validation_data_loader is not None:
create_validator(model, config, results_dir, tensorboardx_logger, validation_data_loader, trainer)

monitor = GpuMonitor(tensorboard_logger=tensorboardx_logger)
# Run the training process
trainer.run(data_loader, max_epochs=config["train"]["epochs"])
trainer.run(train_data_loader, max_epochs=config["train"]["epochs"])

# Save the trained model
model.save(results_dir / config["train"]["weights_filepath"])
monitor.stop()

logger.info("Finished Training")
tensorboardx_logger.close()
2 changes: 0 additions & 2 deletions tests/fibad/test_hsc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ def mkconfig(
"crop_to": crop_to,
"filters": filters,
"filter_catalog": filter_catalog,
},
"prepare": {
"seed": seed,
"train_size": train_size,
"test_size": test_size,
Expand Down

0 comments on commit b252cde

Please sign in to comment.