-
Notifications
You must be signed in to change notification settings - Fork 13
/
beam_search.py
71 lines (59 loc) · 2.26 KB
/
beam_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
'''
Performs a beam search with given parameters.
Takes as input a
interface: a BeamSearchInterface object
params: a set of parameters
n: the number of things to return
k: the number of symbols to expand at each step.
'''
''' the following actions are needed for instance:
instance.next_child(self, must_be_better_than): returns the next child and presteps
instance.step(): prep for the next step.
instance.finalize(): returns the corresponding trees
instance.__cmp__(other): compares the current state of the two things
instance.value: the current value of the state
instance.complete: whether the state is finished
'''
import heapq
class BeamSearcher:
# an object that allows us to perform beam searches to fetch additional
# searches (although I don't think we'll ever need that functionality)
def __init__(self, interface, params):
self.instance = interface.instance(params)
def best(self, width, k, num_out):
# performs the simple nonextensible beam search
assert width>0
assert k>0
assert num_out>0
out = []
heap = [self.instance]
while len(out) < num_out and len(heap)>0:
# print heap
old_heap = [heapq.heappop(heap) for i in range(len(heap))]
old_heap.reverse()
heap = []
for old in old_heap:
#print 'old', old
# if len(heap)>0:print heap[0].value
for _ in range(k): # at most k times.
if len(heap)>=width:
child = old.next_child(heap[0])
#print 'child', child
if child is None: break
heapq.heappushpop(heap, child)
else:
child = old.next_child(None)
#print 'child', child
if child is None: break
heapq.heappush(heap, child)
# now step everything
for h in heap:
assert not h.complete
h.step()
if h.complete:
out.append(h)
heap = [x for x in heap if not x.complete]
#print heap
out.sort()
out.reverse()
return [x.finalize() for x in out]