Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizing VICReg #173

Merged
merged 57 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
1c77a24
up
farakiko Feb 6, 2023
a333d92
up
farakiko Feb 6, 2023
22b6685
add evaluate script
farakiko Feb 6, 2023
0582ead
set VICReg to evaluation mode
farakiko Feb 8, 2023
bd2a3e8
print best validation loss on vicreg loss plot
farakiko Feb 8, 2023
dd32d13
up
farakiko Feb 8, 2023
130cd90
Merge branch 'main' into ssl-studies
farakiko Feb 8, 2023
47e2f80
multigpu compatibility
farakiko Feb 13, 2023
b9df70e
up
farakiko Feb 13, 2023
e0d8966
up
farakiko Feb 13, 2023
4576366
up
farakiko Feb 13, 2023
be41ee1
up
farakiko Feb 13, 2023
7efb462
up
farakiko Feb 13, 2023
a4da8fc
up
farakiko Feb 13, 2023
214226f
up
farakiko Feb 13, 2023
c4a9e79
up
farakiko Feb 13, 2023
47cca38
up
farakiko Feb 13, 2023
53f6eaa
up
farakiko Feb 13, 2023
6128565
Revert to bd2a3e8
farakiko Feb 13, 2023
4d844cc
fix loading of mlpf models
farakiko Feb 14, 2023
feaa5c7
add plots for each component of each loss
farakiko Feb 16, 2023
edc8d31
oops
farakiko Feb 16, 2023
b617b8a
add notebook for ssl optimizations
farakiko Feb 16, 2023
cb8f425
fix quick datasplit mode to make sense
farakiko Feb 17, 2023
736c90c
add vicreg embeddings after gnn in mlpf
farakiko Feb 17, 2023
f0e6c36
for num_convs=0
farakiko Feb 21, 2023
f6ee7ee
up
farakiko Feb 27, 2023
2713762
mdmm
farakiko Feb 27, 2023
db3d42f
up
farakiko Feb 27, 2023
ffd1c6b
Merge branch 'ssl-studies' of https://github.com/farakiko/particleflo…
farakiko Feb 27, 2023
748098e
up
farakiko Feb 27, 2023
d1118fa
add extra dnn layer for native
farakiko Feb 27, 2023
db7e6bc
oops
farakiko Feb 27, 2023
4aef12f
uff
farakiko Feb 27, 2023
a979f47
early results
farakiko Feb 27, 2023
9abde3e
up
farakiko Feb 27, 2023
160f553
early results
farakiko Feb 27, 2023
1d44783
Merge branch 'ssl-studies' of https://github.com/farakiko/particleflo…
farakiko Feb 27, 2023
5f7731f
pooling after decoder
farakiko Feb 27, 2023
a2d581b
up
farakiko Feb 28, 2023
e9e6bbb
update vicreg, loss and multigpu
farakiko Feb 28, 2023
732ce51
up
farakiko Feb 28, 2023
fdf6aae
up
farakiko Feb 28, 2023
a210c66
up
farakiko Feb 28, 2023
1412fe3
up
farakiko Feb 28, 2023
7b31586
test
farakiko Feb 28, 2023
214ea8d
up
farakiko Feb 28, 2023
27424ab
up early stopping
farakiko Feb 28, 2023
b4aab03
better gpu util
farakiko Mar 1, 2023
91458db
up
farakiko Mar 1, 2023
840a204
up
farakiko Mar 1, 2023
ca8efdd
up
farakiko Mar 1, 2023
d6f1a9b
Merge branch 'ssl-studies' of https://github.com/farakiko/particleflo…
farakiko Mar 1, 2023
6e0cabd
up for now
farakiko Mar 1, 2023
17ea9ee
up
farakiko Mar 1, 2023
703634d
up
farakiko Mar 1, 2023
96397bc
Merge branch 'main' into ssl-studies
jpata Mar 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7,679 changes: 7,679 additions & 0 deletions mlpf/mdmm.ipynb

Large diffs are not rendered by default.

62 changes: 59 additions & 3 deletions mlpf/pyg_ssl/VICReg.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,66 @@
import torch.nn as nn
from torch_geometric.data import Batch
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn.conv import GravNetConv

from .utils import CLUSTERS_X, TRACKS_X


# define the Encoder that learns latent representations of tracks and clusters
# these representations will be used by MLPF which is the downstream task
class VICReg(nn.Module):
def __init__(self, encoder, decoder):
super(VICReg, self).__init__()
self.encoder = encoder
self.decoder = decoder

