From 3334da974041bf82be49313bf87394816059ffa0 Mon Sep 17 00:00:00 2001 From: JakubPietrakIntel Date: Mon, 2 Oct 2023 17:57:21 +0000 Subject: [PATCH] revert changes in `parition.py` moved to #8083 --- torch_geometric/distributed/partition.py | 47 +----------------------- 1 file changed, 2 insertions(+), 45 deletions(-) diff --git a/torch_geometric/distributed/partition.py b/torch_geometric/distributed/partition.py index ac4c38a8f9c32..569686a62d710 100644 --- a/torch_geometric/distributed/partition.py +++ b/torch_geometric/distributed/partition.py @@ -7,15 +7,8 @@ import torch from torch_geometric.data import Data, HeteroData -from torch_geometric.loader.cluster import ClusterData -from torch_geometric.typing import ( - Dict, - EdgeType, - EdgeTypeStr, - NodeType, - Tuple, - as_str, -) +from torch_geometric.loader import ClusterData +from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType class Partitioner: @@ -258,39 +251,3 @@ def generate_partition(self): logging.info('Saving partition mapping info') torch.save(node_map, osp.join(self.root, 'node_map.pt')) torch.save(edge_map, osp.join(self.root, 'edge_map.pt')) - - -def load_partition_info( - root_dir: str, - partition_idx: int, -) -> Tuple[Dict, int, int, torch.Tensor, torch.Tensor]: - - # load the partition with PyG format (graphstore/featurestore) - with open(os.path.join(root_dir, 'META.json'), 'rb') as infile: - meta = json.load(infile) - num_partitions = meta['num_parts'] - assert partition_idx >= 0 - assert partition_idx < num_partitions - partition_dir = os.path.join(root_dir, f'part_{partition_idx}') - assert os.path.exists(partition_dir) - - if meta['is_hetero'] is False: - node_pb = torch.load(os.path.join(root_dir, 'node_map.pt')) - edge_pb = torch.load(os.path.join(root_dir, 'edge_map.pt')) - - return (meta, num_partitions, partition_idx, node_pb, edge_pb) - else: - node_pb_dict = {} - node_pb_dir = os.path.join(root_dir, 'node_map') - for ntype in meta['node_types']: - node_pb_dict[ntype] = torch.load( - os.path.join(node_pb_dir, f'{as_str(ntype)}.pt')) - - edge_pb_dict = {} - edge_pb_dir = os.path.join(root_dir, 'edge_map') - for etype in meta['edge_types']: - edge_pb_dict[tuple(etype)] = torch.load( - os.path.join(edge_pb_dir, f'{as_str(etype)}.pt')) - - return (meta, num_partitions, partition_idx, node_pb_dict, - edge_pb_dict)