-
Notifications
You must be signed in to change notification settings - Fork 0
/
char_decoder.py
126 lines (94 loc) · 5.97 KB
/
char_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CS224N 2018-19: Homework 5
"""
import torch
import torch.nn as nn
class CharDecoder(nn.Module):
def __init__(self, hidden_size, char_embedding_size=50, target_vocab=None):
""" Init Character Decoder.
@param hidden_size (int): Hidden size of the decoder LSTM
@param char_embedding_size (int): dimensionality of character embeddings
@param target_vocab (VocabEntry): vocabulary for the target language. See vocab.py for documentation.
"""
### YOUR CODE HERE for part 2a
### TODO - Initialize as an nn.Module.
### - Initialize the following variables:
### self.charDecoder: LSTM. Please use nn.LSTM() to construct this.
### self.char_output_projection: Linear layer, called W_{dec} and b_{dec} in the PDF
### self.decoderCharEmb: Embedding matrix of character embeddings
### self.target_vocab: vocabulary for the target language
###
### Hint: - Use target_vocab.char2id to access the character vocabulary for the target language.
### - Set the padding_idx argument of the embedding matrix.
### - Create a new Embedding layer. Do not reuse embeddings created in Part 1 of this assignment.
super(CharDecoder, self).__init__()
self.hidden_size = hidden_size
self.char_embedding_size = char_embedding_size
self.target_vocab = target_vocab
self.decoderCharEmb = nn.Embedding(len(self.target_vocab.char2id), self.char_embedding_size, padding_idx = self.target_vocab.char2id['<pad>'])
self.charDecoder = nn.LSTM(self.char_embedding_size, self.hidden_size)
self.char_output_projection = nn.Linear(self.hidden_size, len(self.target_vocab.char2id))
### END YOUR CODE
def forward(self, input, dec_hidden=None):
""" Forward pass of character decoder.
@param input: tensor of integers, shape (length, batch)
@param dec_hidden: internal state of the LSTM before reading the input characters. A tuple of two tensors of shape (1, batch, hidden_size)
@returns scores: called s_t in the PDF, shape (length, batch, self.vocab_size)
@returns dec_hidden: internal state of the LSTM after reading the input characters. A tuple of two tensors of shape (1, batch, hidden_size)
"""
### YOUR CODE HERE for part 2b
### TODO - Implement the forward pass of the character decoder.
char_embeddings = self.decoderCharEmb(input.permute(1,0)).permute(1,0,2)
hidden_states, dec_hidden = self.charDecoder(char_embeddings, dec_hidden)
scores = self.char_output_projection(hidden_states)
return scores, dec_hidden
def train_forward(self, char_sequence, dec_hidden=None):
""" Forward computation during training.
@param char_sequence: tensor of integers, shape (length, batch). Note that "length" here and in forward() need not be the same.
@param dec_hidden: initial internal state of the LSTM, obtained from the output of the word-level decoder. A tuple of two tensors of shape (1, batch, hidden_size)
@returns The cross-entropy loss, computed as the *sum* of cross-entropy losses of all the words in the batch.
"""
### YOUR CODE HERE for part 2c
### TODO - Implement training forward pass.
###
### Hint: - Make sure padding characters do not contribute to the cross-entropy loss.
### - char_sequence corresponds to the sequence x_1 ... x_{n+1} from the handout (e.g., <START>,m,u,s,i,c,<END>).
char_input = char_sequence[:-1,:]
scores, _ = self.forward(char_input, dec_hidden)
scores = scores.permute(1,2,0)
target = char_sequence[1:,:].permute(1,0)
loss = nn.functional.cross_entropy(scores, target, ignore_index=self.target_vocab.char2id['<pad>'], reduction='sum')
return loss
### END YOUR CODE
def decode_greedy(self, initialStates, device, max_length=21):
""" Greedy decoding
@param initialStates: initial internal state of the LSTM, a tuple of two tensors of size (1, batch, hidden_size)
@param device: torch.device (indicates whether the model is on CPU or GPU)
@param max_length: maximum length of words to decode
@returns decodedWords: a list (of length batch) of strings, each of which has length <= max_length.
The decoded strings should NOT contain the start-of-word and end-of-word characters.
"""
### YOUR CODE HERE for part 2d
### TODO - Implement greedy decoding.
### Hints:
### - Use target_vocab.char2id and target_vocab.id2char to convert between integers and characters
### - Use torch.tensor(..., device=device) to turn a list of character indices into a tensor.
### - We use curly brackets as start-of-word and end-of-word characters. That is, use the character '{' for <START> and '}' for <END>.
### Their indices are self.target_vocab.start_of_word and self.target_vocab.end_of_word, respectively.
_, bs, _ = initialStates[0].shape
output_words = [["", False] for _ in range(bs)]
current_chars = torch.ones([1, bs], dtype = torch.long, device = device) * self.target_vocab.start_of_word
for t in range(max_length):
scores, initialStates = self.forward(current_chars, initialStates)
current_chars = torch.argmax(scores, dim=-1)
for word_idx, char_idx in enumerate(current_chars.squeeze(0).detach()):
if output_words[word_idx][1] == False:
if char_idx == self.target_vocab.end_of_word:
output_words[word_idx][1] = True
else:
output_words[word_idx][0] += self.target_vocab.id2char[char_idx.item()]
output_words = [ word for word, _ in output_words]
return output_words
### END YOUR CODE