def distinguish_PFelements(self, batch):
"""Takes an event~Batch() and splits it into two Batch() objects representing the tracks/clusters."""

track_id = 1
cluster_id = 2

tracks = Batch(
x=batch.x[batch.x[:, 0] == track_id][:, 1:].float()[
:, :TRACKS_X
], # remove the first input feature which is not needed anymore
ygen=batch.ygen[batch.x[:, 0] == track_id],
ygen_id=batch.ygen_id[batch.x[:, 0] == track_id],
ycand=batch.ycand[batch.x[:, 0] == track_id],
ycand_id=batch.ycand_id[batch.x[:, 0] == track_id],
batch=batch.batch[batch.x[:, 0] == track_id],
)
clusters = Batch(
x=batch.x[batch.x[:, 0] == cluster_id][:, 1:].float()[
:, :CLUSTERS_X
], # remove the first input feature which is not needed anymore
ygen=batch.ygen[batch.x[:, 0] == cluster_id],
ygen_id=batch.ygen_id[batch.x[:, 0] == cluster_id],
ycand=batch.ycand[batch.x[:, 0] == cluster_id],
ycand_id=batch.ycand_id[batch.x[:, 0] == cluster_id],
batch=batch.batch[batch.x[:, 0] == cluster_id],
)
return tracks, clusters

def forward(self, event):

# seperate tracks from clusters
tracks, clusters = self.distinguish_PFelements(event)

# encode to retrieve the representations
track_representations, cluster_representations = self.encoder(tracks, clusters)

# decode/expand to get the embeddings
embedding_tracks, embedding_clusters = self.decoder(track_representations, cluster_representations)

# global pooling to be able to compute a loss
pooled_tracks = global_mean_pool(embedding_tracks, tracks.batch)
pooled_clusters = global_mean_pool(embedding_clusters, clusters.batch)

return pooled_tracks, pooled_clusters


class ENCODER(nn.Module):
"""The Encoder part of VICReg which attempts to learns useful latent representations of tracks and clusters."""

def __init__(
self,
width=126,
Expand Down Expand Up @@ -66,8 +120,10 @@ def forward(self, tracks, clusters):
return embedding_tracks, embedding_clusters


# define the decoder that expands the latent representations of tracks and clusters
class DECODER(nn.Module):
"""The Decoder part of VICReg which attempts to expand the learned latent representations
of tracks and clusters into a space where a loss can be computed."""

def __init__(
self,
input_dim=34,
Expand Down
12 changes: 6 additions & 6 deletions mlpf/pyg_ssl/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ def parse_args():
parser.add_argument("--overwrite", dest="overwrite", action="store_true", help="overwrites the model if True")

# training hyperparameters
parser.add_argument("--lmbd", type=float, default=0.01, help="the lambda term in the VICReg loss")
parser.add_argument("--u", type=float, default=0.1, help="the mu term in the VICReg loss")
parser.add_argument("--v", type=float, default=0.1, help="the nu term in the VICReg loss")
parser.add_argument("--lmbd", type=float, default=1, help="the lambda term in the VICReg loss")
parser.add_argument("--mu", type=float, default=0.1, help="the mu term in the VICReg loss")
parser.add_argument("--nu", type=float, default=1e-9, help="the nu term in the VICReg loss")
parser.add_argument("--n_epochs_mlpf", type=int, default=3, help="number of training epochs for mlpf")
parser.add_argument("--n_epochs_VICReg", type=int, default=3, help="number of training epochs for VICReg")
parser.add_argument("--lr", type=float, default=5e-5, help="learning rate")
parser.add_argument("--batch_size_mlpf", type=int, default=500, help="number of events to process at once")
parser.add_argument("--batch_size_VICReg", type=int, default=2000, help="number of events to process at once")
parser.add_argument("--bs_mlpf", type=int, default=500, help="number of events to process at once")
parser.add_argument("--bs_VICReg", type=int, default=2000, help="number of events to process at once")
parser.add_argument("--patience", type=int, default=50, help="patience before early stopping")
parser.add_argument(
"--FineTune_VICReg", dest="FineTune_VICReg", action="store_true", help="FineTune VICReg during MLPFtraining"
"--FineTune_VICReg", dest="FineTune_VICReg", action="store_true", help="FineTune VICReg during MLPF training"
)

# VICReg encoder architecture
Expand Down
4 changes: 1 addition & 3 deletions mlpf/pyg_ssl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def particle_array_to_awkward(batch_ids, arr_id, arr_p4):
return ret


def evaluate(device, encoder, decoder, mlpf, batch_size_mlpf, mode, outpath, samples):
def evaluate(device, encoder, mlpf, batch_size_mlpf, mode, outpath, samples):
import fastjet
import vector
from jet_utils import build_dummy_array, match_two_jet_collections
Expand All @@ -60,7 +60,6 @@ def evaluate(device, encoder, decoder, mlpf, batch_size_mlpf, mode, outpath, sam

mlpf.eval()
encoder.eval()
decoder.eval()
for sample, data in samples.items():
print(f"Testing the {mode} model on the {sample}")

Expand All @@ -74,7 +73,6 @@ def evaluate(device, encoder, decoder, mlpf, batch_size_mlpf, mode, outpath, sam

mlpf.eval()
encoder.eval()
decoder.eval()

conf_matrix = np.zeros((6, 6))
with torch.no_grad():
Expand Down
204 changes: 123 additions & 81 deletions mlpf/pyg_ssl/mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,49 @@ def forward(self, x, mask):
return x


def ffn(input_dim, output_dim, width, act, dropout):
return nn.Sequential(
nn.Linear(input_dim, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Linear(width, output_dim),
)
def ffn(input_dim, output_dim, width, act, dropout, ssl):
if ssl:
return nn.Sequential(
nn.Linear(input_dim, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Linear(width, output_dim),
)
else:
return nn.Sequential(
nn.Linear(input_dim, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Linear(width, output_dim),
)


class MLPF(nn.Module):
Expand All @@ -80,99 +104,117 @@ def __init__(
width=126,
num_convs=2,
k=32,
native_mlpf=False,
propagate_dimensions=32,
space_dimensions=4,
dropout=0.4,
ssl=False,
VICReg_embedding_dim=0,
):
super(MLPF, self).__init__()

self.act = nn.ELU
self.native_mlpf = native_mlpf # boolean that is true for native mlpf and false for ssl
self.dropout = dropout
self.input_dim = input_dim
self.num_convs = num_convs
self.ssl = ssl # boolean that is True for ssl and False for native mlpf

# embedding of the inputs
self.nn0 = nn.Sequential(
nn.Linear(input_dim, width),
self.act(),
nn.Linear(width, width),
self.act(),
nn.Linear(width, width),
self.act(),
nn.Linear(width, embedding_dim),
)

self.conv_type = "gravnet"
# GNN that uses the embeddings learnt by VICReg as the input features
if self.conv_type == "gravnet":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()
for i in range(num_convs):
self.conv_id.append(GravNetLayer(embedding_dim, space_dimensions, propagate_dimensions, k, dropout))
self.conv_reg.append(GravNetLayer(embedding_dim, space_dimensions, propagate_dimensions, k, dropout))
elif self.conv_type == "attention":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()

for i in range(num_convs):
self.conv_id.append(SelfAttentionLayer(embedding_dim))
self.conv_reg.append(SelfAttentionLayer(embedding_dim))
if num_convs != 0:
self.nn0 = nn.Sequential(
nn.Linear(input_dim, width),
self.act(),
nn.Linear(width, width),
self.act(),
nn.Linear(width, width),
self.act(),
nn.Linear(width, embedding_dim),
)

self.conv_type = "gravnet"
# GNN that uses the embeddings learnt by VICReg as the input features
if self.conv_type == "gravnet":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()
for i in range(num_convs):
self.conv_id.append(GravNetLayer(embedding_dim, space_dimensions, propagate_dimensions, k, dropout))
self.conv_reg.append(GravNetLayer(embedding_dim, space_dimensions, propagate_dimensions, k, dropout))
elif self.conv_type == "attention":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()

for i in range(num_convs):
self.conv_id.append(SelfAttentionLayer(embedding_dim))
self.conv_reg.append(SelfAttentionLayer(embedding_dim))

decoding_dim = input_dim + num_convs * embedding_dim
if ssl:
decoding_dim += VICReg_embedding_dim

# DNN that acts on the node level to predict the PID
self.nn_id = ffn(decoding_dim, NUM_CLASSES, width, self.act, dropout)
self.nn_id = ffn(decoding_dim, NUM_CLASSES, width, self.act, dropout, ssl)

# elementwise DNN for node momentum regression
self.nn_pt = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout)
self.nn_eta = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout)
self.nn_phi = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout)
self.nn_energy = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout)
self.nn_pt = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout, ssl)
self.nn_eta = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout, ssl)
self.nn_phi = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout, ssl)
self.nn_energy = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout, ssl)

# elementwise DNN for node charge regression, classes (-1, 0, 1)
self.nn_charge = ffn(decoding_dim + NUM_CLASSES, 3, width, self.act, dropout)
self.nn_charge = ffn(decoding_dim + NUM_CLASSES, 3, width, self.act, dropout, ssl)

def forward(self, batch):

# unfold the Batch object
input_ = batch.x.float()
batch_idx = batch.batch
if self.ssl:
input_ = batch.x.float()[:, : self.input_dim]
VICReg_embeddings = batch.x.float()[:, self.input_dim :]
else:
input_ = batch.x.float()

embedding = self.nn0(input_)
batch_idx = batch.batch

embeddings_id = []
embeddings_reg = []

if self.conv_type == "gravnet":
# perform a series of graph convolutions
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
embeddings_id.append(conv(conv_input, batch_idx))
for num, conv in enumerate(self.conv_reg):
conv_input = embedding if num == 0 else embeddings_reg[-1]
embeddings_reg.append(conv(conv_input, batch_idx))
elif self.conv_type == "attention":
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
input_padded, mask = torch_geometric.utils.to_dense_batch(conv_input, batch_idx)
out_padded = conv(input_padded, ~mask)
out_stacked = torch.cat([out_padded[i][mask[i]] for i in range(out_padded.shape[0])])
assert out_stacked.shape[0] == conv_input.shape[0]
embeddings_id.append(out_stacked)
for num, conv in enumerate(self.conv_reg):
conv_input = embedding if num == 0 else embeddings_reg[-1]
input_padded, mask = torch_geometric.utils.to_dense_batch(conv_input, batch_idx)
out_padded = conv(input_padded, ~mask)
out_stacked = torch.cat([out_padded[i][mask[i]] for i in range(out_padded.shape[0])])
assert out_stacked.shape[0] == conv_input.shape[0]
embeddings_reg.append(out_stacked)

embedding_id = torch.cat([input_] + embeddings_id, axis=-1)
if self.num_convs != 0:
embedding = self.nn0(input_)

if self.conv_type == "gravnet":
# perform a series of graph convolutions
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
embeddings_id.append(conv(conv_input, batch_idx))
for num, conv in enumerate(self.conv_reg):
conv_input = embedding if num == 0 else embeddings_reg[-1]
embeddings_reg.append(conv(conv_input, batch_idx))
elif self.conv_type == "attention":
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
input_padded, mask = torch_geometric.utils.to_dense_batch(conv_input, batch_idx)
out_padded = conv(input_padded, ~mask)
out_stacked = torch.cat([out_padded[i][mask[i]] for i in range(out_padded.shape[0])])
assert out_stacked.shape[0] == conv_input.shape[0]
embeddings_id.append(out_stacked)
for num, conv in enumerate(self.conv_reg):
conv_input = embedding if num == 0 else embeddings_reg[-1]
input_padded, mask = torch_geometric.utils.to_dense_batch(conv_input, batch_idx)
out_padded = conv(input_padded, ~mask)
out_stacked = torch.cat([out_padded[i][mask[i]] for i in range(out_padded.shape[0])])
assert out_stacked.shape[0] == conv_input.shape[0]
embeddings_reg.append(out_stacked)

if self.ssl:
embedding_id = torch.cat([input_] + embeddings_id + [VICReg_embeddings], axis=-1)
else:
embedding_id = torch.cat([input_] + embeddings_id, axis=-1)

# predict the PIDs
preds_id = self.nn_id(embedding_id)

embedding_reg = torch.cat([input_] + embeddings_reg + [preds_id], axis=-1)
if self.ssl:
embedding_reg = torch.cat([input_] + embeddings_reg + [preds_id] + [VICReg_embeddings], axis=-1)
else:
embedding_reg = torch.cat([input_] + embeddings_reg + [preds_id], axis=-1)

# predict the 4-momentum, add it to the (pt, eta, phi, E) of the PFelement
preds_pt = self.nn_pt(embedding_reg) + input_[:, 1:2]
Expand Down
Loading