Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch backend major update #240

Merged
merged 105 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
6ffe3a3
push
farakiko Oct 16, 2023
367d54c
up for now
farakiko Oct 16, 2023
f807833
UP
farakiko Oct 16, 2023
631ea17
up
farakiko Oct 16, 2023
337a942
up
farakiko Oct 16, 2023
7bc8c4d
up
farakiko Oct 16, 2023
adcd695
up
farakiko Oct 16, 2023
55c102d
up
farakiko Oct 16, 2023
ea56b63
up
farakiko Oct 16, 2023
4b47d78
up
farakiko Oct 16, 2023
de51b48
up
farakiko Oct 16, 2023
37b956a
up
farakiko Oct 16, 2023
52e3dc0
up
farakiko Oct 16, 2023
833fc11
up
farakiko Oct 16, 2023
48ff426
up
farakiko Oct 16, 2023
1b0fc61
up
farakiko Oct 16, 2023
c274aa9
up
farakiko Oct 16, 2023
625fef7
up
farakiko Oct 16, 2023
b5b92a1
up
farakiko Oct 16, 2023
700fc30
up
farakiko Oct 16, 2023
3e8000e
up
farakiko Oct 16, 2023
34dc3fc
up
farakiko Oct 16, 2023
27e0c74
up
farakiko Oct 16, 2023
010ad51
up
farakiko Oct 16, 2023
54d8127
up
farakiko Oct 16, 2023
7fad80b
up
farakiko Oct 16, 2023
0c583a1
up
farakiko Oct 16, 2023
af67bb6
up
farakiko Oct 16, 2023
5d9f8d8
up
farakiko Oct 16, 2023
15ce200
up best configs
farakiko Oct 16, 2023
28bbf7f
up
farakiko Oct 17, 2023
6c1cb77
up
farakiko Oct 17, 2023
41d8db2
up
farakiko Oct 17, 2023
7613613
up
farakiko Oct 17, 2023
fd193ba
fix jetdef
farakiko Oct 17, 2023
590f5e3
up
farakiko Oct 17, 2023
b6174cd
up
farakiko Oct 17, 2023
d76a92f
up
farakiko Oct 17, 2023
6ceb232
fix val loss high values
farakiko Oct 18, 2023
c0362c2
up
farakiko Oct 18, 2023
cb2092a
up
farakiko Oct 18, 2023
3880848
up
farakiko Oct 18, 2023
9c2eb70
up
farakiko Oct 18, 2023
4268671
up
farakiko Oct 18, 2023
717ba8a
up
farakiko Oct 18, 2023
2f199ab
up
farakiko Oct 18, 2023
ab6268a
up
farakiko Oct 18, 2023
53b143c
up
farakiko Oct 18, 2023
6fc89a8
up
farakiko Oct 18, 2023
2007e5d
up
farakiko Oct 18, 2023
6c473c3
up
farakiko Oct 18, 2023
b9c3337
up
farakiko Oct 18, 2023
8f78a15
up
farakiko Oct 18, 2023
72d58be
up
farakiko Oct 18, 2023
30ae37e
up
farakiko Oct 18, 2023
96c2de4
up
farakiko Oct 18, 2023
762a681
up
farakiko Oct 18, 2023
23b34e8
up
farakiko Oct 18, 2023
daaaad7
up
farakiko Oct 18, 2023
14500a0
up
farakiko Oct 18, 2023
48a8f18
up
farakiko Oct 18, 2023
9a1d54c
up
farakiko Oct 18, 2023
13f6666
up
farakiko Oct 18, 2023
a5d4401
up
farakiko Oct 18, 2023
b842016
up
farakiko Oct 18, 2023
e525973
up
farakiko Oct 18, 2023
786aeca
up
farakiko Oct 18, 2023
a1bad80
up
farakiko Oct 18, 2023
9d4563b
up
farakiko Oct 18, 2023
d2787fe
up
farakiko Oct 18, 2023
654738f
up
farakiko Oct 18, 2023
ad582d5
up
farakiko Oct 18, 2023
0cbc991
up
farakiko Oct 18, 2023
ba67008
up
farakiko Oct 18, 2023
c3792fc
up
farakiko Oct 18, 2023
224ff2e
up
farakiko Oct 18, 2023
9a111dc
up
farakiko Oct 18, 2023
d54f523
up
farakiko Oct 18, 2023
1355fb4
up
farakiko Oct 18, 2023
1270929
up
farakiko Oct 18, 2023
0befed7
up
farakiko Oct 18, 2023
8c178ae
up
farakiko Oct 18, 2023
127f6e2
up
farakiko Oct 18, 2023
b3679e6
up
farakiko Oct 18, 2023
1cd4a93
up
farakiko Oct 18, 2023
e26a6dc
up
farakiko Oct 18, 2023
0eb799d
up
farakiko Oct 18, 2023
b2795b4
up
farakiko Oct 18, 2023
cf1f26d
up
farakiko Oct 18, 2023
e6f7e6f
up
farakiko Oct 18, 2023
1317c17
up
farakiko Oct 18, 2023
1300c30
up
farakiko Oct 19, 2023
1605880
up
farakiko Oct 19, 2023
557f47e
up
farakiko Oct 19, 2023
28d1252
up
farakiko Oct 19, 2023
e5296f2
up
farakiko Oct 19, 2023
4fc1f5e
up
farakiko Oct 19, 2023
d018574
up
farakiko Oct 19, 2023
76558d1
up
farakiko Oct 19, 2023
edc0e80
fix loss on cpu bottleneck
farakiko Oct 20, 2023
63d064d
add num-workers and prefetch factors to args
farakiko Oct 20, 2023
2e86f7d
up
farakiko Oct 20, 2023
d0d53e6
up
farakiko Oct 20, 2023
0453aee
up
farakiko Oct 20, 2023
11233d4
up
farakiko Oct 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions mlpf/cuda_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Simple script that tests if CUDA is installed on the number of gpus specefied.

