Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speedup of the pytorch GNN-LSH model #245

Merged
merged 6 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions mlpf/pyg/gnn_lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def point_wise_feed_forward_network(
return nn.Sequential(*layers)


# @torch.compile
def index_dim(a, b):
return a[b]

Expand Down Expand Up @@ -142,6 +143,27 @@ def forward(self, x_msg_binned, msk, training=False):
return dm


@torch.compile
def split_msk_and_msg(bins_split, cmul, x_msg, x_node, msk, n_bins, bin_size):
bins_split_2 = torch.reshape(bins_split, (bins_split.shape[0], bins_split.shape[1] * bins_split.shape[2]))

bins_split_3 = torch.unsqueeze(bins_split_2, axis=-1).expand(
bins_split_2.shape[0], bins_split_2.shape[1], x_msg.shape[-1]
)
x_msg_binned = torch.gather(x_msg, 1, bins_split_3)
x_msg_binned = torch.reshape(x_msg_binned, (cmul.shape[0], n_bins, bin_size, x_msg_binned.shape[-1]))

bins_split_3 = torch.unsqueeze(bins_split_2, axis=-1).expand(
bins_split_2.shape[0], bins_split_2.shape[1], x_node.shape[-1]
)
x_features_binned = torch.gather(x_node, 1, bins_split_3)
x_features_binned = torch.reshape(x_features_binned, (cmul.shape[0], n_bins, bin_size, x_features_binned.shape[-1]))

msk_f_binned = torch.gather(msk, 1, bins_split_2)
msk_f_binned = torch.reshape(msk_f_binned, (cmul.shape[0], n_bins, bin_size, 1))
return x_msg_binned, x_features_binned, msk_f_binned


class MessageBuildingLayerLSH(nn.Module):
def __init__(self, distance_dim=128, max_num_bins=200, bin_size=128, kernel=NodePairGaussianKernel(), **kwargs):
self.initializer = kwargs.pop("initializer", "random_normal")
Expand Down Expand Up @@ -179,15 +201,9 @@ def forward(self, x_msg, x_node, msk, training=False):
bins_split = split_indices_to_bins_batch(cmul, n_bins, self.bin_size, msk)

# replaced tf.gather with torch.vmap, indexing and reshape
bins_split_2 = torch.reshape(bins_split, (bins_split.shape[0], bins_split.shape[1] * bins_split.shape[2]))
x_msg_binned = torch.vmap(index_dim)(x_msg, bins_split_2)
x_features_binned = torch.vmap(index_dim)(x_node, bins_split_2)
msk_f_binned = torch.vmap(index_dim)(msk, bins_split_2)
x_msg_binned = torch.reshape(x_msg_binned, (cmul.shape[0], n_bins, self.bin_size, x_msg_binned.shape[-1]))
x_features_binned = torch.reshape(
x_features_binned, (cmul.shape[0], n_bins, self.bin_size, x_features_binned.shape[-1])
x_msg_binned, x_features_binned, msk_f_binned = split_msk_and_msg(
bins_split, cmul, x_msg, x_node, msk, n_bins, self.bin_size
)
msk_f_binned = torch.reshape(msk_f_binned, (cmul.shape[0], n_bins, self.bin_size, 1))
else:
x_msg_binned = torch.unsqueeze(x_msg, axis=1)
x_features_binned = torch.unsqueeze(x_node, axis=1)
Expand All @@ -211,6 +227,7 @@ def forward(self, x_msg, x_node, msk, training=False):
return bins_split, x_features_binned, dm, msk_f_binned


@torch.compile
def reverse_lsh(bins_split, points_binned_enc):
shp = points_binned_enc.shape
batch_dim = shp[0]
Expand All @@ -220,8 +237,10 @@ def reverse_lsh(bins_split, points_binned_enc):
bins_split_flat = torch.reshape(bins_split, (batch_dim, n_points))
points_binned_enc_flat = torch.reshape(points_binned_enc, (batch_dim, n_points, n_features))

ret = torch.zeros(batch_dim, n_points, n_features).to(device=points_binned_enc.device)
ret = torch.zeros(batch_dim, n_points, n_features, device=points_binned_enc.device)
for ibatch in range(batch_dim):
torch._assert(torch.min(bins_split_flat[ibatch]) >= 0, "reverse_lsh n_points min")
torch._assert(torch.max(bins_split_flat[ibatch]) < n_points, "reverse_lsh n_points max")
ret[ibatch][bins_split_flat[ibatch]] = points_binned_enc_flat[ibatch]
return ret

Expand Down
37 changes: 24 additions & 13 deletions mlpf/pyg/mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def ffn(input_dim, output_dim, width, act, dropout):
)


@torch.compile
def unpad(data_padded, mask):
A = data_padded[mask]
return A


class MLPF(nn.Module):
def __init__(
self,
Expand All @@ -79,7 +85,7 @@ def __init__(
# embedding of the inputs
if num_convs != 0:
self.nn0 = ffn(input_dim, embedding_dim, width, self.act, dropout)

self.bin_size = 640
if self.conv_type == "gravnet":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()
Expand All @@ -95,11 +101,10 @@ def __init__(
elif self.conv_type == "gnn-lsh":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()

for i in range(num_convs):
gnn_conf = {
"inout_dim": embedding_dim,
"bin_size": 640,
"bin_size": self.bin_size,
"max_num_bins": 200,
"distance_dim": 128,
"layernorm": True,
Expand Down Expand Up @@ -133,6 +138,14 @@ def forward(self, event):
if self.num_convs != 0:
embedding = self.nn0(input_)

if self.conv_type != "gravnet":
_, num_nodes = torch.unique(batch_idx, return_counts=True)
max_num_nodes = torch.max(num_nodes).cpu()
max_num_nodes_padded = ((max_num_nodes // self.bin_size) + 1) * self.bin_size
embedding, mask = torch_geometric.utils.to_dense_batch(
embedding, batch_idx, max_num_nodes=max_num_nodes_padded
)

if self.conv_type == "gravnet":
# perform a series of graph convolutions
for num, conv in enumerate(self.conv_id):
Expand All @@ -144,18 +157,16 @@ def forward(self, event):
else:
for num, conv in enumerate(self.conv_id):
conv_input = embedding if num == 0 else embeddings_id[-1]
input_padded, mask = torch_geometric.utils.to_dense_batch(conv_input, batch_idx)
out_padded = conv(input_padded, ~mask)
out_stacked = torch.cat([out_padded[i][mask[i]] for i in range(out_padded.shape[0])])
# assert out_stacked.shape[0] == conv_input.shape[0]
embeddings_id.append(out_stacked)
out_padded = conv(conv_input, ~mask)
embeddings_id.append(out_padded)
for num, conv in enumerate(self.conv_reg):
conv_input = embedding if num == 0 else embeddings_reg[-1]
input_padded, mask = torch_geometric.utils.to_dense_batch(conv_input, batch_idx)
out_padded = conv(input_padded, ~mask)
out_stacked = torch.cat([out_padded[i][mask[i]] for i in range(out_padded.shape[0])])
# assert out_stacked.shape[0] == conv_input.shape[0]
embeddings_reg.append(out_stacked)
out_padded = conv(conv_input, ~mask)
embeddings_reg.append(out_padded)

if self.conv_type != "gravnet":
embeddings_id = [unpad(emb, mask) for emb in embeddings_id]
embeddings_reg = [unpad(emb, mask) for emb in embeddings_reg]

# classification
embedding_id = torch.cat([input_] + embeddings_id, axis=-1)
Expand Down
Loading