-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
dist_neighbor_sampler.py
1074 lines (919 loc) · 41.4 KB
/
dist_neighbor_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import itertools
import logging
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.multiprocessing as mp
from torch import Tensor
from torch_geometric.distributed import (
DistContext,
LocalFeatureStore,
LocalGraphStore,
)
from torch_geometric.distributed.event_loop import (
ConcurrentEventLoop,
to_asyncio_future,
)
from torch_geometric.distributed.rpc import (
RPCCallBase,
RPCRouter,
rpc_async,
rpc_partition_to_workers,
rpc_register,
)
from torch_geometric.distributed.utils import (
BatchDict,
DistEdgeHeteroSamplerInput,
NodeDict,
remove_duplicates,
)
from torch_geometric.sampler import (
EdgeSamplerInput,
HeteroSamplerOutput,
NegativeSampling,
NeighborSampler,
NodeSamplerInput,
SamplerOutput,
)
from torch_geometric.sampler.base import NumNeighbors, SubgraphType
from torch_geometric.sampler.neighbor_sampler import neg_sample
from torch_geometric.sampler.utils import remap_keys
from torch_geometric.typing import EdgeType, NodeType
NumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]]
class RPCSamplingCallee(RPCCallBase):
r"""A wrapper for RPC callee that will perform RPC sampling from remote
processes.
"""
def __init__(self, sampler: NeighborSampler):
super().__init__()
self.sampler = sampler
def rpc_async(self, *args, **kwargs) -> Any:
return self.sampler._sample_one_hop(*args, **kwargs)
def rpc_sync(self, *args, **kwargs) -> Any:
pass
class DistNeighborSampler:
r"""An implementation of a distributed and asynchronised neighbor sampler
used by :class:`~torch_geometric.distributed.DistNeighborLoader` and
:class:`~torch_geometric.distributed.DistLinkNeighborLoader`.
"""
def __init__(
self,
current_ctx: DistContext,
data: Tuple[LocalFeatureStore, LocalGraphStore],
num_neighbors: NumNeighborsType,
channel: Optional[mp.Queue] = None,
replace: bool = False,
subgraph_type: Union[SubgraphType, str] = 'directional',
disjoint: bool = False,
temporal_strategy: str = 'uniform',
time_attr: Optional[str] = None,
concurrency: int = 1,
device: Optional[torch.device] = None,
**kwargs,
):
self.current_ctx = current_ctx
self.feature_store, self.graph_store = data
assert isinstance(self.graph_store, LocalGraphStore)
assert isinstance(self.feature_store, LocalFeatureStore)
self.is_hetero = self.graph_store.meta['is_hetero']
self.num_neighbors = num_neighbors
self.channel = channel
self.concurrency = concurrency
self.device = device
self.event_loop = None
self.replace = replace
self.subgraph_type = SubgraphType(subgraph_type)
self.disjoint = disjoint
self.temporal_strategy = temporal_strategy
self.time_attr = time_attr
self.temporal = time_attr is not None
self.with_edge_attr = self.feature_store.has_edge_attr()
self.csc = True
def init_sampler_instance(self):
self._sampler = NeighborSampler(
data=(self.feature_store, self.graph_store),
num_neighbors=self.num_neighbors,
subgraph_type=self.subgraph_type,
replace=self.replace,
disjoint=self.disjoint,
temporal_strategy=self.temporal_strategy,
time_attr=self.time_attr,
)
self.num_hops = self._sampler.num_neighbors.num_hops
self.node_types = self._sampler.node_types
self.edge_types = self._sampler.edge_types
self.node_time = self._sampler.node_time
self.edge_time = self._sampler.edge_time
def register_sampler_rpc(self) -> None:
partition2workers = rpc_partition_to_workers(
current_ctx=self.current_ctx,
num_partitions=self.graph_store.num_partitions,
current_partition_idx=self.graph_store.partition_idx,
)
self.rpc_router = RPCRouter(partition2workers)
self.feature_store.set_rpc_router(self.rpc_router)
rpc_sample_callee = RPCSamplingCallee(self)
self.rpc_sample_callee_id = rpc_register(rpc_sample_callee)
def init_event_loop(self) -> None:
if self.event_loop is None:
self.event_loop = ConcurrentEventLoop(self.concurrency)
self.event_loop.start_loop()
logging.info(f'{self} uses {self.event_loop}')
# Node-based distributed sampling #########################################
def sample_from_nodes(
self,
inputs: NodeSamplerInput,
**kwargs,
) -> Optional[Union[SamplerOutput, HeteroSamplerOutput]]:
self.init_event_loop()
inputs = NodeSamplerInput.cast(inputs)
if self.channel is None:
# synchronous sampling
return self.event_loop.run_task(
coro=self._sample_from(self.node_sample, inputs))
# asynchronous sampling
cb = kwargs.get("callback", None)
self.event_loop.add_task(
coro=self._sample_from(self.node_sample, inputs), callback=cb)
return None
# Edge-based distributed sampling #########################################
def sample_from_edges(
self,
inputs: EdgeSamplerInput,
neg_sampling: Optional[NegativeSampling] = None,
**kwargs,
) -> Optional[Union[SamplerOutput, HeteroSamplerOutput]]:
self.init_event_loop()
if self.channel is None:
# synchronous sampling
return self.event_loop.run_task(coro=self._sample_from(
self.edge_sample, inputs, self.node_sample, self._sampler.
num_nodes, self.disjoint, self.node_time, neg_sampling))
# asynchronous sampling
cb = kwargs.get("callback", None)
self.event_loop.add_task(
coro=self._sample_from(self.edge_sample, inputs, self.node_sample,
self._sampler.num_nodes, self.disjoint,
self.node_time, neg_sampling), callback=cb)
return None
async def _sample_from(
self,
async_func,
*args,
**kwargs,
) -> Optional[Union[SamplerOutput, HeteroSamplerOutput]]:
sampler_output = await async_func(*args, **kwargs)
if self.subgraph_type == SubgraphType.bidirectional:
sampler_output = sampler_output.to_bidirectional()
res = await self._collate_fn(sampler_output)
if self.channel is None:
return res
self.channel.put(res)
return None
async def node_sample(
self,
inputs: Union[NodeSamplerInput, DistEdgeHeteroSamplerInput],
) -> Union[SamplerOutput, HeteroSamplerOutput]:
r"""Performs layer-by-layer distributed sampling from a
:class:`NodeSamplerInput` or :class:`DistEdgeHeteroSamplerInput` and
returns the output of the sampling procedure.
.. note::
In case of distributed training it is required to synchronize the
results between machines after each layer.
"""
input_type = inputs.input_type
self.input_type = input_type
if isinstance(inputs, NodeSamplerInput):
seed = inputs.node.to(self.device)
batch_size = len(inputs.node)
seed_batch = torch.arange(batch_size) if self.disjoint else None
metadata = (inputs.input_id, inputs.time, batch_size)
seed_time: Optional[Tensor] = None
if self.temporal:
if inputs.time is not None:
seed_time = inputs.time.to(self.device)
elif self.node_time is not None:
if not self.is_hetero:
seed_time = self.node_time[seed]
else:
seed_time = self.node_time[input_type][seed]
else:
raise ValueError("Seed time needs to be specified")
else: # `DistEdgeHeteroSamplerInput`:
metadata = None # Metadata is added during `edge_sample`.
# Heterogeneous Neighborhood Sampling #################################
if self.is_hetero:
if input_type is None:
raise ValueError("Input type should be defined")
node_dict = NodeDict(self.node_types, self.num_hops)
batch_dict = BatchDict(self.node_types, self.num_hops)
if isinstance(inputs, NodeSamplerInput):
seed_dict: Dict[NodeType, Tensor] = {input_type: seed}
if self.temporal:
node_dict.seed_time[input_type][0] = seed_time.clone()
else: # `DistEdgeHeteroSamplerInput`:
seed_dict = inputs.node_dict
if self.temporal:
for k, v in inputs.node_dict.items():
if inputs.time_dict is not None:
node_dict.seed_time[k][0] = inputs.time_dict[k]
elif self.node_time is not None:
node_dict.seed_time[k][0] = self.node_time[k][v]
else:
raise ValueError("Seed time needs to be specified")
edge_dict: Dict[EdgeType, Tensor] = {
k: torch.empty(0, dtype=torch.int64)
for k in self.edge_types
}
sampled_nbrs_per_node_dict: Dict[EdgeType, List[List]] = {
k: [[] for _ in range(self.num_hops)]
for k in self.edge_types
}
num_sampled_edges_dict: Dict[EdgeType, List[int]] = {
k: []
for k in self.edge_types
}
num_sampled_nodes_dict: Dict[NodeType, List[int]] = {
k: [0]
for k in self.node_types
}
# Fill in node_dict and batch_dict with input data:
batch_len = 0
for k, v in seed_dict.items():
node_dict.src[k][0] = v
node_dict.out[k] = v
num_sampled_nodes_dict[k][0] = len(v)
if self.disjoint:
src_batch = torch.arange(batch_len, batch_len + len(v))
batch_dict.src[k][0] = src_batch
batch_dict.out[k] = src_batch
batch_len = len(src_batch)
# Loop over the layers:
for i in range(self.num_hops):
# Sample neighbors per edge type:
for edge_type in self.edge_types:
# `src` is a destination node type of a given edge.
src = edge_type[0] if not self.csc else edge_type[2]
if node_dict.src[src][i].numel() == 0:
# No source nodes of this type in the current layer.
num_sampled_edges_dict[edge_type].append(0)
continue
if isinstance(self.num_neighbors, list):
one_hop_num = self.num_neighbors[i]
else:
one_hop_num = self.num_neighbors[edge_type][i]
# Sample neighbors:
out = await self.sample_one_hop(
node_dict.src[src][i],
one_hop_num,
node_dict.seed_time[src][i],
batch_dict.src[src][i],
edge_type,
)
if out.node.numel() == 0: # No neighbors were sampled.
num_sampled_edges_dict[edge_type].append(0)
continue
# `dst` is a destination node type of a given edge.
dst = edge_type[2] if not self.csc else edge_type[0]
# Remove duplicates:
(
src_node,
node_dict.out[dst],
src_batch,
batch_dict.out[dst],
) = remove_duplicates(
out,
node_dict.out[dst],
batch_dict.out[dst],
self.disjoint,
)
# Create src nodes for the next layer:
node_dict.src[dst][i + 1] = torch.cat(
[node_dict.src[dst][i + 1], src_node])
if self.disjoint:
batch_dict.src[dst][i + 1] = torch.cat(
[batch_dict.src[dst][i + 1], src_batch])
# Save sampled nodes with duplicates to be able to create
# local edge indices:
node_dict.with_dupl[dst] = torch.cat(
[node_dict.with_dupl[dst], out.node])
edge_dict[edge_type] = torch.cat(
[edge_dict[edge_type], out.edge])
if self.disjoint:
batch_dict.with_dupl[dst] = torch.cat(
[batch_dict.with_dupl[dst], out.batch])
if self.temporal and i < self.num_hops - 1:
# Assign seed time based on source node subgraph ID:
if isinstance(inputs, NodeSamplerInput):
src_seed_time = [
seed_time[(seed_batch == batch_idx).nonzero()]
for batch_idx in src_batch
]
src_seed_time = torch.as_tensor(
src_seed_time, dtype=torch.int64)
else: # `DistEdgeHeteroSamplerInput`:
src_seed_time = torch.empty(0, dtype=torch.int64)
for k, v in batch_dict.src.items():
time = [
node_dict.seed_time[k][0][(
v[0] == batch_idx).nonzero()]
for batch_idx in src_batch
]
try:
time = torch.as_tensor(
time, dtype=torch.int64)
src_seed_time = torch.cat(
[src_seed_time, time])
except Exception:
# `time` may be an empty tensors, because
# no nodes of this type were sampled.
pass
node_dict.seed_time[dst][i + 1] = torch.cat(
[node_dict.seed_time[dst][i + 1], src_seed_time])
# Collect sampled neighbors per node for each layer:
sampled_nbrs_per_node_dict[edge_type][i] += out.metadata[0]
num_sampled_edges_dict[edge_type].append(len(out.node))
for node_type in self.node_types:
num_sampled_nodes_dict[node_type].append(
len(node_dict.src[node_type][i + 1]))
sampled_nbrs_per_node_dict = remap_keys(sampled_nbrs_per_node_dict,
self._sampler.to_rel_type)
# Create local edge indices for a batch:
row_dict, col_dict = torch.ops.pyg.hetero_relabel_neighborhood(
self.node_types,
self.edge_types,
seed_dict,
node_dict.with_dupl,
sampled_nbrs_per_node_dict,
self._sampler.num_nodes,
batch_dict.with_dupl,
self.csc,
self.disjoint,
)
sampler_output = HeteroSamplerOutput(
node=node_dict.out,
row=remap_keys(row_dict, self._sampler.to_edge_type),
col=remap_keys(col_dict, self._sampler.to_edge_type),
edge=edge_dict,
batch=batch_dict.out if self.disjoint else None,
num_sampled_nodes=num_sampled_nodes_dict,
num_sampled_edges=num_sampled_edges_dict,
metadata=metadata,
)
# Homogeneous Neighborhood Sampling ###################################
else:
src = seed
node = src.clone()
src_batch = seed_batch.clone() if self.disjoint else None
batch = seed_batch.clone() if self.disjoint else None
src_seed_time = seed_time.clone() if self.temporal else None
node_with_dupl = [torch.empty(0, dtype=torch.int64)]
batch_with_dupl = [torch.empty(0, dtype=torch.int64)]
edge = [torch.empty(0, dtype=torch.int64)]
sampled_nbrs_per_node = []
num_sampled_nodes = [seed.numel()]
num_sampled_edges = []
# Loop over the layers:
for i, one_hop_num in enumerate(self.num_neighbors):
out = await self.sample_one_hop(src, one_hop_num,
src_seed_time, src_batch)
if out.node.numel() == 0:
# No neighbors were sampled:
num_zero_layers = self.num_hops - i
num_sampled_nodes += num_zero_layers * [0]
num_sampled_edges += num_zero_layers * [0]
break
# Remove duplicates:
src, node, src_batch, batch = remove_duplicates(
out, node, batch, self.disjoint)
node_with_dupl.append(out.node)
edge.append(out.edge)
if self.disjoint:
batch_with_dupl.append(out.batch)
if self.temporal and i < self.num_hops - 1:
# Assign seed time based on src nodes subgraph IDs.
src_seed_time = [
seed_time[(seed_batch == batch_idx).nonzero()]
for batch_idx in src_batch
]
src_seed_time = torch.as_tensor(src_seed_time,
dtype=torch.int64)
num_sampled_nodes.append(len(src))
num_sampled_edges.append(len(out.node))
sampled_nbrs_per_node += out.metadata[0]
row, col = torch.ops.pyg.relabel_neighborhood(
seed,
torch.cat(node_with_dupl),
sampled_nbrs_per_node,
self._sampler.num_nodes,
torch.cat(batch_with_dupl) if self.disjoint else None,
self.csc,
self.disjoint,
)
sampler_output = SamplerOutput(
node=node,
row=row,
col=col,
edge=torch.cat(edge),
batch=batch if self.disjoint else None,
num_sampled_nodes=num_sampled_nodes,
num_sampled_edges=num_sampled_edges,
metadata=metadata,
)
return sampler_output
async def edge_sample(
self,
inputs: EdgeSamplerInput,
sample_fn: Callable,
num_nodes: Union[int, Dict[NodeType, int]],
disjoint: bool,
node_time: Optional[Union[Tensor, Dict[str, Tensor]]] = None,
neg_sampling: Optional[NegativeSampling] = None,
) -> Union[SamplerOutput, HeteroSamplerOutput]:
r"""Performs layer-by-layer distributed sampling from an
:class:`EdgeSamplerInput` and returns the output of the sampling
procedure.
.. note::
In case of distributed training it is required to synchronize the
results between machines after each layer.
"""
input_id = inputs.input_id
src = inputs.row
dst = inputs.col
edge_label = inputs.label
edge_label_time = inputs.time
input_type = inputs.input_type
src_time = dst_time = edge_label_time
assert edge_label_time is None or disjoint
assert isinstance(num_nodes, (dict, int))
if not isinstance(num_nodes, dict):
num_src_nodes = num_dst_nodes = num_nodes
else:
num_src_nodes = num_nodes[input_type[0]]
num_dst_nodes = num_nodes[input_type[-1]]
num_pos = src.numel()
num_neg = 0
# Negative Sampling ###################################################
if neg_sampling is not None:
# When we are doing negative sampling, we append negative
# information of nodes/edges to `src`, `dst`, `src_time`,
# `dst_time`. Later on, we can easily reconstruct what belongs to
# positive and negative examples by slicing via `num_pos`.
num_neg = math.ceil(num_pos * neg_sampling.amount)
if neg_sampling.is_binary():
# In the "binary" case, we randomly sample negative pairs of
# nodes.
if isinstance(node_time, dict):
src_node_time = node_time.get(input_type[0])
else:
src_node_time = node_time
src_neg = neg_sample(src, neg_sampling, num_src_nodes,
src_time, src_node_time)
src = torch.cat([src, src_neg], dim=0)
if isinstance(node_time, dict):
dst_node_time = node_time.get(input_type[-1])
else:
dst_node_time = node_time
dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes,
dst_time, dst_node_time)
dst = torch.cat([dst, dst_neg], dim=0)
if edge_label is None:
edge_label = torch.ones(num_pos)
size = (num_neg, ) + edge_label.size()[1:]
edge_neg_label = edge_label.new_zeros(size)
edge_label = torch.cat([edge_label, edge_neg_label])
if edge_label_time is not None:
src_time = dst_time = edge_label_time.repeat(
1 + math.ceil(neg_sampling.amount))[:num_pos + num_neg]
elif neg_sampling.is_triplet():
# In the "triplet" case, we randomly sample negative
# destinations.
if isinstance(node_time, dict):
dst_node_time = node_time.get(input_type[-1])
else:
dst_node_time = node_time
dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes,
dst_time, dst_node_time)
dst = torch.cat([dst, dst_neg], dim=0)
assert edge_label is None
if edge_label_time is not None:
dst_time = edge_label_time.repeat(1 + neg_sampling.amount)
# Heterogeneus Neighborhood Sampling ##################################
if input_type is not None:
if input_type[0] != input_type[-1]: # Two distinct node types:
if not disjoint:
src, inverse_src = src.unique(return_inverse=True)
dst, inverse_dst = dst.unique(return_inverse=True)
seed_dict = {input_type[0]: src, input_type[-1]: dst}
seed_time_dict = None
if edge_label_time is not None: # Always disjoint.
seed_time_dict = {
input_type[0]: src_time,
input_type[-1]: dst_time,
}
out = await sample_fn(
DistEdgeHeteroSamplerInput(
input_id=inputs.input_id,
node_dict=seed_dict,
time_dict=seed_time_dict,
input_type=input_type,
))
else:
# Only a single node type: Merge both source and destination.
seed = torch.cat([src, dst], dim=0)
if not disjoint:
seed, inverse_seed = seed.unique(return_inverse=True)
seed_dict = {input_type[0]: seed}
seed_time = None
if edge_label_time is not None: # Always disjoint.
seed_time = torch.cat([src_time, dst_time], dim=0)
out = await sample_fn(
NodeSamplerInput(
input_id=inputs.input_id,
node=seed,
time=seed_time,
input_type=input_type[0],
))
# Enhance `out` by label information ##############################
if disjoint:
for key, batch in out.batch.items():
out.batch[key] = batch % num_pos
if neg_sampling is None or neg_sampling.is_binary():
if disjoint:
if input_type[0] != input_type[-1]:
edge_label_index = torch.arange(num_pos + num_neg)
edge_label_index = edge_label_index.repeat(2)
edge_label_index = edge_label_index.view(2, -1)
else:
num_labels = num_pos + num_neg
edge_label_index = torch.arange(2 * (num_labels))
edge_label_index = edge_label_index.view(2, -1)
else:
if input_type[0] != input_type[-1]:
edge_label_index = torch.stack([
inverse_src,
inverse_dst,
], dim=0)
else:
edge_label_index = inverse_seed.view(2, -1)
out.metadata = (input_id, edge_label_index, edge_label,
src_time)
elif neg_sampling.is_triplet():
if disjoint:
src_index = torch.arange(num_pos)
if input_type[0] != input_type[-1]:
dst_pos_index = torch.arange(num_pos)
# `dst_neg_index` needs to be offset such that indices
# with offset `num_pos` belong to the same triplet:
dst_neg_index = torch.arange(
num_pos, seed_dict[input_type[-1]].numel())
dst_neg_index = dst_neg_index.view(-1, num_pos).t()
else:
dst_pos_index = torch.arange(num_pos, 2 * num_pos)
dst_neg_index = torch.arange(
2 * num_pos, seed_dict[input_type[-1]].numel())
dst_neg_index = dst_neg_index.view(-1, num_pos).t()
else:
if input_type[0] != input_type[-1]:
src_index = inverse_src
dst_pos_index = inverse_dst[:num_pos]
dst_neg_index = inverse_dst[num_pos:]
else:
src_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
dst_neg_index = inverse_seed[2 * num_pos:]
dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)
out.metadata = (
input_id,
src_index,
dst_pos_index,
dst_neg_index,
src_time,
)
# Homogeneous Neighborhood Sampling ###################################
else:
seed = torch.cat([src, dst], dim=0)
seed_time = None
if not disjoint:
seed, inverse_seed = seed.unique(return_inverse=True)
if edge_label_time is not None: # Always disjoint.
seed_time = torch.cat([src_time, dst_time])
out = await sample_fn(
NodeSamplerInput(
input_id=inputs.input_id,
node=seed,
time=seed_time,
input_type=None,
))
# Enhance `out` by label information ##############################
if neg_sampling is None or neg_sampling.is_binary():
if disjoint:
out.batch = out.batch % num_pos
edge_label_index = torch.arange(seed.numel()).view(2, -1)
else:
edge_label_index = inverse_seed.view(2, -1)
out.metadata = (input_id, edge_label_index, edge_label,
src_time)
elif neg_sampling.is_triplet():
if disjoint:
out.batch = out.batch % num_pos
src_index = torch.arange(num_pos)
dst_pos_index = torch.arange(num_pos, 2 * num_pos)
# `dst_neg_index` needs to be offset such that indices with
# offset `num_pos` belong to the same triplet:
dst_neg_index = torch.arange(2 * num_pos, seed.numel())
dst_neg_index = dst_neg_index.view(-1, num_pos).t()
else:
src_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
dst_neg_index = inverse_seed[2 * num_pos:]
dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)
out.metadata = (
input_id,
src_index,
dst_pos_index,
dst_neg_index,
src_time,
)
return out
def _get_sampler_output(
self,
outputs: List[SamplerOutput],
seed_size: int,
p_id: int,
src_batch: Optional[Tensor] = None,
) -> SamplerOutput:
r"""Used when seed nodes belongs to one partition. It's purpose is to
remove seed nodes from sampled nodes and calculates how many neighbors
were sampled by each src node based on the
:obj:`cumsum_neighbors_per_node`. Returns updated sampler output.
"""
cumsum_neighbors_per_node = outputs[p_id].metadata[0]
# do not include seed
outputs[p_id].node = outputs[p_id].node[seed_size:]
begin = np.array(cumsum_neighbors_per_node[1:])
end = np.array(cumsum_neighbors_per_node[:-1])
sampled_nbrs_per_node = list(np.subtract(begin, end))
outputs[p_id].metadata = (sampled_nbrs_per_node, )
if self.disjoint:
batch = [[src_batch[i]] * nbrs_per_node
for i, nbrs_per_node in enumerate(sampled_nbrs_per_node)]
outputs[p_id].batch = Tensor(
list(itertools.chain.from_iterable(batch))).type(torch.int64)
return outputs[p_id]
def _merge_sampler_outputs(
self,
partition_ids: Tensor,
partition_orders: Tensor,
outputs: List[SamplerOutput],
one_hop_num: int,
src_batch: Optional[Tensor] = None,
) -> SamplerOutput:
r"""Merges samplers outputs from different partitions, so that they
are sorted according to the sampling order. Removes seed nodes from
sampled nodes and calculates how many neighbors were sampled by each
src node based on the :obj:`cumsum_neighbors_per_node`. Leverages the
:obj:`pyg-lib` :meth:`merge_sampler_outputs` function.
Args:
partition_ids (torch.Tensor): Contains information on which
partition seeds nodes are located on.
partition_orders (torch.Tensor): Contains information about the
order of seed nodes in each partition.
outputs (List[SamplerOutput]): List of all samplers outputs.
one_hop_num (int): Max number of neighbors sampled in the current
layer.
src_batch (torch.Tensor, optional): The batch assignment of seed
nodes. (default: :obj:`None`)
Returns :obj:`SamplerOutput` containing all merged outputs.
"""
sampled_nodes_with_dupl = [
o.node if o is not None else torch.empty(0, dtype=torch.int64)
for o in outputs
]
edge_ids = [
o.edge if o is not None else torch.empty(0, dtype=torch.int64)
for o in outputs
]
cumm_sampled_nbrs_per_node = [
o.metadata[0] if o is not None else [] for o in outputs
]
partition_ids = partition_ids.tolist()
partition_orders = partition_orders.tolist()
partitions_num = self.graph_store.meta["num_parts"]
out = torch.ops.pyg.merge_sampler_outputs(
sampled_nodes_with_dupl,
edge_ids,
cumm_sampled_nbrs_per_node,
partition_ids,
partition_orders,
partitions_num,
one_hop_num,
src_batch,
self.disjoint,
)
(
out_node_with_dupl,
out_edge,
out_batch,
out_sampled_nbrs_per_node,
) = out
return SamplerOutput(
out_node_with_dupl,
None,
None,
out_edge,
out_batch if self.disjoint else None,
metadata=(out_sampled_nbrs_per_node, ),
)
async def sample_one_hop(
self,
srcs: Tensor,
one_hop_num: int,
seed_time: Optional[Tensor] = None,
src_batch: Optional[Tensor] = None,
edge_type: Optional[EdgeType] = None,
) -> SamplerOutput:
r"""Samples one-hop neighbors for a set of seed nodes in :obj:`srcs`.
If seed nodes are located on a local partition, evaluates the sampling
function on the current machine. If seed nodes are from a remote
partition, sends a request to a remote machine that contains this
partition.
"""
src_node_type = None if not self.is_hetero else edge_type[2]
partition_ids = self.graph_store.get_partition_ids_from_nids(
srcs, src_node_type)
partition_orders = torch.zeros(len(partition_ids), dtype=torch.long)
p_outputs: List[SamplerOutput] = [
None
] * self.graph_store.meta["num_parts"]
futs: List[torch.futures.Future] = []
local_only = True
single_partition = len(set(partition_ids.tolist())) == 1
for i in range(self.graph_store.num_partitions):
p_id = (self.graph_store.partition_idx +
i) % self.graph_store.num_partitions
p_mask = partition_ids == p_id
p_srcs = torch.masked_select(srcs, p_mask)
p_seed_time = (torch.masked_select(seed_time, p_mask)
if self.temporal else None)
p_indices = torch.arange(len(p_srcs), dtype=torch.long)
partition_orders[p_mask] = p_indices
if p_srcs.shape[0] > 0:
if p_id == self.graph_store.partition_idx:
# Sample for one hop on a local machine:
p_nbr_out = self._sample_one_hop(p_srcs, one_hop_num,
p_seed_time, edge_type)
p_outputs.pop(p_id)
p_outputs.insert(p_id, p_nbr_out)
else: # Sample on a remote machine:
local_only = False
to_worker = self.rpc_router.get_to_worker(p_id)
futs.append(
rpc_async(
to_worker,
self.rpc_sample_callee_id,
args=(p_srcs, one_hop_num, p_seed_time, edge_type),
))
if not local_only:
# Src nodes are remote
res_fut_list = await to_asyncio_future(
torch.futures.collect_all(futs))
for i, res_fut in enumerate(res_fut_list):
p_id = (self.graph_store.partition_idx + i +
1) % self.graph_store.num_partitions
p_outputs.pop(p_id)
p_outputs.insert(p_id, res_fut.wait())
# All src nodes are in the same partition
if single_partition:
return self._get_sampler_output(p_outputs, len(srcs),
partition_ids[0], src_batch)
return self._merge_sampler_outputs(partition_ids, partition_orders,
p_outputs, one_hop_num, src_batch)
def _sample_one_hop(
self,
input_nodes: Tensor,
num_neighbors: int,
seed_time: Optional[Tensor] = None,
edge_type: Optional[EdgeType] = None,
) -> SamplerOutput:
r"""Implements one-hop neighbor sampling for a set of input nodes for a
specific edge type.
"""
if not self.is_hetero:
colptr = self._sampler.colptr
row = self._sampler.row
node_time = self.node_time
edge_time = self.edge_time
else:
# Given edge type, get input data and evaluate sample function:
rel_type = '__'.join(edge_type)
colptr = self._sampler.colptr_dict[rel_type]
row = self._sampler.row_dict[rel_type]
# `node_time` is a destination node time:
node_time = (self.node_time or {}).get(edge_type[0], None)
edge_time = (self.edge_time or {}).get(edge_type, None)
out = torch.ops.pyg.dist_neighbor_sample(
colptr,
row,
input_nodes.to(colptr.dtype),
num_neighbors,
node_time,
edge_time,
seed_time,
None, # TODO: edge_weight
True, # csc
self.replace,
self.subgraph_type != SubgraphType.induced,
self.disjoint and self.temporal,
self.temporal_strategy,
)
node, edge, cumsum_neighbors_per_node = out
if self.disjoint and self.temporal:
# We create a batch during the step of merging sampler outputs.
_, node = node.t().contiguous()
return SamplerOutput(
node=node,
row=None,
col=None,
edge=edge,
batch=None,
metadata=(cumsum_neighbors_per_node, ),
)
async def _collate_fn(
self, output: Union[SamplerOutput, HeteroSamplerOutput]
) -> Union[SamplerOutput, HeteroSamplerOutput]:
r"""Collect labels and features for the sampled subgrarph if necessary,
and put them into a sample message.
"""
if self.is_hetero: