Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input opt #346

Open
wants to merge 66 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
5314a07
up
farakiko Sep 20, 2024
e0ea079
up utils.edm.py with updated feature list
farakiko Sep 20, 2024
5096060
tag 2.2.0
farakiko Sep 20, 2024
dd6bad7
up
farakiko Sep 20, 2024
7e1ac71
process whole dir
farakiko Sep 20, 2024
c7d8a08
debug
farakiko Sep 20, 2024
d20de3a
up
farakiko Sep 20, 2024
a20d8e8
up
farakiko Sep 20, 2024
78fbf1a
up
farakiko Sep 20, 2024
1e2f3d1
up
farakiko Sep 20, 2024
f2e4181
remove break
farakiko Sep 20, 2024
df08d50
up
farakiko Sep 20, 2024
e13f323
up logging
farakiko Sep 20, 2024
e9930eb
up
farakiko Sep 20, 2024
34b89a0
up
farakiko Sep 20, 2024
99082ad
up
farakiko Sep 20, 2024
47c00b2
standardize inputs
farakiko Sep 20, 2024
97a3194
up elemtypes_nonzero
farakiko Sep 20, 2024
eb51a50
up
farakiko Sep 20, 2024
9520e20
up
farakiko Sep 20, 2024
1ec4e9b
up
farakiko Sep 23, 2024
963f812
up
farakiko Sep 23, 2024
c3dea3c
up
farakiko Sep 23, 2024
8057f0d
up
farakiko Sep 23, 2024
60994ce
up 26 feats
farakiko Sep 23, 2024
c985323
up
farakiko Sep 23, 2024
b41ba20
add 26 input_dim
farakiko Sep 23, 2024
31bd710
up
farakiko Sep 23, 2024
7381876
up vs 2.2.0 for standardization
farakiko Sep 23, 2024
a7f9a46
more configs
farakiko Sep 23, 2024
74e635f
better docs
farakiko Sep 23, 2024
dd4c43a
up standardization pipeline
farakiko Sep 23, 2024
80bd330
better docs
farakiko Sep 23, 2024
c040399
fix input dim for other datasets
farakiko Sep 23, 2024
bcfb277
pca
farakiko Sep 23, 2024
b2e7c2e
add standardize_inputs: False to all configs
farakiko Sep 23, 2024
0b90158
up
farakiko Sep 23, 2024
a253a4a
debug
farakiko Sep 23, 2024
bb92cb5
up
farakiko Sep 23, 2024
8e94559
revert
farakiko Sep 23, 2024
a852256
debug
farakiko Sep 23, 2024
4641557
up
farakiko Sep 23, 2024
4d56a63
oops
farakiko Sep 23, 2024
8f190b7
fixed
farakiko Sep 23, 2024
57e2924
up
farakiko Sep 23, 2024
cad00fa
logging
farakiko Sep 23, 2024
f91da81
up
farakiko Sep 23, 2024
010eef5
up
farakiko Sep 23, 2024
9156459
check
farakiko Sep 23, 2024
812f05c
remove unnecessary config
farakiko Sep 23, 2024
0968e51
up
farakiko Sep 23, 2024
42ff712
debug
farakiko Sep 23, 2024
fb9e68a
revert
farakiko Sep 23, 2024
181c534
up new config for all samples
farakiko Sep 23, 2024
8293990
oopsie
farakiko Sep 23, 2024
687b5d7
up
farakiko Sep 23, 2024
efdb489
pca
farakiko Sep 23, 2024
df1ecba
up configs
farakiko Sep 24, 2024
d6252c0
up
farakiko Sep 24, 2024
8d3d685
up
farakiko Sep 24, 2024
1aa21a9
up configs
farakiko Sep 24, 2024
832df1c
try new loss
farakiko Sep 24, 2024
e336230
up
farakiko Sep 24, 2024
84957d8
up
farakiko Sep 24, 2024
91c1fac
up
farakiko Sep 24, 2024
cc523e0
fix pca
farakiko Sep 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlpf/heptfds/clic_pf_edm4hep/ttbar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import tensorflow as tf
import tensorflow_datasets as tfds
from utils_edm import (
X_FEATURES_CL,
X_FEATURES_TRK,
Expand All @@ -9,8 +10,6 @@
split_sample,
)

