-
Notifications
You must be signed in to change notification settings - Fork 3
/
batch.py
207 lines (186 loc) · 9.61 KB
/
batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import torch
import torch_geometric
from torch_geometric.data import Data
import pdb
# This is a copy from torch_geometric/data/batch.py
# which is modified to support batch asignment in subgraph level
class Batch(Data):
r"""A plain old python object modeling a batch of graphs as one big
(dicconnected) graph. With :class:`torch_geometric.data.Data` being the
base class, all its methods can also be used here.
In addition, single graphs can be reconstructed via the assignment vector
:obj:`batch`, which maps each node to its respective graph identifier.
"""
def __init__(self, batch=None, **kwargs):
super(Batch, self).__init__(**kwargs)
self.batch = batch
self.__data_class__ = Data
self.__slices__ = None
@staticmethod
def from_data_list(data_list, follow_batch=[]):
r"""Constructs a batch object from a python list holding
:class:`torch_geometric.data.Data` objects.
The assignment vector :obj:`batch` is created on the fly.
Additionally, creates assignment batch vectors for each key in
:obj:`follow_batch`."""
keys = [set(data.keys) for data in data_list]
keys = list(set.union(*keys))
assert 'batch' not in keys
batch = Batch()
batch.__data_class__ = data_list[0].__class__
batch.__slices__ = {key: [0] for key in keys}
for key in keys:
batch[key] = []
for key in follow_batch:
batch['{}_batch'.format(key)] = []
cumsum = {key: 0 for key in keys}
if 'assignment_index_2' in keys:
cumsum['assignment_index_2'] = torch.LongTensor([[0], [0]])
if 'assignment_index_3' in keys:
cumsum['assignment_index_3'] = torch.LongTensor([[0], [0]])
batch.batch = []
for i, data in enumerate(data_list):
for key in data.keys:
item = data[key]
if torch.is_tensor(item) and item.dtype != torch.bool:
item = item + cumsum[key]
if torch.is_tensor(item):
size = item.size(data.__cat_dim__(key, data[key]))
else:
size = 1
batch.__slices__[key].append(size + batch.__slices__[key][-1])
if key == 'node_to_subgraph':
cumsum[key] = cumsum[key] + data.num_subgraphs
elif key == 'subgraph_to_graph':
cumsum[key] = cumsum[key] + 1
elif key == 'original_edge_index':
cumsum[key] = cumsum[key] + data.num_subgraphs
elif key == 'tree_edge_index':
cumsum[key] = cumsum[key] + data.num_cliques
elif key == 'atom2clique_index':
cumsum[key] = cumsum[key] + torch.tensor([[data.num_atoms], [data.num_cliques]])
elif key == 'edge_index_2':
cumsum[key] = cumsum[key] + data.iso_type_2.shape[0]
elif key == 'edge_index_3':
cumsum[key] = cumsum[key] + data.iso_type_3.shape[0]
elif key == 'batch_2':
cumsum[key] = cumsum[key] + 1
elif key == 'batch_3':
cumsum[key] = cumsum[key] + 1
elif key == 'assignment2_to_subgraph':
cumsum[key] = cumsum[key] + data.num_subgraphs
elif key == 'assignment3_to_subgraph':
cumsum[key] = cumsum[key] + data.num_subgraphs
elif key == 'assignment_index_2':
cumsum[key] = cumsum[key] + torch.LongTensor([[data.num_nodes], [data.iso_type_2.shape[0]]])
elif key == 'assignment_index_3':
inc = data.iso_type_2.shape[0] if 'assignment_index_2' in data else data.num_nodes
cumsum[key] = cumsum[key] + torch.LongTensor([[inc], [data.iso_type_3.shape[0]]])
elif key == 'node_to_subgraph_node':
cumsum[key] = cumsum[key] + torch.max(data.node_to_subgraph_node) + 1
elif key == 'subgraph_node_to_subgraph':
cumsum[key] = cumsum[key] + data.num_subgraphs
elif key == 'subgraph_edge_index':
cumsum[key] = cumsum[key] + torch.max(data.node_to_subgraph_node) + 1
elif key == 'node_to_subgraph2':
cumsum[key] = cumsum[key] + data.node_to_subgraph2[-1] + 1
elif key == 'subgraph2_to_subgraph':
cumsum[key] = cumsum[key] + data.subgraph2_to_subgraph[-1] + 1
elif key == 'subgraph2_to_graph':
cumsum[key] = cumsum[key] + 1
elif key == 'center_idx':
cumsum[key] = cumsum[key] + data.num_nodes
elif key == 'node_to_original_node':
cumsum[key] = cumsum[key] + data.num_original_nodes
else:
cumsum[key] = cumsum[key] + data.__inc__(key, item)
batch[key].append(item)
if key in follow_batch:
item = torch.full((size, ), i, dtype=torch.long)
batch['{}_batch'.format(key)].append(item)
num_nodes = data.num_nodes
if num_nodes is not None:
item = torch.full((num_nodes, ), i, dtype=torch.long)
batch.batch.append(item)
if num_nodes is None:
batch.batch = None
for key in batch.keys:
item = batch[key][0]
if torch.is_tensor(item):
batch[key] = torch.cat(batch[key],
dim=data_list[0].__cat_dim__(key, item))
elif isinstance(item, int) or isinstance(item, float):
batch[key] = torch.tensor(batch[key])
# Copy custom data functions to batch (does not work yet):
# if data_list.__class__ != Data:
# org_funcs = set(Data.__dict__.keys())
# funcs = set(data_list[0].__class__.__dict__.keys())
# batch.__custom_funcs__ = funcs.difference(org_funcs)
# for func in funcs.difference(org_funcs):
# setattr(batch, func, getattr(data_list[0], func))
if torch_geometric.is_debug_enabled():
batch.debug()
return batch.contiguous()
def to_data_list(self):
r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects
from the batch object.
The batch object must have been created via :meth:`from_data_list` in
order to be able reconstruct the initial objects."""
if self.__slices__ is None:
raise RuntimeError(
('Cannot reconstruct data list from batch because the batch '
'object was not created using Batch.from_data_list()'))
keys = [key for key in self.keys if key[-5:] != 'batch']
cumsum = {key: 0 for key in keys}
if 'assignment_index_2' in keys:
cumsum['assignment_index_2'] = torch.LongTensor([[0], [0]])
if 'assignment_index_3' in keys:
cumsum['assignment_index_3'] = torch.LongTensor([[0], [0]])
data_list = []
for i in range(len(self.__slices__[keys[0]]) - 1):
data = self.__data_class__()
for key in keys:
if torch.is_tensor(self[key]):
data[key] = self[key].narrow(
data.__cat_dim__(key,
self[key]), self.__slices__[key][i],
self.__slices__[key][i + 1] - self.__slices__[key][i])
if self[key].dtype != torch.bool:
data[key] = data[key] - cumsum[key]
else:
data[key] = self[key][self.__slices__[key][i]:self.
__slices__[key][i + 1]]
if key == 'node_to_subgraph':
cumsum[key] = cumsum[key] + data.num_subgraphs
elif key == 'subgraph_to_graph':
cumsum[key] = cumsum[key] + 1
elif key == 'original_edge_index':
cumsum[key] = cumsum[key] + data.num_subgraphs
elif key == 'tree_edge_index':
cumsum[key] = cumsum[key] + data.num_cliques
elif key == 'atom2clique_index':
cumsum[key] = cumsum[key] + torch.tensor([[data.num_atoms], [data.num_cliques]])
elif key == 'edge_index_2':
cumsum[key] = cumsum[key] + data.iso_type_2.shape[0]
elif key == 'edge_index_3':
cumsum[key] = cumsum[key] + data.iso_type_3.shape[0]
elif key == 'batch_2':
cumsum[key] = cumsum[key] + 1
elif key == 'batch_3':
cumsum[key] = cumsum[key] + 1
elif key == 'assignment2_to_subgraph':
cumsum[key] = cumsum[key] + data.num_subgraphs
elif key == 'assignment3_to_subgraph':
cumsum[key] = cumsum[key] + data.num_subgraphs
elif key == 'assignment_index_2':
cumsum[key] = cumsum[key] + torch.LongTensor([[data.num_nodes], [data.iso_type_2.shape[0]]])
elif key == 'assignment_index_3':
cumsum[key] = cumsum[key] + torch.LongTensor([[data.iso_type_2.shape[0]], [data.iso_type_3.shape[0]]])
else:
cumsum[key] = cumsum[key] + data.__inc__(key, data[key])
data_list.append(data)
return data_list
@property
def num_graphs(self):
"""Returns the number of graphs in the batch."""
return self.batch[-1].item() + 1