Author: Farouk Mokhtar
"""

import argparse
import logging
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import torch
from pyg.logger import _logger

logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser()


parser.add_argument("--gpus", type=str, default="0", help="to use CPU set to empty string; else e.g., `0,1`")


def main():
args = parser.parse_args()
world_size = len(args.gpus.split(",")) # will be 1 for both cpu ("") and single-gpu ("0")

if args.gpus:
assert (
world_size <= torch.cuda.device_count()
), f"--gpus is too high (specefied {world_size} gpus but only {torch.cuda.device_count()} gpus are available)"

torch.cuda.empty_cache()
if world_size > 1:
_logger.info(f"Will use torch.nn.parallel.DistributedDataParallel() and {world_size} gpus", color="purple")
for rank in range(world_size):
_logger.info(torch.cuda.get_device_name(rank), color="purple")

elif world_size == 1:
rank = 0
_logger.info(f"Will use single-gpu: {torch.cuda.get_device_name(rank)}", color="purple")

else:
rank = "cpu"
_logger.info("Will use cpu", color="purple")


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions mlpf/plotting/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ def get_class_names(dataset_name):
"delphes_ttbar_pf": r"Delphes-CMS $pp \rightarrow \mathrm{t}\overline{\mathrm{t}}$",
"delphes_qcd_pf": r"Delphes-CMS $pp \rightarrow \mathrm{QCD}$",
"cms_pf_qcd": r"CMS QCD+PU events",
"cms_pf_ztt": r"CMS Ztt events",
"cms_pf_multi_particle_gun": r"CMS multi particle gun events",
"cms_pf_single_electron": r"CMS single electron particle gun events",
"cms_pf_single_gamma": r"CMS single photon gun events",
"cms_pf_single_mu": r"CMS single muon particle gun events",
"cms_pf_single_pi": r"CMS single pion particle gun events",
"cms_pf_single_pi0": r"CMS single neutral pion particle gun events",
"cms_pf_single_proton": r"CMS single proton particle gun events",
"cms_pf_single_tau": r"CMS single tau particle gun events",
"cms_pf_sms_t1tttt": r"CMS sms t1tttt events",
}


Expand Down
14 changes: 7 additions & 7 deletions mlpf/pyg/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@ The current pytorch backend shares the same dataset format as the tensorflow bac

# Supervised training or testing

First make sure to update the config yaml `../../parameters/pyg_config.yaml` to your desired model parameter configuration and choice of physics samples for training and testing.
First make sure to update the config yaml e.g. `../../parameters/pyg-cms-test.yaml` to your desired model parameter configuration and choice of physics samples for training and testing.

After that, the entry point to launch training or testing for either CMS, DELPHES or CLIC is the same.
After that, the entry point to launch training or testing for either CMS, DELPHES or CLIC is the same. From the main repo run,

```bash
cd ../
python -u pyg_pipeline.py --dataset=${} --data_dir=${} --model-prefix=${} --gpus=${}
python -u mlpf/pyg_pipeline.py --dataset=${} --data_dir=${} --prefix=${} --gpus=${} --ntrain 10 --nvalid 10 --ntest 10
```
where:
- `--dataset`: choices are `cms` or `delphes` or `clic`
- `--data_dir`: path to the tensorflow_datasets (e.g. `../data/tensorflow_datasets/`)
- `--model-prefix`: path pointing to the model directory that holds the results (e.g. `../experiments/MLPF_test`)
- `--prefix`: path pointing to the model directory (note: a unique hash will be appended to avoid overwrite)
- `--gpus`: to use CPU set to empty string ""; else to use gpus provide e.g. "0,1"
- `ntrain`, `nvalid`, `ntest`: specefies number of events (per sample) that will be used

Adding the arguments:
- `--load` will load a pre-trained model
- `--train` will run a training (may train a loaded model if `--load` is provided)
- `--load` will load a pre-trained model
- `--train` will run a training (may train a loaded model if `--load` is provided)
- `--test` will run inference and save the predictions as `.parquets`
- `--make-plots` will use the predictions stored after running with `--test` to make plots for evaluation
- `--export-onnx` will export the model to ONNX
Expand Down
4 changes: 0 additions & 4 deletions mlpf/pyg/model.py → mlpf/pyg/gnn_lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ def point_wise_feed_forward_network(
activation="ELU",
dropout=0.0,
):

layers = []
layers.append(
nn.Linear(
Expand Down Expand Up @@ -160,7 +159,6 @@ def __init__(self, distance_dim=128, max_num_bins=200, bin_size=128, kernel=Node
)

def forward(self, x_msg, x_node, msk, training=False):

shp = x_msg.shape
n_points = shp[1]

Expand Down Expand Up @@ -230,7 +228,6 @@ def reverse_lsh(bins_split, points_binned_enc):

class CombinedGraphLayer(nn.Module):
def __init__(self, *args, **kwargs):

self.inout_dim = kwargs.pop("inout_dim")
self.max_num_bins = kwargs.pop("max_num_bins")
self.bin_size = kwargs.pop("bin_size")
Expand Down Expand Up @@ -274,7 +271,6 @@ def __init__(self, *args, **kwargs):
self.dropout_layer = torch.nn.Dropout(self.dropout)

def forward(self, x, msk):

n_elems = x.shape[1]
bins_to_pad_to = -torch.floor_divide(-n_elems, self.bin_size)

Expand Down
177 changes: 67 additions & 110 deletions mlpf/pyg/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import os.path as osp
import time
from pathlib import Path

Expand All @@ -24,132 +23,98 @@
)

from .logger import _logger
from .utils import CLASS_NAMES

jetdef = fastjet.JetDefinition(fastjet.ee_genkt_algorithm, 0.7, -1.0)
jet_pt = 5.0
jet_match_dr = 0.1


def particle_array_to_awkward(batch_ids, arr_id, arr_p4):
ret = {
"cls_id": arr_id,
"pt": arr_p4[:, 1],
"eta": arr_p4[:, 2],
"sin_phi": arr_p4[:, 3],
"cos_phi": arr_p4[:, 4],
"energy": arr_p4[:, 5],
}
ret["phi"] = np.arctan2(ret["sin_phi"], ret["cos_phi"])
ret = awkward.from_iter([{k: ret[k][batch_ids == b] for k in ret.keys()} for b in np.unique(batch_ids)])
return ret
from .utils import CLASS_NAMES, unpack_predictions, unpack_target


@torch.no_grad()
def run_predictions(rank, mlpf, loader, sample, outpath):
def run_predictions(rank, model, loader, sample, outpath, jetdef, jet_ptcut=5.0, jet_match_dr=0.1):
"""Runs inference on the given sample and stores the output as .parquet files."""

if not osp.isdir(f"{outpath}/preds/{sample}"):
os.makedirs(f"{outpath}/preds/{sample}")
model.eval()

ti = time.time()
for i, batch in tqdm.tqdm(enumerate(loader), total=len(loader)):
ygen = unpack_target(batch.ygen)
ycand = unpack_target(batch.ycand)
ypred = unpack_predictions(model(batch.to(rank)))

for i, event in tqdm.tqdm(enumerate(loader), total=len(loader)):
event.X = event.X.to(rank)
event.batch = event.batch.to(rank)

# recall target ~ ["PDG", "charge", "pt", "eta", "sin_phi", "cos_phi", "energy", "jet_idx"]
target_ids = event.ygen[:, 0].long()
event.ygen = event.ygen[:, 1:]
for k, v in ypred.items():
ypred[k] = v.detach().cpu()

cand_ids = event.ycand[:, 0].long()
event.ycand = event.ycand[:, 1:]

# make mlpf forward pass
pred_ids_one_hot, pred_momentum, pred_charge = mlpf(event)
pred_ids_one_hot = pred_ids_one_hot.detach().cpu()
pred_momentum = pred_momentum.detach().cpu()
pred_charge = pred_charge.detach().cpu()

pred_ids = torch.argmax(pred_ids_one_hot, axis=-1)
pred_charge = torch.argmax(pred_charge, axis=1, keepdim=True) - 1
pred_p4 = torch.cat([pred_charge, pred_momentum], axis=-1)

batch_ids = event.batch.cpu().numpy()
awkvals = {
"gen": particle_array_to_awkward(batch_ids, target_ids.cpu().numpy(), event.ygen.cpu().numpy()),
"cand": particle_array_to_awkward(batch_ids, cand_ids.cpu().numpy(), event.ycand.cpu().numpy()),
"pred": particle_array_to_awkward(batch_ids, pred_ids.cpu().numpy(), pred_p4.cpu().numpy()),
}
# loop over the batch to disentangle the events
batch_ids = batch.batch.cpu().numpy()

gen_p4, cand_p4, pred_p4 = [], [], []
gen_cls, cand_cls, pred_cls = [], [], []
Xs = []
jets_coll = {}
Xs, p4s = [], {"gen": [], "cand": [], "pred": []}
for _ibatch in np.unique(batch_ids):
msk_batch = batch_ids == _ibatch
msk_gen = (target_ids[msk_batch] != 0).numpy()
msk_cand = (cand_ids[msk_batch] != 0).numpy()
msk_pred = (pred_ids[msk_batch] != 0).numpy()

Xs.append(event.X[msk_batch].cpu().numpy())
Xs.append(batch.X[msk_batch].cpu().numpy())

gen_p4.append(event.ygen[msk_batch, 1:][msk_gen].numpy())
gen_cls.append(target_ids[msk_batch][msk_gen].numpy())
# mask nulls for jet reconstruction
msk = (ygen["cls_id"][msk_batch] != 0).numpy()
p4s["gen"].append(ygen["p4"][msk_batch][msk].numpy())

cand_p4.append(event.ycand[msk_batch, 1:][msk_cand].numpy())
cand_cls.append(cand_ids[msk_batch][msk_cand].numpy())
msk = (ycand["cls_id"][msk_batch] != 0).numpy()
p4s["cand"].append(ycand["p4"][msk_batch][msk].numpy())

pred_p4.append(pred_momentum[msk_batch, :][msk_pred].numpy())
pred_cls.append(pred_ids[msk_batch][msk_pred].numpy())
msk = (ypred["cls_id"][msk_batch] != 0).numpy()
p4s["pred"].append(ypred["p4"][msk_batch][msk].numpy())

Xs = awkward.from_iter(Xs)
gen_p4 = awkward.from_iter(gen_p4)
gen_cls = awkward.from_iter(gen_cls)
gen_p4 = vector.awk(
awkward.zip({"pt": gen_p4[:, :, 0], "eta": gen_p4[:, :, 1], "phi": gen_p4[:, :, 2], "e": gen_p4[:, :, 3]})
)

cand_p4 = awkward.from_iter(cand_p4)
cand_cls = awkward.from_iter(cand_cls)
cand_p4 = vector.awk(
awkward.zip({"pt": cand_p4[:, :, 0], "eta": cand_p4[:, :, 1], "phi": cand_p4[:, :, 2], "e": cand_p4[:, :, 3]})
)
for typ in ["gen", "cand"]:
vec = vector.awk(
awkward.zip(
{
"pt": awkward.from_iter(p4s[typ])[:, :, 0],
"eta": awkward.from_iter(p4s[typ])[:, :, 1],
"phi": awkward.from_iter(p4s[typ])[:, :, 2],
"e": awkward.from_iter(p4s[typ])[:, :, 3],
}
)
)
cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)
jets_coll[typ] = cluster.inclusive_jets(min_pt=jet_ptcut)

# in case of no predicted particles in the batch
if torch.sum(pred_ids != 0) == 0:
pt = build_dummy_array(len(pred_p4), np.float64)
eta = build_dummy_array(len(pred_p4), np.float64)
phi = build_dummy_array(len(pred_p4), np.float64)
pred_cls = build_dummy_array(len(pred_p4), np.float64)
energy = build_dummy_array(len(pred_p4), np.float64)
pred_p4 = vector.awk(awkward.zip({"pt": pt, "eta": eta, "phi": phi, "e": energy}))
if torch.sum(ypred["cls_id"] != 0) == 0:
vec = vector.awk(
awkward.zip(
{
"pt": build_dummy_array(len(p4s["pred"]), np.float64),
"eta": build_dummy_array(len(p4s["pred"]), np.float64),
"phi": build_dummy_array(len(p4s["pred"]), np.float64),
"e": build_dummy_array(len(p4s["pred"]), np.float64),
}
)
)
else:
pred_p4 = awkward.from_iter(pred_p4)
pred_cls = awkward.from_iter(pred_cls)
pred_p4 = vector.awk(
vec = vector.awk(
awkward.zip(
{
"pt": pred_p4[:, :, 0],
"eta": pred_p4[:, :, 1],
"phi": pred_p4[:, :, 2],
"e": pred_p4[:, :, 3],
"pt": awkward.from_iter(p4s["pred"])[:, :, 0],
"eta": awkward.from_iter(p4s["pred"])[:, :, 1],
"phi": awkward.from_iter(p4s["pred"])[:, :, 2],
"e": awkward.from_iter(p4s["pred"])[:, :, 3],
}
)
)

jets_coll = {}

cluster1 = fastjet.ClusterSequence(awkward.Array(gen_p4.to_xyzt()), jetdef)
jets_coll["gen"] = cluster1.inclusive_jets(min_pt=jet_pt)
cluster2 = fastjet.ClusterSequence(awkward.Array(cand_p4.to_xyzt()), jetdef)
jets_coll["cand"] = cluster2.inclusive_jets(min_pt=jet_pt)
cluster3 = fastjet.ClusterSequence(awkward.Array(pred_p4.to_xyzt()), jetdef)
jets_coll["pred"] = cluster3.inclusive_jets(min_pt=jet_pt)
cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)
jets_coll["pred"] = cluster.inclusive_jets(min_pt=jet_ptcut)

gen_to_pred = match_two_jet_collections(jets_coll, "gen", "pred", jet_match_dr)
gen_to_cand = match_two_jet_collections(jets_coll, "gen", "cand", jet_match_dr)

matched_jets = awkward.Array({"gen_to_pred": gen_to_pred, "gen_to_cand": gen_to_cand})

awkvals = {
"gen": awkward.from_iter([{k: ygen[k][batch_ids == b] for k in ygen.keys()} for b in np.unique(batch_ids)]),
"cand": awkward.from_iter([{k: ycand[k][batch_ids == b] for k in ycand.keys()} for b in np.unique(batch_ids)]),
"pred": awkward.from_iter([{k: ypred[k][batch_ids == b] for k in ypred.keys()} for b in np.unique(batch_ids)]),
}

awkward.to_parquet(
awkward.Array(
{
Expand All @@ -163,9 +128,6 @@ def run_predictions(rank, mlpf, loader, sample, outpath):
)
_logger.info(f"Saved predictions at {outpath}/preds/{sample}/pred_{rank}_{i}.parquet")

if i == 100:
break

_logger.info(f"Time taken to make predictions on device {rank} is: {((time.time() - ti) / 60):.2f} min")


Expand All @@ -174,25 +136,20 @@ def make_plots(outpath, sample, dataset):

mplhep.set_style(mplhep.styles.CMS)

class_names = CLASS_NAMES[dataset]

_title = format_dataset_name(sample) # use the dataset names from the common nomenclature

if not os.path.isdir(f"{outpath}/plots/"):
os.makedirs(f"{outpath}/plots/")
os.system(f"mkdir -p {outpath}/plots/{sample}")

plots_path = Path(f"{outpath}/plots/")
plots_path = Path(f"{outpath}/plots/{sample}/")
pred_path = Path(f"{outpath}/preds/{sample}/")

yvals, X, _ = load_eval_data(str(pred_path / "*.parquet"), -1)

plot_num_elements(X, cp_dir=plots_path, title=_title)
plot_sum_energy(yvals, class_names, cp_dir=plots_path, title=_title)
plot_num_elements(X, cp_dir=plots_path, title=format_dataset_name(sample))
plot_sum_energy(yvals, CLASS_NAMES[dataset], cp_dir=plots_path, title=format_dataset_name(sample))

plot_jet_ratio(yvals, cp_dir=plots_path, title=_title)
plot_jet_ratio(yvals, cp_dir=plots_path, title=format_dataset_name(sample))

met_data = compute_met_and_ratio(yvals)
plot_met(met_data, cp_dir=plots_path, title=_title)
plot_met_ratio(met_data, cp_dir=plots_path, title=_title)
plot_met(met_data, cp_dir=plots_path, title=format_dataset_name(sample))
plot_met_ratio(met_data, cp_dir=plots_path, title=format_dataset_name(sample))

plot_particles(yvals, cp_dir=plots_path, title=_title)
plot_particles(yvals, cp_dir=plots_path, title=format_dataset_name(sample))
Loading
Loading