Skip to content

Commit

Permalink
Merge pull request #5 from stsan9/rel
Browse files Browse the repository at this point in the history
Rel
  • Loading branch information
stsan9 authored Sep 26, 2021
2 parents 7ce7028 + 7b3b6c3 commit 2523ce0
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 154 deletions.
60 changes: 23 additions & 37 deletions graph_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,19 @@
import logging

from torch_geometric.data import Dataset, Data
from process_util import jet_particles
from process_util import jet_particles, normalize
from natsort import natsorted
from sys import exit

ONE_HUNDRED_GEV = 100.0

class GraphDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None, n_jets=1000,
n_events_merge=100, n_events=1000, lhco=False, lhco_back=False):
n_events_merge=100, n_events=1000, lhco=False, lhco_back=False, R=1.0):
self.n_jets = n_jets
self.n_events_merge = n_events_merge
self.n_events = n_events
self.lhco = lhco
self.lhco_back = lhco_back
self.R = R
super(GraphDataset, self).__init__(root, transform, pre_transform)


Expand All @@ -49,33 +48,21 @@ def download(self):

def process(self):
Js = []
R = 0.4
for raw_path in self.raw_paths:
# load jet-particles dataset
if self.lhco or self.lhco_back:
print("Loading LHCO Dataset")
X = jet_particles(raw_path, self.n_events, self.lhco_back)
X = jet_particles(raw_path, self.n_events, self.lhco_back, self.R)
else:
print("Loading QG Dataset")
X, _ = ef.qg_jets.load(self.n_jets, pad=False, cache_dir=self.root+'/raw')

# clean and store list of jets as particles (pt, eta, phi)
# store list of jets as particles (pt_rel, eta_rel, phi_rel)
Js = []
jet_ctr = 0
for x in X:
if not (self.lhco or self.lhco_back):
# ignore padded particles and removed particle id information
x = x[x[:,0] > 0,:3]
# center jet according to pt-centroid
yphi_avg = np.average(x[:,1:3], weights=x[:,0], axis=0)
x[:,1:3] -= yphi_avg
# mask out any particles farther than R=0.4 away from center (rare)
x = x[np.linalg.norm(x[:,1:3], axis=1) <= R]
# add to list
if len(x) == 0: continue
for jet_ctr, x in enumerate(X):
x = normalize(x) # pt_rel, phi_rel, eta_rel
Js.append(x)
# stop when n_jets stored
jet_ctr += 1
if jet_ctr == self.n_jets: break

