Skip to content

Commit

Permalink
tune the pytorch mlpf model to be more similar to tf (#165)
Browse files Browse the repository at this point in the history
Former-commit-id: 96d053e
  • Loading branch information
jpata authored Jan 26, 2023
1 parent 71fc273 commit 8c90b73
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 25 deletions.
62 changes: 44 additions & 18 deletions mlpf/pyg_ssl/mlpf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn
from torch_geometric.nn.conv import GravNetConv

Expand All @@ -9,10 +10,12 @@ def __init__(
self,
input_dim=34,
width=126,
num_convs=2,
num_convs=3,
k=8,
embedding_dim=34,
embedding_dim=128,
native_mlpf=False,
propagate_dimensions=32,
space_dimensions=4,
):
super(MLPF, self).__init__()

Expand All @@ -22,31 +25,41 @@ def __init__(
if native_mlpf:
# embedding of the inputs that is necessary for native mlpf training
self.nn0 = nn.Sequential(
nn.Linear(input_dim, 126),
nn.Linear(input_dim, width),
self.act(),
nn.Linear(126, 126),
nn.Linear(width, width),
self.act(),
nn.Linear(126, 126),
nn.Linear(width, width),
self.act(),
nn.Linear(126, embedding_dim),
nn.Linear(width, embedding_dim),
)

# GNN that uses the embeddings learnt by VICReg as the input features
self.conv = nn.ModuleList()
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()
for i in range(num_convs):
self.conv.append(
self.conv_id.append(
GravNetConv(
embedding_dim,
embedding_dim,
space_dimensions=4,
propagate_dimensions=22,
space_dimensions=space_dimensions,
propagate_dimensions=propagate_dimensions,
k=k,
)
)
self.conv_reg.append(
GravNetConv(
embedding_dim,
embedding_dim,
space_dimensions=space_dimensions,
propagate_dimensions=propagate_dimensions,
k=k,
)
)

# DNN that acts on the node level to predict the PID
self.nn_id = nn.Sequential(
nn.Linear(embedding_dim, width),
nn.Linear(num_convs * embedding_dim, width),
self.act(),
nn.Linear(width, width),
self.act(),
Expand All @@ -55,8 +68,9 @@ def __init__(
nn.Linear(width, NUM_CLASSES),
)

# elementwise DNN for node momentum regression
self.nn_reg = nn.Sequential(
nn.Linear(embedding_dim, width),
nn.Linear(num_convs * embedding_dim, width),
self.act(),
nn.Linear(width, width),
self.act(),
Expand All @@ -65,8 +79,9 @@ def __init__(
nn.Linear(width, 4),
)

# elementwise DNN for node charge regression
self.nn_charge = nn.Sequential(
nn.Linear(embedding_dim, width),
nn.Linear(num_convs * embedding_dim, width),
self.act(),
nn.Linear(width, width),
self.act(),
Expand All @@ -87,16 +102,27 @@ def forward(self, batch):
else:
embedding = input_

embeddings_id = []
embeddings_reg = []

# perform a series of graph convolutions
for num, conv in enumerate(self.conv):
embedding = conv(embedding, batch)
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
embeddings_id.append(conv(conv_input, batch))

for num, conv in enumerate(self.conv_reg):
conv_input = embedding if num == 0 else embeddings_reg[-1]
embeddings_reg.append(conv(conv_input, batch))

embedding_id = torch.cat(embeddings_id, axis=-1)
embedding_reg = torch.cat(embeddings_reg, axis=-1)

# predict the PIDs
preds_id = self.nn_id(embedding)
preds_id = self.nn_id(embedding_id)

# predict the 4-momentum, add it to the (pt, eta, phi, E) of the PFelement
preds_momentum = self.nn_reg(embedding) + input_[:, 1:5]
preds_momentum = self.nn_reg(embedding_reg) + input_[:, 1:5]

pred_charge = self.nn_charge(embedding)
pred_charge = self.nn_charge(embedding_reg)

return preds_id, preds_momentum, pred_charge
24 changes: 19 additions & 5 deletions mlpf/pyg_ssl/training_mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def train(device, encoder, mlpf, train_loader, valid_loader, optimizer, optimize
# initialize loss counters
losses = 0

epoch_loss_id = 0.0
epoch_loss_momentum = 0.0
epoch_loss_charge = 0.0

for i, batch in enumerate(loader):

if mode == "ssl":
Expand All @@ -86,15 +90,16 @@ def train(device, encoder, mlpf, train_loader, valid_loader, optimizer, optimize
target_momentum = event_on_device.ygen[:, 1:].to(dtype=torch.float32)
target_charge = event_on_device.ygen[:, 0:1].to(dtype=torch.float32)

weights = compute_weights(device, target_ids, num_classes=6) # to accomodate class imbalance
loss_id = torch.nn.functional.cross_entropy(pred_ids_one_hot, target_ids, weight=weights) # for classifying PID
loss_id = torch.nn.functional.cross_entropy(pred_ids_one_hot, target_ids) # for classifying PID

# for regression, mask the loss in cases there is no true particle
msk_true_particle = torch.unsqueeze((target_ids != 0).to(dtype=torch.float32), axis=-1)
loss_momentum = torch.nn.functional.huber_loss(
pred_momentum * msk_true_particle, target_momentum * msk_true_particle
) # for regressing p4
loss_charge = torch.nn.functional.huber_loss(
pred_charge * msk_true_particle, target_charge * msk_true_particle
) # for regressing p4
) # for regressing charge
loss = loss_id + loss_momentum + loss_charge

# update parameters
Expand All @@ -109,8 +114,17 @@ def train(device, encoder, mlpf, train_loader, valid_loader, optimizer, optimize

losses += loss.detach()

epoch_loss_id += loss_id.detach().cpu().item() / len(loader)
epoch_loss_momentum += loss_momentum.detach().cpu().item() / len(loader)
epoch_loss_charge += loss_charge.detach().cpu().item() / len(loader)

losses = losses.cpu().item() / len(loader)

print(
"loss_id={:.2f} loss_momentum={:.2f} loss_charge={:.2f}".format(
epoch_loss_id, epoch_loss_momentum, epoch_loss_charge
)
)
return losses


Expand Down Expand Up @@ -195,8 +209,8 @@ def training_loop_mlpf(
)

fig, ax = plt.subplots()
ax.plot(range(len(losses_train)), losses_train, label="training")
ax.plot(range(len(losses_valid)), losses_valid, label="validation")
ax.plot(range(len(losses_train)), losses_train, label="training ({:.2f})".format(losses_train[-1]))
ax.plot(range(len(losses_valid)), losses_valid, label="validation ({:.2f})".format(losses_valid[-1]))
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")
if mode == "ssl":
Expand Down
8 changes: 6 additions & 2 deletions mlpf/ssl_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os.path as osp

import matplotlib
import mplhep
import numpy as np
import torch
import torch_geometric
Expand All @@ -13,7 +14,7 @@
from pyg_ssl.VICReg import DECODER, ENCODER

matplotlib.use("Agg")

mplhep.set_style(mplhep.styles.CMS)

"""
Developing a PyTorch Geometric semi-supervised (VICReg-based https://arxiv.org/abs/2105.04906) pipeline
Expand Down Expand Up @@ -41,7 +42,10 @@

world_size = torch.cuda.device_count()

torch.backends.cudnn.benchmark = True
# our data size varies from batch to batch, because each set of N_batch events has a different number of particles
torch.backends.cudnn.benchmark = False

# torch.autograd.set_detect_anomaly(True)

# load the clic dataset
data_train_VICReg, data_valid_VICReg, data_train_mlpf, data_valid_mlpf = data_split(args.dataset, args.data_split_mode)
Expand Down

0 comments on commit 8c90b73

Please sign in to comment.