Skip to content

Commit

Permalink
Allow for multi-dimensional edge_labels in LinkNeighborLoader (#5186
Browse files Browse the repository at this point in the history
)

* allow multi-dim edge_label

* changelog
  • Loading branch information
rusty1s authored Aug 10, 2022
1 parent 119318f commit 797e3d9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- Allow for multi-dimensional `edge_labels` in `LinkNeighborLoader` ([#5186](https://github.com/pyg-team/pytorch_geometric/pull/5186)]
- Fixed `GINEConv` bug with non-sequential input ([#5154](https://github.com/pyg-team/pytorch_geometric/pull/5154)]
- Improved error message ([#5095](https://github.com/pyg-team/pytorch_geometric/pull/5095))
- Fixed `HGTLoader` bug which produced outputs with missing edge types ([#5067](https://github.com/pyg-team/pytorch_geometric/pull/5067))
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def update_time_(node_time_dict, index, node_type, num_nodes):
return node_time_dict

def __call__(self, query: List[Tuple[Tensor]]):
query = [torch.tensor(s) for s in zip(*query)]
query = [torch.stack(s, dim=0) for s in zip(*query)]
edge_label_index = torch.stack(query[:2], dim=0)
edge_label = query[2]
edge_label_time = query[3] if len(query) == 4 else None
Expand Down

0 comments on commit 797e3d9

Please sign in to comment.