Skip to content

Commit

Permalink
Merge pull request ORNL#71 from allaffa/remap_all_fields_of_data_to_gpus
Browse files Browse the repository at this point in the history
data.batch is remapped to the same device as data.x
  • Loading branch information
streeve authored Dec 16, 2021
2 parents 56a581b + cae3816 commit 9e8cae6
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion hydragnn/models/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def forward(self, data):

#### multi-head decoder part####
# shared dense layers for graph level output
x_graph = global_mean_pool(x, batch)
x_graph = global_mean_pool(x, batch.to(x.device))
outputs = []
for head_dim, headloc, type_head in zip(
self.head_dims, self.heads_NN, self.head_type
Expand Down

0 comments on commit 9e8cae6

Please sign in to comment.