Skip to content

Commit

Permalink
Merge pull request #81 from mattiadg/context-gate
Browse files Browse the repository at this point in the history
Context Gates
  • Loading branch information
guillaumekln authored Jun 30, 2017
2 parents 87593c5 + d8decfd commit 58c8b52
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 7 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ The following OpenNMT features are implemented:
- saving and loading from checkpoints
- inference (translation) with batching and beam search
- multi-GPU
- Context gate

Not yet implemented:

Expand Down
27 changes: 20 additions & 7 deletions onmt/Models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
import onmt.modules
import onmt
from onmt.modules.Gate import ContextGateFactory
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.nn.utils.rnn import pack_padded_sequence as pack

Expand Down Expand Up @@ -116,6 +117,12 @@ def __init__(self, opt, dicts):
self.rnn = stackedCell(opt.layers, input_size,
opt.rnn_size, opt.dropout)
self.attn = onmt.modules.GlobalAttention(opt.rnn_size)
self.context_gate = None
if opt.context_gate is not None:
self.context_gate = ContextGateFactory(
opt.context_gate, opt.word_vec_size,
opt.rnn_size, opt.rnn_size, opt.rnn_size
)
self.dropout = nn.Dropout(opt.dropout)

self.hidden_size = opt.rnn_size
Expand All @@ -134,13 +141,19 @@ def forward(self, input, hidden, context, init_output):
outputs = []
output = init_output
for emb_t in emb.split(1):
emb_t = emb_t.squeeze(0)
emb_inp = emb_t.squeeze(0)
if self.input_feed:
emb_t = torch.cat([emb_t, output], 1)

output, hidden = self.rnn(emb_t, hidden)
output, attn = self.attn(output, context.transpose(0, 1))
output = self.dropout(output)
emb_inp = torch.cat([emb_inp, output], 1)

rnn_output, hidden = self.rnn(emb_inp, hidden)
attn_output, attn = self.attn(rnn_output, context.transpose(0, 1))
if self.context_gate is not None:
output = self.context_gate(
emb_t.squeeze(0), rnn_output, attn_output
)
output = self.dropout(output)
else:
output = self.dropout(attn_output)
outputs += [output]

outputs = torch.stack(outputs)
Expand Down
90 changes: 90 additions & 0 deletions onmt/modules/Gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Context gate is a decoder module that takes as input the previous word
embedding, the current decoder state and the attention state, and produces a
gate.
The gate can be used to select the input from the target side context
(decoder state), from the source context (attention state) or both.
"""
import torch
import torch.nn as nn


def ContextGateFactory(type, embeddings_size, decoder_size,
attention_size, output_size):
"""Returns the correct ContextGate class"""

gate_types = {'source': SourceContextGate,
'target': TargetContextGate,
'both': BothContextGate}

assert type in gate_types, "Not valid ContextGate type: {0}".format(type)
return gate_types[type](embeddings_size, decoder_size, attention_size,
output_size)


class ContextGate(nn.Module):
"""Implement up to the computation of the gate"""

def __init__(self, embeddings_size, decoder_size,
attention_size, output_size):
super(ContextGate, self).__init__()
input_size = embeddings_size + decoder_size + attention_size
self.gate = nn.Linear(input_size, output_size, bias=True)
self.sig = nn.Sigmoid()
self.source_proj = nn.Linear(attention_size, output_size)
self.target_proj = nn.Linear(embeddings_size + decoder_size,
output_size)

def forward(self, prev_emb, dec_state, attn_state):
input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=2)
z = self.sig(self.gate(input_tensor))
proj_source = self.source_proj(attn_state)
proj_target = self.target_proj(
torch.cat((prev_emb, dec_state), dim=2))
return z, proj_source, proj_target


class SourceContextGate(nn.Module):
"""Apply the context gate only to the source context"""

def __init__(self, embeddings_size, decoder_size,
attention_size, output_size):
super(SourceContextGate, self).__init__()
self.context_gate = ContextGate(embeddings_size, decoder_size,
attention_size, output_size)
self.tanh = nn.Tanh()

def forward(self, prev_emb, dec_state, attn_state):
z, source, target = self.context_gate(
prev_emb, dec_state, attn_state)
return self.tanh(target + z * source)


class TargetContextGate(nn.Module):
"""Apply the context gate only to the target context"""

def __init__(self, embeddings_size, decoder_size,
attention_size, output_size):
super(TargetContextGate, self).__init__()
self.context_gate = ContextGate(embeddings_size, decoder_size,
attention_size, output_size)
self.tanh = nn.Tanh()

def forward(self, prev_emb, dec_state, attn_state):
z, source, target = self.context_gate(prev_emb, dec_state, attn_state)
return self.tanh(z * target + source)


class BothContextGate(nn.Module):
"""Apply the context gate to both contexts"""

def __init__(self, embeddings_size, decoder_size,
attention_size, output_size):
super(BothContextGate, self).__init__()
self.context_gate = ContextGate(embeddings_size, decoder_size,
attention_size, output_size)
self.tanh = nn.Tanh()

def forward(self, prev_emb, dec_state, attn_state):
z, source, target = self.context_gate(prev_emb, dec_state, attn_state)
return self.tanh((1. - z) * target + z * source)
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
parser.add_argument('-brnn_merge', default='concat',
help="""Merge action for the bidirectional hidden states:
[concat|sum]""")
parser.add_argument('-context_gate', type=str, default=None,
choices=['source', 'target', 'both'],
help="""Type of context gate to use [source|target|both].
Do not select for no context gate.""")

# Optimization options
parser.add_argument('-encoder_type', default='text',
Expand Down

0 comments on commit 58c8b52

Please sign in to comment.