From ae0425875fd438c909e54651dc641c64df2bdfd8 Mon Sep 17 00:00:00 2001 From: Joosep Pata Date: Thu, 26 Oct 2023 18:06:50 +0300 Subject: [PATCH] making the 3d-padded models more efficient in pytorch (#256) * initial wip * move padding to collate * avoid compile * specify default conv type * train and valid separately * add saving weight * update submission script * valid only on rank0 * use_cuda --- mlpf/pyg/PFDataset.py | 35 +++-- mlpf/pyg/gnn_lsh.py | 41 +++--- mlpf/pyg/mlpf.py | 36 ++--- mlpf/pyg/training.py | 254 +++++++++++--------------------- mlpf/pyg/utils.py | 34 +++-- mlpf/pyg_pipeline.py | 68 ++++++++- scripts/tallinn/a100/pytorch.sh | 15 ++ scripts/tallinn/rtx/pytorch.sh | 15 ++ 8 files changed, 248 insertions(+), 250 deletions(-) create mode 100755 scripts/tallinn/a100/pytorch.sh create mode 100755 scripts/tallinn/rtx/pytorch.sh diff --git a/mlpf/pyg/PFDataset.py b/mlpf/pyg/PFDataset.py index c2b8653f9..5ad17d937 100644 --- a/mlpf/pyg/PFDataset.py +++ b/mlpf/pyg/PFDataset.py @@ -1,17 +1,18 @@ from typing import List, Optional +from types import SimpleNamespace import tensorflow_datasets as tfds import torch import torch.utils.data from torch import Tensor +import torch_geometric from torch_geometric.data import Batch, Data -from torch_geometric.data.data import BaseData class PFDataset: """Builds a DataSource from tensorflow datasets.""" - def __init__(self, data_dir, name, split, keys_to_get, num_samples=None): + def __init__(self, data_dir, name, split, keys_to_get, pad_3d=True, num_samples=None): """ Args data_dir: path to tensorflow_datasets (e.g. `../data/tensorflow_datasets/`) @@ -29,7 +30,6 @@ def __init__(self, data_dir, name, split, keys_to_get, num_samples=None): # to make dataset_info pickable tmp = self.ds.dataset_info - from types import SimpleNamespace self.ds.dataset_info = SimpleNamespace() self.ds.dataset_info.name = tmp.name @@ -39,6 +39,8 @@ def __init__(self, data_dir, name, split, keys_to_get, num_samples=None): # any selection of ["X", "ygen", "ycand"] to retrieve self.keys_to_get = keys_to_get + self.pad_3d = pad_3d + if num_samples: self.ds = torch.utils.data.Subset(self.ds, range(num_samples)) @@ -50,7 +52,7 @@ def get_distributed_sampler(self): sampler = torch.utils.data.distributed.DistributedSampler(self.ds) return sampler - def get_loader(self, batch_size, world_size, num_workers=0, prefetch_factor=None): + def get_loader(self, batch_size, world_size, rank, use_cuda=False, num_workers=0, prefetch_factor=None): if (num_workers > 0) and (prefetch_factor is None): prefetch_factor = 2 # default prefetch_factor when num_workers>0 @@ -62,10 +64,12 @@ def get_loader(self, batch_size, world_size, num_workers=0, prefetch_factor=None return DataLoader( self.ds, batch_size=batch_size, - collate_fn=Collater(self.keys_to_get), + collate_fn=Collater(self.keys_to_get, pad_3d=self.pad_3d), sampler=sampler, num_workers=num_workers, prefetch_factor=prefetch_factor, + pin_memory=use_cuda, + pin_memory_device="cuda:{}".format(rank) if use_cuda else "", ) def __len__(self): @@ -109,10 +113,12 @@ def __init__( class Collater: """Based on the Collater found on torch_geometric docs we build our own.""" - def __init__(self, keys_to_get, follow_batch=None, exclude_keys=None): + def __init__(self, keys_to_get, follow_batch=None, exclude_keys=None, pad_bin_size=640, pad_3d=True): self.follow_batch = follow_batch self.exclude_keys = exclude_keys self.keys_to_get = keys_to_get + self.pad_bin_size = pad_bin_size + self.pad_3d = pad_3d def __call__(self, inputs): num_samples_in_batch = len(inputs) @@ -125,12 +131,21 @@ def __call__(self, inputs): batch[ev][elem_key] = Tensor(inputs[ev][elem_key]) batch[ev]["batch"] = torch.tensor([ev] * len(inputs[ev][elem_key])) - elem = batch[0] + ret = Batch.from_data_list(batch, self.follow_batch, self.exclude_keys) + + if not self.pad_3d: + return ret + else: + ret = {k: torch_geometric.utils.to_dense_batch(getattr(ret, k), ret.batch) for k in elem_keys} + + ret["mask"] = ret["X"][1] - if isinstance(elem, BaseData): - return Batch.from_data_list(batch, self.follow_batch, self.exclude_keys) + # remove the mask from each element + for k in elem_keys: + ret[k] = ret[k][0] - raise TypeError(f"DataLoader found invalid type: {type(elem)}") + ret = Batch(**ret) + return ret def my_getitem(self, vals): diff --git a/mlpf/pyg/gnn_lsh.py b/mlpf/pyg/gnn_lsh.py index 7239efca6..f3a375bda 100644 --- a/mlpf/pyg/gnn_lsh.py +++ b/mlpf/pyg/gnn_lsh.py @@ -33,11 +33,6 @@ def point_wise_feed_forward_network( return nn.Sequential(*layers) -# @torch.compile -def index_dim(a, b): - return a[b] - - def split_indices_to_bins_batch(cmul, nbins, bin_size, msk): a = torch.argmax(cmul, axis=-1) @@ -143,7 +138,6 @@ def forward(self, x_msg_binned, msk, training=False): return dm -@torch.compile def split_msk_and_msg(bins_split, cmul, x_msg, x_node, msk, n_bins, bin_size): bins_split_2 = torch.reshape(bins_split, (bins_split.shape[0], bins_split.shape[1] * bins_split.shape[2])) @@ -164,6 +158,23 @@ def split_msk_and_msg(bins_split, cmul, x_msg, x_node, msk, n_bins, bin_size): return x_msg_binned, x_features_binned, msk_f_binned +def reverse_lsh(bins_split, points_binned_enc): + shp = points_binned_enc.shape + batch_dim = shp[0] + n_points = shp[1] * shp[2] + n_features = shp[-1] + + bins_split_flat = torch.reshape(bins_split, (batch_dim, n_points)) + points_binned_enc_flat = torch.reshape(points_binned_enc, (batch_dim, n_points, n_features)) + + ret = torch.zeros(batch_dim, n_points, n_features, device=points_binned_enc.device) + for ibatch in range(batch_dim): + # torch._assert(torch.min(bins_split_flat[ibatch]) >= 0, "reverse_lsh n_points min") + # torch._assert(torch.max(bins_split_flat[ibatch]) < n_points, "reverse_lsh n_points max") + ret[ibatch][bins_split_flat[ibatch]] = points_binned_enc_flat[ibatch] + return ret + + class MessageBuildingLayerLSH(nn.Module): def __init__(self, distance_dim=128, max_num_bins=200, bin_size=128, kernel=NodePairGaussianKernel(), **kwargs): self.initializer = kwargs.pop("initializer", "random_normal") @@ -227,24 +238,6 @@ def forward(self, x_msg, x_node, msk, training=False): return bins_split, x_features_binned, dm, msk_f_binned -@torch.compile -def reverse_lsh(bins_split, points_binned_enc): - shp = points_binned_enc.shape - batch_dim = shp[0] - n_points = shp[1] * shp[2] - n_features = shp[-1] - - bins_split_flat = torch.reshape(bins_split, (batch_dim, n_points)) - points_binned_enc_flat = torch.reshape(points_binned_enc, (batch_dim, n_points, n_features)) - - ret = torch.zeros(batch_dim, n_points, n_features, device=points_binned_enc.device) - for ibatch in range(batch_dim): - torch._assert(torch.min(bins_split_flat[ibatch]) >= 0, "reverse_lsh n_points min") - torch._assert(torch.max(bins_split_flat[ibatch]) < n_points, "reverse_lsh n_points max") - ret[ibatch][bins_split_flat[ibatch]] = points_binned_enc_flat[ibatch] - return ret - - class CombinedGraphLayer(nn.Module): def __init__(self, *args, **kwargs): self.inout_dim = kwargs.pop("inout_dim") diff --git a/mlpf/pyg/mlpf.py b/mlpf/pyg/mlpf.py index a5a813f78..caefbec0e 100644 --- a/mlpf/pyg/mlpf.py +++ b/mlpf/pyg/mlpf.py @@ -1,7 +1,5 @@ import torch import torch.nn as nn -import torch_geometric -import torch_geometric.utils from torch_geometric.nn.conv import GravNetConv from .gnn_lsh import CombinedGraphLayer @@ -53,12 +51,6 @@ def ffn(input_dim, output_dim, width, act, dropout): ) -@torch.compile -def unpad(data_padded, mask): - A = data_padded[mask] - return A - - class MLPF(nn.Module): def __init__( self, @@ -132,21 +124,14 @@ def __init__( def forward(self, event): # unfold the Batch object input_ = event.X.float() - batch_idx = event.batch embeddings_id, embeddings_reg = [], [] if self.num_convs != 0: - embedding = self.nn0(input_) - - if self.conv_type != "gravnet": - _, num_nodes = torch.unique(batch_idx, return_counts=True) - max_num_nodes = torch.max(num_nodes).cpu() - max_num_nodes_padded = ((max_num_nodes // self.bin_size) + 1) * self.bin_size - embedding, mask = torch_geometric.utils.to_dense_batch( - embedding, batch_idx, max_num_nodes=max_num_nodes_padded - ) if self.conv_type == "gravnet": + embedding = self.nn0(input_) + + batch_idx = event.batch # perform a series of graph convolutions for num, conv in enumerate(self.conv_id): conv_input = embedding if num == 0 else embeddings_id[-1] @@ -155,6 +140,8 @@ def forward(self, event): conv_input = embedding if num == 0 else embeddings_reg[-1] embeddings_reg.append(conv(conv_input, batch_idx)) else: + mask = event.mask + embedding = self.nn0(input_) for num, conv in enumerate(self.conv_id): conv_input = embedding if num == 0 else embeddings_id[-1] out_padded = conv(conv_input, ~mask) @@ -164,11 +151,6 @@ def forward(self, event): out_padded = conv(conv_input, ~mask) embeddings_reg.append(out_padded) - if self.conv_type != "gravnet": - embeddings_id = [unpad(emb, mask) for emb in embeddings_id] - embeddings_reg = [unpad(emb, mask) for emb in embeddings_reg] - - # classification embedding_id = torch.cat([input_] + embeddings_id, axis=-1) preds_id = self.nn_id(embedding_id) @@ -183,10 +165,10 @@ def forward(self, event): # predict the 4-momentum, add it to the (pt, eta, sin phi, cos phi, E) of the input PFelement # the feature order is defined in fcc/postprocessing.py -> track_feature_order, cluster_feature_order - preds_pt = self.nn_pt(embedding_reg) + input_[:, 1:2] - preds_eta = self.nn_eta(embedding_reg) + input_[:, 2:3] - preds_phi = self.nn_phi(embedding_reg) + input_[:, 3:5] - preds_energy = self.nn_energy(embedding_reg) + input_[:, 5:6] + preds_pt = self.nn_pt(embedding_reg) + input_[..., 1:2] + preds_eta = self.nn_eta(embedding_reg) + input_[..., 2:3] + preds_phi = self.nn_phi(embedding_reg) + input_[..., 3:5] + preds_energy = self.nn_energy(embedding_reg) + input_[..., 5:6] preds_momentum = torch.cat([preds_pt, preds_eta, preds_phi, preds_energy], axis=-1) pred_charge = self.nn_charge(embedding_reg) diff --git a/mlpf/pyg/training.py b/mlpf/pyg/training.py index ada6c7510..7827cf68a 100644 --- a/mlpf/pyg/training.py +++ b/mlpf/pyg/training.py @@ -16,7 +16,7 @@ from .logger import _logger from .utils import unpack_predictions, unpack_target -# from torch.profiler import profile, record_function, ProfilerActivity +from torch.profiler import profile, record_function, ProfilerActivity # Ignore divide by 0 errors @@ -33,16 +33,23 @@ def mlpf_loss(y, ypred): """ loss = {} loss_obj_id = FocalLoss(gamma=2.0) - loss["Classification"] = 100 * loss_obj_id(ypred["cls_id_onehot"], y["cls_id"]) msk_true_particle = torch.unsqueeze((y["cls_id"] != 0).to(dtype=torch.float32), axis=-1) - loss["Regression"] = 10 * torch.nn.functional.huber_loss( - ypred["momentum"] * msk_true_particle, y["momentum"] * msk_true_particle - ) - loss["Charge"] = torch.nn.functional.cross_entropy( - ypred["charge"] * msk_true_particle, (y["charge"] * msk_true_particle[:, 0]).to(dtype=torch.int64) - ) + ypred["momentum"] = ypred["momentum"] * msk_true_particle + ypred["charge"] = ypred["charge"] * msk_true_particle + y["momentum"] = y["momentum"] * msk_true_particle + y["charge"] = y["charge"] * msk_true_particle[..., 0] + + # pytorch expects (N, C, ...) + if ypred["cls_id_onehot"].ndim > 2: + ypred["cls_id_onehot"] = ypred["cls_id_onehot"].permute((0, 2, 1)) + ypred["charge"] = ypred["charge"].permute((0, 2, 1)) + + loss["Classification"] = 100 * loss_obj_id(ypred["cls_id_onehot"], y["cls_id"]) + + loss["Regression"] = 10 * torch.nn.functional.huber_loss(ypred["momentum"], y["momentum"]) + loss["Charge"] = torch.nn.functional.cross_entropy(ypred["charge"], y["charge"].to(dtype=torch.int64)) loss["Total"] = loss["Classification"] + loss["Regression"] + loss["Charge"] return loss @@ -79,13 +86,12 @@ def __init__( super().__init__() self.alpha = alpha self.gamma = gamma - self.ignore_index = ignore_index self.reduction = reduction - self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none", ignore_index=ignore_index) + self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none") def __repr__(self): - arg_keys = ["alpha", "gamma", "ignore_index", "reduction"] + arg_keys = ["alpha", "gamma", "reduction"] arg_vals = [self.__dict__[k] for k in arg_keys] arg_strs = [f"{k}={v!r}" for k, v in zip(arg_keys, arg_vals)] arg_str = ", ".join(arg_strs) @@ -99,20 +105,16 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor: # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,) y = y.view(-1) - unignored_mask = y != self.ignore_index - y = y[unignored_mask] - if len(y) == 0: - return torch.tensor(0.0) - x = x[unignored_mask] - # compute weighted cross entropy term: -alpha * log(pt) # (alpha is already part of self.nll_loss) log_p = F.log_softmax(x, dim=-1) ce = self.nll_loss(log_p, y) # get true class column from each row - all_rows = torch.arange(len(x)) - log_pt = log_p[all_rows, y] + # this is slow due to indexing + # all_rows = torch.arange(len(x)) + # log_pt = log_p[all_rows, y] + log_pt = torch.gather(log_p, 1, y.unsqueeze(axis=-1)).squeeze(axis=-1) # compute focal term: (1 - pt)^gamma pt = log_pt.exp() @@ -129,155 +131,58 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor: return loss -def train( - rank, - world_size, - model, - optimizer, - train_loader, - valid_loader, - best_val_loss, - stale_epochs, - patience, - outdir, - tensorboard_writer=None, -): +def train_and_valid(rank, world_size, model, optimizer, data_loader, is_train): """ Performs training over a given epoch. Will run a validation step every N_STEPS and after the last training batch. """ - global ISTEP_GLOBAL - - N_STEPS = 1000 # number of steps before running validation - - _logger.info(f"Initiating a training run on device {rank}", color="red") - # initialize loss counters (note: these will be reset after N_STEPS) - train_loss = {"Total": 0.0, "Classification": 0.0, "Regression": 0.0, "Charge": 0.0} - valid_loss = {"Total": 0.0, "Classification": 0.0, "Regression": 0.0, "Charge": 0.0} + _logger.info(f"Initiating a train={is_train} run on device rank={rank}", color="red") # this one will keep accumulating `train_loss` and then return the average epoch_loss = {"Total": 0.0, "Classification": 0.0, "Regression": 0.0, "Charge": 0.0} - istep = 0 - model.train() - for itrain, batch in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)): - istep += 1 - - ygen = unpack_target(batch.to(rank).ygen) - ypred = unpack_predictions(model(batch.to(rank))) - - # JP: need to debug this - # assert np.all(target_charge.unique().cpu().numpy() == [0, 1, 2]) - loss = mlpf_loss(ygen, ypred) + if is_train: + model.train() + else: + model.eval() - for param in model.parameters(): - param.grad = None - loss["Total"].backward() - optimizer.step() + for itrain, batch in tqdm.tqdm(enumerate(data_loader), total=len(data_loader)): + batch = batch.to(rank, non_blocking=True) - for loss_ in train_loss: - train_loss[loss_] += loss[loss_].detach() - for loss_ in epoch_loss: - epoch_loss[loss_] += loss[loss_].detach() + ygen = unpack_target(batch.ygen) - # run a quick validation run at intervals of N_STEPS or at the last step - if (((itrain % N_STEPS) == 0) and (itrain != 0)) or (itrain == (len(train_loader) - 1)): - if itrain == (len(train_loader) - 1): - nsteps = istep + if is_train: + ypred = model(batch) + else: + if world_size > 1: # validation is only run on a single machine + ypred = model.module(batch) else: - nsteps = N_STEPS - istep = 0 + ypred = model(batch) - if tensorboard_writer: - for loss_ in train_loss: - tensorboard_writer.add_scalar(f"step_train/loss_{loss_}", train_loss[loss_] / nsteps, ISTEP_GLOBAL) - tensorboard_writer.flush() + ypred = unpack_predictions(ypred) - _logger.info( - f"Rank {rank}: " - + f"train_loss_tot={train_loss['Total']/nsteps:.2f} " - + f"train_loss_id={train_loss['Classification']/nsteps:.2f} " - + f"train_loss_momentum={train_loss['Regression']/nsteps:.2f} " - + f"train_loss_charge={train_loss['Charge']/nsteps:.2f} " - ) - train_loss = {"Total": 0.0, "Classification": 0.0, "Regression": 0.0, "Charge": 0.0} + if is_train: + loss = mlpf_loss(ygen, ypred) + for param in model.parameters(): + param.grad = None + else: + with torch.no_grad(): + loss = mlpf_loss(ygen, ypred) - if world_size > 1: - dist.barrier() # wait until training run is finished on all ranks before running the validation + if is_train: + loss["Total"].backward() + optimizer.step() - if (rank == 0) or (rank == "cpu"): - _logger.info(f"Initiating a quick validation run on device {rank}", color="red") - model.eval() - - valid_loss = {"Total": 0.0, "Classification": 0.0, "Regression": 0.0, "Charge": 0.0} - with torch.no_grad(): - for ival, batch in tqdm.tqdm(enumerate(valid_loader), total=len(valid_loader)): - ygen = unpack_target(batch.to(rank).ygen) - if world_size > 1: # validation is only run on a single machine - ypred = unpack_predictions(model.module(batch.to(rank))) - else: - ypred = unpack_predictions(model(batch.to(rank))) - - loss = mlpf_loss(ygen, ypred) - - for loss_ in valid_loss: - valid_loss[loss_] += loss[loss_].detach() - - for loss_ in valid_loss: - valid_loss[loss_] = valid_loss[loss_].cpu().item() / len(valid_loader) - - if tensorboard_writer: - for loss_ in valid_loss: - tensorboard_writer.add_scalar(f"step_valid/loss_{loss_}", valid_loss[loss_], ISTEP_GLOBAL) - - if valid_loss["Total"] < best_val_loss: - best_val_loss = valid_loss["Total"] - - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - model_state_dict = model.module.state_dict() - else: - model_state_dict = model.state_dict() - - torch.save( - {"model_state_dict": model_state_dict, "optimizer_state_dict": optimizer.state_dict()}, - f"{outdir}/best_weights.pth", - ) - _logger.info( - f"finished {itrain+1}/{len(train_loader)} iterations and saved the model at {outdir}/best_weights.pth" # noqa - ) - stale_epochs = torch.tensor(0, device=rank) - else: - _logger.info(f"finished {itrain}/{len(train_loader)} iterations") - stale_epochs += 1 - - _logger.info( - f"Rank {rank}: " - + f"val_loss_tot={valid_loss['Total']:.2f} " - + f"val_loss_id={valid_loss['Classification']:.2f} " - + f"val_loss_momentum={valid_loss['Regression']:.2f} " - + f"val_loss_charge={valid_loss['Charge']:.2f} " - + f"best_val_loss={best_val_loss:.2f} " - + f"stale={stale_epochs} " - ) - ISTEP_GLOBAL += 1 - - model.train() # prepare for next training loop - - if world_size > 1: - dist.barrier() # wait until validation run on rank 0 is finished before going to the next epoch - dist.broadcast(stale_epochs, src=0) # broadcast stale_epochs to all gpus - - if stale_epochs > patience: - _logger.info("breaking due to stale epochs") - return None, None, None, stale_epochs + for loss_ in epoch_loss: + epoch_loss[loss_] += loss[loss_].detach() - if tensorboard_writer: - tensorboard_writer.flush() + if world_size > 1: + dist.barrier() for loss_ in epoch_loss: - epoch_loss[loss_] = epoch_loss[loss_].cpu().item() / len(train_loader) + epoch_loss[loss_] = epoch_loss[loss_].cpu().item() / len(data_loader) - return epoch_loss, valid_loss, best_val_loss, stale_epochs + return epoch_loss def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, num_epochs, patience, outdir, hpo=False): @@ -307,7 +212,7 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n for loss in losses_of_interest: losses["train"][loss], losses["valid"][loss] = [], [] - stale_epochs, best_val_loss = torch.tensor(0, device=rank), 99999.9 + stale_epochs, best_val_loss = torch.tensor(0, device=rank), float("inf") start_epoch = 0 if hpo: @@ -329,19 +234,35 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n t0 = time.time() # training step - losses_t, losses_v, best_val_loss, stale_epochs = train( - rank, - world_size, - model, - optimizer, - train_loader, - valid_loader, - best_val_loss, - stale_epochs, - patience, - outdir, - tensorboard_writer, - ) + if epoch == -1: + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, with_stack=True + ) as prof: + with record_function("model_train"): + losses_t = train_and_valid(rank, world_size, model, optimizer, train_loader, True) + prof.export_chrome_trace("trace.json") + else: + losses_t = train_and_valid(rank, world_size, model, optimizer, train_loader, True) + + if (rank == 0) or (rank == "cpu"): + losses_v = train_and_valid(rank, world_size, model, optimizer, valid_loader, False) + if losses_v["Total"] < best_val_loss: + best_val_loss = losses_v["Total"] + stale_epochs = 0 + else: + stale_epochs += 1 + + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model_state_dict = model.module.state_dict() + else: + model_state_dict = model.state_dict() + + torch.save( + {"model_state_dict": model_state_dict, "optimizer_state_dict": optimizer.state_dict()}, + # "{outdir}/weights-{epoch:02d}-{val_loss:.6f}.pth".format( + # outdir=outdir, epoch=epoch+1, val_loss=losses_v["Total"]), + f"{outdir}/best_weights.pth", + ) if hpo: # save model, optimizer and epoch number for HPO-supported checkpointing @@ -366,9 +287,8 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n if stale_epochs > patience: break - # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - # with record_function("model_train"): - # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) + for k, v in losses_t.items(): + tensorboard_writer.add_scalar(f"epoch/train_loss_rank_{rank}_" + k, v, epoch) if (rank == 0) or (rank == "cpu"): for k, v in losses_t.items(): @@ -380,8 +300,6 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n for k, v in losses_v.items(): tensorboard_writer.add_scalar("epoch/valid_loss_" + k, v, epoch) - tensorboard_writer.flush() - t1 = time.time() epochs_remaining = num_epochs - (epoch + 1) @@ -422,4 +340,6 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n with open(f"{outdir}/mlpf_losses.pkl", "wb") as f: pkl.dump(losses, f) + tensorboard_writer.flush() + _logger.info(f"Done with training. Total training time on device {rank} is {round((time.time() - t0_initial)/60,3)}min") diff --git a/mlpf/pyg/utils.py b/mlpf/pyg/utils.py index d95bc656a..451e0cdc5 100644 --- a/mlpf/pyg/utils.py +++ b/mlpf/pyg/utils.py @@ -113,22 +113,22 @@ def unpack_target(y): ret = {} - ret["cls_id"] = y[:, 0].long() - ret["charge"] = torch.clamp((y[:, 1] + 1).to(dtype=torch.float32), 0, 2) # -1, 0, 1 -> 0, 1, 2 + ret["cls_id"] = y[..., 0].long() + ret["charge"] = torch.clamp((y[..., 1] + 1).to(dtype=torch.float32), 0, 2) # -1, 0, 1 -> 0, 1, 2 for i, feat in enumerate(Y_FEATURES): if i >= 2: # skip the cls and charge as they are defined above - ret[feat] = y[:, i].to(dtype=torch.float32) + ret[feat] = y[..., i].to(dtype=torch.float32) ret["phi"] = torch.atan2(ret["sin_phi"], ret["cos_phi"]) # do some sanity checks - assert torch.all(ret["pt"] >= 0.0) # pt - assert torch.all(torch.abs(ret["sin_phi"]) <= 1.0) # sin_phi - assert torch.all(torch.abs(ret["cos_phi"]) <= 1.0) # cos_phi - assert torch.all(ret["energy"] >= 0.0) # energy + # assert torch.all(ret["pt"] >= 0.0) # pt + # assert torch.all(torch.abs(ret["sin_phi"]) <= 1.0) # sin_phi + # assert torch.all(torch.abs(ret["cos_phi"]) <= 1.0) # cos_phi + # assert torch.all(ret["energy"] >= 0.0) # energy # note ~ momentum = ["pt", "eta", "sin_phi", "cos_phi", "energy"] - ret["momentum"] = y[:, 2:-1].to(dtype=torch.float32) + ret["momentum"] = y[..., 2:-1].to(dtype=torch.float32) ret["p4"] = torch.cat( [ret["pt"].unsqueeze(1), ret["eta"].unsqueeze(1), ret["phi"].unsqueeze(1), ret["energy"].unsqueeze(1)], axis=1 ) @@ -143,17 +143,23 @@ def unpack_predictions(preds): # ret["charge"] = torch.argmax(ret["charge"], axis=1, keepdim=True) - 1 # unpacking - ret["pt"] = ret["momentum"][:, 0] - ret["eta"] = ret["momentum"][:, 1] - ret["sin_phi"] = ret["momentum"][:, 2] - ret["cos_phi"] = ret["momentum"][:, 3] - ret["energy"] = ret["momentum"][:, 4] + ret["pt"] = ret["momentum"][..., 0] + ret["eta"] = ret["momentum"][..., 1] + ret["sin_phi"] = ret["momentum"][..., 2] + ret["cos_phi"] = ret["momentum"][..., 3] + ret["energy"] = ret["momentum"][..., 4] # new variables ret["cls_id"] = torch.argmax(ret["cls_id_onehot"], axis=-1) ret["phi"] = torch.atan2(ret["sin_phi"], ret["cos_phi"]) ret["p4"] = torch.cat( - [ret["pt"].unsqueeze(1), ret["eta"].unsqueeze(1), ret["phi"].unsqueeze(1), ret["energy"].unsqueeze(1)], axis=1 + [ + ret["pt"].unsqueeze(axis=-1), + ret["eta"].unsqueeze(axis=-1), + ret["phi"].unsqueeze(axis=-1), + ret["energy"].unsqueeze(axis=-1), + ], + axis=-1, ) return ret diff --git a/mlpf/pyg_pipeline.py b/mlpf/pyg_pipeline.py index 6732728a0..317566d78 100644 --- a/mlpf/pyg_pipeline.py +++ b/mlpf/pyg_pipeline.py @@ -51,7 +51,9 @@ parser.add_argument("--num-epochs", type=int, default=None, help="number of training epochs") parser.add_argument("--patience", type=int, default=None, help="patience before early stopping") parser.add_argument("--lr", type=float, default=None, help="learning rate") -parser.add_argument("--conv-type", type=str, default=None, help="choices are ['gnn_lsh', 'gravnet', 'attention']") +parser.add_argument( + "--conv-type", type=str, default="gravnet", help="which graph layer to use", choices=["gravnet", "attention", "gnn_lsh"] +) parser.add_argument("--make-plots", action="store_true", default=None, help="make plots of the test predictions") parser.add_argument("--export-onnx", action="store_true", default=None, help="exports the model to onnx") parser.add_argument("--ntrain", type=int, default=None, help="training samples to use, if None use entire dataset") @@ -66,6 +68,9 @@ def run(rank, world_size, config, args, outdir, logfile): """Demo function that will be passed to each gpu if (world_size > 1) else will run normally on the given device.""" + pad_3d = args.conv_type != "gravnet" + use_cuda = rank != "cpu" + if world_size > 1: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" @@ -123,10 +128,26 @@ def run(rank, world_size, config, args, outdir, logfile): version = config["train_dataset"][config["dataset"]][sample]["version"] batch_size = config["train_dataset"][config["dataset"]][sample]["batch_size"] * config["gpu_batch_multiplier"] - ds = PFDataset(config["data_dir"], f"{sample}:{version}", "train", ["X", "ygen"], num_samples=config["ntrain"]) + ds = PFDataset( + config["data_dir"], + f"{sample}:{version}", + "train", + ["X", "ygen"], + pad_3d=pad_3d, + num_samples=config["ntrain"], + ) _logger.info(f"train_dataset: {ds}, {len(ds)}", color="blue") - train_loaders.append(ds.get_loader(batch_size, world_size, config["num_workers"], config["prefetch_factor"])) + train_loaders.append( + ds.get_loader( + batch_size, + world_size, + rank, + use_cuda=use_cuda, + num_workers=config["num_workers"], + prefetch_factor=config["prefetch_factor"], + ) + ) train_loader = InterleavedIterator(train_loaders) @@ -139,11 +160,25 @@ def run(rank, world_size, config, args, outdir, logfile): ) ds = PFDataset( - config["data_dir"], f"{sample}:{version}", "test", ["X", "ygen", "ycand"], num_samples=config["nvalid"] + config["data_dir"], + f"{sample}:{version}", + "test", + ["X", "ygen"], + pad_3d=pad_3d, + num_samples=config["nvalid"], ) _logger.info(f"valid_dataset: {ds}, {len(ds)}", color="blue") - valid_loaders.append(ds.get_loader(batch_size, 1, config["num_workers"], config["prefetch_factor"])) + valid_loaders.append( + ds.get_loader( + batch_size, + 1, + rank, + use_cuda=use_cuda, + num_workers=config["num_workers"], + prefetch_factor=config["prefetch_factor"], + ) + ) valid_loader = InterleavedIterator(valid_loaders) else: @@ -177,12 +212,26 @@ def run(rank, world_size, config, args, outdir, logfile): batch_size = config["test_dataset"][config["dataset"]][sample]["batch_size"] * config["gpu_batch_multiplier"] ds = PFDataset( - config["data_dir"], f"{sample}:{version}", "test", ["X", "ygen", "ycand"], num_samples=config["ntest"] + config["data_dir"], + f"{sample}:{version}", + "test", + ["X", "ygen", "ycand"], + pad_3d=pad_3d, + num_samples=config["ntest"], ) _logger.info(f"test_dataset: {ds}, {len(ds)}", color="blue") test_loaders[sample] = InterleavedIterator( - [ds.get_loader(batch_size, world_size, config["num_workers"], config["prefetch_factor"])] + [ + ds.get_loader( + batch_size, + world_size, + 0, + use_cuda=use_cuda, + num_workers=config["num_workers"], + prefetch_factor=config["prefetch_factor"], + ) + ] ) if not osp.isdir(f"{outdir}/preds/{sample}"): @@ -267,7 +316,7 @@ def device_agnostic_run(config, args, world_size, outdir): if config["gpus"]: assert ( world_size <= torch.cuda.device_count() - ), f"--gpus is too high (specefied {world_size} gpus but only {torch.cuda.device_count()} gpus are available)" + ), f"--gpus is too high (specified {world_size} gpus but only {torch.cuda.device_count()} gpus are available)" torch.cuda.empty_cache() if world_size > 1: @@ -293,6 +342,9 @@ def device_agnostic_run(config, args, world_size, outdir): def main(): + + # torch.multiprocessing.set_start_method('spawn') + args = parser.parse_args() world_size = len(args.gpus.split(",")) # will be 1 for both cpu ("") and single-gpu ("0") diff --git a/scripts/tallinn/a100/pytorch.sh b/scripts/tallinn/a100/pytorch.sh new file mode 100755 index 000000000..d1cd8e7dd --- /dev/null +++ b/scripts/tallinn/a100/pytorch.sh @@ -0,0 +1,15 @@ +#!/bin/bash +#SBATCH --partition gpu +#SBATCH --gres gpu:a100:1 +#SBATCH --mem-per-gpu 40G +#SBATCH -o logs/slurm-%x-%j-%N.out + +IMG=/home/software/singularity/pytorch.simg +cd ~/particleflow + +#TF training +singularity exec -B /scratch/persistent --nv \ + --env PYTHONPATH=hep_tfds \ + $IMG python3.10 mlpf/pyg_pipeline.py --dataset cms --gpus 0 \ + --data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pyg-cms.yaml \ + --train --conv-type gnn_lsh --num-epochs 10 --gpu-batch-multiplier 10 --num-workers 1 --prefetch-factor 10 --ntrain 200 diff --git a/scripts/tallinn/rtx/pytorch.sh b/scripts/tallinn/rtx/pytorch.sh new file mode 100755 index 000000000..5553ce416 --- /dev/null +++ b/scripts/tallinn/rtx/pytorch.sh @@ -0,0 +1,15 @@ +#!/bin/bash +#SBATCH --partition gpu +#SBATCH --gres gpu:rtx:2 +#SBATCH --mem-per-gpu 40G +#SBATCH -o logs/slurm-%x-%j-%N.out + +IMG=/home/software/singularity/pytorch.simg +cd ~/particleflow + +#TF training +singularity exec -B /scratch/persistent --nv \ + --env PYTHONPATH=hep_tfds \ + $IMG python3.10 mlpf/pyg_pipeline.py --dataset cms --gpus 0,1 \ + --data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pyg-cms-small.yaml \ + --train --conv-type gnn_lsh --num-epochs 10 --ntrain 1000 --ntest 1000 --gpu-batch-multiplier 1 --num-workers 1 --prefetch-factor 10