Skip to content

Commit

Permalink
enable temporal distributed sampling + unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kgajdamo committed Oct 30, 2023
1 parent 7eee3d5 commit b7acc09
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 15 deletions.
146 changes: 137 additions & 9 deletions test/distributed/test_dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,30 @@
from torch_geometric.testing import withPackage


def create_data(rank, world_size):
def create_data(rank, world_size, temporal=False):
num_nodes = 10
# create dist data
if rank == 0:
# partition 0
node_id = torch.tensor([0, 1, 2, 3, 4, 5, 6])
# sorted by dst
edge_index = torch.tensor([
[1, 2, 3, 4, 5, 0, 0],
[0, 1, 2, 3, 4, 4, 9],
])
else:
# partition 1
node_id = torch.tensor([0, 4, 5, 6, 7, 8, 9])
# sorted by dst
edge_index = torch.tensor([
[5, 6, 7, 8, 9, 0, 5],
[5, 6, 7, 8, 9, 5, 0],
[4, 5, 6, 7, 8, 9, 9],
])

feature_store = LocalFeatureStore.from_data(node_id)
graph_store = LocalGraphStore.from_data(None, edge_index, num_nodes=10)
graph_store = LocalGraphStore.from_data(None, edge_index,
num_nodes=num_nodes,
is_sorted=True)

graph_store.node_pb = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
graph_store.meta.update({'num_parts': 2})
Expand All @@ -44,17 +49,25 @@ def create_data(rank, world_size):

dist_data = (feature_store, graph_store)

# create reference data
# create reference data sorted by dst
edge_index = torch.tensor([
[1, 2, 3, 4, 5, 0, 0, 6, 7, 8, 9, 5],
[0, 1, 2, 3, 4, 4, 9, 5, 6, 7, 8, 9],
])
data = Data(x=None, y=None, edge_index=edge_index, num_nodes=10)

data = Data(x=None, y=None, edge_index=edge_index, num_nodes=num_nodes)

if temporal:
# create time data sorted by edge_index srcs
data_time = torch.tensor([5, 0, 1, 3, 3, 4, 4, 4, 4, 4])
feature_store.put_tensor(data_time, group_name=None, attr_name="time")

data.time = data_time

return (dist_data, data)


