Skip to content

Commit

Permalink
optimizing VICReg (#173)
Browse files Browse the repository at this point in the history
* add evaluate script

* set VICReg to evaluation mode

* print best validation loss on vicreg loss plot

* multigpu compatibility

* Revert to bd2a3e8

* fix loading of mlpf models

* add plots for each component of each loss

* oops

* add notebook for ssl optimizations

* fix quick datasplit mode to make sense

* add vicreg embeddings after gnn in mlpf

* for num_convs=0

* mdmm

* add extra dnn layer for native

* early results

* pooling after decoder

* update vicreg, loss and multigpu

* up early stopping

* better gpu util

---------

Co-authored-by: Joosep Pata <joosep.pata@gmail.com>
  • Loading branch information
farakiko and jpata authored Mar 1, 2023
1 parent 359b892 commit 9d84f26
Show file tree
Hide file tree
Showing 11 changed files with 21,953 additions and 423 deletions.
7,679 changes: 7,679 additions & 0 deletions mlpf/mdmm.ipynb

Large diffs are not rendered by default.

62 changes: 59 additions & 3 deletions mlpf/pyg_ssl/VICReg.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,66 @@
import torch.nn as nn
from torch_geometric.data import Batch
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn.conv import GravNetConv

from .utils import CLUSTERS_X, TRACKS_X


# define the Encoder that learns latent representations of tracks and clusters
# these representations will be used by MLPF which is the downstream task
class VICReg(nn.Module):
def __init__(self, encoder, decoder):
super(VICReg, self).__init__()
self.encoder = encoder
self.decoder = decoder

def distinguish_PFelements(self, batch):
"""Takes an event~Batch() and splits it into two Batch() objects representing the tracks/clusters."""

track_id = 1
cluster_id = 2

tracks = Batch(
x=batch.x[batch.x[:, 0] == track_id][:, 1:].float()[
:, :TRACKS_X
], # remove the first input feature which is not needed anymore
ygen=batch.ygen[batch.x[:, 0] == track_id],
ygen_id=batch.ygen_id[batch.x[:, 0] == track_id],
ycand=batch.ycand[batch.x[:, 0] == track_id],
ycand_id=batch.ycand_id[batch.x[:, 0] == track_id],
batch=batch.batch[batch.x[:, 0] == track_id],
)
clusters = Batch(
x=batch.x[batch.x[:, 0] == cluster_id][:, 1:].float()[
:, :CLUSTERS_X
], # remove the first input feature which is not needed anymore
ygen=batch.ygen[batch.x[:, 0] == cluster_id],
ygen_id=batch.ygen_id[batch.x[:, 0] == cluster_id],
ycand=batch.ycand[batch.x[:, 0] == cluster_id],
ycand_id=batch.ycand_id[batch.x[:, 0] == cluster_id],
batch=batch.batch[batch.x[:, 0] == cluster_id],
)
return tracks, clusters

def forward(self, event):

# seperate tracks from clusters
tracks, clusters = self.distinguish_PFelements(event)

# encode to retrieve the representations
track_representations, cluster_representations = self.encoder(tracks, clusters)

# decode/expand to get the embeddings
embedding_tracks, embedding_clusters = self.decoder(track_representations, cluster_representations)

# global pooling to be able to compute a loss
pooled_tracks = global_mean_pool(embedding_tracks, tracks.batch)
pooled_clusters = global_mean_pool(embedding_clusters, clusters.batch)

return pooled_tracks, pooled_clusters


class ENCODER(nn.Module):
"""The Encoder part of VICReg which attempts to learns useful latent representations of tracks and clusters."""

def __init__(
self,
width=126,
Expand Down Expand Up @@ -66,8 +120,10 @@ def forward(self, tracks, clusters):
return embedding_tracks, embedding_clusters


# define the decoder that expands the latent representations of tracks and clusters
class DECODER(nn.Module):
"""The Decoder part of VICReg which attempts to expand the learned latent representations
of tracks and clusters into a space where a loss can be computed."""

def __init__(
self,
input_dim=34,
Expand Down
12 changes: 6 additions & 6 deletions mlpf/pyg_ssl/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ def parse_args():
parser.add_argument("--overwrite", dest="overwrite", action="store_true", help="overwrites the model if True")

# training hyperparameters
parser.add_argument("--lmbd", type=float, default=0.01, help="the lambda term in the VICReg loss")
parser.add_argument("--u", type=float, default=0.1, help="the mu term in the VICReg loss")
parser.add_argument("--v", type=float, default=0.1, help="the nu term in the VICReg loss")
parser.add_argument("--lmbd", type=float, default=1, help="the lambda term in the VICReg loss")
parser.add_argument("--mu", type=float, default=0.1, help="the mu term in the VICReg loss")
parser.add_argument("--nu", type=float, default=1e-9, help="the nu term in the VICReg loss")
parser.add_argument("--n_epochs_mlpf", type=int, default=3, help="number of training epochs for mlpf")
parser.add_argument("--n_epochs_VICReg", type=int, default=3, help="number of training epochs for VICReg")
parser.add_argument("--lr", type=float, default=5e-5, help="learning rate")
parser.add_argument("--batch_size_mlpf", type=int, default=500, help="number of events to process at once")
parser.add_argument("--batch_size_VICReg", type=int, default=2000, help="number of events to process at once")
parser.add_argument("--bs_mlpf", type=int, default=500, help="number of events to process at once")
parser.add_argument("--bs_VICReg", type=int, default=2000, help="number of events to process at once")
parser.add_argument("--patience", type=int, default=50, help="patience before early stopping")
parser.add_argument(
"--FineTune_VICReg", dest="FineTune_VICReg", action="store_true", help="FineTune VICReg during MLPFtraining"
"--FineTune_VICReg", dest="FineTune_VICReg", action="store_true", help="FineTune VICReg during MLPF training"
)

# VICReg encoder architecture
Expand Down
4 changes: 1 addition & 3 deletions mlpf/pyg_ssl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def particle_array_to_awkward(batch_ids, arr_id, arr_p4):
return ret


def evaluate(device, encoder, decoder, mlpf, batch_size_mlpf, mode, outpath, samples):
def evaluate(device, encoder, mlpf, batch_size_mlpf, mode, outpath, samples):
import fastjet
import vector
from jet_utils import build_dummy_array, match_two_jet_collections
Expand All @@ -60,7 +60,6 @@ def evaluate(device, encoder, decoder, mlpf, batch_size_mlpf, mode, outpath, sam

mlpf.eval()
encoder.eval()
decoder.eval()
for sample, data in samples.items():
print(f"Testing the {mode} model on the {sample}")

Expand All @@ -74,7 +73,6 @@ def evaluate(device, encoder, decoder, mlpf, batch_size_mlpf, mode, outpath, sam

mlpf.eval()
encoder.eval()
decoder.eval()

conf_matrix = np.zeros((6, 6))
with torch.no_grad():
Expand Down
204 changes: 123 additions & 81 deletions mlpf/pyg_ssl/mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,49 @@ def forward(self, x, mask):
return x


def ffn(input_dim, output_dim, width, act, dropout):
return nn.Sequential(
nn.Linear(input_dim, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Linear(width, output_dim),
)
def ffn(input_dim, output_dim, width, act, dropout, ssl):
if ssl:
return nn.Sequential(
nn.Linear(input_dim, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Linear(width, output_dim),
)
else:
return nn.Sequential(
nn.Linear(input_dim, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Linear(width, output_dim),
)


class MLPF(nn.Module):
Expand All @@ -80,99 +104,117 @@ def __init__(
width=126,
num_convs=2,
k=32,
native_mlpf=False,
propagate_dimensions=32,
space_dimensions=4,
dropout=0.4,
ssl=False,
VICReg_embedding_dim=0,
):
super(MLPF, self).__init__()

self.act = nn.ELU
self.native_mlpf = native_mlpf # boolean that is true for native mlpf and false for ssl
self.dropout = dropout
self.input_dim = input_dim
self.num_convs = num_convs
self.ssl = ssl # boolean that is True for ssl and False for native mlpf

# embedding of the inputs
self.nn0 = nn.Sequential(
nn.Linear(input_dim, width),
self.act(),
nn.Linear(width, width),
self.act(),
nn.Linear(width, width),
self.act(),
nn.Linear(width, embedding_dim),
)

self.conv_type = "gravnet"
# GNN that uses the embeddings learnt by VICReg as the input features
if self.conv_type == "gravnet":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()
for i in range(num_convs):
self.conv_id.append(GravNetLayer(embedding_dim, space_dimensions, propagate_dimensions, k, dropout))
self.conv_reg.append(GravNetLayer(embedding_dim, space_dimensions, propagate_dimensions, k, dropout))
elif self.conv_type == "attention":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()

for i in range(num_convs):
self.conv_id.append(SelfAttentionLayer(embedding_dim))
self.conv_reg.append(SelfAttentionLayer(embedding_dim))
if num_convs != 0:
self.nn0 = nn.Sequential(
nn.Linear(input_dim, width),
self.act(),
nn.Linear(width, width),
self.act(),
nn.Linear(width, width),
self.act(),
nn.Linear(width, embedding_dim),
)

self.conv_type = "gravnet"
# GNN that uses the embeddings learnt by VICReg as the input features
if self.conv_type == "gravnet":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()
for i in range(num_convs):
self.conv_id.append(GravNetLayer(embedding_dim, space_dimensions, propagate_dimensions, k, dropout))
self.conv_reg.append(GravNetLayer(embedding_dim, space_dimensions, propagate_dimensions, k, dropout))
elif self.conv_type == "attention":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()

for i in range(num_convs):
self.conv_id.append(SelfAttentionLayer(embedding_dim))
self.conv_reg.append(SelfAttentionLayer(embedding_dim))

decoding_dim = input_dim + num_convs * embedding_dim
if ssl:
decoding_dim += VICReg_embedding_dim

# DNN that acts on the node level to predict the PID
self.nn_id = ffn(decoding_dim, NUM_CLASSES, width, self.act, dropout)
self.nn_id = ffn(decoding_dim, NUM_CLASSES, width, self.act, dropout, ssl)

# elementwise DNN for node momentum regression
self.nn_pt = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout)
self.nn_eta = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout)
self.nn_phi = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout)
self.nn_energy = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout)
self.nn_pt = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout, ssl)
self.nn_eta = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout, ssl)
self.nn_phi = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout, ssl)
self.nn_energy = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout, ssl)