import tensorflow_datasets as tfds

_DESCRIPTION = """
CLIC EDM4HEP dataset with ee -> ttbar at 380GeV.
- X: reconstructed tracks and clusters, variable number N per event
Expand All @@ -26,7 +25,7 @@


class ClicEdmTtbarPf(tfds.core.GeneratorBasedBuilder):
VERSION = tfds.core.Version("2.1.0")
VERSION = tfds.core.Version("2.2.0")
RELEASE_NOTES = {
"1.0.0": "Initial release.",
"1.1.0": "update stats, move to 380 GeV",
Expand All @@ -36,6 +35,7 @@ class ClicEdmTtbarPf(tfds.core.GeneratorBasedBuilder):
"1.5.0": "Regenerate with ARRAY_RECORD",
"2.0.0": "Add ispu, genjets, genmet; disable genjet_idx; truth def not based on gp.status==1",
"2.1.0": "Bump dataset size",
"2.2.0": "Additional cluster input features",
}
MANUAL_DOWNLOAD_INSTRUCTIONS = """
For the raw input files in ROOT EDM4HEP format, please see the citation above.
Expand Down
13 changes: 12 additions & 1 deletion mlpf/heptfds/clic_pf_edm4hep/utils_edm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random

import awkward as ak
import numpy as np
import random

# from fcc/postprocessing.py
X_FEATURES_TRK = [
Expand Down Expand Up @@ -39,6 +40,16 @@
"sigma_x",
"sigma_y",
"sigma_z",
# additional cluster input features
"energyError",
"sigma_energy",
"sigma_x_weighted",
"sigma_y_weighted",
"sigma_z_weighted",
"energy_weighted_width",
"pos_shower_max",
"width_shower_max",
"energy_shower_max",
]

Y_FEATURES = ["PDG", "charge", "pt", "eta", "sin_phi", "cos_phi", "energy", "ispu"]
Expand Down
41 changes: 35 additions & 6 deletions mlpf/pyg/mlpf.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import math

import numpy as np
import torch
import torch.nn as nn

from .gnn_lsh import CombinedGraphLayer

from pyg.logger import _logger
import math
import numpy as np
from torch.nn.attention import SDPBackend, sdpa_kernel

from .gnn_lsh import CombinedGraphLayer


def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
# From https://github.com/rwightman/pytorch-image-models/blob/
Expand Down Expand Up @@ -57,6 +57,32 @@ def norm_cdf(x):
return tensor


def standardize_input(X, elemtypes_nonzero, standardization_dict):

for i, ielem in enumerate(elemtypes_nonzero):

Xfeat_normed_msked = X.clone()

# get mean/std of features of that elem
mean = torch.tensor(standardization_dict[f"PFelement{ielem}"]["mean"]).to(Xfeat_normed_msked.device)
std = torch.tensor(standardization_dict[f"PFelement{ielem}"]["std"]).to(Xfeat_normed_msked.device)

# standardize
Xfeat_normed_msked[..., 1:] = (Xfeat_normed_msked[..., 1:] - mean[..., 1:]) / std[..., 1:]

# msk other elements
msk = Xfeat_normed_msked[..., 0:1] == ielem
Xfeat_normed_msked = Xfeat_normed_msked * msk
Xfeat_normed_msked = torch.nan_to_num(Xfeat_normed_msked, nan=0.0)

if i == 0:
Xfeat_normed = Xfeat_normed_msked
else:
Xfeat_normed += Xfeat_normed_msked

return Xfeat_normed


def get_activation(activation):
if activation == "elu":
act = nn.ELU
Expand Down Expand Up @@ -372,9 +398,12 @@ def __init__(
self.final_norm_reg = torch.nn.LayerNorm(embed_dim)

# @torch.compile
def forward(self, X_features, mask):
def forward(self, X_features, mask, standardization_dict=None):
Xfeat_normed = X_features

if standardization_dict is not None:
Xfeat_normed = standardize_input(X_features, self.elemtypes_nonzero, standardization_dict)

embeddings_id, embeddings_reg = [], []
if self.num_convs != 0:
if self.input_encoding == "joint":
Expand Down
Loading
Loading