def dist_neighbor_sampler_homo(
def dist_neighbor_sampler(
world_size: int,
rank: int,
master_port: int,
Expand Down Expand Up @@ -137,9 +150,97 @@ def dist_neighbor_sampler_homo(
torch.distributed.barrier()


def dist_neighbor_sampler_temporal(
world_size: int,
rank: int,
master_port: int,
seed_time: torch.tensor = None,
temporal_strategy: str = 'uniform',
):
dist_data, data = create_data(rank, world_size, temporal=True)

current_ctx = DistContext(
rank=rank,
global_rank=rank,
world_size=world_size,
global_world_size=world_size,
group_name="dist-sampler-test",
)

# Initialize training process group of PyTorch.
torch.distributed.init_process_group(
backend="gloo",
rank=current_ctx.rank,
world_size=current_ctx.world_size,
init_method="tcp://{}:{}".format('localhost', master_port),
)

num_neighbors = [-1, -1] if temporal_strategy == 'uniform' else [1, 1]
dist_sampler = DistNeighborSampler(
data=dist_data,
current_ctx=current_ctx,
rpc_worker_names={},
num_neighbors=num_neighbors,
shuffle=False,
disjoint=True,
temporal_strategy=temporal_strategy,
time_attr='time',
)

init_rpc(
current_ctx=current_ctx,
rpc_worker_names={},
master_addr='localhost',
master_port=master_port,
)

dist_sampler.register_sampler_rpc()
dist_sampler.init_event_loop()

# close RPC & worker group at exit:
atexit.register(close_sampler, 0, dist_sampler)
torch.distributed.barrier()

# seed nodes
if rank == 0:
input_node = torch.tensor([1, 6], dtype=torch.int64)
else:
input_node = torch.tensor([4, 9], dtype=torch.int64)

inputs = NodeSamplerInput(
input_id=None,
node=input_node,
time=seed_time,
)

# evaluate distributed node sample function
out_dist = dist_sampler.event_loop.run_task(
coro=dist_sampler.node_sample(inputs))

torch.distributed.barrier()

sampler = NeighborSampler(data=data, num_neighbors=num_neighbors,
disjoint=True,
temporal_strategy=temporal_strategy,
time_attr='time')

# evaluate node sample function
out = node_sample(inputs, sampler._sample)

# compare distributed output with single machine output
assert torch.equal(out_dist.node, out.node)
assert torch.equal(out_dist.row, out.row)
assert torch.equal(out_dist.col, out.col)
assert torch.equal(out_dist.batch, out.batch)
assert out_dist.num_sampled_nodes == out.num_sampled_nodes
assert out_dist.num_sampled_edges == out.num_sampled_edges

torch.distributed.barrier()


@withPackage('pyg_lib')
@pytest.mark.parametrize("disjoint", [True, False])
def test_dist_neighbor_sampler_homo(disjoint):
def test_dist_neighbor_sampler(disjoint):
mp_context = torch.multiprocessing.get_context("spawn")
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.bind(("127.0.0.1", 0))
Expand All @@ -148,16 +249,43 @@ def test_dist_neighbor_sampler_homo(disjoint):

world_size = 2
w0 = mp_context.Process(
target=dist_neighbor_sampler_homo,
target=dist_neighbor_sampler,
args=(world_size, 0, port, disjoint),
)

w1 = mp_context.Process(
target=dist_neighbor_sampler_homo,
target=dist_neighbor_sampler,
args=(world_size, 1, port, disjoint),
)

w0.start()
w1.start()
w0.join()
w1.join()


@withPackage('pyg_lib')
@pytest.mark.parametrize("seed_time", [None, torch.tensor([3, 6])])
@pytest.mark.parametrize("temporal_strategy", ['uniform', 'last'])
def test_dist_neighbor_sampler_temporal(seed_time, temporal_strategy):
mp_context = torch.multiprocessing.get_context("spawn")
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.bind(("127.0.0.1", 0))
port = s.getsockname()[1]
s.close()

world_size = 2
w0 = mp_context.Process(
target=dist_neighbor_sampler_temporal,
args=(world_size, 0, port, seed_time, temporal_strategy),
)

w1 = mp_context.Process(
target=dist_neighbor_sampler_temporal,
args=(world_size, 1, port, seed_time, temporal_strategy),
)

w0.start()
w1.start()
w0.join()
w1.join()
15 changes: 10 additions & 5 deletions torch_geometric/distributed/dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ async def node_sample(
num_sampled_edges = []

Check warning on line 347 in torch_geometric/distributed/dist_neighbor_sampler.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/distributed/dist_neighbor_sampler.py#L347

Added line #L347 was not covered by tests

# loop over the layers
for one_hop_num in self.num_neighbors:
for i, one_hop_num in enumerate(self.num_neighbors):

Check warning on line 350 in torch_geometric/distributed/dist_neighbor_sampler.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/distributed/dist_neighbor_sampler.py#L350

Added line #L350 was not covered by tests
out = await self.sample_one_hop(src, one_hop_num, seed_time,
src_batch)
if out.node.numel() == 0:
Expand All @@ -364,6 +364,12 @@ async def node_sample(
if self.disjoint:
batch_with_dupl.append(out.batch)

if seed_time is not None and i < self.num_hops - 1:

Check warning on line 367 in torch_geometric/distributed/dist_neighbor_sampler.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/distributed/dist_neighbor_sampler.py#L367

Added line #L367 was not covered by tests
# get seed_time for the next layer based on the previous
# seed_time and sampled neighbors per node info
seed_time = torch.repeat_interleave(

Check warning on line 370 in torch_geometric/distributed/dist_neighbor_sampler.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/distributed/dist_neighbor_sampler.py#L370

Added line #L370 was not covered by tests
seed_time, torch.as_tensor(out.metadata[0]))

num_sampled_nodes.append(len(src))
num_sampled_edges.append(len(out.node))
sampled_nbrs_per_node += out.metadata[0]

Check warning on line 375 in torch_geometric/distributed/dist_neighbor_sampler.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/distributed/dist_neighbor_sampler.py#L375

Added line #L375 was not covered by tests
Expand Down Expand Up @@ -610,17 +616,16 @@ def _sample_one_hop(
)
node, edge, cumsum_neighbors_per_node = out

Check warning on line 617 in torch_geometric/distributed/dist_neighbor_sampler.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/distributed/dist_neighbor_sampler.py#L617

Added line #L617 was not covered by tests

batch = None
# return batch only during temporal sampling
if self.disjoint and node_time is not None:

Check warning on line 619 in torch_geometric/distributed/dist_neighbor_sampler.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/distributed/dist_neighbor_sampler.py#L619

Added line #L619 was not covered by tests
batch, node = node.t().contiguous()
# We create a batch during the step of merging sampler outputs.
_, node = node.t().contiguous()

Check warning on line 621 in torch_geometric/distributed/dist_neighbor_sampler.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/distributed/dist_neighbor_sampler.py#L621

Added line #L621 was not covered by tests

return SamplerOutput(

Check warning on line 623 in torch_geometric/distributed/dist_neighbor_sampler.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/distributed/dist_neighbor_sampler.py#L623

Added line #L623 was not covered by tests
node=node,
row=None,
col=None,
edge=edge,
batch=batch,
batch=None,
metadata=(cumsum_neighbors_per_node, ),
)

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __init__(
if time_attr is not None and len(time_attrs) != 1:
raise ValueError("Temporal sampling specified but did "
"not find any temporal data")

else:
time_attrs[0].index = None # Reset index for full data.
time_tensor = feature_store.get_tensor(time_attrs[0])
self.node_time = time_tensor
Expand Down

0 comments on commit b7acc09

Please sign in to comment.