Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multistep training with batch_size >=1 per GPU #139

Merged
merged 19 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
ae5bf40
Initial commit of loading sequence for the new datasets and dataloade…
jsschreck Dec 23, 2024
8755d28
Updated train_multi + bug fixes
jsschreck Dec 24, 2024
8232708
Bug updates post milti-step training tests
jsschreck Dec 24, 2024
428cc1d
Still working out daemon issues main vs imported
jsschreck Dec 24, 2024
f59ff3c
Added and tested single-step within the new scheme; added train_unive…
jsschreck Dec 26, 2024
d50849e
Fixed tqdm bug and tested this trainer against grad-accum for single …
jsschreck Dec 26, 2024
6b6093a
Adding (depcreated) singlestep dataset to datasets directory
jsschreck Dec 26, 2024
cb3d266
Cleaning up redundant method calls, adding logging details
jsschreck Dec 27, 2024
fd62daa
Updating logging messages for edge cases
jsschreck Dec 27, 2024
9113420
Fixed import error
jsschreck Dec 27, 2024
3c181c0
Fixed the batch size * history len bug
jsschreck Dec 28, 2024
cf8427f
Fixed bug in MultiprocessingBatcher with indices assigned to workers
jsschreck Dec 29, 2024
e83484d
Removed prefetch option from ERA5_MultiStep_Batcher dataloader b/c of…
jsschreck Dec 29, 2024
6e80b86
Added pseudo-sampler to enable prefetch with ERA5_MultiStep_Batcher, …
jsschreck Dec 30, 2024
08420fb
Linting
jsschreck Dec 30, 2024
c85c987
Fixed a few bugs related to dataset len and batches per epoch
jsschreck Dec 31, 2024
4721f30
Added unversal key for the trainer to use grad_accum
jsschreck Dec 31, 2024
c37eafb
Added example config for version 2.0 which will support the changes i…
jsschreck Dec 31, 2024
0be3d9a
Final update of the config before merging
jsschreck Jan 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 11 additions & 258 deletions applications/train_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,26 @@
import shutil
import logging
import warnings
from glob import glob

from pathlib import Path
from argparse import ArgumentParser
from echo.src.base_objective import BaseObjective

import torch
from torch.cuda.amp import GradScaler
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from credit.distributed import distributed_model_wrapper, setup, get_rank_info

from credit.seed import seed_everything
from credit.loss import VariableTotalLoss2D

# from credit.datasets.sequential_multistep import DistributedSequentialDataset
from credit.datasets.era5_multistep import ERA5_and_Forcing_MultiStep
from credit.transforms import load_transforms
from credit.scheduler import load_scheduler
from credit.trainers import load_trainer
from credit.parser import credit_main_parser, training_data_check
from credit.datasets.load_dataset_and_dataloader import (
load_dataset,
load_dataloader
)

from credit.metrics import LatWeightedMetrics
from credit.pbs import launch_script, launch_script_mpi
Expand All @@ -49,88 +48,6 @@
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


def load_dataset_and_sampler(
conf,
all_ERA_files,
surface_files,
dyn_forcing_files,
diagnostic_files,
world_size,
rank,
is_train=True,
):
"""
Load the dataset and sampler for training or validation.

Args:
conf (dict): Configuration dictionary containing dataset and training parameters.
all_ERA_files (list): List of file paths for the dataset.
surface_files (list): List of file paths for the surface data.
dyn_forcing_files (list): List of file paths for the dyn_forcing data.
diagnostic_files (list): List of file paths for the diagnostic data.
world_size (int): Number of processes participating in the job.
rank (int): Rank of the current process.
is_train (bool): Flag indicating whether the dataset is for training or validation.

Returns:
tuple: A tuple containing the dataset and the distributed sampler.
"""
seed = conf["seed"]
# --------------------------------------------------- #
# separate training set and validation set cases
if is_train:
history_len = conf["data"]["history_len"]
forecast_len = conf["data"]["forecast_len"]
name = "training"
else:
history_len = conf["data"]["valid_history_len"]
forecast_len = conf["data"]["valid_forecast_len"]
name = "validation"

# transforms
transforms = load_transforms(conf)

# Z-score
dataset = ERA5_and_Forcing_MultiStep(
varname_upper_air=conf["data"]["variables"],
varname_surface=conf["data"]["surface_variables"],
varname_dyn_forcing=conf["data"]["dynamic_forcing_variables"],
varname_forcing=conf["data"]["forcing_variables"],
varname_static=conf["data"]["static_variables"],
varname_diagnostic=conf["data"]["diagnostic_variables"],
filenames=all_ERA_files,
filename_surface=surface_files,
filename_dyn_forcing=dyn_forcing_files,
filename_forcing=conf["data"]["save_loc_forcing"],
filename_static=conf["data"]["save_loc_static"],
filename_diagnostic=diagnostic_files,
history_len=history_len,
forecast_len=forecast_len,
skip_periods=conf["data"]["skip_periods"],
max_forecast_len=conf["data"]["max_forecast_len"],
transform=transforms,
rank=rank,
world_size=world_size,
seed=seed,
)

# Pytorch sampler
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
seed=seed,
shuffle=is_train,
drop_last=True,
)

logging.info(
f" Loaded a {name} ERA dataset, and a distributed sampler (forecast length = {forecast_len + 1})"
)

return dataset, sampler


