-
Notifications
You must be signed in to change notification settings - Fork 692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to implement batch training? #62
Comments
The memory usage is just too high to use batch training. The input with shape |
I'm really looking forward to your solution to Best, Shuyue |
I also met critical memory usage problems while implementing on a large-number nodes dataset, e.g., 5000 - 10000 nodes. I really hope you can solve the memory problem bro. Looking forward to seeing your progress bro. Best regards, Shuyue |
Hello, Shuyue My solution is under review now. You can see it at Cheers, |
Actually, implementing batch training is not a hard work if memory is sufficient during training. Here is an example: https://blog.csdn.net/pixian3729/article/details/110261140 (In Simplified Chinese). Good luck! |
Dear Howard, Thanks a lot for your help! You are an extraordinary man! I really appreciate it. Another thing to inquire about is the For details: File "/home/jsy/.conda/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/jsy/.conda/envs/torch/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/jsy/.conda/envs/torch/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/jsy/.conda/envs/torch/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/home/jsy/.conda/envs/torch/lib/python3.8/site-packages/torch/_utils.py", line 428, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/jsy/.conda/envs/torch/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/jsy/.conda/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/jsy/gnn/gnn.py", line 32, in forward
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
File "/home/jsy/gnn/gnn.py", line 32, in <listcomp>
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
File "/home/jsy/.conda/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/jsy/gnn/layers.py", line 32, in forward
attention = torch.where(adj > 0, e, zero_vec)
RuntimeError: The size of tensor a (3000) must match the size of tensor b (1500) at non-singleton dimension 1 Thanks in advance and I'm looking forward to your reply. Best regards, Shuyue |
Shuyue, happy to see your reply. Multi-GPU training must implement batch training. My solution is just for memory issues. Could you show your code for batch training? |
Dear Howard, Thanks a lot for your quick reply. As for the batch training, please refer to this issue #36. Best, Shuyue |
Dear Howard, If you don't mind, you can upload and pull the batch training with less memory usage codes. Thanks in advance. Best, Shuyue |
Dear Howard, Thanks a lot for your help. I edited the batch training codes: At the class GraphAttentionLayer(nn.Module):
"""
https://github.com/Diego999/pyGAT/blob/master/layers.py
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
"""
def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(GraphAttentionLayer, self).__init__()
self.dropout = dropout
self.in_features = in_features
self.out_features = out_features
self.alpha = alpha
self.concat = concat
self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)
self.leakyrelu = nn.LeakyReLU(self.alpha)
def forward(self, h, adj):
"""
:param h: (batch_zize, number_nodes, in_features)
:param adj: (batch_size, number_nodes, number_nodes)
:return: (batch_zize, number_nodes, out_features)
"""
# batchwise matrix multiplication
Wh = torch.matmul(h, self.W) # (batch_zize, number_nodes, in_features) * (in_features, out_features) -> (batch_zize, number_nodes, out_features)
e = self.prepare_batch(Wh) # (batch_zize, number_nodes, number_nodes)
# (batch_zize, number_nodes, number_nodes)
zero_vec = -9e15 * torch.ones_like(e)
# (batch_zize, number_nodes, number_nodes)
attention = torch.where(adj > 0, e, zero_vec)
# (batch_zize, number_nodes, number_nodes)
attention = F.softmax(attention, dim=-1)
# (batch_zize, number_nodes, number_nodes)
attention = F.dropout(attention, self.dropout, training=self.training)
# batched matrix multiplication (batch_zize, number_nodes, out_features)
h_prime = torch.matmul(attention, Wh)
if self.concat:
return F.elu(h_prime)
else:
return h_prime
def prepare_batch(self, Wh):
"""
with batch training
:param Wh: (batch_zize, number_nodes, out_features)
:return:
"""
# Wh.shape (B, N, out_feature)
# self.a.shape (2 * out_feature, 1)
# Wh1&2.shape (B, N, 1)
# e.shape (B, N, N)
B, N, E = Wh.shape # (B, N, N)
# (B, N, out_feature) X (out_feature, 1) -> (B, N, 1)
Wh1 = torch.matmul(Wh, self.a[:self.out_features, :]) # (B, N, out_feature) X (out_feature, 1) -> (B, N, 1)
Wh2 = torch.matmul(Wh, self.a[self.out_features:, :]) # (B, N, out_feature) X (out_feature, 1) -> (B, N, 1)
# broadcast add (B, N, 1) + (B, 1, N)
e = Wh1 + Wh2.permute(0, 2, 1) # (B, N, N)
return self.leakyrelu(e)
def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' At the same time, We should change line 21: x = torch.cat([att(x, adj) for att in self.attentions], dim=-1) Meanwhile, can you help me review and check the codes to make sure they worked correctly, I mean, mathematically? Thanks in advance and have a nice day! Best, Shuyue |
Dear Howard, I also found that the memory usage surged while training with batches. I'm looking forward to your reply. Best, Shuyue |
Dear Shuyue, I'm glad to see your reply.
Your implementation is consistent with the algorithm in the original paper, and so am I. I don't think there is anything wrong.
I also suffered from this. Finally, I lowered the hyperparameters to avoid OOM errors. I will continue to find some ways to reduce memory usage from the algorithm. Cheers, |
Great! Let me try to figure out how to implement on multi GPUs to distribute memories to multiple devices. Thanks again for your help and guide. Have a nice day! Best, Shuyue |
Any idea on this issue? #64 |
Thank you very much for your sharing I have also been doing batch training recently Can you share your code? |
Hi, Thanks for your interest! I'm currently working on cross-subject research rather than GNN. However, after I finished the project, if there are superior performances, I'll open-source the codes. Best regards, Shuyue |
I found that the multi-head attention implementation of pyGAT could be improved. |
Dear Howard, Could your share your codes or ideas? Thanks in advance. Best, Shuyue |
Hi,
Thanks a lot for your great work! I really appreciate it!
Is there any way to implement batch training?
Thanks in advance.
Best,
Shuyue
The text was updated successfully, but these errors were encountered: