Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seq2seq #989

Merged
merged 13 commits into from
Jun 2, 2019
160 changes: 160 additions & 0 deletions tensorlayer/models/seq2seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#! /usr/bin/python
# -*- coding: utf-8 -*-

import tensorflow as tf
import tensorlayer as tl
import numpy as np
from tensorlayer.models import Model
from tensorlayer.layers import Dense, Dropout, Input
from tensorlayer.layers.core import Layer


class Seq2seq(Model):
"""vanilla stacked layer Seq2Seq model.

Parameters
----------
decoder_seq_length: int
The length of your target sequence
cell_enc : str, tf.function
The RNN function cell for your encoder stack, e.g tf.keras.layers.GRUCell
cell_dec : str, tf.function
The RNN function cell for your decoder stack, e.g. tf.keras.layers.GRUCell
n_layer : int
The number of your RNN layers for both encoder and decoder block
embedding_layer : tl.Layer
A embedding layer, e.g. tl.layers.Embedding(vocabulary_size=voc_size, embedding_size=emb_dim)
name : str
The model name

Examples
---------
Classify stacked-layer Seq2Seq model, see `chatbot <https://github.com/tensorlayer/seq2seq-chatbot>`__

Returns
-------
static stacked-layer Seq2Seq model.
"""

def __init__(self, decoder_seq_length, cell_enc, cell_dec, n_units=256, n_layer=3, embedding_layer=None, name=None):
super(Seq2seq, self).__init__(name=name)
self.embedding_layer = embedding_layer
self.vocabulary_size = embedding_layer.vocabulary_size
self.embedding_size = embedding_layer.embedding_size
self.n_layer = n_layer
self.enc_layers = []
self.dec_layers = []
for i in range(n_layer):
if (i == 0):
self.enc_layers.append(
tl.layers.RNN(
cell=cell_enc(units=n_units), in_channels=self.embedding_size, return_last_state=True
)
)
else:
self.enc_layers.append(
tl.layers.RNN(cell=cell_enc(units=n_units), in_channels=n_units, return_last_state=True)
)

for i in range(n_layer):
if (i == 0):
self.dec_layers.append(
tl.layers.RNN(
cell=cell_dec(units=n_units), in_channels=self.embedding_size, return_last_state=True
)
)
else:
self.dec_layers.append(
tl.layers.RNN(cell=cell_dec(units=n_units), in_channels=n_units, return_last_state=True)
)

self.reshape_layer = tl.layers.Reshape([-1, n_units])
self.dense_layer = tl.layers.Dense(n_units=self.vocabulary_size, in_channels=n_units)
self.reshape_layer_after = tl.layers.Reshape([-1, decoder_seq_length, self.vocabulary_size])
self.reshape_layer_individual_sequence = tl.layers.Reshape([-1, 1, self.vocabulary_size])

def inference(self, encoding, seq_length, start_token, top_n):
"""Inference mode"""
"""
Parameters
----------
encoding : input tensor
The source sequences
seq_length : int
The expected length of your predicted sequence.
start_token : int
<SOS> : The token of "start of sequence"
top_n : int
Random search algorithm based on the top top_n words sorted by the probablity.
"""
feed_output = self.embedding_layer(encoding[0])
state = [None for i in range(self.n_layer)]

for i in range(self.n_layer):
feed_output, state[i] = self.enc_layers[i](feed_output, return_state=True)
batch_size = len(encoding[0].numpy())
decoding = [[start_token] for i in range(batch_size)]
feed_output = self.embedding_layer(decoding)
for i in range(self.n_layer):
feed_output, state[i] = self.dec_layers[i](feed_output, initial_state=state[i], return_state=True)

feed_output = self.reshape_layer(feed_output)
feed_output = self.dense_layer(feed_output)
feed_output = self.reshape_layer_individual_sequence(feed_output)
feed_output = tf.argmax(feed_output, -1)
# [B, 1]
final_output = feed_output