# calc emd between all jet pairs and save datum
Expand All @@ -84,33 +71,29 @@ def process(self):
for k, (i, j) in enumerate(jetpairs):
if k % (len(jetpairs) // 20) == 0:
print(f'Generated: {k}/{len(jetpairs)}')
emdval, G = ef.emd.emd(Js[i], Js[j], R=R, return_flow=True)
emdval = emdval/ONE_HUNDRED_GEV
G = G/ONE_HUNDRED_GEV
Ei = np.sum(Js[i][:,0])
Ej = np.sum(Js[j][:,0])
jiNorm = np.zeros((Js[i].shape[0],Js[i].shape[1]+1)) # add a field
jjNorm = np.zeros((Js[j].shape[0],Js[j].shape[1]+1)) # add a field
jiNorm[:,:3] = Js[i].copy()
jjNorm[:,:3] = Js[j].copy()
jiNorm[:,0] = jiNorm[:,0]/Ei
jjNorm[:,0] = jjNorm[:,0]/Ej
jiNorm[:,3] = -1*np.ones((Js[i].shape[0]))
jjNorm[:,3] = np.ones((Js[j].shape[0]))
jetpair = np.concatenate([jiNorm, jjNorm], axis=0)
emdval, G = ef.emd.emd(Js[i], Js[j], R=self.R, return_flow=True)

# differentiate 2 jets by column of 1 vs -1
ji = np.zeros((Js[i].shape[0],Js[i].shape[1]+1))
jj = np.zeros((Js[j].shape[0],Js[j].shape[1]+1))
ji[:,:3] = Js[i].copy()
jj[:,:3] = Js[j].copy()
ji[:,3] = -1*np.ones((Js[i].shape[0]))
jj[:,3] = np.ones((Js[j].shape[0]))
jetpair = np.concatenate([ji, jj], axis=0)

nparticles_i = len(Js[i])
nparticles_j = len(Js[j])
pairs = [[m, n] for (m, n) in itertools.product(range(0,nparticles_i),range(nparticles_i,nparticles_i+nparticles_j))]
edge_index = torch.tensor(pairs, dtype=torch.long)
edge_index = edge_index.t().contiguous()
u = torch.tensor([[Ei/ONE_HUNDRED_GEV, Ej/ONE_HUNDRED_GEV]], dtype=torch.float)
edge_y = torch.tensor([[G[m,n-nparticles_i] for m, n in pairs]], dtype=torch.float)
edge_y = edge_y.t().contiguous()

x = torch.tensor(jetpair, dtype=torch.float)
y = torch.tensor([[emdval]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index, y=y, u=u, edge_y=edge_y)
data = Data(x=x, edge_index=edge_index, y=y, edge_y=edge_y)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
Expand All @@ -121,6 +104,9 @@ def process(self):
torch.save(datas, osp.join(self.processed_dir, 'data_{}.pt'.format(k)))
datas=[]

def len(self):
return len(self.processed_file_names)

def get(self, idx):
data = torch.load(osp.join(self.processed_dir, self.processed_file_names[idx]))
return data
Expand All @@ -145,4 +131,4 @@ def get(self, idx):

gdata = GraphDataset(root=args.input_dir, n_jets=args.n_jets, n_events_merge=args.n_events_merge, lhco=args.lhco, lhco_back=args.lhco_back)

print("Done")
print("Done")
2 changes: 2 additions & 0 deletions loss_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class LossFunction:
def __init__(self, lossname, lam1=1, lam2=100):
if lossname == 'mse':
loss = nn.MSELoss(reduction='mean')
elif lossname == 'huber':
loss = nn.HuberLoss(reduction='mean', delta=1.0)
else:
loss = getattr(self, lossname)
self.loss_ftn = loss
Expand Down
93 changes: 66 additions & 27 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch_scatter import scatter_mean

class EdgeNet(nn.Module):
def __init__(self, input_dim=4, big_dim=32, bigger_dim=128, global_dim=2, output_dim=1, aggr='mean'):
def __init__(self, input_dim=4, big_dim=32, bigger_dim=128, output_dim=1, aggr='mean'):
super(EdgeNet, self).__init__()
convnn = nn.Sequential(nn.Linear(2*(input_dim), big_dim),
nn.ReLU(),
Expand All @@ -18,9 +18,7 @@ def __init__(self, input_dim=4, big_dim=32, bigger_dim=128, global_dim=2, output

self.batchnorm = nn.BatchNorm1d(input_dim)

self.batchnormglobal = nn.BatchNorm1d(global_dim)

self.outnn = nn.Sequential(nn.Linear(big_dim+global_dim, bigger_dim),
self.outnn = nn.Sequential(nn.Linear(big_dim, bigger_dim),
nn.ReLU(),
nn.Linear(bigger_dim, bigger_dim),
nn.ReLU(),
Expand All @@ -32,13 +30,11 @@ def __init__(self, input_dim=4, big_dim=32, bigger_dim=128, global_dim=2, output
def forward(self, data):
x = self.batchnorm(data.x)
x = self.conv(x,data.edge_index)
u1 = self.batchnormglobal(data.u)
u2 = scatter_mean(x, data.batch, dim=0)
data.u = torch.cat([u1, u2],dim=-1)
return self.outnn(data.u)
u = scatter_mean(x, data.batch, dim=0)
return self.outnn(u)

class DynamicEdgeNet(nn.Module):
def __init__(self, input_dim=4, big_dim=128, bigger_dim=256, global_dim=2, output_dim=1, k=16, aggr='mean'):
def __init__(self, input_dim=4, big_dim=128, bigger_dim=256, output_dim=1, k=16, aggr='mean'):
super(DynamicEdgeNet, self).__init__()
convnn = nn.Sequential(nn.Linear(2*(input_dim), big_dim),
nn.ReLU(),
Expand All @@ -49,9 +45,7 @@ def __init__(self, input_dim=4, big_dim=128, bigger_dim=256, global_dim=2, outpu

self.batchnorm = nn.BatchNorm1d(input_dim)

self.batchnormglobal = nn.BatchNorm1d(global_dim)

self.outnn = nn.Sequential(nn.Linear(big_dim+global_dim, bigger_dim),
self.outnn = nn.Sequential(nn.Linear(big_dim, bigger_dim),
nn.ReLU(),
nn.Linear(bigger_dim, bigger_dim),
nn.ReLU(),
Expand All @@ -63,13 +57,11 @@ def __init__(self, input_dim=4, big_dim=128, bigger_dim=256, global_dim=2, outpu
def forward(self, data):
x = self.batchnorm(data.x)
x = self.conv(x, data.batch)
u1 = self.batchnormglobal(data.u)
u2 = scatter_mean(x, data.batch, dim=0)
data.u = torch.cat([u1, u2],dim=-1)
return self.outnn(data.u)
u = scatter_mean(x, data.batch, dim=0)
return self.outnn(u)

class DeeperDynamicEdgeNet(nn.Module):
def __init__(self, input_dim=4, big_dim=32, bigger_dim=256, global_dim=2, output_dim=1, k=16, aggr='mean'):
def __init__(self, input_dim=4, big_dim=32, bigger_dim=256, output_dim=1, k=16, aggr='mean'):
super(DeeperDynamicEdgeNet, self).__init__()
convnn = nn.Sequential(nn.Linear(2*(input_dim), big_dim),
nn.BatchNorm1d(big_dim),
Expand Down Expand Up @@ -97,8 +89,7 @@ def __init__(self, input_dim=4, big_dim=32, bigger_dim=256, global_dim=2, output
)

self.batchnorm = nn.BatchNorm1d(input_dim)
self.batchnormglobal = nn.BatchNorm1d(global_dim)
self.outnn = nn.Sequential(nn.Linear(big_dim*4+input_dim+global_dim, bigger_dim),
self.outnn = nn.Sequential(nn.Linear(big_dim*4+input_dim, bigger_dim),
nn.BatchNorm1d(bigger_dim),
nn.ReLU(),
nn.Linear(bigger_dim, bigger_dim),
Expand All @@ -119,14 +110,12 @@ def forward(self, data):
x = torch.cat([x1, x2],dim=-1)
x2 = self.conv3(x, data.batch)
x = torch.cat([x1, x2],dim=-1)
u1 = self.batchnormglobal(data.u)
u2 = scatter_mean(x, data.batch, dim=0)
data.u = torch.cat([u1, u2],dim=-1)
return self.outnn(data.u)
u = scatter_mean(x, data.batch, dim=0)
return self.outnn(u)


class DeeperDynamicEdgeNetPredictFlow(nn.Module):
def __init__(self, input_dim=4, big_dim=32, bigger_dim=256, global_dim=2, output_dim=1, k=16, aggr='mean'):
def __init__(self, input_dim=4, big_dim=32, bigger_dim=256, output_dim=1, k=16, aggr='mean'):
super(DeeperDynamicEdgeNetPredictFlow, self).__init__()
convnn = nn.Sequential(nn.Linear(2*(input_dim), big_dim),
nn.BatchNorm1d(big_dim),
Expand Down Expand Up @@ -186,9 +175,9 @@ def forward(self, data):


class SymmetricDDEdgeNet(nn.Module):
def __init__(self, input_dim=4, big_dim=32, bigger_dim=256, global_dim=2, output_dim=1, k=16, aggr='mean'):
def __init__(self, input_dim=4, big_dim=32, bigger_dim=256, output_dim=1, k=16, aggr='mean'):
super(SymmetricDDEdgeNet, self).__init__()
self.EdgeNet = DeeperDynamicEdgeNet(input_dim, big_dim, bigger_dim, global_dim, output_dim, k, aggr)
self.EdgeNet = DeeperDynamicEdgeNet(input_dim, big_dim, bigger_dim, output_dim, k, aggr)

def forward(self, data):
# dual copies with different orderings
Expand All @@ -199,4 +188,54 @@ def forward(self, data):
emd_1 = self.EdgeNet(data_1)
emd_2 = self.EdgeNet(data_2)
loss = (emd_1 + emd_2) / 2
return loss, emd_1, emd_2
return loss, emd_1, emd_2

class SymmetricDDEdgeNetSqr(nn.Module):
def __init__(self, input_dim=4, big_dim=32, bigger_dim=256, output_dim=1, k=16, aggr='mean'):
super(SymmetricDDEdgeNetSqr, self).__init__()
self.EdgeNet = DeeperDynamicEdgeNet(input_dim, big_dim, bigger_dim, output_dim, k, aggr)

def forward(self, data):
# dual copies with different orderings
data_1 = data
data_2 = copy.deepcopy(data)
data_2.x[:,-1] *= -1

emd_1 = torch.square(self.EdgeNet(data_1))
emd_2 = torch.square(self.EdgeNet(data_2))
loss = (emd_1 + emd_2) / 2
return loss, emd_1, emd_2

class SymmetricDDEdgeNetSpl(nn.Module):
def __init__(self, input_dim=4, big_dim=32, bigger_dim=256, output_dim=1, k=16, aggr='mean'):
super(SymmetricDDEdgeNetSpl, self).__init__()
self.EdgeNet = DeeperDynamicEdgeNet(input_dim, big_dim, bigger_dim, output_dim, k, aggr)

def forward(self, data):
# dual copies with different orderings
data_1 = data
data_2 = copy.deepcopy(data)
data_2.x[:,-1] *= -1

spl = nn.Softplus()
emd_1 = spl(self.EdgeNet(data_1))
emd_2 = spl(self.EdgeNet(data_2))
loss = (emd_1 + emd_2) / 2
return loss, emd_1, emd_2

class SymmetricDDEdgeNetRel(nn.Module):
def __init__(self, input_dim=4, big_dim=32, bigger_dim=256, output_dim=1, k=16, aggr='mean'):
super(SymmetricDDEdgeNetRel, self).__init__()
self.EdgeNet = DeeperDynamicEdgeNet(input_dim, big_dim, bigger_dim, output_dim, k, aggr)

def forward(self, data):
# dual copies with different orderings
data_1 = data
data_2 = copy.deepcopy(data)
data_2.x[:,-1] *= -1

rel = nn.ReLU()
emd_1 = rel(self.EdgeNet(data_1))
emd_2 = rel(self.EdgeNet(data_2))
loss = (emd_1 + emd_2) / 2
return loss, emd_1, emd_2
Loading

0 comments on commit 2523ce0

Please sign in to comment.