Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed May 11, 2022
1 parent 6175158 commit a9eed49
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,6 @@ class LinkNeighborLoader(torch.utils.data.DataLoader):
replacement. (default: :obj:`False`)
directed (bool, optional): If set to :obj:`False`, will include all
edges between all sampled nodes. (default: :obj:`True`)
transform (Callable, optional): A function/transform that takes in
a sampled mini-batch and returns a transformed version.
(default: :obj:`None`)
neg_sampling_ratio (float, optional): The ratio of sampled negative
edges to the number of positive edges.
If :obj:`edge_label` does not exist, it will be automatically
Expand All @@ -219,6 +216,13 @@ class LinkNeighborLoader(torch.utils.data.DataLoader):
:meth:`F.binary_cross_entropy`) and of type
:obj:`torch.long` for multi-class classification (to facilitate the
ease-of-use of :meth:`F.cross_entropy`). (default: :obj:`0.0`).
transform (Callable, optional): A function/transform that takes in
a sampled mini-batch and returns a transformed version.
(default: :obj:`None`)
is_sorted (bool, optional): If set to :obj:`True`, assumes that
:obj:`edge_index` is sorted by column. This avoids internal
re-sorting of the data and can improve runtime and memory
efficiency. (default: :obj:`False`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
Expand All @@ -231,9 +235,10 @@ def __init__(
edge_label: OptTensor = None,
replace: bool = False,
directed: bool = True,
neg_sampling_ratio: float = 0.0,
transform: Callable = None,
is_sorted: bool = False,
neighbor_sampler: Optional[LinkNeighborSampler] = None,
neg_sampling_ratio: float = 0.0,
**kwargs,
):
# Remove for PyTorch Lightning:
Expand All @@ -259,9 +264,15 @@ def __init__(

if neighbor_sampler is None:
self.neighbor_sampler = LinkNeighborSampler(
data, num_neighbors, replace, directed, edge_type,
data,
num_neighbors,
replace,
directed,
input_type=edge_type,
is_sorted=is_sorted,
neg_sampling_ratio=self.neg_sampling_ratio,
share_memory=kwargs.get('num_workers', 0) > 0,
neg_sampling_ratio=self.neg_sampling_ratio)
)

super().__init__(Dataset(edge_label_index, edge_label),
collate_fn=self.neighbor_sampler, **kwargs)
Expand Down

0 comments on commit a9eed49

Please sign in to comment.