Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement GNN-LSH model in torch #211

Merged
merged 17 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 45 additions & 18 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,25 @@ jobs:
python-version: '3.10'
cache: 'pip'
- run: pip install -r requirements.txt
- run: pip3 install torch==1.13.0 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-1.13.0+cpu.html
- run: pip3 install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.1+cpu.html

tf-clic-pipeline:
tf-unittest:
runs-on: ubuntu-22.04
needs: [deps]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'pip'
- run: pip install -r requirements.txt
- run: ./scripts/local_test_cms_pipeline.sh
- run: PYTHONPATH=. python3 -m unittest tests/test_tf.py

tf-clic-pipeline:
runs-on: ubuntu-22.04
needs: [tf-unittest]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand All @@ -45,7 +58,7 @@ jobs:

tf-clic-hits-pipeline:
runs-on: ubuntu-22.04
needs: [deps]
needs: [tf-unittest]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand All @@ -57,7 +70,7 @@ jobs:

tf-delphes-pipeline:
runs-on: ubuntu-22.04
needs: [deps]
needs: [tf-unittest]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand All @@ -69,7 +82,7 @@ jobs:

tf-cms-pipeline:
runs-on: ubuntu-22.04
needs: [deps]
needs: [tf-unittest]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand All @@ -79,7 +92,7 @@ jobs:
- run: pip install -r requirements.txt
- run: ./scripts/local_test_cms_pipeline.sh

pyg-cms-pipeline:
pyg-unittests:
runs-on: ubuntu-22.04
needs: [deps-pyg]
steps:
Expand All @@ -89,48 +102,62 @@ jobs:
python-version: '3.10'
cache: 'pip'
- run: pip install -r requirements.txt
- run: pip3 install torch==1.13.0 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-1.13.0+cpu.html
- run: pip3 install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.1+cpu.html
- run: PYTHONPATH=. python3 -m unittest tests/test_torch_and_tf.py

pyg-cms-pipeline:
runs-on: ubuntu-22.04
needs: [pyg-unittests]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'pip'
- run: pip install -r requirements.txt
- run: pip3 install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.1+cpu.html
- run: ./scripts/local_test_pyg_cms.sh

pyg-delphes-pipeline:
runs-on: ubuntu-22.04
needs: [deps-pyg]
needs: [pyg-unittests]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'pip'
- run: pip install -r requirements.txt
- run: pip3 install torch==1.13.0 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-1.13.0+cpu.html
- run: pip3 install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.1+cpu.html
- run: ./scripts/local_test_pyg_delphes.sh

pyg-clic-pipeline:
runs-on: ubuntu-22.04
needs: [deps-pyg]
needs: [pyg-unittests]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'pip'
- run: pip install -r requirements.txt
- run: pip3 install torch==1.13.0 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-1.13.0+cpu.html
- run: pip3 install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.1+cpu.html
- run: ./scripts/local_test_pyg_clic.sh

