-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_embeddings.py
63 lines (50 loc) · 2.03 KB
/
model_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CS224N 2018-19: Homework 4
model_embeddings.py: Embeddings for the NMT model
Pencheng Yin <pcyin@cs.cmu.edu>
Sahil Chopra <schopra8@stanford.edu>
Anand Dhoot <anandd@stanford.edu>
"""
import torch.nn as nn
class ModelEmbeddings(nn.Module):
"""
Class that converts input words to their embeddings.
"""
def __init__(self, embed_size, vocab):
"""
Init the Embedding layers.
@param embed_size (int): Embedding size (dimensionality)
@param vocab (Vocab): Vocabulary object containing src and tgt languages
See vocab.py for documentation.
"""
super(ModelEmbeddings, self).__init__()
self.embed_size = embed_size
# default values
self.source = None
self.target = None
src_pad_token_idx = vocab.src['<pad>']
tgt_pad_token_idx = vocab.tgt['<pad>']
### YOUR CODE HERE (~2 Lines)
### TODO - Initialize the following variables:
### self.source (Embedding Layer for source language)
### self.target (Embedding Layer for target langauge)
###
### Note:
### 1. `vocab` object contains two vocabularies:
### `vocab.src` for source
### `vocab.tgt` for target
### 2. You can get the length of a specific vocabulary by running:
### `len(vocab.<specific_vocabulary>)`
### 3. Remember to include the padding token for the specific vocabulary
### when creating your Embedding.
###
### Use the following docs to properly initialize these variables:
### Embedding Layer:
### https://pytorch.org/docs/stable/nn.html#torch.nn.Embeddings
src_len=len(vocab.src)
tar_len=len(vocab.tgt)
self.source=nn.Embedding(src_len,embed_size,padding_idx=src_pad_token_idx)
self.target=nn.Embedding(tar_len,embed_size,padding_idx=tgt_pad_token_idx)
### END YOUR CODE