def load_model_states_and_optimizer(conf, model, device):
"""
Load the model states, optimizer, scheduler, and gradient scaler.
Expand Down Expand Up @@ -345,7 +262,6 @@ def main(rank, world_size, conf, backend, trial=False):
setup(rank, world_size, conf["trainer"]["mode"], backend)

# infer device id from rank

device = (
torch.device(f"cuda:{rank % torch.cuda.device_count()}")
if torch.cuda.is_available()
Expand All @@ -357,207 +273,44 @@ def main(rank, world_size, conf, backend, trial=False):
seed = conf["seed"]
seed_everything(seed)

train_batch_size = conf["trainer"]["train_batch_size"]
valid_batch_size = conf["trainer"]["valid_batch_size"]

# get file names
all_ERA_files = sorted(glob(conf["data"]["save_loc"]))

# <------------------------------------------ std_new or 'std_cached'
if conf["data"]["scaler_type"] == "std_new" or "std_cached":
# check and glob surface files
if ("surface_variables" in conf["data"]) and (
len(conf["data"]["surface_variables"]) > 0
):
surface_files = sorted(glob(conf["data"]["save_loc_surface"]))

else:
surface_files = None

# check and glob dyn forcing files
if ("dynamic_forcing_variables" in conf["data"]) and (
len(conf["data"]["dynamic_forcing_variables"]) > 0
):
dyn_forcing_files = sorted(glob(conf["data"]["save_loc_dynamic_forcing"]))

else:
dyn_forcing_files = None

# check and glob diagnostic files
if ("diagnostic_variables" in conf["data"]) and (
len(conf["data"]["diagnostic_variables"]) > 0
):
diagnostic_files = sorted(glob(conf["data"]["save_loc_diagnostic"]))

else:
diagnostic_files = None
# Load the dataset using the provided dataset_type
train_dataset = load_dataset(conf, rank=rank, world_size=world_size, is_train=True)
valid_dataset = load_dataset(conf, rank=rank, world_size=world_size, is_train=False)

# -------------------------------------------------- #
# import training / validation years from conf

if "train_years" in conf["data"]:
train_years_range = conf["data"]["train_years"]
else:
train_years_range = [1979, 2014]

if "valid_years" in conf["data"]:
valid_years_range = conf["data"]["valid_years"]
else:
valid_years_range = [2014, 2018]

# convert year info to str for file name search
train_years = [
str(year) for year in range(train_years_range[0], train_years_range[1])
]
valid_years = [
str(year) for year in range(valid_years_range[0], valid_years_range[1])
]

# Filter the files for training / validation
train_files = [
file for file in all_ERA_files if any(year in file for year in train_years)
]
valid_files = [
file for file in all_ERA_files if any(year in file for year in valid_years)
]

# <----------------------------------- std_new or 'std_cached'
if conf["data"]["scaler_type"] == "std_new" or "std_cached":
if surface_files is not None:
train_surface_files = [
file
for file in surface_files
if any(year in file for year in train_years)
]
valid_surface_files = [
file
for file in surface_files
if any(year in file for year in valid_years)
]

else:
train_surface_files = None
valid_surface_files = None

if dyn_forcing_files is not None:
train_dyn_forcing_files = [
file
for file in dyn_forcing_files
if any(year in file for year in train_years)
]
valid_dyn_forcing_files = [
file
for file in dyn_forcing_files
if any(year in file for year in valid_years)
]

else:
train_dyn_forcing_files = None
valid_dyn_forcing_files = None

if diagnostic_files is not None:
train_diagnostic_files = [
file
for file in diagnostic_files
if any(year in file for year in train_years)
]
valid_diagnostic_files = [
file
for file in diagnostic_files
if any(year in file for year in valid_years)
]

else:
train_diagnostic_files = None
valid_diagnostic_files = None

# load dataset and sampler
train_dataset, train_sampler = load_dataset_and_sampler(
conf,
train_files,
train_surface_files,
train_dyn_forcing_files,
train_diagnostic_files,
world_size,
rank,
is_train=True,
)
# validation set and sampler
valid_dataset, valid_sampler = load_dataset_and_sampler(
conf,
valid_files,
valid_surface_files,
valid_dyn_forcing_files,
valid_diagnostic_files,
world_size,
rank,
is_train=False,
)

# setup the dataloder for this process

train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=False,
sampler=train_sampler,
pin_memory=True,
persistent_workers=False,
num_workers=1, # multiprocessing is handled in the dataset
drop_last=True,
prefetch_factor=4,
)

valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=valid_batch_size,
shuffle=False,
sampler=valid_sampler,
pin_memory=False,
num_workers=1, # multiprocessing is handled in the dataset
drop_last=True,
prefetch_factor=4,
)
# Load the dataloader
train_loader = load_dataloader(conf, train_dataset, rank=rank, world_size=world_size, is_train=True)
valid_loader = load_dataloader(conf, valid_dataset, rank=rank, world_size=world_size, is_train=False)

# model

m = load_model(conf)

# have to send the module to the correct device first

m.to(device)

# move out of eager-mode
if conf["trainer"].get("compile", False):
m = torch.compile(m)

# Wrap in DDP or FSDP module, or none

model = distributed_model_wrapper(conf, m, device)

# Load model weights (if any), an optimizer, scheduler, and gradient scaler

conf, model, optimizer, scheduler, scaler = load_model_states_and_optimizer(
conf, model, device
)

# Train and validation losses

train_criterion = VariableTotalLoss2D(conf)
valid_criterion = VariableTotalLoss2D(conf, validation=True)

# Optional load stopping probability annealer

# Set up some metrics

metrics = LatWeightedMetrics(conf)

# Initialize a trainer object
trainer_cls = load_trainer(conf)
trainer = trainer_cls(model, rank, module=(conf["trainer"]["mode"] == "ddp"))

# Fit the model

result = trainer.fit(
conf,
train_loader=train_loader,
Expand Down Expand Up @@ -721,7 +474,7 @@ def train(self, trial, conf):
# track hyperparameters and run metadata
config=conf,
)

seed = conf["seed"]
seed_everything(seed)

Expand Down
Loading
Loading