diff --git a/model/unet_model.py b/model/unet_model.py new file mode 100644 index 0000000..c216310 --- /dev/null +++ b/model/unet_model.py @@ -0,0 +1,116 @@ +import copy +import torch +from torch import nn +from torch.cuda import amp + +class UNetModel: + """Core model architecture implementation for diffusion models.""" + def __init__(self, conf): + """ + Initialize the UNet model. + + Args: + conf: Configuration object containing model parameters + """ + self.conf = conf + self.model = conf.make_model_conf().make_model() + self.ema_model = copy.deepcopy(self.model) + self.ema_model.requires_grad_(False) + self.ema_model.eval() + + # Calculate model size + model_size = 0 + for param in self.model.parameters(): + model_size += param.data.nelement() + print('Model params: %.2f M' % (model_size / 1024 / 1024)) + + # Initialize samplers + self.sampler = conf.make_diffusion_conf().make_sampler() + self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() + self.T_sampler = conf.make_T_sampler() + + # Initialize latent samplers if needed + if conf.train_mode.use_latent_net(): + self.latent_sampler = conf.make_latent_diffusion_conf().make_sampler() + self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf().make_sampler() + else: + self.latent_sampler = None + self.eval_latent_sampler = None + + def update_ema(self, decay): + """ + Update the exponential moving average model. + + Args: + decay: EMA decay rate + """ + self._ema(self.model, self.ema_model, decay) + + def _ema(self, source, target, decay): + """ + Apply exponential moving average update. + + Args: + source: Source model + target: Target model (EMA) + decay: EMA decay rate + """ + source_dict = source.state_dict() + target_dict = target.state_dict() + for key in source_dict.keys(): + target_dict[key].data.copy_(target_dict[key].data * decay + + source_dict[key].data * (1 - decay)) + + def encode(self, x): + """ + Encode input using the model's encoder. + + Args: + x: Input tensor + + Returns: + Encoded representation + """ + assert self.conf.model_type.has_autoenc() + cond = self.ema_model.encoder.forward(x) + return cond + + def encode_stochastic(self, x, cond, T=None): + """ + Stochastically encode input. + + Args: + x: Input tensor + cond: Conditioning tensor + T: Number of diffusion steps + + Returns: + Stochastically encoded sample + """ + if T is None: + sampler = self.eval_sampler + else: + sampler = self.conf._make_diffusion_conf(T).make_sampler() + out = sampler.ddim_reverse_sample_loop(self.ema_model, + x, + model_kwargs={'cond': cond}) + return out['sample'] + + def forward(self, noise=None, x_start=None, use_ema=False): + """ + Forward pass through the model. + + Args: + noise: Input noise + x_start: Starting point for diffusion + use_ema: Whether to use EMA model + + Returns: + Generated sample + """ + with amp.autocast(False): + model = self.ema_model if use_ema else self.model + gen = self.eval_sampler.sample(model=model, + noise=noise, + x_start=x_start) + return gen diff --git a/preprocessing/unet_preprocessing.py b/preprocessing/unet_preprocessing.py new file mode 100644 index 0000000..67bc6de --- /dev/null +++ b/preprocessing/unet_preprocessing.py @@ -0,0 +1,115 @@ +import torch +from torch.utils.data import DataLoader, TensorDataset, ConcatDataset +from dataset import * +from dist_utils import get_world_size, get_rank +import numpy as np + +class UNetPreprocessor: + """Handles data preprocessing and dataset creation for UNet models.""" + def __init__(self, conf): + """ + Initialize the preprocessor. + + Args: + conf: Configuration object + """ + self.conf = conf + self.train_data = None + self.val_data = None + + def setup(self, seed=None, global_rank=0): + """ + Set up datasets with proper seeding. + + Args: + seed: Random seed + global_rank: Current process rank + """ + # Set seed for each worker separately + if seed is not None: + seed_worker = seed * get_world_size() + global_rank + np.random.seed(seed_worker) + torch.manual_seed(seed_worker) + torch.cuda.manual_seed(seed_worker) + print('local seed:', seed_worker) + + # Create datasets + self.train_data = self.conf.make_dataset() + print('train data:', len(self.train_data)) + self.val_data = self.train_data + print('val data:', len(self.val_data)) + + def create_train_dataloader(self, batch_size, drop_last=True, shuffle=True): + """ + Create training dataloader. + + Args: + batch_size: Batch size + drop_last: Whether to drop the last incomplete batch + shuffle: Whether to shuffle the data + + Returns: + DataLoader for training + """ + if not hasattr(self, "train_data") or self.train_data is None: + self.setup() + + # Create a DataLoader directly + dataloader = torch.utils.data.DataLoader( + self.train_data, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + num_workers=0, # Use 0 to avoid pickling issues + persistent_workers=False + ) + return SizedIterableWrapper(dataloader, len(self.train_data)) + + def create_val_dataloader(self, batch_size, drop_last=False): + """ + Create validation dataloader. + + Args: + batch_size: Batch size + drop_last: Whether to drop the last incomplete batch + + Returns: + DataLoader for validation + """ + if not hasattr(self, "val_data") or self.val_data is None: + self.setup() + + dataloader = torch.utils.data.DataLoader( + self.val_data, + batch_size=batch_size, + shuffle=False, + drop_last=drop_last, + num_workers=0, + persistent_workers=False + ) + return dataloader + + def create_latent_dataset(self, conds): + """ + Create a dataset from latent conditions. + + Args: + conds: Latent conditions tensor + + Returns: + TensorDataset containing the conditions + """ + return TensorDataset(conds) + + +class SizedIterableWrapper: + """Wrapper for iterables that provides a __len__ method.""" + def __init__(self, dataloader, length): + self.dataloader = dataloader + self._length = length + + def __iter__(self): + return iter(self.dataloader) + + def __len__(self): + return self._length diff --git a/source/experiment.py b/source/experiment.py index 29abee1..51646cd 100644 --- a/source/experiment.py +++ b/source/experiment.py @@ -1,1163 +1,125 @@ -import copy -import json import os -import re - -import numpy as np -import pandas as pd -import pytorch_lightning as pl +import argparse import torch -#from numpy.lib.function_base import flip -from numpy.lib._function_base_impl import flip -from pytorch_lightning import loggers as pl_loggers -from pytorch_lightning.callbacks import * -from torch import nn from torch.cuda import amp -from torch.distributions import Categorical -from torch.optim.optimizer import Optimizer -from torch.utils.data.dataset import ConcatDataset, TensorDataset -from torchvision.utils import make_grid, save_image - -from config import * -from dataset import * -from dist_utils import * -from lmdb_writer import * -from metrics import * -from renderer import * - - -class SizedIterableWrapper: - # The constructor accepts a dataloader and a length. - # 'dataloader' can be any iterable (like a list, generator, etc.), - # and 'length' represents the total number of items it is supposed to yield. - def __init__(self, dataloader, length): - self.dataloader = dataloader # Store the provided dataloader - self._length = length # Store the provided length - - # The __iter__ method makes the object iterable. - # It returns an iterator for the wrapped dataloader. - def __iter__(self): - return iter(self.dataloader) - - # The __len__ method returns the stored length. - # This is useful when you need to know how many items the dataloader should yield. - def __len__(self): - return self._length - - - -class LitModel(pl.LightningModule): - def __init__(self, conf: TrainConfig): - super().__init__() - assert conf.train_mode != TrainMode.manipulate - if conf.seed is not None: - pl.seed_everything(conf.seed) - - self.save_hyperparameters(conf.as_dict_jsonable()) - - self.conf = conf - - self.model = conf.make_model_conf().make_model() - self.ema_model = copy.deepcopy(self.model) - self.ema_model.requires_grad_(False) - self.ema_model.eval() - - model_size = 0 - for param in self.model.parameters(): - model_size += param.data.nelement() - print('Model params: %.2f M' % (model_size / 1024 / 1024)) - - self.sampler = conf.make_diffusion_conf().make_sampler() - self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() - - # this is shared for both model and latent - self.T_sampler = conf.make_T_sampler() - - if conf.train_mode.use_latent_net(): - self.latent_sampler = conf.make_latent_diffusion_conf( - ).make_sampler() - self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf( - ).make_sampler() - else: - self.latent_sampler = None - self.eval_latent_sampler = None - - # initial variables for consistent sampling - self.register_buffer( - 'x_T', - torch.randn(conf.sample_size, 3, conf.img_size, conf.img_size)) - - - #if conf.pretrain is not None: - # print(f'loading pretrain ... {conf.pretrain.name}') - # state = torch.load(conf.pretrain.path, map_location='cpu',weights_only=False) - # print('step:', state['global_step']) - # self.load_state_dict(state['state_dict'], strict=False) - - if conf.pretrain is not None: # Check if a pretrain configuration is provided - print( - f'loading pretrain ... {conf.pretrain.name}') # Print the name of the pretrain configuration being loaded - # Load the saved model state from the provided path. - # 'map_location' is set to 'cpu' to move the loaded tensors to CPU. - # 'weights_only=False' ensures the full state (not just the model weights) is loaded. - state = torch.load(conf.pretrain.path, map_location='cpu', weights_only=False) - print('step:', state['global_step']) # Print the current global step from the loaded checkpoint state - # Load the state dictionary into the model. - # 'strict=False' allows for some keys in the model's state dict to be missing or extra. - self.load_state_dict(state['state_dict'], strict=False) - - if conf.latent_infer_path is not None: - print('loading latent stats ...') - # same here, loading stuff - state = torch.load(conf.latent_infer_path, weights_only=False) - self.conds = state['conds'] - self.register_buffer('conds_mean', state['conds_mean'][None, :]) - self.register_buffer('conds_std', state['conds_std'][None, :]) - else: - self.conds_mean = None - self.conds_std = None - - def normalize(self, cond): - cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to( - self.device) - return cond - - def denormalize(self, cond): - cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to( - self.device) - return cond - - def sample(self, N, device, T=None, T_latent=None): - if T is None: - sampler = self.eval_sampler - latent_sampler = self.latent_sampler - else: - sampler = self.conf._make_diffusion_conf(T).make_sampler() - latent_sampler = self.conf._make_latent_diffusion_conf(T_latent).make_sampler() - - noise = torch.randn(N, - 3, - self.conf.img_size, - self.conf.img_size, - device=device) - pred_img = render_uncondition( - self.conf, - self.ema_model, - noise, - sampler=sampler, - latent_sampler=latent_sampler, - conds_mean=self.conds_mean, - conds_std=self.conds_std, - ) - pred_img = (pred_img + 1) / 2 - return pred_img - - def render(self, noise, cond=None, T=None): - if T is None: - sampler = self.eval_sampler - else: - sampler = self.conf._make_diffusion_conf(T).make_sampler() - - if cond is not None: - pred_img = render_condition(self.conf, - self.ema_model, - noise, - sampler=sampler, - cond=cond) - else: - pred_img = render_uncondition(self.conf, - self.ema_model, - noise, - sampler=sampler, - latent_sampler=None) - pred_img = (pred_img + 1) / 2 - return pred_img - - def encode(self, x): - # TODO: - assert self.conf.model_type.has_autoenc() - cond = self.ema_model.encoder.forward(x) - return cond - - def encode_stochastic(self, x, cond, T=None): - if T is None: - sampler = self.eval_sampler - else: - sampler = self.conf._make_diffusion_conf(T).make_sampler() - out = sampler.ddim_reverse_sample_loop(self.ema_model, - x, - model_kwargs={'cond': cond}) - return out['sample'] - - def forward(self, noise=None, x_start=None, ema_model: bool = False): - with amp.autocast(False): - if ema_model: - model = self.ema_model - else: - model = self.model - gen = self.eval_sampler.sample(model=model, - noise=noise, - x_start=x_start) - return gen - - def setup(self, stage=None) -> None: - """ - make datasets & seeding each worker separately - """ - ############################################## - # NEED TO SET THE SEED SEPARATELY HERE - if self.conf.seed is not None: - seed = self.conf.seed * get_world_size() + self.global_rank - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - print('local seed:', seed) - ############################################## - self.train_data = self.conf.make_dataset() - print('train data:', len(self.train_data)) - self.val_data = self.train_data - print('val data:', len(self.val_data)) - - - #def _train_dataloader(self, drop_last=True): - """ - #really make the dataloader - """ - # make sure to use the fraction of batch size - # the batch size is global! - #conf = self.conf.clone() - #conf.batch_size = self.batch_size - - #dataloader = conf.make_loader(self.train_data, - # shuffle=True, - # drop_last=drop_last) - #return dataloader - - def _train_dataloader(self, drop_last=True): - """ - Really make the dataloader. - """ - if not hasattr(self, "train_data"): - self.setup('fit') - if not hasattr(self, "train_data"): - raise ValueError( - "train_data is not initialized even after setup() call. Please ensure setup() properly initializes train_data." - ) - - # Clone the configuration and set the correct batch size. - conf = self.conf.clone() - conf.batch_size = self.batch_size - - # Create a DataLoader directly instead of make loader, picke issues and multiprocessing - dataloader = torch.utils.data.DataLoader( - self.train_data, - batch_size=conf.batch_size, - shuffle=True, - drop_last=drop_last, - num_workers=0, # Use 0 on Windows to avoid pickling issues. - persistent_workers=False - ) - return dataloader - - #def train_dataloader(self): - # """ - # return the dataloader, if diffusion mode => return image dataset - # if latent mode => return the inferred latent dataset - # """ - # print('on train dataloader start ...') - #if self.conf.train_mode.require_dataset_infer(): - # if self.conds is None: - # usually we load self.conds from a file - # so we do not need to do this again! - # self.conds = self.infer_whole_dataset() - # need to use float32! unless the mean & std will be off! - # (1, c) - # self.conds_mean.data = self.conds.float().mean(dim=0, - # keepdim=True) - # self.conds_std.data = self.conds.float().std(dim=0, - # keepdim=True) - # print('mean:', self.conds_mean.mean(), 'std:', - # self.conds_std.mean()) - - # return the dataset with pre-calculated conds - # conf = self.conf.clone() - # conf.batch_size = self.batch_size - # data = TensorDataset(self.conds) - # return conf.make_loader(data, shuffle=True) - #else: - # return self._train_dataloader() - - def train_dataloader(self): - """ - Return the dataloader: - - If in diffusion mode, return an image dataset. - - If in latent mode, return the inferred latent dataset. - """ - print('on train dataloader start ...') - - # Check if the current training mode requires dataset inference. - if self.conf.train_mode.require_dataset_infer(): - # If conditions (self.conds) are not already available, compute them. - if self.conds is None: - # Infer and set the complete dataset conditions. - # Typically, self.conds might be loaded from a file, avoiding re-computation. - self.conds = self.infer_whole_dataset() - - # Compute the mean of conditions as float32 to prevent precision issues. - # This is done along dimension 0, preserving the dimension for later operations. - self.conds_mean.data = self.conds.float().mean(dim=0, keepdim=True) - - # Compute the standard deviation of conditions as float32. - self.conds_std.data = self.conds.float().std(dim=0, keepdim=True) - - # Log the mean and standard deviation values for verification. - print('mean:', self.conds_mean.mean(), 'std:', self.conds_std.mean()) - - # Clone the current configuration to avoid modifying the original. - conf = self.conf.clone() - - # Set the batch size in the cloned configuration. - conf.batch_size = self.batch_size - - # Create a TensorDataset from the inferred conditions. - data = TensorDataset(self.conds) - - # Use the configuration to create a data loader with shuffling enabled. - loader = conf.make_loader(data, shuffle=True) - - # Return a wrapped loader that includes the explicit length of the dataset. - return SizedIterableWrapper(loader, len(data)) # PyLightning stuff - - else: - # If dataset inference isn't required, use the default training dataloader. - return self._train_dataloader() - - @property - def batch_size(self): - """ - local batch size for each worker - """ - ws = get_world_size() - assert self.conf.batch_size % ws == 0 - return self.conf.batch_size // ws - - @property - def num_samples(self): - """ - (global) batch size * iterations - """ - # batch size here is global! - # global_step already takes into account the accum batches - return self.global_step * self.conf.batch_size_effective - - def is_last_accum(self, batch_idx): - """ - is it the last gradient accumulation loop? - used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not - """ - return (batch_idx + 1) % self.conf.accum_batches == 0 - - def infer_whole_dataset(self, - with_render=False, - T_render=None, - render_save_path=None): - """ - predicting the latents given images using the encoder - - Args: - both_flips: include both original and flipped images; no need, it's not an improvement - with_render: whether to also render the images corresponding to that latent - render_save_path: lmdb output for the rendered images - """ - data = self.conf.make_dataset() - if isinstance(data, CelebAlmdb) and data.crop_d2c: - # special case where we need the d2c crop - data.transform = make_transform(self.conf.img_size, - flip_prob=0, - crop_d2c=True) - else: - data.transform = make_transform(self.conf.img_size, flip_prob=0) - - # data = SubsetDataset(data, 21) - - loader = self.conf.make_loader( - data, - shuffle=False, - drop_last=False, - batch_size=self.conf.batch_size_eval, - parallel=True, +from model.unet_model import UNetModel +from preprocessing.unet_preprocessing import UNetPreprocessor +from utils.unet_loader import UNetLoader +from training.unet_trainer import UNetTrainer + +def parse_args(): + parser = argparse.ArgumentParser(description="Train or evaluate diffusion models") + parser.add_argument("--config", type=str, required=True, help="Path to configuration file") + parser.add_argument("--mode", type=str, default="train", choices=["train", "eval"], help="Mode to run") + parser.add_argument("--gpus", type=str, default="0", help="Comma-separated list of GPU IDs") + parser.add_argument("--nodes", type=int, default=1, help="Number of nodes") + parser.add_argument("--eval_path", type=str, default=None, help="Path to checkpoint for evaluation") + parser.add_argument("--eval_programs", type=str, default=None, help="Comma-separated list of evaluation programs") + return parser.parse_args() + +def load_config(config_path): + # This is a placeholder - you'll need to implement config loading based on your actual config format + from config import TrainConfig + import importlib.util + + spec = importlib.util.spec_from_file_location("config_module", config_path) + config_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config_module) + + # Assuming the config file defines a 'conf' variable + return config_module.conf + +def main(): + args = parse_args() + + # Set visible GPUs + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + gpus = list(range(len(args.gpus.split(",")))) + + # Load configuration + conf = load_config(args.config) + + # Override config with command line arguments + if args.eval_path: + conf.eval_path = args.eval_path + if args.eval_programs: + conf.eval_programs = args.eval_programs.split(",") + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Initialize components + model = UNetModel(conf) + preprocessor = UNetPreprocessor(conf) + loader = UNetLoader(conf) + trainer = UNetTrainer(model, preprocessor, loader, conf, device=device) + + # Run in specified mode + if args.mode == "train": + # Setup for training + preprocessor.setup(seed=conf.seed, global_rank=0) + + # Train the model + trainer.train( + num_epochs=conf.total_samples // conf.batch_size_effective // len(preprocessor.train_data), + batch_size=conf.batch_size // len(gpus), + fp16=conf.fp16 ) - model = self.ema_model - model.eval() - conds = [] - - if with_render: - sampler = self.conf._make_diffusion_conf( - T=T_render or self.conf.T_eval).make_sampler() - - if self.global_rank == 0: - writer = LMDBImageWriter(render_save_path, - format='webp', - quality=100) - else: - writer = nullcontext() - else: - writer = nullcontext() - - with writer: - for batch in tqdm(loader, total=len(loader), desc='infer'): - with torch.no_grad(): - # (n, c) - # print('idx:', batch['index']) - cond = model.encoder(batch['img'].to(self.device)) - - # used for reordering to match the original dataset - idx = batch['index'] - idx = self.all_gather(idx) - if idx.dim() == 2: - idx = idx.flatten(0, 1) - argsort = idx.argsort() - - if with_render: - noise = torch.randn(len(cond), - 3, - self.conf.img_size, - self.conf.img_size, - device=self.device) - render = sampler.sample(model, noise=noise, cond=cond) - render = (render + 1) / 2 - # print('render:', render.shape) - # (k, n, c, h, w) - render = self.all_gather(render) - if render.dim() == 5: - # (k*n, c) - render = render.flatten(0, 1) - - # print('global_rank:', self.global_rank) - - if self.global_rank == 0: - writer.put_images(render[argsort]) - - # (k, n, c) - cond = self.all_gather(cond) - - if cond.dim() == 3: - # (k*n, c) - cond = cond.flatten(0, 1) - - conds.append(cond[argsort].cpu()) - # break - model.train() - # (N, c) cpu - - conds = torch.cat(conds).float() - return conds - - def training_step(self, batch, batch_idx): - """ - given an input, calculate the loss function - no optimization at this stage. - """ - with amp.autocast(False): - # batch size here is local! - # forward - if self.conf.train_mode.require_dataset_infer(): - # this mode as pre-calculated cond - cond = batch[0] - if self.conf.latent_znormalize: - cond = (cond - self.conds_mean.to( - self.device)) / self.conds_std.to(self.device) - else: - imgs, idxs = batch['img'], batch['index'] - # print(f'(rank {self.global_rank}) batch size:', len(imgs)) - x_start = imgs - - if self.conf.train_mode == TrainMode.diffusion: - """ - main training mode!!! - """ - # with numpy seed we have the problem that the sample t's are related! - t, weight = self.T_sampler.sample(len(x_start), x_start.device) - losses = self.sampler.training_losses(model=self.model, - x_start=x_start, - t=t) - elif self.conf.train_mode.is_latent_diffusion(): - """ - training the latent variables! - """ - # diffusion on the latent - t, weight = self.T_sampler.sample(len(cond), cond.device) - latent_losses = self.latent_sampler.training_losses( - model=self.model.latent_net, x_start=cond, t=t) - # train only do the latent diffusion - losses = { - 'latent': latent_losses['loss'], - 'loss': latent_losses['loss'] - } - else: - raise NotImplementedError() - - loss = losses['loss'].mean() - # divide by accum batches to make the accumulated gradient exact! - for key in ['loss', 'vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']: - if key in losses: - losses[key] = self.all_gather(losses[key]).mean() - - if self.global_rank == 0: - self.logger.experiment.add_scalar('loss', losses['loss'], - self.num_samples) - for key in ['vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']: - if key in losses: - self.logger.experiment.add_scalar( - f'loss/{key}', losses[key], self.num_samples) - - return {'loss': loss} - - # def on_train_batch_end(self, outputs, batch, batch_idx: int, - # dataloader_idx: int) -> None: - # """ - # after each training step ... - #""" - #if self.is_last_accum(batch_idx): - # only apply ema on the last gradient accumulation step, - # if it is the iteration that has optimizer.step() - # if self.conf.train_mode == TrainMode.latent_diffusion: - # it trains only the latent hence change only the latent - # ema(self.model.latent_net, self.ema_model.latent_net, - # self.conf.ema_decay) - # else: - # ema(self.model, self.ema_model, self.conf.ema_decay) - - # logging - # if self.conf.train_mode.require_dataset_infer(): - # imgs = None - # else: - # imgs = batch['img'] - # self.log_sample(x_start=imgs) - # self.evaluate_scores() - - #def on_before_optimizer_step(self, optimizer: Optimizer, - # optimizer_idx: int) -> None: - # fix the fp16 + clip grad norm problem with pytorch lightinng - # this is the currently correct way to do it - #if self.conf.grad_clip > 0: - # from trainer.params_grads import grads_norm, iter_opt_params - # params = [ - # p for group in optimizer.param_groups for p in group['params'] - # ] - # print('before:', grads_norm(iter_opt_params(optimizer))) - # torch.nn.utils.clip_grad_norm_(params, - # max_norm=self.conf.grad_clip) - # print('after:', grads_norm(iter_opt_params(optimizer))) - - - # Change in PyLightning framework - def on_train_batch_end(self, outputs, batch, batch_idx: int) -> None: - """ - after each training step ... - """ - if self.is_last_accum(batch_idx): - # only apply ema on the last gradient accumulation step, - # if it is the iteration that has optimizer.step() - if self.conf.train_mode == TrainMode.latent_diffusion: - # it trains only the latent hence change only the latent - ema(self.model.latent_net, self.ema_model.latent_net, - self.conf.ema_decay) - else: - ema(self.model, self.ema_model, self.conf.ema_decay) - - # logging - if self.conf.train_mode.require_dataset_infer(): - imgs = None - else: - imgs = batch['img'] - self.log_sample(x_start=imgs) - self.evaluate_scores() - - # Change in PyLightning framework - def on_before_optimizer_step(self, optimizer: Optimizer,**kwargs) -> None: - # fix the fp16 + clip grad norm problem with pytorch lightning - # this is the currently correct way to do it - if self.conf.grad_clip > 0: - params = [p for group in optimizer.param_groups for p in group['params']] - torch.nn.utils.clip_grad_norm_(params, max_norm=self.conf.grad_clip) - - def log_sample(self, x_start): - """ - put images to the tensorboard - """ - def do(model, - postfix, - use_xstart, - save_real=False, - no_latent_diff=False, - interpolate=False): - model.eval() - with torch.no_grad(): - all_x_T = self.split_tensor(self.x_T) - batch_size = min(len(all_x_T), self.conf.batch_size_eval) - # allow for superlarge models - loader = DataLoader(all_x_T, batch_size=batch_size) - - Gen = [] - for x_T in loader: - if use_xstart: - _xstart = x_start[:len(x_T)] - else: - _xstart = None - - if self.conf.train_mode.is_latent_diffusion( - ) and not use_xstart: - # diffusion of the latent first - gen = render_uncondition( - conf=self.conf, - model=model, - x_T=x_T, - sampler=self.eval_sampler, - latent_sampler=self.eval_latent_sampler, - conds_mean=self.conds_mean, - conds_std=self.conds_std) - else: - if not use_xstart and self.conf.model_type.has_noise_to_cond( - ): - model: BeatGANsAutoencModel - # special case, it may not be stochastic, yet can sample - cond = torch.randn(len(x_T), - self.conf.style_ch, - device=self.device) - cond = model.noise_to_cond(cond) - else: - if interpolate: - with amp.autocast(self.conf.fp16): - cond = model.encoder(_xstart) - i = torch.randperm(len(cond)) - cond = (cond + cond[i]) / 2 - else: - cond = None - gen = self.eval_sampler.sample(model=model, - noise=x_T, - cond=cond, - x_start=_xstart) - Gen.append(gen) - - gen = torch.cat(Gen) - gen = self.all_gather(gen) - if gen.dim() == 5: - # (n, c, h, w) - gen = gen.flatten(0, 1) - - if save_real and use_xstart: - # save the original images to the tensorboard - real = self.all_gather(_xstart) - if real.dim() == 5: - real = real.flatten(0, 1) - - if self.global_rank == 0: - grid_real = (make_grid(real) + 1) / 2 - self.logger.experiment.add_image( - f'sample{postfix}/real', grid_real, - self.num_samples) - - if self.global_rank == 0: - # save samples to the tensorboard - grid = (make_grid(gen) + 1) / 2 - sample_dir = os.path.join(self.conf.logdir, - f'sample{postfix}') - if not os.path.exists(sample_dir): - os.makedirs(sample_dir) - path = os.path.join(sample_dir, - '%d.png' % self.num_samples) - save_image(grid, path) - self.logger.experiment.add_image(f'sample{postfix}', grid, - self.num_samples) - model.train() - - if self.conf.sample_every_samples > 0 and is_time( - self.num_samples, self.conf.sample_every_samples, - self.conf.batch_size_effective): - - if self.conf.train_mode.require_dataset_infer(): - do(self.model, '', use_xstart=False) - do(self.ema_model, '_ema', use_xstart=False) - else: - if self.conf.model_type.has_autoenc( - ) and self.conf.model_type.can_sample(): - do(self.model, '', use_xstart=False) - do(self.ema_model, '_ema', use_xstart=False) - # autoencoding mode - do(self.model, '_enc', use_xstart=True, save_real=True) - do(self.ema_model, - '_enc_ema', - use_xstart=True, - save_real=True) - elif self.conf.train_mode.use_latent_net(): - do(self.model, '', use_xstart=False) - do(self.ema_model, '_ema', use_xstart=False) - # autoencoding mode - do(self.model, '_enc', use_xstart=True, save_real=True) - do(self.model, - '_enc_nodiff', - use_xstart=True, - save_real=True, - no_latent_diff=True) - do(self.ema_model, - '_enc_ema', - use_xstart=True, - save_real=True) - else: - do(self.model, '', use_xstart=True, save_real=True) - do(self.ema_model, '_ema', use_xstart=True, save_real=True) - - def evaluate_scores(self): - """ - evaluate FID and other scores during training (put to the tensorboard) - For, FID. It is a fast version with 5k images (gold standard is 50k). - Don't use its results in the paper! - """ - def fid(model, postfix): - score = evaluate_fid(self.eval_sampler, - model, - self.conf, - device=self.device, - train_data=self.train_data, - val_data=self.val_data, - latent_sampler=self.eval_latent_sampler, - conds_mean=self.conds_mean, - conds_std=self.conds_std) - if self.global_rank == 0: - self.logger.experiment.add_scalar(f'FID{postfix}', score, - self.num_samples) - if not os.path.exists(self.conf.logdir): - os.makedirs(self.conf.logdir) - with open(os.path.join(self.conf.logdir, 'eval.txt'), - 'a') as f: - metrics = { - f'FID{postfix}': score, - 'num_samples': self.num_samples, - } - f.write(json.dumps(metrics) + "\n") - - def lpips(model, postfix): - if self.conf.model_type.has_autoenc( - ) and self.conf.train_mode.is_autoenc(): - # {'lpips', 'ssim', 'mse'} - score = evaluate_lpips(self.eval_sampler, - model, - self.conf, - device=self.device, - val_data=self.val_data, - latent_sampler=self.eval_latent_sampler) - - if self.global_rank == 0: - for key, val in score.items(): - self.logger.experiment.add_scalar( - f'{key}{postfix}', val, self.num_samples) - - if self.conf.eval_every_samples > 0 and self.num_samples > 0 and is_time( - self.num_samples, self.conf.eval_every_samples, - self.conf.batch_size_effective): - print(f'eval fid @ {self.num_samples}') - lpips(self.model, '') - fid(self.model, '') - - if self.conf.eval_ema_every_samples > 0 and self.num_samples > 0 and is_time( - self.num_samples, self.conf.eval_ema_every_samples, - self.conf.batch_size_effective): - print(f'eval fid ema @ {self.num_samples}') - fid(self.ema_model, '_ema') - # it's too slow - # lpips(self.ema_model, '_ema') - - def configure_optimizers(self): - out = {} - if self.conf.optimizer == OptimizerType.adam: - optim = torch.optim.Adam(self.model.parameters(), - lr=self.conf.lr, - weight_decay=self.conf.weight_decay) - elif self.conf.optimizer == OptimizerType.adamw: - optim = torch.optim.AdamW(self.model.parameters(), - lr=self.conf.lr, - weight_decay=self.conf.weight_decay) + elif args.mode == "eval": + # Setup for evaluation + preprocessor.setup(seed=conf.seed, global_rank=0) + + # Load checkpoint + if conf.eval_path: + loader.load_checkpoint(model.model, filename=conf.eval_path) else: - raise NotImplementedError() - out['optimizer'] = optim - if self.conf.warmup > 0: - sched = torch.optim.lr_scheduler.LambdaLR(optim, - lr_lambda=WarmupLR( - self.conf.warmup)) - out['lr_scheduler'] = { - 'scheduler': sched, - 'interval': 'step', - } - return out - - def split_tensor(self, x): - """ - extract the tensor for a corresponding "worker" in the batch dimension - - Args: - x: (n, c) - - Returns: x: (n_local, c) - """ - n = len(x) - rank = self.global_rank - world_size = get_world_size() - # print(f'rank: {rank}/{world_size}') - per_rank = n // world_size - return x[rank * per_rank:(rank + 1) * per_rank] - - def test_step(self, batch, *args, **kwargs): - """ - for the "eval" mode. - We first select what to do according to the "conf.eval_programs". - test_step will only run for "one iteration" (it's a hack!). + loader.load_checkpoint(model.model) - We just want the multi-gpu support. - """ - # make sure you seed each worker differently! - self.setup() - - # it will run only one step! - print('global step:', self.global_step) - """ - "infer" = predict the latent variables using the encoder on the whole dataset - """ - if 'infer' in self.conf.eval_programs: - if 'infer' in self.conf.eval_programs: - print('infer ...') - conds = self.infer_whole_dataset().float() - # NOTE: always use this path for the latent.pkl files - save_path = f'checkpoints/{self.conf.name}/latent.pkl' - else: - raise NotImplementedError() - - if self.global_rank == 0: - conds_mean = conds.mean(dim=0) - conds_std = conds.std(dim=0) - if not os.path.exists(os.path.dirname(save_path)): - os.makedirs(os.path.dirname(save_path)) - torch.save( - { - 'conds': conds, - 'conds_mean': conds_mean, - 'conds_std': conds_std, - }, save_path) - """ - "infer+render" = predict the latent variables using the encoder on the whole dataset - THIS ALSO GENERATE CORRESPONDING IMAGES - """ - # infer + reconstruction quality of the input - for each in self.conf.eval_programs: - if each.startswith('infer+render'): - m = re.match(r'infer\+render([0-9]+)', each) - if m is not None: - T = int(m[1]) - self.setup() - print(f'infer + reconstruction T{T} ...') - conds = self.infer_whole_dataset( + # Run evaluation programs + for program in conf.eval_programs: + if program.startswith("fid"): + # Extract parameters from program string + if "(" in program and ")" in program: + # Format: fid(T,T_latent) + params = program[program.find("(")+1:program.find(")")].split(",") + T = int(params[0]) + T_latent = int(params[1]) if len(params) > 1 else None + else: + # Format: fidT + T = int(program[3:]) + T_latent = None + + # Evaluate FID + print(f"Evaluating FID with T={T}, T_latent={T_latent}") + score = trainer.evaluate_fid(T=T, T_latent=T_latent) + print(f"FID score: {score}") + + elif program.startswith("recon"): + # Format: reconT + T = int(program[5:]) + + # Evaluate reconstruction + print(f"Evaluating reconstruction with T={T}") + scores = trainer.evaluate_lpips(T=T) + for k, v in scores.items(): + print(f"{k}: {v}") + + elif program.startswith("infer"): + # Infer latents + if "+" in program: + # Format: infer+renderT + T = int(program[12:]) + print(f"Inferring latents and rendering with T={T}") + trainer.infer_whole_dataset( with_render=True, T_render=T, - render_save_path= - f'latent_infer_render{T}/{self.conf.name}.lmdb', + render_save_path=f'latent_infer_render{T}/{conf.name}.lmdb' ) - save_path = f'latent_infer_render{T}/{self.conf.name}.pkl' - conds_mean = conds.mean(dim=0) - conds_std = conds.std(dim=0) - if not os.path.exists(os.path.dirname(save_path)): - os.makedirs(os.path.dirname(save_path)) - torch.save( - { - 'conds': conds, - 'conds_mean': conds_mean, - 'conds_std': conds_std, - }, save_path) - - # evals those "fidXX" - """ - "fid" = unconditional generation (conf.train_mode = diffusion). - Note: Diff. autoenc will still receive real images in this mode. - "fid," = unconditional generation for latent models (conf.train_mode = latent_diffusion). - Note: Diff. autoenc will still NOT receive real images in this made. - but you need to make sure that the train_mode is latent_diffusion. - """ - for each in self.conf.eval_programs: - if each.startswith('fid'): - m = re.match(r'fid\(([0-9]+),([0-9]+)\)', each) - clip_latent_noise = False - if m is not None: - # eval(T1,T2) - T = int(m[1]) - T_latent = int(m[2]) - print(f'evaluating FID T = {T}... latent T = {T_latent}') - else: - m = re.match(r'fidclip\(([0-9]+),([0-9]+)\)', each) - if m is not None: - # fidclip(T1,T2) - T = int(m[1]) - T_latent = int(m[2]) - clip_latent_noise = True - print( - f'evaluating FID (clip latent noise) T = {T}... latent T = {T_latent}' - ) - else: - # evalT - _, T = each.split('fid') - T = int(T) - T_latent = None - print(f'evaluating FID T = {T}...') - - self.train_dataloader() - sampler = self.conf._make_diffusion_conf(T=T).make_sampler() - if T_latent is not None: - latent_sampler = self.conf._make_latent_diffusion_conf( - T=T_latent).make_sampler() - else: - latent_sampler = None - - conf = self.conf.clone() - conf.eval_num_images = 50_000 - score = evaluate_fid( - sampler, - self.ema_model, - conf, - device=self.device, - train_data=self.train_data, - val_data=self.val_data, - latent_sampler=latent_sampler, - conds_mean=self.conds_mean, - conds_std=self.conds_std, - remove_cache=False, - clip_latent_noise=clip_latent_noise, - ) - if T_latent is None: - self.log(f'fid_ema_T{T}', score) else: - name = 'fid' - if clip_latent_noise: - name += '_clip' - name += f'_ema_T{T}_Tlatent{T_latent}' - self.log(name, score) - """ - "recon" = reconstruction & autoencoding (without noise inversion) - """ - for each in self.conf.eval_programs: - if each.startswith('recon'): - self.model: BeatGANsAutoencModel - _, T = each.split('recon') - T = int(T) - print(f'evaluating reconstruction T = {T}...') - - sampler = self.conf._make_diffusion_conf(T=T).make_sampler() - - conf = self.conf.clone() - # eval whole val dataset - conf.eval_num_images = len(self.val_data) - # {'lpips', 'mse', 'ssim'} - score = evaluate_lpips(sampler, - self.ema_model, - conf, - device=self.device, - val_data=self.val_data, - latent_sampler=None) - for k, v in score.items(): - self.log(f'{k}_ema_T{T}', v) - """ - "inv" = reconstruction with noise inversion - """ - for each in self.conf.eval_programs: - if each.startswith('inv'): - self.model: BeatGANsAutoencModel - _, T = each.split('inv') - T = int(T) - print( - f'evaluating reconstruction with noise inversion T = {T}...' - ) - - sampler = self.conf._make_diffusion_conf(T=T).make_sampler() - - conf = self.conf.clone() - # eval whole val dataset - conf.eval_num_images = len(self.val_data) - # {'lpips', 'mse', 'ssim'} - score = evaluate_lpips(sampler, - self.ema_model, - conf, - device=self.device, - val_data=self.val_data, - latent_sampler=None, - use_inverted_noise=True) - for k, v in score.items(): - self.log(f'{k}_inv_ema_T{T}', v) - - -def ema(source, target, decay): - source_dict = source.state_dict() - target_dict = target.state_dict() - for key in source_dict.keys(): - target_dict[key].data.copy_(target_dict[key].data * decay + - source_dict[key].data * (1 - decay)) - - -class WarmupLR: - def __init__(self, warmup) -> None: - self.warmup = warmup - - def __call__(self, step): - return min(step, self.warmup) / self.warmup - - -def is_time(num_samples, every, step_size): - closest = (num_samples // every) * every - return num_samples - closest < step_size - - -def train(conf: TrainConfig, gpus, nodes=1, mode: str = 'train'): - print('conf:', conf.name) - # assert not (conf.fp16 and conf.grad_clip > 0 - # ), 'pytorch lightning has bug with amp + gradient clipping' - model = LitModel(conf) - - if not os.path.exists(conf.logdir): - os.makedirs(conf.logdir) - checkpoint = ModelCheckpoint(dirpath=f'{conf.logdir}', - save_last=True, - save_top_k=1, - every_n_train_steps=conf.save_every_samples // - conf.batch_size_effective) - checkpoint_path = f'{conf.logdir}/last.ckpt' - print('ckpt path:', checkpoint_path) - if os.path.exists(checkpoint_path): - resume = checkpoint_path - print('resume!') - else: - if conf.continue_from is not None: - # continue from a checkpoint - resume = conf.continue_from.path - else: - resume = None - - tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir, - name=None, - version='') - - # from pytorch_lightning. - """ - plugins = [] - if len(gpus) == 1 and nodes == 1: - accelerator = None - else: - accelerator = 'ddp' - from pytorch_lightning.plugins import DDPPlugin - - # important for working with gradient checkpoint - plugins.append(DDPPlugin(find_unused_parameters=False)) - - trainer = pl.Trainer( - max_steps=conf.total_samples // conf.batch_size_effective, - resume_from_checkpoint=resume, - gpus=gpus, - num_nodes=nodes, - accelerator=accelerator, - precision=16 if conf.fp16 else 32, - callbacks=[ - checkpoint, - LearningRateMonitor(), - ], - # clip in the model instead - # gradient_clip_val=conf.grad_clip, - replace_sampler_ddp=True, - logger=tb_logger, - accumulate_grad_batches=conf.accum_batches, - plugins=plugins, - ) - """ - - if len(gpus) == 1 and nodes == 1: - accelerator = 'cuda' - trainer_kwargs = {} - plugins = None - - else: - accelerator = 'ddp' - # For PyTorch Lightning 2.x - from pytorch_lightning.strategies import DDPStrategy - - # important for working with gradient checkpoint - plugins = [] # Keep your existing plugins list initialization - trainer_kwargs = { - 'strategy': DDPStrategy(find_unused_parameters=False) - } - - use_dist_sampler = True if (len(gpus) >= 2 or nodes >= 2) else False - - trainer = pl.Trainer( - max_steps=conf.total_samples // conf.batch_size_effective, - devices=gpus, - num_nodes=nodes, - accelerator=accelerator, - precision=16 if conf.fp16 else 32, - callbacks=[ - checkpoint, - LearningRateMonitor(), - ], - use_distributed_sampler=use_dist_sampler, - logger=tb_logger, - accumulate_grad_batches=conf.accum_batches, - plugins=plugins if len(gpus) > 1 else None, - **trainer_kwargs if 'trainer_kwargs' in locals() else {}, # Add this line - ) - - if mode == 'train': - - # Get the train dataloader from your model - train_loader = model.train_dataloader() - # If multiple loaders are returned, manually create a CombinedLoader and prime it. - if isinstance(train_loader, (list, dict)): - from pytorch_lightning.utilities.data import CombinedLoader - - combined = CombinedLoader(train_loader, mode="max_size_cycle") - _ = iter(combined) # This ensures internal state is set - - trainer.fit(model, ckpt_path=resume) - - #trainer.fit(model) - - elif mode == 'eval': - # load the latest checkpoint - # perform lpips - # dummy loader to allow calling "test_step" - dummy = DataLoader(TensorDataset(torch.tensor([0.] * conf.batch_size)), - batch_size=conf.batch_size) - eval_path = conf.eval_path or checkpoint_path - # conf.eval_num_images = 50 - print('loading from:', eval_path) - state = torch.load(eval_path, map_location='cpu') - print('step:', state['global_step']) - model.load_state_dict(state['state_dict']) - # trainer.fit(model) - out = trainer.test(model, dataloaders=dummy) - # first (and only) loader - out = out[0] - print(out) - - if get_rank() == 0: - # save to tensorboard - for k, v in out.items(): - tb_logger.experiment.add_scalar( - k, v, state['global_step'] * conf.batch_size_effective) + # Format: infer + print("Inferring latents") + trainer.infer_whole_dataset() - # # save to file - # # make it a dict of list - # for k, v in out.items(): - # out[k] = [v] - tgt = f'evals/{conf.name}.txt' - dirname = os.path.dirname(tgt) - if not os.path.exists(dirname): - os.makedirs(dirname) - with open(tgt, 'a') as f: - f.write(json.dumps(out) + "\n") - # pd.DataFrame(out).to_csv(tgt) - else: - raise NotImplementedError() +if __name__ == "__main__": + main() diff --git a/training/unet_trainer.py b/training/unet_trainer.py new file mode 100644 index 0000000..7d94bba --- /dev/null +++ b/training/unet_trainer.py @@ -0,0 +1,711 @@ +import os +import json +import numpy as np +import torch +from torch import nn +from torch.cuda import amp +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader, TensorDataset +from torchvision.utils import make_grid, save_image +from tqdm import tqdm + +from dist_utils import get_world_size, get_rank, all_gather +from metrics import evaluate_fid, evaluate_lpips +from renderer import render_uncondition, render_condition + +class UNetTrainer: + """Handles the training process for UNet diffusion models.""" + def __init__(self, model, preprocessor, loader, conf, device='cuda'): + """ + Initialize the trainer. + + Args: + model: UNetModel instance + preprocessor: UNetPreprocessor instance + loader: UNetLoader instance + conf: Configuration object + device: Device to use for training + """ + self.model = model + self.preprocessor = preprocessor + self.loader = loader + self.conf = conf + self.device = device + + # Initialize training state + self.global_step = 0 + self.num_samples = 0 + self.global_rank = get_rank() + + # Register buffer for consistent sampling + self.x_T = torch.randn(conf.sample_size, 3, conf.img_size, conf.img_size, device=device) + + # Initialize latent normalization stats + self.conds = None + self.conds_mean = None + self.conds_std = None + + # Load latent stats if path is provided + if conf.latent_infer_path is not None: + print('Loading latent stats...') + stats = self.loader.load_latent_stats(conf.latent_infer_path) + if stats: + self.conds = stats['conds'] + self.conds_mean = stats['conds_mean'][None, :].to(device) + self.conds_std = stats['conds_std'][None, :].to(device) + + def configure_optimizers(self): + """ + Configure optimizers and learning rate schedulers. + + Returns: + Tuple of (optimizer, scheduler) + """ + if self.conf.optimizer == 'adam': + optimizer = torch.optim.Adam( + self.model.model.parameters(), + lr=self.conf.lr, + weight_decay=self.conf.weight_decay + ) + elif self.conf.optimizer == 'adamw': + optimizer = torch.optim.AdamW( + self.model.model.parameters(), + lr=self.conf.lr, + weight_decay=self.conf.weight_decay + ) + else: + raise NotImplementedError(f"Optimizer {self.conf.optimizer} not implemented") + + scheduler = None + if self.conf.warmup > 0: + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=self._warmup_lr(self.conf.warmup) + ) + + return optimizer, scheduler + + def _warmup_lr(self, warmup): + """ + Create a warmup learning rate function. + + Args: + warmup: Number of warmup steps + + Returns: + Learning rate lambda function + """ + def lr_lambda(step): + return min(step, warmup) / warmup + return lr_lambda + + def normalize(self, cond): + """ + Normalize latent conditions. + + Args: + cond: Conditions to normalize + + Returns: + Normalized conditions + """ + if self.conds_mean is None or self.conds_std is None: + return cond + return (cond - self.conds_mean) / self.conds_std + + def denormalize(self, cond): + """ + Denormalize latent conditions. + + Args: + cond: Normalized conditions + + Returns: + Denormalized conditions + """ + if self.conds_mean is None or self.conds_std is None: + return cond + return (cond * self.conds_std) + self.conds_mean + + def is_last_accum(self, batch_idx): + """ + Check if this is the last gradient accumulation step. + + Args: + batch_idx: Current batch index + + Returns: + Boolean indicating if this is the last accumulation step + """ + return (batch_idx + 1) % self.conf.accum_batches == 0 + + def train_step(self, batch, batch_idx): + """ + Perform a single training step. + + Args: + batch: Batch of data + batch_idx: Index of the current batch + + Returns: + Loss value + """ + with amp.autocast(False): + # Handle different training modes + if self.conf.train_mode.require_dataset_infer(): + # This mode has pre-calculated cond + cond = batch[0] + if self.conf.latent_znormalize: + cond = self.normalize(cond) + x_start = None + else: + imgs, idxs = batch['img'], batch['index'] + x_start = imgs + cond = None + + # Different training modes + if self.conf.train_mode == 'diffusion': + # Main training mode + t, weight = self.model.T_sampler.sample(len(x_start), x_start.device) + losses = self.model.sampler.training_losses(model=self.model.model, + x_start=x_start, + t=t) + elif self.conf.train_mode.is_latent_diffusion(): + # Training the latent variables + t, weight = self.model.T_sampler.sample(len(cond), cond.device) + latent_losses = self.model.latent_sampler.training_losses( + model=self.model.model.latent_net, x_start=cond, t=t) + # Train only do the latent diffusion + losses = { + 'latent': latent_losses['loss'], + 'loss': latent_losses['loss'] + } + else: + raise NotImplementedError(f"Training mode {self.conf.train_mode} not implemented") + + loss = losses['loss'].mean() + + # Gather losses from all processes + gathered_losses = {} + for key in ['loss', 'vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']: + if key in losses: + gathered_losses[key] = all_gather(losses[key]).mean() + + return loss, gathered_losses + + def train_epoch(self, dataloader, optimizer, scheduler=None, scaler=None, log_interval=10): + """ + Train for one epoch. + + Args: + dataloader: DataLoader for training data + optimizer: Optimizer + scheduler: Learning rate scheduler + scaler: Gradient scaler for mixed precision + log_interval: How often to log + + Returns: + Average loss for the epoch + """ + self.model.model.train() + total_loss = 0 + num_batches = 0 + + for batch_idx, batch in enumerate(tqdm(dataloader, desc="Training")): + # Move batch to device + if isinstance(batch, dict): + batch = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in batch.items()} + elif isinstance(batch, list) or isinstance(batch, tuple): + batch = [b.to(self.device) if torch.is_tensor(b) else b for b in batch] + + # Forward and backward pass + if scaler is not None: + with amp.autocast(True): + loss, gathered_losses = self.train_step(batch, batch_idx) + loss = loss / self.conf.accum_batches # Normalize for gradient accumulation + + scaler.scale(loss).backward() + + if self.is_last_accum(batch_idx): + # Apply gradient clipping + if self.conf.grad_clip > 0: + scaler.unscale_(optimizer) + nn.utils.clip_grad_norm_(self.model.model.parameters(), self.conf.grad_clip) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + if scheduler is not None: + scheduler.step() + else: + loss, gathered_losses = self.train_step(batch, batch_idx) + loss = loss / self.conf.accum_batches # Normalize for gradient accumulation + + loss.backward() + + if self.is_last_accum(batch_idx): + # Apply gradient clipping + if self.conf.grad_clip > 0: + nn.utils.clip_grad_norm_(self.model.model.parameters(), self.conf.grad_clip) + + optimizer.step() + optimizer.zero_grad() + + if scheduler is not None: + scheduler.step() + + # Update EMA model + if self.is_last_accum(batch_idx): + if self.conf.train_mode == 'latent_diffusion': + # Only update latent part for latent diffusion + self.model._ema(self.model.model.latent_net, self.model.ema_model.latent_net, self.conf.ema_decay) + else: + self.model.update_ema(self.conf.ema_decay) + + # Log samples + if batch_idx % log_interval == 0: + if self.conf.train_mode.require_dataset_infer(): + imgs = None + else: + imgs = batch['img'] if isinstance(batch, dict) else None + self.log_sample(x_start=imgs) + + # Update global step and samples + self.global_step += 1 + self.num_samples = self.global_step * self.conf.batch_size_effective + + # Evaluate metrics periodically + self.evaluate_scores() + + total_loss += loss.item() * self.conf.accum_batches + num_batches += 1 + + # Log losses + if batch_idx % log_interval == 0 and self.global_rank == 0: + print(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item() * self.conf.accum_batches:.4f}") + for key, value in gathered_losses.items(): + print(f" {key}: {value.item():.4f}") + + avg_loss = total_loss / num_batches + return avg_loss + + def train(self, num_epochs, batch_size=None, fp16=False): + """ + Train the model for multiple epochs. + + Args: + num_epochs: Number of epochs to train + batch_size: Batch size (uses conf.batch_size if None) + fp16: Whether to use mixed precision training + + Returns: + Final model + """ + # Setup + if batch_size is None: + batch_size = self.conf.batch_size // get_world_size() # Local batch size + + # Create dataloaders + train_loader = self.preprocessor.create_train_dataloader(batch_size) + + # Configure optimizers + optimizer, scheduler = self.configure_optimizers() + + # Load checkpoint if exists + start_epoch = 0 + if os.path.exists(f'{self.conf.logdir}/last.ckpt'): + self.global_step = self.loader.load_checkpoint( + self.model.model, optimizer, scheduler + ) + start_epoch = self.global_step // len(train_loader) + self.num_samples = self.global_step * self.conf.batch_size_effective + + # Setup mixed precision if needed + scaler = amp.GradScaler() if fp16 else None + + # Training loop + for epoch in range(start_epoch, num_epochs): + print(f"Epoch {epoch+1}/{num_epochs}") + + # Train for one epoch + avg_loss = self.train_epoch(train_loader, optimizer, scheduler, scaler) + + # Save checkpoint + if self.global_rank == 0: + self.loader.save_last_checkpoint( + self.model.model, optimizer, scheduler, self.global_step + ) + + if epoch % self.conf.save_epoch_interval == 0: + self.loader.save_checkpoint( + self.model.model, optimizer, scheduler, self.global_step, + f'{self.conf.logdir}/checkpoint_epoch{epoch+1}.ckpt' + ) + + print(f"Epoch {epoch+1} completed. Average loss: {avg_loss:.4f}") + + return self.model + + def sample(self, N, T=None, T_latent=None): + """ + Generate samples from the model. + + Args: + N: Number of samples to generate + T: Number of diffusion steps + T_latent: Number of latent diffusion steps + + Returns: + Generated images + """ + if T is None: + sampler = self.model.eval_sampler + latent_sampler = self.model.latent_sampler + else: + sampler = self.conf._make_diffusion_conf(T).make_sampler() + latent_sampler = self.conf._make_latent_diffusion_conf(T_latent).make_sampler() + + noise = torch.randn(N, + 3, + self.conf.img_size, + self.conf.img_size, + device=self.device) + + pred_img = render_uncondition( + self.conf, + self.model.ema_model, + noise, + sampler=sampler, + latent_sampler=latent_sampler, + conds_mean=self.conds_mean, + conds_std=self.conds_std, + ) + pred_img = (pred_img + 1) / 2 + return pred_img + + def render(self, noise, cond=None, T=None): + """ + Render images from noise with optional conditioning. + + Args: + noise: Input noise + cond: Conditioning information + T: Number of diffusion steps + + Returns: + Rendered images + """ + if T is None: + sampler = self.model.eval_sampler + else: + sampler = self.conf._make_diffusion_conf(T).make_sampler() + + if cond is not None: + pred_img = render_condition(self.conf, + self.model.ema_model, + noise, + sampler=sampler, + cond=cond) + else: + pred_img = render_uncondition(self.conf, + self.model.ema_model, + noise, + sampler=sampler, + latent_sampler=None) + pred_img = (pred_img + 1) / 2 + return pred_img + + def infer_whole_dataset(self, with_render=False, T_render=None, render_save_path=None): + """ + Infer latents for the entire dataset. + + Args: + with_render: Whether to also render images + T_render: Number of diffusion steps for rendering + render_save_path: Path to save rendered images + + Returns: + Inferred conditions + """ + from tqdm import tqdm + from contextlib import nullcontext + from source.lmdb_writer import LMDBImageWriter + + data = self.conf.make_dataset() + if isinstance(data, CelebAlmdb) and data.crop_d2c: + # Special case where we need the d2c crop + data.transform = make_transform(self.conf.img_size, + flip_prob=0, + crop_d2c=True) + else: + data.transform = make_transform(self.conf.img_size, flip_prob=0) + + loader = self.conf.make_loader( + data, + shuffle=False, + drop_last=False, + batch_size=self.conf.batch_size_eval, + parallel=True, + ) + model = self.model.ema_model + model.eval() + conds = [] + + if with_render: + sampler = self.conf._make_diffusion_conf( + T=T_render or self.conf.T_eval).make_sampler() + + if self.global_rank == 0: + writer = LMDBImageWriter(render_save_path, + format='webp', + quality=100) + else: + writer = nullcontext() + else: + writer = nullcontext() + + with writer: + for batch in tqdm(loader, total=len(loader), desc='infer'): + with torch.no_grad(): + # (n, c) + cond = model.encoder(batch['img'].to(self.device)) + + # Used for reordering to match the original dataset + idx = batch['index'] + idx = all_gather(idx) + if idx.dim() == 2: + idx = idx.flatten(0, 1) + argsort = idx.argsort() + + if with_render: + noise = torch.randn(len(cond), + 3, + self.conf.img_size, + self.conf.img_size, + device=self.device) + render = sampler.sample(model, noise=noise, cond=cond) + render = (render + 1) / 2 + # (k, n, c, h, w) + render = all_gather(render) + if render.dim() == 5: + # (k*n, c) + render = render.flatten(0, 1) + + if self.global_rank == 0: + writer.put_images(render[argsort]) + + # (k, n, c) + cond = all_gather(cond) + + if cond.dim() == 3: + # (k*n, c) + cond = cond.flatten(0, 1) + + conds.append(cond[argsort].cpu()) + + model.train() + # (N, c) cpu + conds = torch.cat(conds).float() + + # Calculate and save statistics + if self.global_rank == 0: + self.conds = conds + self.conds_mean = conds.mean(dim=0, keepdim=True).to(self.device) + self.conds_std = conds.std(dim=0, keepdim=True).to(self.device) + + self.loader.save_latent_stats(conds, self.conds_mean.cpu(), self.conds_std.cpu()) + + return conds + + def log_sample(self, x_start): + """ + Log generated samples to tensorboard. + + Args: + x_start: Real images for comparison (optional) + """ + def do(model, postfix, use_xstart, save_real=False, no_latent_diff=False, interpolate=False): + model.eval() + with torch.no_grad(): + all_x_T = self._split_tensor(self.x_T) + batch_size = min(len(all_x_T), self.conf.batch_size_eval) + # Allow for superlarge models + loader = DataLoader(all_x_T, batch_size=batch_size) + + Gen = [] + for x_T in loader: + if use_xstart: + _xstart = x_start[:len(x_T)] + else: + _xstart = None + if self.conf.train_mode.is_latent_diffusion() and not use_xstart: + # Diffusion of the latent first + gen = render_uncondition( + conf=self.conf, + model=model, + x_T=x_T, + sampler=self.model.eval_sampler, + latent_sampler=self.model.eval_latent_sampler, + conds_mean=self.conds_mean, + conds_std=self.conds_std) + else: + if not use_xstart and self.conf.model_type.has_noise_to_cond(): + # Special case, it may not be stochastic, yet can sample + cond = torch.randn(len(x_T), + self.conf.style_ch, + device=self.device) + cond = model.noise_to_cond(cond) + else: + if interpolate: + with amp.autocast(self.conf.fp16): + cond = model.encoder(_xstart) + i = torch.randperm(len(cond)) + cond = (cond + cond[i]) / 2 + else: + cond = None + gen = self.model.eval_sampler.sample(model=model, + noise=x_T, + cond=cond, + x_start=_xstart) + Gen.append(gen) + + gen = torch.cat(Gen) + gen = all_gather(gen) + if gen.dim() == 5: + # (n, c, h, w) + gen = gen.flatten(0, 1) + + if save_real and use_xstart: + # Save the original images to the tensorboard + real = all_gather(_xstart) + if real.dim() == 5: + real = real.flatten(0, 1) + + if self.global_rank == 0: + grid_real = (make_grid(real) + 1) / 2 + # Save real images + sample_dir = os.path.join(self.conf.logdir, f'sample{postfix}') + if not os.path.exists(sample_dir): + os.makedirs(sample_dir) + save_image(grid_real, os.path.join(sample_dir, f'real_{self.num_samples}.png')) + + if self.global_rank == 0: + # Save samples to disk + grid = (make_grid(gen) + 1) / 2 + sample_dir = os.path.join(self.conf.logdir, f'sample{postfix}') + if not os.path.exists(sample_dir): + os.makedirs(sample_dir) + path = os.path.join(sample_dir, f'{self.num_samples}.png') + save_image(grid, path) + model.train() + + if self.conf.sample_every_samples > 0 and self._is_time( + self.num_samples, self.conf.sample_every_samples, + self.conf.batch_size_effective): + + if self.conf.train_mode.require_dataset_infer(): + do(self.model.model, '', use_xstart=False) + do(self.model.ema_model, '_ema', use_xstart=False) + else: + if self.conf.model_type.has_autoenc() and self.conf.model_type.can_sample(): + do(self.model.model, '', use_xstart=False) + do(self.model.ema_model, '_ema', use_xstart=False) + # Autoencoding mode + do(self.model.model, '_enc', use_xstart=True, save_real=True) + do(self.model.ema_model, '_enc_ema', use_xstart=True, save_real=True) + elif self.conf.train_mode.use_latent_net(): + do(self.model.model, '', use_xstart=False) + do(self.model.ema_model, '_ema', use_xstart=False) + # Autoencoding mode + do(self.model.model, '_enc', use_xstart=True, save_real=True) + do(self.model.model, '_enc_nodiff', use_xstart=True, save_real=True, no_latent_diff=True) + do(self.model.ema_model, '_enc_ema', use_xstart=True, save_real=True) + else: + do(self.model.model, '', use_xstart=True, save_real=True) + do(self.model.ema_model, '_ema', use_xstart=True, save_real=True) + + def evaluate_scores(self): + """ + Evaluate FID and other scores during training. + """ + def fid(model, postfix): + score = evaluate_fid(self.model.eval_sampler, + model, + self.conf, + device=self.device, + train_data=self.preprocessor.train_data, + val_data=self.preprocessor.val_data, + latent_sampler=self.model.eval_latent_sampler, + conds_mean=self.conds_mean, + conds_std=self.conds_std) + if self.global_rank == 0: + print(f"FID{postfix}: {score}") + if not os.path.exists(self.conf.logdir): + os.makedirs(self.conf.logdir) + with open(os.path.join(self.conf.logdir, 'eval.txt'), 'a') as f: + metrics = { + f'FID{postfix}': score, + 'num_samples': self.num_samples, + } + f.write(json.dumps(metrics) + "\n") + + def lpips(model, postfix): + if self.conf.model_type.has_autoenc() and self.conf.train_mode.is_autoenc(): + # {'lpips', 'ssim', 'mse'} + score = evaluate_lpips(self.model.eval_sampler, + model, + self.conf, + device=self.device, + val_data=self.preprocessor.val_data, + latent_sampler=self.model.eval_latent_sampler) + + if self.global_rank == 0: + for key, val in score.items(): + print(f"{key}{postfix}: {val}") + + if self.conf.eval_every_samples > 0 and self.num_samples > 0 and self._is_time( + self.num_samples, self.conf.eval_every_samples, + self.conf.batch_size_effective): + print(f'Evaluating FID @ {self.num_samples}') + lpips(self.model.model, '') + fid(self.model.model, '') + + if self.conf.eval_ema_every_samples > 0 and self.num_samples > 0 and self._is_time( + self.num_samples, self.conf.eval_ema_every_samples, + self.conf.batch_size_effective): + print(f'Evaluating FID EMA @ {self.num_samples}') + fid(self.model.ema_model, '_ema') + + def _split_tensor(self, x): + """ + Split tensor across workers. + + Args: + x: Tensor to split + + Returns: + Local portion of the tensor + """ + n = len(x) + rank = self.global_rank + world_size = get_world_size() + per_rank = n // world_size + return x[rank * per_rank:(rank + 1) * per_rank] + + def _is_time(self, num_samples, every, step_size): + """ + Check if it's time to perform an action based on number of samples. + + Args: + num_samples: Current number of samples + every: Frequency in samples + step_size: Step size in samples + + Returns: + Boolean indicating if it's time + """ + closest = (num_samples // every) * every + return num_samples - closest < step_size + + diff --git a/utils/unet_loader.py b/utils/unet_loader.py new file mode 100644 index 0000000..c7db35f --- /dev/null +++ b/utils/unet_loader.py @@ -0,0 +1,158 @@ +import os +import torch +import json +import numpy as np + +class UNetLoader: + """Handles model loading, saving and checkpoint management.""" + def __init__(self, conf, logdir=None): + """ + Initialize the loader. + + Args: + conf: Configuration object + logdir: Directory for logs and checkpoints + """ + self.conf = conf + self.logdir = logdir or conf.logdir + + # Create log directory if it doesn't exist + if not os.path.exists(self.logdir): + os.makedirs(self.logdir) + + def save_checkpoint(self, model, optimizer, scheduler=None, global_step=0, filename=None): + """ + Save a checkpoint. + + Args: + model: Model to save + optimizer: Optimizer state to save + scheduler: Learning rate scheduler to save + global_step: Current training step + filename: Filename for the checkpoint + """ + if filename is None: + filename = f'{self.logdir}/checkpoint_{global_step}.ckpt' + + checkpoint = { + 'state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'global_step': global_step, + } + + if scheduler is not None: + checkpoint['scheduler_state_dict'] = scheduler.state_dict() + + torch.save(checkpoint, filename) + print(f"Saved checkpoint to {filename}") + + def save_last_checkpoint(self, model, optimizer, scheduler=None, global_step=0): + """ + Save the latest checkpoint. + + Args: + model: Model to save + optimizer: Optimizer state to save + scheduler: Learning rate scheduler to save + global_step: Current training step + """ + self.save_checkpoint(model, optimizer, scheduler, global_step, f'{self.logdir}/last.ckpt') + + def load_checkpoint(self, model, optimizer=None, scheduler=None, filename=None, map_location='cpu'): + """ + Load a checkpoint. + + Args: + model: Model to load weights into + optimizer: Optimizer to load state into + scheduler: Learning rate scheduler to load state into + filename: Checkpoint filename + map_location: Device to load tensors onto + + Returns: + global_step from the checkpoint + """ + if filename is None: + filename = f'{self.logdir}/last.ckpt' + + if not os.path.exists(filename): + print(f"No checkpoint found at {filename}") + return 0 + + print(f"Loading checkpoint from {filename}") + checkpoint = torch.load(filename, map_location=map_location) + + model.load_state_dict(checkpoint['state_dict']) + + if optimizer is not None and 'optimizer_state_dict' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + if scheduler is not None and 'scheduler_state_dict' in checkpoint: + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + + global_step = checkpoint.get('global_step', 0) + print(f"Loaded checkpoint from step {global_step}") + + return global_step + + def load_pretrained(self, model, pretrain_path, map_location='cpu'): + """ + Load pretrained weights. + + Args: + model: Model to load weights into + pretrain_path: Path to pretrained weights + map_location: Device to load tensors onto + """ + if pretrain_path is None: + return + + print(f'Loading pretrained model from {pretrain_path}') + state = torch.load(pretrain_path, map_location=map_location, weights_only=False) + print('step:', state['global_step']) + model.load_state_dict(state['state_dict'], strict=False) + + def save_latent_stats(self, conds, conds_mean, conds_std, path=None): + """ + Save latent statistics. + + Args: + conds: Latent conditions + conds_mean: Mean of conditions + conds_std: Standard deviation of conditions + path: Save path + """ + if path is None: + path = f'checkpoints/{self.conf.name}/latent.pkl' + + if not os.path.exists(os.path.dirname(path)): + os.makedirs(os.path.dirname(path)) + + torch.save({ + 'conds': conds, + 'conds_mean': conds_mean, + 'conds_std': conds_std, + }, path) + print(f"Saved latent stats to {path}") + + def load_latent_stats(self, path=None, map_location='cpu'): + """ + Load latent statistics. + + Args: + path: Load path + map_location: Device to load tensors onto + + Returns: + Dictionary containing conds, conds_mean, and conds_std + """ + if path is None: + path = f'checkpoints/{self.conf.name}/latent.pkl' + + if not os.path.exists(path): + print(f"No latent stats found at {path}") + return None + + print(f"Loading latent stats from {path}") + stats = torch.load(path, map_location=map_location) + return stats