-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearch.py
682 lines (562 loc) · 30.7 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
import torch
from typing import List, Callable, Tuple, Dict, cast
import warnings
StateType = Dict[str, torch.Tensor]
StepFunctionType = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
class BeamSearch:
"""
Implements the beam search algorithm for decoding the most likely sequences.
[0]: https://arxiv.org/abs/1702.01806
# Parameters
end_index : `int`
The index of the "stop" or "end" token in the target vocabulary.
max_steps : `int`, optional (default = `50`)
The maximum number of decoding steps to take, i.e. the maximum length
of the predicted sequences.
beam_size : `int`, optional (default = `10`)
The width of the beam used.
per_node_beam_size : `int`, optional (default = `beam_size`)
The maximum number of candidates to consider per node, at each step in the search.
If not given, this just defaults to `beam_size`. Setting this parameter
to a number smaller than `beam_size` may give better results, as it can introduce
more diversity into the search. See [Beam Search Strategies for Neural Machine Translation.
Freitag and Al-Onaizan, 2017][0].
"""
def __init__(
self,
end_index: int,
max_steps: int = 50,
beam_size: int = 10,
per_node_beam_size: int = None,
) -> None:
self.end_index = end_index
self.max_steps = max_steps
self.beam_size = beam_size
self.per_node_beam_size = per_node_beam_size or beam_size
@staticmethod
def reconstruct_sequences(predictions, backpointers):
# Reconstruct the sequences.
# shape: [(batch_size, beam_size, 1)]
reconstructed_predictions = [predictions[-1].unsqueeze(2)]
# shape: (batch_size, beam_size)
cur_backpointers = backpointers[-1]
for timestep in range(len(predictions) - 2, 0, -1):
# shape: (batch_size, beam_size, 1)
cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
reconstructed_predictions.append(cur_preds)
# shape: (batch_size, beam_size)
cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
# shape: (batch_size, beam_size, 1)
final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
reconstructed_predictions.append(final_preds)
return reconstructed_predictions
@torch.no_grad()
def search(
self, start_predictions: torch.Tensor, start_state: StateType, step: StepFunctionType
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given a starting state and a step function, apply beam search to find the
most likely target sequences.
Notes
-----
If your step function returns `-inf` for some log probabilities
(like if you're using a masked log-softmax) then some of the "best"
sequences returned may also have `-inf` log probability. Specifically
this happens when the beam size is smaller than the number of actions
with finite log probability (non-zero probability) returned by the step function.
Therefore if you're using a mask you may want to check the results from `search`
and potentially discard sequences with non-finite log probability.
# Parameters
start_predictions : `torch.Tensor`
A tensor containing the initial predictions with shape `(batch_size,)`.
Usually the initial predictions are just the index of the "start" token
in the target vocabulary.
start_state : `StateType`
The initial state passed to the `step` function. Each value of the state dict
should be a tensor of shape `(batch_size, *)`, where `*` means any other
number of dimensions.
step : `StepFunctionType`
A function that is responsible for computing the next most likely tokens,
given the current state and the predictions from the last time step.
The function should accept two arguments. The first being a tensor
of shape `(group_size,)`, representing the index of the predicted
tokens from the last time step, and the second being the current state.
The `group_size` will be `batch_size * beam_size`, except in the initial
step, for which it will just be `batch_size`.
The function is expected to return a tuple, where the first element
is a tensor of shape `(group_size, target_vocab_size)` containing
the log probabilities of the tokens for the next step, and the second
element is the updated state. The tensor in the state should have shape
`(group_size, *)`, where `*` means any other number of dimensions.
# Returns
`Tuple[torch.Tensor, torch.Tensor]`
Tuple of `(predictions, log_probabilities)`, where `predictions`
has shape `(batch_size, beam_size, max_steps)` and `log_probabilities`
has shape `(batch_size, beam_size)`.
"""
batch_size = start_predictions.size()[0]
# List of (batch_size, beam_size) tensors. One for each time step. Does not
# include the start symbols, which are implicit.
predictions: List[torch.Tensor] = []
# List of (batch_size, beam_size) tensors. One for each time step. None for
# the first. Stores the index n for the parent prediction, i.e.
# predictions[t-1][i][n], that it came from.
backpointers: List[torch.Tensor] = []
# Calculate the first timestep. This is done outside the main loop
# because we are going from a single decoder input (the output from the
# encoder) to the top `beam_size` decoder outputs. On the other hand,
# within the main loop we are going from the `beam_size` elements of the
# beam to `beam_size`^2 candidates from which we will select the top
# `beam_size` elements for the next iteration.
# shape: (batch_size, num_classes)
start_class_log_probabilities, state = step(start_predictions, start_state, 0)
num_classes = start_class_log_probabilities.size()[1]
# shape: (batch_size, beam_size), (batch_size, beam_size)
start_top_log_probabilities, start_predicted_classes = start_class_log_probabilities.topk(
self.beam_size
)
if self.beam_size == 1 and (start_predicted_classes == self.end_index).all():
warnings.warn(
"Empty sequences predicted. You may want to increase the beam size or ensure "
"your step function is working properly.",
RuntimeWarning,
)
return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
# The log probabilities for the last time step.
# shape: (batch_size, beam_size)
last_log_probabilities = start_top_log_probabilities
# shape: [(batch_size, beam_size)]
predictions.append(start_predicted_classes)
# Log probability tensor that mandates that the end token is selected.
# shape: (batch_size * beam_size, num_classes)
log_probs_after_end = start_class_log_probabilities.new_full(
(batch_size * self.beam_size, num_classes), float("-inf")
)
log_probs_after_end[:, self.end_index] = 0.0
# Set the same state for each element in the beam.
for key, state_tensor in state.items():
if state_tensor is None:
continue
_, *last_dims = state_tensor.size()
# shape: (batch_size * beam_size, *)
state[key] = (
state_tensor.unsqueeze(1)
.expand(batch_size, self.beam_size, *last_dims)
.reshape(batch_size * self.beam_size, *last_dims)
)
for timestep in range(self.max_steps - 1):
# shape: (batch_size * beam_size,)
last_predictions = predictions[-1].reshape(batch_size * self.beam_size)
# If every predicted token from the last step is `self.end_index`,
# then we can stop early.
if (last_predictions == self.end_index).all():
break
# Take a step. This get the predicted log probs of the next classes
# and updates the state.
# shape: (batch_size * beam_size, num_classes)
class_log_probabilities, state = step(last_predictions, state, timestep + 1)
# shape: (batch_size * beam_size, num_classes)
last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
batch_size * self.beam_size, num_classes
)
# Here we are finding any beams where we predicted the end token in
# the previous timestep and replacing the distribution with a
# one-hot distribution, forcing the beam to predict the end token
# this timestep as well.
# shape: (batch_size * beam_size, num_classes)
cleaned_log_probabilities = torch.where(
last_predictions_expanded == self.end_index,
log_probs_after_end,
class_log_probabilities,
)
# shape (both): (batch_size * beam_size, per_node_beam_size)
top_log_probabilities, predicted_classes = cleaned_log_probabilities.topk(
self.per_node_beam_size
)
# Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
# so that we can add them to the current log probs for this timestep.
# This lets us maintain the log probability of each element on the beam.
# shape: (batch_size * beam_size, per_node_beam_size)
expanded_last_log_probabilities = (
last_log_probabilities.unsqueeze(2)
.expand(batch_size, self.beam_size, self.per_node_beam_size)
.reshape(batch_size * self.beam_size, self.per_node_beam_size)
)
# shape: (batch_size * beam_size, per_node_beam_size)
summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
# shape: (batch_size, beam_size * per_node_beam_size)
reshaped_summed = summed_top_log_probabilities.reshape(
batch_size, self.beam_size * self.per_node_beam_size
)
# shape: (batch_size, beam_size * per_node_beam_size)
reshaped_predicted_classes = predicted_classes.reshape(
batch_size, self.beam_size * self.per_node_beam_size
)
# Keep only the top `beam_size` beam indices.
# shape: (batch_size, beam_size), (batch_size, beam_size)
restricted_beam_log_probs, restricted_beam_indices = reshaped_summed.topk(
self.beam_size
)
# Use the beam indices to extract the corresponding classes.
# shape: (batch_size, beam_size)
restricted_predicted_classes = reshaped_predicted_classes.gather(
1, restricted_beam_indices
)
predictions.append(restricted_predicted_classes)
# shape: (batch_size, beam_size)
last_log_probabilities = restricted_beam_log_probs
# The beam indices come from a `beam_size * per_node_beam_size` dimension where the
# indices with a common ancestor are grouped together. Hence
# dividing by per_node_beam_size gives the ancestor. (Note that this is integer
# division as the tensor is a LongTensor.)
# shape: (batch_size, beam_size)
backpointer = restricted_beam_indices // self.per_node_beam_size
backpointers.append(backpointer)
# Keep only the pieces of the state tensors corresponding to the
# ancestors created this iteration.
for key, state_tensor in state.items():
if state_tensor is None:
continue
_, *last_dims = state_tensor.size()
# shape: (batch_size, beam_size, *)
expanded_backpointer = backpointer.view(
batch_size, self.beam_size, *([1] * len(last_dims))
).expand(batch_size, self.beam_size, *last_dims)
# shape: (batch_size * beam_size, *)
state[key] = (
state_tensor.reshape(batch_size, self.beam_size, *last_dims)
.gather(1, expanded_backpointer)
.reshape(batch_size * self.beam_size, *last_dims)
)
if not torch.isfinite(last_log_probabilities).all():
warnings.warn(
"Infinite log probabilities encountered. Some final sequences may not make sense. "
"This can happen when the beam size is larger than the number of valid (non-zero "
"probability) transitions that the step function produces.",
RuntimeWarning,
)
reconstructed_predictions = self.reconstruct_sequences(predictions, backpointers)
# shape: (batch_size, beam_size, max_steps)
all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
return all_predictions, last_log_probabilities
class DiverseBeamSearch:
"""
Implements the diverse beam search algorithm.
[0]: https://arxiv.org/abs/1611.08562
# Parameters
end_index : `int`
The index of the "stop" or "end" token in the target vocabulary.
max_steps : `int`, optional (default = `50`)
The maximum number of decoding steps to take, i.e. the maximum length
of the predicted sequences.
beam_size : `int`, optional (default = `10`)
The width of the beam used.
per_node_beam_size : `int`, optional (default = `beam_size`)
The maximum number of candidates to consider per node, at each step in the search.
If not given, this just defaults to `beam_size`. Setting this parameter
to a number smaller than `beam_size` may give better results, as it can introduce
more diversity into the search. See [Beam Search Strategies for Neural Machine Translation.
Freitag and Al-Onaizan, 2017][0].
gamma: `float`, optional (default = `1.`)
The diversity rate
"""
def __init__(
self,
end_index: int,
max_steps: int = 50,
beam_size: int = 10,
per_node_beam_size: int = None,
gamma: float = 0.1,
) -> None:
self.end_index = end_index
self.max_steps = max_steps
self.beam_size = beam_size
self.per_node_beam_size = per_node_beam_size or beam_size
self.gamma = gamma
@staticmethod
def reconstruct_sequences(predictions, backpointers):
# Reconstruct the sequences.
# shape: [(batch_size, beam_size, 1)]
reconstructed_predictions = [predictions[-1].unsqueeze(2)]
# shape: (batch_size, beam_size)
cur_backpointers = backpointers[-1]
for timestep in range(len(predictions) - 2, 0, -1):
# shape: (batch_size, beam_size, 1)
cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
reconstructed_predictions.append(cur_preds)
# shape: (batch_size, beam_size)
cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
# shape: (batch_size, beam_size, 1)
final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
reconstructed_predictions.append(final_preds)
return reconstructed_predictions
@torch.no_grad()
def search(
self, start_predictions: torch.Tensor, start_state: StateType, step: StepFunctionType
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = start_predictions.size()[0]
# List of (batch_size, beam_size) tensors. One for each time step. Does not
# include the start symbols, which are implicit.
predictions: List[torch.Tensor] = []
# List of (batch_size, beam_size) tensors. One for each time step. None for
# the first. Stores the index n for the parent prediction, i.e.
# predictions[t-1][i][n], that it came from.
backpointers: List[torch.Tensor] = []
# Calculate the first timestep. This is done outside the main loop
# because we are going from a single decoder input (the output from the
# encoder) to the top `beam_size` decoder outputs. On the other hand,
# within the main loop we are going from the `beam_size` elements of the
# beam to `beam_size`^2 candidates from which we will select the top
# `beam_size` elements for the next iteration.
# shape: (batch_size, num_classes)
start_class_log_probabilities, state = step(start_predictions, start_state, 0)
num_classes = start_class_log_probabilities.size()[1]
# shape: (batch_size, beam_size), (batch_size, beam_size)
start_top_log_probabilities, start_predicted_classes = start_class_log_probabilities.topk(
self.beam_size
)
if self.beam_size == 1 and (start_predicted_classes == self.end_index).all():
warnings.warn(
"Empty sequences predicted. You may want to increase the beam size or ensure "
"your step function is working properly.",
RuntimeWarning,
)
return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
# The log probabilities for the last time step.
# shape: (batch_size, beam_size)
last_log_probabilities = start_top_log_probabilities
# shape: [(batch_size, beam_size)]
predictions.append(start_predicted_classes)
# Log probability tensor that mandates that the end token is selected.
# shape: (batch_size * beam_size, num_classes)
log_probs_after_end = start_class_log_probabilities.new_full(
(batch_size * self.beam_size, num_classes), float("-inf")
)
log_probs_after_end[:, self.end_index] = 0.0
# Set the same state for each element in the beam.
for key, state_tensor in state.items():
if state_tensor is None:
continue
_, *last_dims = state_tensor.size()
# shape: (batch_size * beam_size, *)
state[key] = (
state_tensor.unsqueeze(1)
.expand(batch_size, self.beam_size, *last_dims)
.reshape(batch_size * self.beam_size, *last_dims)
)
for timestep in range(self.max_steps - 1):
# shape: (batch_size * beam_size,)
last_predictions = predictions[-1].reshape(batch_size * self.beam_size)
# If every predicted token from the last step is `self.end_index`,
# then we can stop early.
if (last_predictions == self.end_index).all():
break
# Take a step. This get the predicted log probs of the next classes
# and updates the state.
# shape: (batch_size * beam_size, num_classes)
class_log_probabilities, state = step(last_predictions, state, timestep + 1)
# shape: (batch_size * beam_size, num_classes)
last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
batch_size * self.beam_size, num_classes
)
# Here we are finding any beams where we predicted the end token in
# the previous timestep and replacing the distribution with a
# one-hot distribution, forcing the beam to predict the end token
# this timestep as well.
# shape: (batch_size * beam_size, num_classes)
cleaned_log_probabilities = torch.where(
last_predictions_expanded == self.end_index,
log_probs_after_end,
class_log_probabilities,
)
# shape (both): (batch_size * beam_size, per_node_beam_size)
top_log_probabilities, predicted_classes = cleaned_log_probabilities.topk(
self.per_node_beam_size
)
# Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
# so that we can add them to the current log probs for this timestep.
# This lets us maintain the log probability of each element on the beam.
# shape: (batch_size * beam_size, per_node_beam_size)
expanded_last_log_probabilities = (
last_log_probabilities.unsqueeze(2)
.expand(batch_size, self.beam_size, self.per_node_beam_size)
.reshape(batch_size * self.beam_size, self.per_node_beam_size)
)
# shape: (batch_size * beam_size, per_node_beam_size)
summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
# Rewrite scores for diverse decoding.
# shape: (batch_size * beam_size, per_node_beam_size)
ranking = torch.arange(self.per_node_beam_size).to(summed_top_log_probabilities)
rewritten_summed_top_log_probabilities = summed_top_log_probabilities - self.gamma * ranking[None, :]
# shape: (batch_size, beam_size * per_node_beam_size)
reshaped_summed = summed_top_log_probabilities.reshape(
batch_size, self.beam_size * self.per_node_beam_size
)
# shape: (batch_size, beam_size * per_node_beam_size)
reshaped_rewritten_summed = rewritten_summed_top_log_probabilities.reshape(
batch_size, self.beam_size * self.per_node_beam_size
)
# shape: (batch_size, beam_size * per_node_beam_size)
reshaped_predicted_classes = predicted_classes.reshape(
batch_size, self.beam_size * self.per_node_beam_size
)
# Keep only the top `beam_size` beam indices.
# shape: (batch_size, beam_size), (batch_size, beam_size)
_, restricted_beam_indices = reshaped_rewritten_summed.topk(
self.beam_size
)
restricted_beam_log_probs = reshaped_summed.gather(
1, restricted_beam_indices
)
# Use the beam indices to extract the corresponding classes.
# shape: (batch_size, beam_size)
restricted_predicted_classes = reshaped_predicted_classes.gather(
1, restricted_beam_indices
)
predictions.append(restricted_predicted_classes)
# shape: (batch_size, beam_size)
last_log_probabilities = restricted_beam_log_probs
# The beam indices come from a `beam_size * per_node_beam_size` dimension where the
# indices with a common ancestor are grouped together. Hence
# dividing by per_node_beam_size gives the ancestor. (Note that this is integer
# division as the tensor is a LongTensor.)
# shape: (batch_size, beam_size)
backpointer = restricted_beam_indices // self.per_node_beam_size
backpointers.append(backpointer)
# Keep only the pieces of the state tensors corresponding to the
# ancestors created this iteration.
for key, state_tensor in state.items():
if state_tensor is None:
continue
_, *last_dims = state_tensor.size()
# shape: (batch_size, beam_size, *)
expanded_backpointer = backpointer.view(
batch_size, self.beam_size, *([1] * len(last_dims))
).expand(batch_size, self.beam_size, *last_dims)
# shape: (batch_size * beam_size, *)
state[key] = (
state_tensor.reshape(batch_size, self.beam_size, *last_dims)
.gather(1, expanded_backpointer)
.reshape(batch_size * self.beam_size, *last_dims)
)
if not torch.isfinite(last_log_probabilities).all():
warnings.warn(
"Infinite log probabilities encountered. Some final sequences may not make sense. "
"This can happen when the beam size is larger than the number of valid (non-zero "
"probability) transitions that the step function produces.",
RuntimeWarning,
)
reconstructed_predictions = self.reconstruct_sequences(predictions, backpointers)
# shape: (batch_size, beam_size, max_steps)
all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
return all_predictions, last_log_probabilities
class RandomSample:
def __init__(
self,
end_index: int,
max_steps: int = 50,
sample_size: int = 10,
k: int = 10,
) -> None:
self.end_index = end_index
self.max_steps = max_steps
self.sample_size = sample_size
self.k = k
@torch.no_grad()
def search(
self, start_predictions: torch.Tensor, start_state: StateType, step: StepFunctionType
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = start_predictions.size()[0]
# List of (batch_size, sample_size) tensors. One for each time step. Does not
# include the start symbols, which are implicit.
predictions: List[torch.Tensor] = []
# Calculate the first timestep. This is done outside the main loop
# because we are going from a single decoder input (the output from the
# encoder) to the top `sample_size` decoder outputs. On the other hand,
# within the main loop we are going from the `sample_size` elements of the
# beam to `sample_size`^2 candidates from which we will select the top
# `sample_size` elements for the next iteration.
# shape: (batch_size, num_classes)
start_class_log_probabilities, state = step(start_predictions, start_state, 0)
num_classes = start_class_log_probabilities.size()[1]
# shape: (batch_size, sample_size)
truncated_start_class_log_probabilities, truncated_classes = torch.topk(start_class_log_probabilities, self.k)
truncated_start_predicted_classes = torch.multinomial(torch.exp(truncated_start_class_log_probabilities), self.sample_size, replacement=True)
start_predicted_classes = torch.gather(truncated_classes, dim=1, index=truncated_start_predicted_classes)
# shape: (batch_size, sample_size)
start_top_log_probabilities = torch.gather(truncated_start_class_log_probabilities, dim=1, index=truncated_start_predicted_classes)
if self.sample_size == 1 and (start_predicted_classes == self.end_index).all():
warnings.warn(
"Empty sequences predicted. You may want to increase the beam size or ensure "
"your step function is working properly.",
RuntimeWarning,
)
return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
# The log probabilities for the last time step.
# shape: (batch_size, sample_size)
last_log_probabilities = start_top_log_probabilities
# shape: [(batch_size, sample_size)]
predictions.append(start_predicted_classes)
# Log probability tensor that mandates that the end token is selected.
# shape: (batch_size * sample_size, num_classes)
log_probs_after_end = start_class_log_probabilities.new_full(
(batch_size * self.sample_size, num_classes), float("-inf")
)
log_probs_after_end[:, self.end_index] = 0.0
# Set the same state for each element in the beam.
for key, state_tensor in state.items():
if state_tensor is None:
continue
_, *last_dims = state_tensor.size()
# shape: (batch_size * sample_size, *)
state[key] = (
state_tensor.unsqueeze(1)
.expand(batch_size, self.sample_size, *last_dims)
.reshape(batch_size * self.sample_size, *last_dims)
)
for timestep in range(self.max_steps - 1):
# shape: (batch_size * sample_size,)
last_predictions = predictions[-1].reshape(batch_size * self.sample_size)
# If every predicted token from the last step is `self.end_index`,
# then we can stop early.
if (last_predictions == self.end_index).all():
break
# Take a step. This get the predicted log probs of the next classes
# and updates the state.
# shape: (batch_size * sample_size, num_classes)
class_log_probabilities, state = step(last_predictions, state, timestep + 1)
# shape: (batch_size * sample_size, num_classes)
last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
batch_size * self.sample_size, num_classes
)
# Here we are finding any samples where we predicted the end token in
# the previous timestep and replacing the distribution with a
# one-hot distribution, forcing the beam to predict the end token
# this timestep as well.
# shape: (batch_size * sample_size, num_classes)
cleaned_log_probabilities = torch.where(
last_predictions_expanded == self.end_index,
log_probs_after_end,
class_log_probabilities,
)
# shape (both): (batch_size * sample_size, 1)
# shape: (batch_size, sample_size)
truncated_cleaned_log_probabilities, truncated_classes = torch.topk(cleaned_log_probabilities, self.k)
truncated_predicted_classes = torch.multinomial(torch.exp(truncated_cleaned_log_probabilities), 1, replacement=True)
predicted_classes = torch.gather(truncated_classes, dim=1, index=truncated_predicted_classes)
# shape: (batch_size, sample_size)
top_log_probabilities = torch.gather(truncated_cleaned_log_probabilities, dim=1, index=truncated_predicted_classes)
predicted_classes = predicted_classes.reshape(batch_size, self.sample_size)
top_log_probabilities = top_log_probabilities.reshape(batch_size, self.sample_size)
predictions.append(predicted_classes)
last_log_probabilities = top_log_probabilities + last_log_probabilities
if not torch.isfinite(last_log_probabilities).all():
warnings.warn(
"Infinite log probabilities encountered. Some final sequences may not make sense. "
"This can happen when the beam size is larger than the number of valid (non-zero "
"probability) transitions that the step function produces.",
RuntimeWarning,
)
all_predictions = torch.stack(predictions, 2)
return all_predictions, last_log_probabilities