for i in range(seq_length - 1):
feed_output = self.embedding_layer(feed_output)
for i in range(self.n_layer):
feed_output, state[i] = self.dec_layers[i](feed_output, initial_state=state[i], return_state=True)
feed_output = self.reshape_layer(feed_output)
feed_output = self.dense_layer(feed_output)
feed_output = self.reshape_layer_individual_sequence(feed_output)
ori_feed_output = feed_output
if (top_n is not None):
for k in range(batch_size):
idx = np.argpartition(ori_feed_output[k][0], -top_n)[-top_n:]
probs = [ori_feed_output[k][0][i] for i in idx]
probs = probs / np.sum(probs)
feed_output = np.random.choice(idx, p=probs)
feed_output = tf.convert_to_tensor([[feed_output]], dtype=tf.int64)
if (k == 0):
final_output_temp = feed_output
else:
final_output_temp = tf.concat([final_output_temp, feed_output], 0)
feed_output = final_output_temp
else:
feed_output = tf.argmax(feed_output, -1)
final_output = tf.concat([final_output, feed_output], 1)

return final_output, state

def forward(self, inputs, seq_length=20, start_token=None, return_state=False, top_n=None):

state = [None for i in range(self.n_layer)]
if (self.is_train):
encoding = inputs[0]
enc_output = self.embedding_layer(encoding)

for i in range(self.n_layer):
enc_output, state[i] = self.enc_layers[i](enc_output, return_state=True)

decoding = inputs[1]
dec_output = self.embedding_layer(decoding)

for i in range(self.n_layer):
dec_output, state[i] = self.dec_layers[i](dec_output, initial_state=state[i], return_state=True)

dec_output = self.reshape_layer(dec_output)
denser_output = self.dense_layer(dec_output)
output = self.reshape_layer_after(denser_output)
else:
encoding = inputs
output, state = self.inference(encoding, seq_length, start_token, top_n)

if (return_state):
return output, state
else:
return output
96 changes: 96 additions & 0 deletions tests/models/test_seq2seq_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import unittest

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import numpy as np
import tensorflow as tf
import tensorlayer as tl
from tqdm import tqdm
from sklearn.utils import shuffle
from tensorlayer.models.seq2seq import Seq2seq
from tests.utils import CustomTestCase
from tensorlayer.cost import cross_entropy_seq


class Model_SEQ2SEQ_Test(CustomTestCase):

@classmethod
def setUpClass(cls):

cls.batch_size = 16

cls.vocab_size = 20
cls.embedding_size = 32
cls.dec_seq_length = 5
cls.trainX = np.random.randint(20, size=(50, 6))
cls.trainY = np.random.randint(20, size=(50, cls.dec_seq_length + 1))
cls.trainY[:, 0] = 0 # start_token == 0

# Parameters
cls.src_len = len(cls.trainX)
cls.tgt_len = len(cls.trainY)

assert cls.src_len == cls.tgt_len

cls.num_epochs = 100
cls.n_step = cls.src_len // cls.batch_size

@classmethod
def tearDownClass(cls):
pass

def test_basic_simpleSeq2Seq(self):
model_ = Seq2seq(
decoder_seq_length=5,
cell_enc=tf.keras.layers.GRUCell,
cell_dec=tf.keras.layers.GRUCell,
n_layer=3,
n_units=128,
embedding_layer=tl.layers.Embedding(vocabulary_size=self.vocab_size, embedding_size=self.embedding_size),
)

optimizer = tf.optimizers.Adam(learning_rate=0.001)

for epoch in range(self.num_epochs):
model_.train()
trainX, trainY = shuffle(self.trainX, self.trainY)
total_loss, n_iter = 0, 0
for X, Y in tqdm(tl.iterate.minibatches(inputs=trainX, targets=trainY, batch_size=self.batch_size,
shuffle=False), total=self.n_step,
desc='Epoch[{}/{}]'.format(epoch + 1, self.num_epochs), leave=False):

dec_seq = Y[:, :-1]
target_seq = Y[:, 1:]

with tf.GradientTape() as tape:
## compute outputs
output = model_(inputs=[X, dec_seq])

output = tf.reshape(output, [-1, self.vocab_size])

loss = cross_entropy_seq(logits=output, target_seqs=target_seq)

grad = tape.gradient(loss, model_.all_weights)
optimizer.apply_gradients(zip(grad, model_.all_weights))

total_loss += loss
n_iter += 1

model_.eval()
test_sample = trainX[0:2, :].tolist()

top_n = 1
for i in range(top_n):
prediction = model_([test_sample], seq_length=self.dec_seq_length, start_token=0, top_n=1)
print("Prediction: >>>>> ", prediction, "\n Target: >>>>> ", trainY[0:2, 1:], "\n\n")

# printing average loss after every epoch
print('Epoch [{}/{}]: loss {:.4f}'.format(epoch + 1, self.num_epochs, total_loss / n_iter))


if __name__ == '__main__':
unittest.main()