-
Notifications
You must be signed in to change notification settings - Fork 9
/
gnn.py
253 lines (213 loc) · 8.55 KB
/
gnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""For using a Graph Neural Network (GNN) to guide miniKanren"""
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from torch.autograd import Variable
import lisp
import helper
from helper import MLP, MergeMLP
from interact import Interaction
from gnn_grammar import (GNN_TOKENS, GNN_RELATIONS, GNN_NODES, GNN_IDS,
parse_split_trees)
class GNNModel(nn.Module):
"""Graph Neural Network model for guiding miniKanren, using PyTorch.
"""
def __init__(self,
embedding_size=64,
message_size=128,
msg_fn_layers=2,
merge_fn_extra_layers=2,
num_passes=1,
edge_embedding_size=32,
cuda=False):
super(GNNModel, self).__init__()
# set hyperparameters
self.num_passes = num_passes # number of up/down passes
self.embedding_size = embedding_size
self.message_size = message_size
self.edge_embedding_size = edge_embedding_size
# set known vocab embedding -- for now also the consts
self.embedding = nn.Embedding(len(GNN_TOKENS), embedding_size)
# leaf scoring function for outputting
self.score = MLP([embedding_size, 1])
# message functions for each class x child x direction
self.msg_fn_keys = [k for Class in GNN_NODES
for k in Class.msg_fn_keys()]
# edge embedding for each edge type
self.edge_embedding = nn.Embedding(len(self.msg_fn_keys),
self.edge_embedding_size)
# create mapping of msg fn keys -> index
self.msg_fn_dict = {}
for i, k in enumerate(self.msg_fn_keys):
self.msg_fn_dict[k] = Variable(torch.LongTensor([i]))
if cuda:
self.msg_fn_dict[k] = self.msg_fn_dict[k].cuda()
# create the message functions:
msg_fn_shape = [self.embedding_size + self.edge_embedding_size] + \
[self.message_size] * (msg_fn_layers - 1) +\
[self.message_size]
self.msg_fn_shared = MLP(msg_fn_shape)
# merge function for each class
self.merge_fn = {}
for Class in GNN_NODES:
if Class.nmerge > 0:
layers = [self.message_size * i
for i in range(Class.nmerge, 0, -1)] + \
[self.message_size] * merge_fn_extra_layers
self.merge_fn[Class.name] = MergeMLP(layers)
self.lvar_epsilon = torch.nn.Parameter(torch.FloatTensor([-10.0]))
# gru for each class
self.gru = {
Class.name : nn.GRUCell(
input_size=self.message_size,
hidden_size=self.embedding_size,
bias=True)
for Class in GNN_NODES
}
self.lvar_epsilon = torch.nn.Parameter(torch.FloatTensor([-10.0]))
# add modules in msgfn, mergefn, gru manually
for k, module in self.gru.items():
self.add_module("gru_%s" % k, module)
for k, module in self.merge_fn.items():
self.add_module("merge_%s" % k, module)
self._cuda = cuda
if self._cuda:
self.cuda()
# these functions below are compatible with a modified version
# of pytorch fold for training
def init_const(self, ids):
"""Initialize a constant embedding by looking up the ids."""
if self._cuda:
ids = ids.cuda()
emb = self.embedding(ids)
return emb
def init_lvar(self, lvars):
"""Initialize a logic variable embedding by looking up the ids."""
n = lvars.size()[0]
id = Variable(torch.LongTensor([GNN_IDS["lvar"]]))
if self._cuda:
id = id.cuda()
emb = self.embedding(id).repeat(n, 1)
return emb
def get_message(self, key, emb):
"""Get messages -- this might no longer work with pytorch fold
after simplification...
"""
k_index = self.msg_fn_dict[key]
edge_emb = self.edge_embedding(k_index)
edge_emb = edge_emb.repeat(emb.size(0), 1)
merged_input = torch.cat([emb, edge_emb], 1) # for batching
return self.msg_fn_shared(merged_input )
def get_merge(self, key, *msgs):
"""Merge messages together
key -- the node type
msgs -- messages (tensors) from other nodes
"""
return self.merge_fn[key](*msgs)
def get_gru(self, key, msg, old):
"""Apply GRU to combine messages
key -- the node type
msg -- message (tensors) to be added to the node
old -- current embedding of the node
"""
return self.gru[key](msg, old)
def get_merge_cat(self, nn, *msgs): # n = len(msgs)
"""Merge messages together by concatenation
nn -- length of messages
msgs -- messages (tensors) from other nodes
"""
messages = torch.stack(msgs, 0)
merged = torch.mean(messages, 0)
return merged
def get_logits(self, n, *embs): # n = len(embs)
"""Score the embeddings to produce logits
n -- number of embeddings
embs -- embeddings (tensors) to score
"""
embs_cat = torch.stack(embs, 0)
leaf_logit = self.score(embs_cat)
leaf_logit = leaf_logit.transpose(1,0).squeeze(-1)
return leaf_logit
def get_cat(self, n, *embs): # n = len(embs)
"""Concatenate embeddings
n -- number of embeddings
embs -- embeddings (tensors) to score
"""
return torch.stack(embs, 1).squeeze(-1)
def get_combine_min(self, n, *logits): # n = len(embs)
"""Apply min pooling to logits
n -- number of embeddings
logits -- logit scores to pool together
"""
logits = torch.cat(logits, 1)
return torch.min(logits, dim=1, keepdim=True)[0]
def get_combine_max(self, n, *logits): # n = len(embs)
"""Apply max pooling to logits
n -- number of embeddings
logits -- logit scores to pool together
"""
logits = torch.cat(logits, 1)
return torch.max(logits, dim=1, keepdim=True)[0]
def get_combine_mean(self, n, *logits): # n = len(embs)
"""Apply average pooling to logits
n -- number of embeddings
logits -- logit scores to pool together
"""
logits = torch.cat(logits, 1)
return torch.mean(logits, dim=1, keepdim=True)
def gnn_forward(asts, model, num_passes=1, test_acc=None):
"""
Forward pass for Graph Neural Network, set up so that pytorch
fold can be used for dynamic batching
"""
# reset
for ast, acc in asts:
ast.reset(model)
ast.annotate()
# first upward pass
for ast, acc in asts: # the different conj in the disj
for leaf in acc['constraints']: # constraint in conj
leaf.up_first(model)
for passes in range(num_passes):
# downward
for ast, acc in asts: # the different conj in the disj
for leaf in acc['constraints']: # constraint in conj
leaf.down(model)
# update logic variables
for lvar in acc['lvar'].values():
lvar.actually_update(model)
for leaf in acc['constraints']: # constraint in conj
leaf.up(model)
# read out the logit
out = []
for ast, acc in asts: # conj
leaf_logit = ast.logit(model)
out.append(leaf_logit)
out = model.get_cat(len(out), *out)
return out
def test_forward():
"""Simple example using interaction, parsing, and sample call to GNN.forward()
where we still take the ground truth path at each step."""
problem = "(q-transform/hint (quote (lambda (cdr (cdr (var ()))))) (quote ((() y . 1) (#f y () . #t) (#f b () b . y) (x #f (#f . #f) . #t) (a #f y x s . a))))"
step = 0
model = GNNModel() # small model, randomly initialized
print("Starting problem:", problem)
with Interaction(lisp.parse(problem)) as env:
signal = None
while signal != "solved":
# parse & score
acc = {'constraints': []}
parsed_subtree = parse_split_trees(env.state)
out = gnn_forward(parsed_subtree, model)
prob = F.softmax(out, 1)
print(prob)
# ignore the score and take the ground-truth step
signal = env.follow_path(env.good_path)
step += 1
print('Step', step, 'Signal:', signal)
print("Completed.")
if __name__ == '__main__':
test_forward()