Skip to content

Commit

Permalink
implement multi-type
Browse files Browse the repository at this point in the history
  • Loading branch information
Padarn committed Aug 9, 2022
1 parent c2ba93f commit 63b70b3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
41 changes: 22 additions & 19 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
to_csc,
to_hetero_csc,
)
from torch_geometric.typing import InputNodes, NumNeighbors
from torch_geometric.typing import HeteroNodeList, InputNodes, NumNeighbors


class NeighborSampler:
Expand Down Expand Up @@ -148,7 +148,11 @@ def _set_num_neighbors_and_num_hops(self, num_neighbors):
# Add at least one element to the list to ensure `max` is well-defined
self.num_hops = max([0] + [len(v) for v in num_neighbors.values()])

def _sparse_neighbor_sample(self, index: Tensor):
def _sparse_neighbor_sample(self, index: Union[List[int], Tensor]):

if not isinstance(index, torch.LongTensor):
index = torch.LongTensor(index)

fn = torch.ops.torch_sparse.neighbor_sample
node, row, col, edge = fn(
self.colptr,
Expand All @@ -160,7 +164,16 @@ def _sparse_neighbor_sample(self, index: Tensor):
)
return node, row, col, edge

def _hetero_sparse_neighbor_sample(self, index_dict: Dict[str, Tensor]):
def _hetero_sparse_neighbor_sample(self, index: Union[List[int], Tensor,
HeteroNodeList]):

if isinstance(index, List) and isinstance(index[0], Tuple):
index_dict = deconvert_hetero(index)
else:
if not isinstance(index, torch.LongTensor):
index = torch.LongTensor(index)
index_dict = {self.input_type: index}

if self.node_time_dict is None:
fn = torch.ops.torch_sparse.hetero_neighbor_sample
node_dict, row_dict, col_dict, edge_dict = fn(
Expand Down Expand Up @@ -198,17 +211,13 @@ def _hetero_sparse_neighbor_sample(self, index_dict: Dict[str, Tensor]):
)
return node_dict, row_dict, col_dict, edge_dict

def __call__(self, index: Union[List[int], Tensor]):
if not isinstance(index, torch.LongTensor):
index = torch.LongTensor(index)
def __call__(self, index: Union[List[int], Tensor, HeteroNodeList]):

if self.data_cls != 'custom' and issubclass(self.data_cls, Data):
return self._sparse_neighbor_sample(index) + (index.numel(), )
if self.data_cls == 'custom' or issubclass(self.data_cls, HeteroData):
return self._hetero_sparse_neighbor_sample(index) + (
index.numel(), )

elif self.data_cls == 'custom' or issubclass(self.data_cls,
HeteroData):
return self._hetero_sparse_neighbor_sample(
{self.input_type: index}) + (index.numel(), )
return self._sparse_neighbor_sample(index) + (index.numel(), )


class NeighborLoader(torch.utils.data.DataLoader):
Expand Down Expand Up @@ -514,9 +523,6 @@ def to_index(tensor):
return node_type, input_nodes.index


HeteroNodeList = List[Tuple[str, Tensor]]


def convert_hetero(node_type: str, index: Tensor) -> HeteroNodeList:
return [(node_type, i) for i in index]

Expand All @@ -525,7 +531,4 @@ def deconvert_hetero(node_list: HeteroNodeList) -> Dict[str, Tensor]:
node_dicts = defaultdict(list)
for t, node in node_list:
node_dicts[t].append(node)
for k, v in node_dicts.items():
if isinstance(v[0], Tensor):
node_dicts[k] = torch.stack(v)
return node_dicts
return {k: torch.LongTensor(v) for k, v in node_dicts.items()}
4 changes: 3 additions & 1 deletion torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

# Types for sampling ##########################################################

InputNodes = Union[OptTensor, NodeType, Tuple[NodeType, OptTensor]]
InputNodes = Union[OptTensor, NodeType, Tuple[NodeType, OptTensor],
List[Tuple[NodeType, OptTensor]]]
InputEdges = Union[OptTensor, EdgeType, Tuple[EdgeType, OptTensor]]
NumNeighbors = Union[List[int], Dict[EdgeType, List[int]]]
HeteroNodeList = List[Tuple[str, Tensor]]

0 comments on commit 63b70b3

Please sign in to comment.