-
Notifications
You must be signed in to change notification settings - Fork 0
/
beam_search.py
193 lines (161 loc) · 8.66 KB
/
beam_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Modifications Copyright 2017 Abigail See
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This file contains code to run beam search decoding"""
import tensorflow as tf
import numpy as np
import data
FLAGS = tf.app.flags.FLAGS
class Hypothesis(object):
"""Class to represent a hypothesis during beam search. Holds all the information needed for the hypothesis."""
def __init__(self, tokens, log_probs, state, attn_dists, p_gens, coverage):
"""Hypothesis constructor.
Args:
tokens: List of integers. The ids of the tokens that form the summary so far.
log_probs: List, same length as tokens, of floats, giving the log probabilities of the tokens so far.
state: Current state of the decoder, a LSTMStateTuple.
attn_dists: List, same length as tokens, of numpy arrays with shape (attn_length). These are the attention distributions so far.
p_gens: List, same length as tokens, of floats, or None if not using pointer-generator model. The values of the generation probability so far.
coverage: Numpy array of shape (attn_length), or None if not using coverage. The current coverage vector.
"""
self.tokens = tokens
self.log_probs = log_probs
self.state = state
self.attn_dists = attn_dists
self.p_gens = p_gens
self.coverage = coverage
def extend(self, token, log_prob, state, attn_dist, p_gen, coverage):
"""Return a NEW hypothesis, extended with the information from the latest step of beam search.
Args:
token: Integer. Latest token produced by beam search.
log_prob: Float. Log prob of the latest token.
state: Current decoder state, a LSTMStateTuple.
attn_dist: Attention distribution from latest step. Numpy array shape (attn_length).
p_gen: Generation probability on latest step. Float.
coverage: Latest coverage vector. Numpy array shape (attn_length), or None if not using coverage.
Returns:
New Hypothesis for next step.
"""
return Hypothesis(tokens = self.tokens + [token],
log_probs = self.log_probs + [log_prob],
state = state,
attn_dists = self.attn_dists + [attn_dist],
p_gens = self.p_gens + [p_gen],
coverage = coverage)
@property
def latest_token(self):
return self.tokens[-1]
@property
def log_prob(self):
# the log probability of the hypothesis so far is the sum of the log probabilities of the tokens so far
return sum(self.log_probs)
@property
def avg_log_prob(self):
# normalize log probability by number of tokens (otherwise longer sequences always have lower probability)
return self.log_prob / len(self.tokens)
def run_beam_search(sess, model, vocab, batch):
"""Performs beam search decoding on the given example.
Args:
sess: a tf.Session
model: a seq2seq model
vocab: Vocabulary object
batch: Batch object that is the same example repeated across the batch
Returns:
best_hyp: Hypothesis object; the best hypothesis found by beam search.
"""
# Run the encoder to get the encoder hidden states and decoder initial state
enc_states, dec_in_state, time_key, local_states, global_states, sen_states = model.run_encoder(sess, batch)
# dec_in_state is a LSTMStateTuple
# enc_states has shape [batch_size, <=max_enc_steps, 2*hidden_dim].
# Initialize beam_size-many hyptheses
hyps = [Hypothesis(tokens=[vocab.word2id(data.START_DECODING)],
log_probs=[0.0],
state=dec_in_state,
attn_dists=[],
p_gens=[],
coverage=np.zeros([batch.enc_batch.shape[1]]) # zero vector of length attention_length
) for _ in range(FLAGS.beam_size)]
results = [] # this will contain finished hypotheses (those that have emitted the [STOP] token)
steps = 0
while steps < FLAGS.max_dec_steps and len(results) < FLAGS.beam_size:
latest_tokens = [h.latest_token for h in hyps] # latest token produced by each hypothesis
latest_tokens = [t if t in range(vocab.size()) else vocab.word2id(data.UNKNOWN_TOKEN) for t in latest_tokens] # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings
states = [h.state for h in hyps] # list of current decoder states of the hypotheses
prev_coverage = [h.coverage for h in hyps] # list of coverage vectors (or None)
# Run one step of the decoder to get the new info
(topk_ids, topk_log_probs, new_states, attn_dists, p_gens, new_coverage) = model.decode_onestep(sess=sess,
batch=batch,
latest_tokens=latest_tokens,
enc_states=enc_states,
dec_init_states=states,
time_key=time_key,
local_states=local_states,
global_states=global_states,
prev_coverage=prev_coverage)
# if steps > 20:
# p_gens = np.ones_like(p_gens)
# Extend each hypothesis and collect them all in all_hyps
all_hyps = []
num_orig_hyps = 1 if steps == 0 else len(hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct.
for i in range(num_orig_hyps):
h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], new_coverage[i] # take the ith hypothesis and new decoder state info
for j in range(FLAGS.beam_size * 2): # for each of the top 2*beam_size hyps:
# Extend the ith hypothesis with the jth option
new_hyp = h.extend(token=topk_ids[i, j],
log_prob=topk_log_probs[i, j],
state=new_state,
attn_dist=attn_dist,
p_gen=p_gen,
coverage=new_coverage_i)
all_hyps.append(new_hyp)
# Filter and collect any hypotheses that have produced the end token.
hyps = [] # will contain hypotheses for the next step
for h in sort_hyps(all_hyps): # in order of most likely h
if h.latest_token == vocab.word2id(data.STOP_DECODING): # if stop token is reached...
# If this hypothesis is sufficiently long, put in results. Otherwise discard.
if steps >= FLAGS.min_dec_steps:
results.append(h)
else: # hasn't reached stop token, so continue to extend this hypothesis
hyps.append(h)
if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size:
# Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop.
break
steps += 1
ext_steps = 0
dec_state = [dec_in_state] * 4
cells = [np.expand_dims(state.c, axis=0) for state in dec_state]
hiddens = [np.expand_dims(state.h, axis=0) for state in dec_state]
new_c = np.concatenate(cells, axis=0) # shape [batch_size,hidden_dim]
new_h = np.concatenate(hiddens, axis=0) # shape [batch_size,hidden_dim]
dec_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
ext_sen = []
last_sen = np.zeros([FLAGS.batch_size, 24])
while ext_steps < FLAGS.max_ext_steps:
ext_input = np.expand_dims(last_sen, 1)
ext_ids, ext_new_state, sen_attn_dists = model.extract_sentence(sess, batch, sen_states, ext_input, dec_state)
dec_state = ext_new_state
ext_sen.append(ext_ids[0])
ext_steps += 1
last_sen = np.eye(24)[np.squeeze(np.argmax(np.expand_dims(sen_attn_dists[0], 1), -1))]
# At this point, either we've got beam_size results, or we've reached maximum decoder steps
if len(results)==0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results
results = hyps
# Sort hypotheses by average log probability
hyps_sorted = sort_hyps(results)
# Return the hypothesis with highest average log prob
return hyps_sorted[0], ext_sen
def sort_hyps(hyps):
"""Return a list of Hypothesis objects, sorted by descending average log probability"""
return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True)