Skip to content

Commit

Permalink
Fix dtype conversion for Index|EdgeIndex in multi-process `DataLo…
Browse files Browse the repository at this point in the history
…ader` (#9295)
  • Loading branch information
rusty1s authored May 5, 2024
1 parent 8eeb27e commit 8bb44ed
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 0 deletions.
2 changes: 2 additions & 0 deletions test/loader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def test_index_dataloader(num_workers):

for batch in loader:
assert isinstance(batch.index, Index)
assert batch.index.dtype == torch.long
assert batch.index.dim_size == 7
assert batch.index.is_sorted

Expand Down Expand Up @@ -236,6 +237,7 @@ def test_edge_index_dataloader(num_workers, sort_order):

for batch in loader:
assert isinstance(batch.edge_index, EdgeIndex)
assert batch.edge_index.dtype == torch.long
assert batch.edge_index.sparse_size() == (6, 6)
assert batch.edge_index.sort_order == sort_order
assert batch.edge_index.is_undirected
Expand Down
5 changes: 5 additions & 0 deletions torch_geometric/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,11 @@ def is_undirected(self) -> bool:
r"""Returns whether indices are bidirectional."""
return self._is_undirected

@property
def dtype(self) -> torch.dtype: # type: ignore
# TODO Remove once PyTorch does not override `dtype` in `DataLoader`.
return self._data.dtype

# Cache Interface #########################################################

@overload
Expand Down
5 changes: 5 additions & 0 deletions torch_geometric/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ def is_sorted(self) -> bool:
r"""TODO."""
return self._is_sorted

@property
def dtype(self) -> torch.dtype: # type: ignore
# TODO Remove once PyTorch does not override `dtype` in `DataLoader`.
return self._data.dtype

# Cache Interface #########################################################

def get_dim_size(self) -> int:
Expand Down

0 comments on commit 8bb44ed

Please sign in to comment.