# elementwise DNN for node charge regression, classes (-1, 0, 1)
self.nn_charge = ffn(decoding_dim + NUM_CLASSES, 3, width, self.act, dropout)
self.nn_charge = ffn(decoding_dim + NUM_CLASSES, 3, width, self.act, dropout, ssl)

def forward(self, batch):

# unfold the Batch object
input_ = batch.x.float()
batch_idx = batch.batch
if self.ssl:
input_ = batch.x.float()[:, : self.input_dim]
VICReg_embeddings = batch.x.float()[:, self.input_dim :]
else:
input_ = batch.x.float()

embedding = self.nn0(input_)
batch_idx = batch.batch

embeddings_id = []
embeddings_reg = []

if self.conv_type == "gravnet":
# perform a series of graph convolutions
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
embeddings_id.append(conv(conv_input, batch_idx))
for num, conv in enumerate(self.conv_reg):
conv_input = embedding if num == 0 else embeddings_reg[-1]
embeddings_reg.append(conv(conv_input, batch_idx))
elif self.conv_type == "attention":
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
input_padded, mask = torch_geometric.utils.to_dense_batch(conv_input, batch_idx)
out_padded = conv(input_padded, ~mask)
out_stacked = torch.cat([out_padded[i][mask[i]] for i in range(out_padded.shape[0])])
assert out_stacked.shape[0] == conv_input.shape[0]
embeddings_id.append(out_stacked)
for num, conv in enumerate(self.conv_reg):
conv_input = embedding if num == 0 else embeddings_reg[-1]
input_padded, mask = torch_geometric.utils.to_dense_batch(conv_input, batch_idx)
out_padded = conv(input_padded, ~mask)
out_stacked = torch.cat([out_padded[i][mask[i]] for i in range(out_padded.shape[0])])
assert out_stacked.shape[0] == conv_input.shape[0]
embeddings_reg.append(out_stacked)

embedding_id = torch.cat([input_] + embeddings_id, axis=-1)
if self.num_convs != 0:
embedding = self.nn0(input_)

if self.conv_type == "gravnet":
# perform a series of graph convolutions
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
embeddings_id.append(conv(conv_input, batch_idx))
for num, conv in enumerate(self.conv_reg):
conv_input = embedding if num == 0 else embeddings_reg[-1]
embeddings_reg.append(conv(conv_input, batch_idx))
elif self.conv_type == "attention":
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
input_padded, mask = torch_geometric.utils.to_dense_batch(conv_input, batch_idx)
out_padded = conv(input_padded, ~mask)
out_stacked = torch.cat([out_padded[i][mask[i]] for i in range(out_padded.shape[0])])
assert out_stacked.shape[0] == conv_input.shape[0]
embeddings_id.append(out_stacked)
for num, conv in enumerate(self.conv_reg):
conv_input = embedding if num == 0 else embeddings_reg[-1]
input_padded, mask = torch_geometric.utils.to_dense_batch(conv_input, batch_idx)
out_padded = conv(input_padded, ~mask)
out_stacked = torch.cat([out_padded[i][mask[i]] for i in range(out_padded.shape[0])])
assert out_stacked.shape[0] == conv_input.shape[0]
embeddings_reg.append(out_stacked)

if self.ssl:
embedding_id = torch.cat([input_] + embeddings_id + [VICReg_embeddings], axis=-1)
else:
embedding_id = torch.cat([input_] + embeddings_id, axis=-1)

# predict the PIDs
preds_id = self.nn_id(embedding_id)

embedding_reg = torch.cat([input_] + embeddings_reg + [preds_id], axis=-1)
if self.ssl:
embedding_reg = torch.cat([input_] + embeddings_reg + [preds_id] + [VICReg_embeddings], axis=-1)
else:
embedding_reg = torch.cat([input_] + embeddings_reg + [preds_id], axis=-1)

# predict the 4-momentum, add it to the (pt, eta, phi, E) of the PFelement
preds_pt = self.nn_pt(embedding_reg) + input_[:, 1:2]
Expand Down
Loading

0 comments on commit 9d84f26

Please sign in to comment.