pyg-ssl-pipeline:
runs-on: ubuntu-22.04
needs: [deps-pyg]
needs: [pyg-unittests]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'pip'
- run: pip install -r requirements.txt
- run: pip3 install torch==1.13.0 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-1.13.0+cpu.html
- run: pip3 install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- run: pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.1+cpu.html
- run: ./scripts/local_test_pyg_ssl.sh
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ data/*
experiments/*
prp/*
*.pth
test/__pycache__/

*/__pycache__/*
.DS_Store
Expand Down
2 changes: 1 addition & 1 deletion mlpf/pyg/PFGraphDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def processed_file_names(self):
proc_list = glob(osp.join(self.processed_dir, "*.pt"))
return sorted([processed_path.replace(self.processed_dir, ".") for processed_path in proc_list])

def __len__(self):
def len(self):
return len(self.processed_file_names)

def download(self):
Expand Down
4 changes: 2 additions & 2 deletions mlpf/pyg/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def parse_args():
parser.add_argument("--prefix", type=str, default="MLPF_model", help="directory to hold the model and all plots")

# for loading the data
parser.add_argument("--dataset", type=str, required=True, help="CMS or DELPHES?")
parser.add_argument("--dataset", type=str, required=True, help="CLIC, CMS or DELPHES")
parser.add_argument("--data_path", type=str, default="../data/", help="path which contains the samples")
parser.add_argument("--sample", type=str, default="QCD", help="sample to test on")
parser.add_argument("--n_train", type=int, default=2, help="number of files to use for training")
Expand All @@ -31,7 +31,7 @@ def parse_args():
parser.add_argument("--width", type=int, default=256, help="hidden dimension of mlpf")
parser.add_argument("--embedding_dim", type=int, default=256, help="first embedding of mlpf")
parser.add_argument("--num_convs", type=int, default=3, help="number of graph layers for mlpf")
parser.add_argument("--dropout", type=float, default=0.4, help="dropout for MLPF model")
parser.add_argument("--dropout", type=float, default=0.0, help="dropout for MLPF model")
parser.add_argument("--space_dim", type=int, default=4, help="Gravnet hyperparameter")
parser.add_argument("--propagate_dim", type=int, default=22, help="Gravnet hyperparameter")
parser.add_argument("--nearest", type=int, default=32, help="k nearest neighbors in gravnet layer")
Expand Down
101 changes: 38 additions & 63 deletions mlpf/pyg/mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import torch.nn as nn
import torch_geometric
import torch_geometric.utils
from torch_geometric.nn.conv import GravNetConv # also returns edge index
from torch_geometric.nn.conv import GravNetConv

# from pyg_ssl.gravnet import GravNetConv # also returns edge index
# from pyg_ssl.gravnet import GravNetConv # this version also returns edge index

from mlpf.pyg.model import CombinedGraphLayer


class GravNetLayer(nn.Module):
Expand Down Expand Up @@ -46,49 +48,14 @@ def forward(self, x, mask):
return x


def ffn(input_dim, output_dim, width, act, dropout, ssl):
if ssl:
return nn.Sequential(
nn.Linear(input_dim, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Linear(width, output_dim),
)
else:
return nn.Sequential(
nn.Linear(input_dim, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, width),
act(),
torch.nn.LayerNorm(width),
nn.Linear(width, output_dim),
)
def ffn(input_dim, output_dim, width, act, dropout):
return nn.Sequential(
nn.Linear(input_dim, width),
act(),
torch.nn.LayerNorm(width),
nn.Dropout(dropout),
nn.Linear(width, output_dim),
)


class MLPF(nn.Module):
Expand All @@ -97,7 +64,7 @@ def __init__(
input_dim=34,
NUM_CLASSES=8,
embedding_dim=128,
width=126,
width=128,
num_convs=2,
k=32,
propagate_dimensions=32,
Expand All @@ -116,15 +83,7 @@ def __init__(

# embedding of the inputs
if num_convs != 0:
self.nn0 = nn.Sequential(
nn.Linear(input_dim, width),
self.act(),
nn.Linear(width, width),
self.act(),
nn.Linear(width, width),
self.act(),
nn.Linear(width, embedding_dim),
)
self.nn0 = ffn(input_dim, embedding_dim, width, self.act, dropout)

self.conv_type = "gravnet"
# GNN that uses the embeddings learnt by VICReg as the input features
Expand All @@ -137,26 +96,42 @@ def __init__(
elif self.conv_type == "attention":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()

for i in range(num_convs):
self.conv_id.append(SelfAttentionLayer(embedding_dim))
self.conv_reg.append(SelfAttentionLayer(embedding_dim))
elif self.conv_type == "gnn-lsh":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()

for i in range(num_convs):
gnn_conf = {
"inout_dim": embedding_dim,
"bin_size": 256,
"max_num_bins": 200,
"distance_dim": 128,
"layernorm": True,
"num_node_messages": 2,
"dropout": 0.0,
"ffn_dist_hidden_dim": 64,
}
self.conv_id.append(CombinedGraphLayer(**gnn_conf))
self.conv_reg.append(CombinedGraphLayer(**gnn_conf))

decoding_dim = input_dim + num_convs * embedding_dim
if ssl:
decoding_dim += VICReg_embedding_dim

# DNN that acts on the node level to predict the PID
self.nn_id = ffn(decoding_dim, NUM_CLASSES, width, self.act, dropout, ssl)
self.nn_id = ffn(decoding_dim, NUM_CLASSES, width, self.act, dropout)

# elementwise DNN for node momentum regression
self.nn_pt = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout, ssl)
self.nn_eta = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout, ssl)
self.nn_phi = ffn(decoding_dim + NUM_CLASSES, 2, width, self.act, dropout, ssl)
self.nn_energy = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout, ssl)
self.nn_pt = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout)
self.nn_eta = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout)
self.nn_phi = ffn(decoding_dim + NUM_CLASSES, 2, width, self.act, dropout)
self.nn_energy = ffn(decoding_dim + NUM_CLASSES, 1, width, self.act, dropout)

# elementwise DNN for node charge regression, classes (-1, 0, 1)
self.nn_charge = ffn(decoding_dim + NUM_CLASSES, 3, width, self.act, dropout, ssl)
self.nn_charge = ffn(decoding_dim + NUM_CLASSES, 3, width, self.act, dropout)

def forward(self, batch):

Expand All @@ -183,7 +158,7 @@ def forward(self, batch):
for num, conv in enumerate(self.conv_reg):
conv_input = embedding if num == 0 else embeddings_reg[-1]
embeddings_reg.append(conv(conv_input, batch_idx))
elif self.conv_type == "attention":
else:
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
input_padded, mask = torch_geometric.utils.to_dense_batch(conv_input, batch_idx)
Expand Down
Loading