Skip to content

Commit

Permalink
revert changes in parition.py moved to #8083
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubPietrakIntel authored and kgajdamo committed Oct 10, 2023
1 parent 99ce405 commit 3334da9
Showing 1 changed file with 2 additions and 45 deletions.
47 changes: 2 additions & 45 deletions torch_geometric/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,8 @@
import torch

from torch_geometric.data import Data, HeteroData
from torch_geometric.loader.cluster import ClusterData
from torch_geometric.typing import (
Dict,
EdgeType,
EdgeTypeStr,
NodeType,
Tuple,
as_str,
)
from torch_geometric.loader import ClusterData
from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType


class Partitioner:
Expand Down Expand Up @@ -258,39 +251,3 @@ def generate_partition(self):
logging.info('Saving partition mapping info')
torch.save(node_map, osp.join(self.root, 'node_map.pt'))
torch.save(edge_map, osp.join(self.root, 'edge_map.pt'))


def load_partition_info(
root_dir: str,
partition_idx: int,
) -> Tuple[Dict, int, int, torch.Tensor, torch.Tensor]:

# load the partition with PyG format (graphstore/featurestore)
with open(os.path.join(root_dir, 'META.json'), 'rb') as infile:
meta = json.load(infile)
num_partitions = meta['num_parts']
assert partition_idx >= 0
assert partition_idx < num_partitions
partition_dir = os.path.join(root_dir, f'part_{partition_idx}')
assert os.path.exists(partition_dir)

if meta['is_hetero'] is False:
node_pb = torch.load(os.path.join(root_dir, 'node_map.pt'))
edge_pb = torch.load(os.path.join(root_dir, 'edge_map.pt'))

return (meta, num_partitions, partition_idx, node_pb, edge_pb)
else:
node_pb_dict = {}
node_pb_dir = os.path.join(root_dir, 'node_map')
for ntype in meta['node_types']:
node_pb_dict[ntype] = torch.load(
os.path.join(node_pb_dir, f'{as_str(ntype)}.pt'))

edge_pb_dict = {}
edge_pb_dir = os.path.join(root_dir, 'edge_map')
for etype in meta['edge_types']:
edge_pb_dict[tuple(etype)] = torch.load(
os.path.join(edge_pb_dir, f'{as_str(etype)}.pt'))

return (meta, num_partitions, partition_idx, node_pb_dict,
edge_pb_dict)

0 comments on commit 3334da9

Please sign in to comment.