diff --git a/fluid/transformer/.gitignore b/fluid/transformer/.gitignore new file mode 100644 index 0000000000..0d20b6487c --- /dev/null +++ b/fluid/transformer/.gitignore @@ -0,0 +1 @@ +*.pyc diff --git a/fluid/transformer/README.md b/fluid/transformer/README.md new file mode 100644 index 0000000000..4988c6b1f2 --- /dev/null +++ b/fluid/transformer/README.md @@ -0,0 +1,19 @@ +# Attention is All You Need: A Paddle Fluid implementation + +This is a Paddle Fluid implementation of the Transformer model in [Attention is All You Need]() (Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, arxiv, 2017). + +If you use the dataset/code in your research, please cite the paper: + +```text +@inproceedings{vaswani2017attention, + title={Attention is all you need}, + author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and Kaiser, {\L}ukasz and Polosukhin, Illia}, + booktitle={Advances in Neural Information Processing Systems}, + pages={6000--6010}, + year={2017} +} +``` + +### TODO + +This project is still under active development. diff --git a/fluid/transformer/config.py b/fluid/transformer/config.py new file mode 100644 index 0000000000..b0ec296e1a --- /dev/null +++ b/fluid/transformer/config.py @@ -0,0 +1,73 @@ +class TrainTaskConfig(object): + use_gpu = False + # the epoch number to train. + pass_num = 2 + + # number of sequences contained in a mini-batch. + batch_size = 64 + + # the hyper params for Adam optimizer. + learning_rate = 0.001 + beta1 = 0.9 + beta2 = 0.98 + eps = 1e-9 + + +class ModelHyperParams(object): + # Dictionary size for source and target language. This model directly uses + # paddle.dataset.wmt16 in which , and token has + # alreay been added, but the token is not added. Transformer requires + # sequences in a mini-batch are padded to have the same length. A token is + # added into the original dictionary in paddle.dateset.wmt16. + + # size of source word dictionary. + src_vocab_size = 10000 + # index for token in source language. + src_pad_idx = src_vocab_size + + # size of target word dictionay + trg_vocab_size = 10000 + # index for token in target language. + trg_pad_idx = trg_vocab_size + + # position value corresponding to the token. + pos_pad_idx = 0 + + # max length of sequences. It should plus 1 to include position + # padding token for position encoding. + max_length = 50 + + # the dimension for word embeddings, which is also the last dimension of + # the input and output of multi-head attention, position-wise feed-forward + # networks, encoder and decoder. + + d_model = 512 + # size of the hidden layer in position-wise feed-forward networks. + d_inner_hid = 1024 + # the dimension that keys are projected to for dot-product attention. + d_key = 64 + # the dimension that values are projected to for dot-product attention. + d_value = 64 + # number of head used in multi-head attention. + n_head = 8 + # number of sub-layers to be stacked in the encoder and decoder. + n_layer = 6 + # dropout rate used by all dropout layers. + dropout = 0.1 + + +# Names of position encoding table which will be initialized externally. +pos_enc_param_names = ( + "src_pos_enc_table", + "trg_pos_enc_table", ) + +# Names of all data layers listed in order. +input_data_names = ( + "src_word", + "src_pos", + "trg_word", + "trg_pos", + "src_slf_attn_bias", + "trg_slf_attn_bias", + "trg_src_attn_bias", + "lbl_word", ) diff --git a/fluid/transformer/model.py b/fluid/transformer/model.py new file mode 100644 index 0000000000..1a5ae39677 --- /dev/null +++ b/fluid/transformer/model.py @@ -0,0 +1,477 @@ +from functools import partial +import numpy as np + +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +import paddle.v2.fluid.layers as layers + +from config import TrainTaskConfig, input_data_names, pos_enc_param_names + +# FIXME(guosheng): Remove out the batch_size from the model. +batch_size = TrainTaskConfig.batch_size + + +def position_encoding_init(n_position, d_pos_vec): + """ + Generate the initial values for the sinusoid position encoding table. + """ + position_enc = np.array([[ + pos / np.power(10000, 2 * (j // 2) / d_pos_vec) + for j in range(d_pos_vec) + ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)]) + position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i + position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1 + return position_enc.astype("float32") + + +def multi_head_attention(queries, + keys, + values, + attn_bias, + d_key, + d_value, + d_model, + num_heads=1, + dropout_rate=0.): + """ + Multi-Head Attention. Note that attn_bias is added to the logit before + computing softmax activiation to mask certain selected positions so that + they will not considered in attention weights. + """ + if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3): + raise ValueError( + "Inputs: quries, keys and values should all be 3-D tensors.") + + def __compute_qkv(queries, keys, values, num_heads, d_key, d_value): + """ + Add linear projection to queries, keys, and values. + """ + q = layers.fc(input=queries, + size=d_key * num_heads, + bias_attr=False, + num_flatten_dims=2) + k = layers.fc(input=keys, + size=d_key * num_heads, + bias_attr=False, + num_flatten_dims=2) + v = layers.fc(input=values, + size=d_value * num_heads, + bias_attr=False, + num_flatten_dims=2) + return q, k, v + + def __split_heads(x, num_heads): + """ + Reshape the last dimension of inpunt tensor x so that it becomes two + dimensions and then transpose. Specifically, input a tensor with shape + [bs, max_sequence_length, num_heads * hidden_dim] then output a tensor + with shape [bs, num_heads, max_sequence_length, hidden_dim]. + """ + if num_heads == 1: + return x + + hidden_size = x.shape[-1] + # FIXME(guosheng): Decouple the program desc with batch_size. + reshaped = layers.reshape( + x=x, shape=[batch_size, -1, num_heads, hidden_size // num_heads]) + + # permuate the dimensions into: + # [batch_size, num_heads, max_sequence_len, hidden_size_per_head] + return layers.transpose(x=reshaped, perm=[0, 2, 1, 3]) + + def __combine_heads(x): + """ + Transpose and then reshape the last two dimensions of inpunt tensor x + so that it becomes one dimension, which is reverse to __split_heads. + """ + if len(x.shape) == 3: return x + if len(x.shape) != 4: + raise ValueError("Input(x) should be a 4-D Tensor.") + + trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) + # FIXME(guosheng): Decouple the program desc with batch_size. + return layers.reshape( + x=trans_x, + shape=map(int, + [batch_size, -1, trans_x.shape[2] * trans_x.shape[3]])) + + def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate): + """ + Scaled Dot-Product Attention + """ + + # FIXME(guosheng): Optimize the shape in reshape_op or softmax_op. + + # The current implementation of softmax_op only supports 2D tensor, + # consequently it cannot be directly used here. + # If to use the reshape_op, Besides, the shape of product inferred in + # compile-time is not the actual shape in run-time. It cann't be used + # to set the attribute of reshape_op. + # So, here define the softmax for temporary solution. + + def __softmax(x, eps=1e-9): + exp_out = layers.exp(x=x) + sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False) + return layers.elementwise_div(x=exp_out, y=sum_out, axis=0) + + scaled_q = layers.scale(x=q, scale=d_key**-0.5) + product = layers.matmul(x=scaled_q, y=k, transpose_y=True) + weights = __softmax(layers.elementwise_add(x=product, y=attn_bias)) + if dropout_rate: + weights = layers.dropout( + weights, dropout_prob=dropout_rate, is_test=False) + out = layers.matmul(weights, v) + return out + + q, k, v = __compute_qkv(queries, keys, values, num_heads, d_key, d_value) + + q = __split_heads(q, num_heads) + k = __split_heads(k, num_heads) + v = __split_heads(v, num_heads) + + ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key, + dropout_rate) + + out = __combine_heads(ctx_multiheads) + + # Project back to the model size. + proj_out = layers.fc(input=out, + size=d_model, + bias_attr=False, + num_flatten_dims=2) + return proj_out + + +def positionwise_feed_forward(x, d_inner_hid, d_hid): + """ + Position-wise Feed-Forward Networks. + This module consists of two linear transformations with a ReLU activation + in between, which is applied to each position separately and identically. + """ + hidden = layers.fc(input=x, + size=d_inner_hid, + num_flatten_dims=2, + act="relu") + out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2) + return out + + +def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.): + """ + Add residual connection, layer normalization and droput to the out tensor + optionally according to the value of process_cmd. + + This will be used before or after multi-head attention and position-wise + feed-forward networks. + """ + for cmd in process_cmd: + if cmd == "a": # add residual connection + out = out + prev_out if prev_out else out + elif cmd == "n": # add layer normalization + out = layers.layer_norm(out, begin_norm_axis=len(out.shape) - 1) + elif cmd == "d": # add dropout + if dropout: + out = layers.dropout(out, dropout_prob=dropout, is_test=False) + return out + + +pre_process_layer = partial(pre_post_process_layer, None) +post_process_layer = pre_post_process_layer + + +def prepare_encoder(src_word, + src_pos, + src_vocab_size, + src_emb_dim, + src_pad_idx, + src_max_len, + dropout=0., + pos_pad_idx=0, + pos_enc_param_name=None): + """Add word embeddings and position encodings. + The output tensor has a shape of: + [batch_size, max_src_length_in_batch, d_model]. + + This module is used at the bottom of the encoder stacks. + """ + src_word_emb = layers.embedding( + src_word, size=[src_vocab_size, src_emb_dim], padding_idx=src_pad_idx) + src_pos_enc = layers.embedding( + src_pos, + size=[src_max_len, src_emb_dim], + padding_idx=pos_pad_idx, + param_attr=fluid.ParamAttr( + name=pos_enc_param_name, trainable=False)) + enc_input = src_word_emb + src_pos_enc + + # FIXME(guosheng): Decouple the program desc with batch_size. + enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim]) + return layers.dropout( + enc_input, dropout_prob=dropout, + is_test=False) if dropout else enc_input + + +prepare_encoder = partial( + prepare_encoder, pos_enc_param_name=pos_enc_param_names[0]) +prepare_decoder = partial( + prepare_encoder, pos_enc_param_name=pos_enc_param_names[1]) + + +def encoder_layer(enc_input, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate=0.): + """The encoder layers that can be stacked to form a deep encoder. + + This module consits of a multi-head (self) attention followed by + position-wise feed-forward networks and both the two components companied + with the post_process_layer to add residual connection, layer normalization + and droput. + """ + attn_output = multi_head_attention(enc_input, enc_input, enc_input, + attn_bias, d_key, d_value, d_model, + n_head, dropout_rate) + attn_output = post_process_layer(enc_input, attn_output, "dan", + dropout_rate) + ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model) + return post_process_layer(attn_output, ffd_output, "dan", dropout_rate) + + +def encoder(enc_input, + attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate=0.): + """ + The encoder is composed of a stack of identical layers returned by calling + encoder_layer. + """ + for i in range(n_layer): + enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value, + d_model, d_inner_hid, dropout_rate) + enc_input = enc_output + return enc_output + + +def decoder_layer(dec_input, + enc_output, + slf_attn_bias, + dec_enc_attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate=0.): + """ The layer to be stacked in decoder part. + + The structure of this module is similar to that in the encoder part except + a multi-head attention is added to implement encoder-decoder attention. + """ + slf_attn_output = multi_head_attention( + dec_input, + dec_input, + dec_input, + slf_attn_bias, + d_key, + d_value, + d_model, + n_head, + dropout_rate, ) + slf_attn_output = post_process_layer( + dec_input, + slf_attn_output, + "dan", # residual connection + dropout + layer normalization + dropout_rate, ) + enc_attn_output = multi_head_attention( + slf_attn_output, + enc_output, + enc_output, + dec_enc_attn_bias, + d_key, + d_value, + d_model, + n_head, + dropout_rate, ) + enc_attn_output = post_process_layer( + slf_attn_output, + enc_attn_output, + "dan", # residual connection + dropout + layer normalization + dropout_rate, ) + ffd_output = positionwise_feed_forward( + enc_attn_output, + d_inner_hid, + d_model, ) + dec_output = post_process_layer( + enc_attn_output, + ffd_output, + "dan", # residual connection + dropout + layer normalization + dropout_rate, ) + return dec_output + + +def decoder(dec_input, + enc_output, + dec_slf_attn_bias, + dec_enc_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate=0.): + """ + The decoder is composed of a stack of identical decoder_layer layers. + """ + for i in range(n_layer): + dec_output = decoder_layer( + dec_input, + enc_output, + dec_slf_attn_bias, + dec_enc_attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, ) + dec_input = dec_output + return dec_output + + +def transformer( + src_vocab_size, + trg_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + src_pad_idx, + trg_pad_idx, + pos_pad_idx, ): + # The shapes here act as placeholder. + # The shapes set here is to pass the infer-shape in compile time. The actual + # shape of src_word in run time is: + # [batch_size * max_src_length_in_a_batch, 1]. + src_word = layers.data( + name=input_data_names[0], + shape=[batch_size * max_length, 1], + dtype="int64", + append_batch_size=False) + # The actual shape of src_pos in runtime is: + # [batch_size * max_src_length_in_a_batch, 1]. + src_pos = layers.data( + name=input_data_names[1], + shape=[batch_size * max_length, 1], + dtype="int64", + append_batch_size=False) + # The actual shape of trg_word is in runtime is: + # [batch_size * max_trg_length_in_a_batch, 1]. + trg_word = layers.data( + name=input_data_names[2], + shape=[batch_size * max_length, 1], + dtype="int64", + append_batch_size=False) + # The actual shape of trg_pos in runtime is: + # [batch_size * max_trg_length_in_a_batch, 1]. + trg_pos = layers.data( + name=input_data_names[3], + shape=[batch_size * max_length, 1], + dtype="int64", + append_batch_size=False) + # The actual shape of src_slf_attn_bias in runtime is: + # [batch_size, n_head, max_src_length_in_a_batch, max_src_length_in_a_batch]. + # This input is used to remove attention weights on paddings. + src_slf_attn_bias = layers.data( + name=input_data_names[4], + shape=[batch_size, n_head, max_length, max_length], + dtype="float32", + append_batch_size=False) + # The actual shape of trg_slf_attn_bias in runtime is: + # [batch_size, n_head, max_trg_length_in_batch, max_trg_length_in_batch]. + # This is used to remove attention weights on paddings and subsequent words. + trg_slf_attn_bias = layers.data( + name=input_data_names[5], + shape=[batch_size, n_head, max_length, max_length], + dtype="float32", + append_batch_size=False) + # The actual shape of trg_src_attn_bias in runtime is: + # [batch_size, n_head, max_trg_length_in_batch, max_src_length_in_batch]. + # This is used to remove attention weights on paddings. + trg_src_attn_bias = layers.data( + name=input_data_names[6], + shape=[batch_size, n_head, max_length, max_length], + dtype="float32", + append_batch_size=False) + + enc_input = prepare_encoder( + src_word, + src_pos, + src_vocab_size, + d_model, + src_pad_idx, + max_length, + dropout_rate, ) + enc_output = encoder( + enc_input, + src_slf_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, ) + + dec_input = prepare_decoder( + trg_word, + trg_pos, + trg_vocab_size, + d_model, + trg_pad_idx, + max_length, + dropout_rate, ) + dec_output = decoder( + dec_input, + enc_output, + trg_slf_attn_bias, + trg_src_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, ) + + # TODO(guosheng): Share the weight matrix between the embedding layers and + # the pre-softmax linear transformation. + predict = layers.reshape( + x=layers.fc(input=dec_output, + size=trg_vocab_size, + bias_attr=False, + num_flatten_dims=2), + shape=[-1, trg_vocab_size], + act="softmax") + # The actual shape of gold in runtime is: + # [batch_size * max_trg_length_in_a_batch, 1]. + gold = layers.data( + name=input_data_names[7], + shape=[batch_size * max_length, 1], + dtype="int64", + append_batch_size=False) + cost = layers.cross_entropy(input=predict, label=gold) + return layers.mean(x=cost) diff --git a/fluid/transformer/train.py b/fluid/transformer/train.py new file mode 100644 index 0000000000..76904669de --- /dev/null +++ b/fluid/transformer/train.py @@ -0,0 +1,140 @@ +import numpy as np + +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + +from model import transformer, position_encoding_init +from config import TrainTaskConfig, ModelHyperParams, \ + pos_enc_param_names, input_data_names + + +def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, + max_length, n_head, place): + """ + Pad the instances to the max sequence length in batch, and generate the + corresponding position data and attention bias. Then, convert the numpy + data to tensors and return a dict mapping names to tensors. + """ + input_dict = {} + + def __pad_batch_data(insts, + pad_idx, + is_target=False, + return_pos=True, + return_attn_bias=True, + return_max_len=True): + """ + Pad the instances to the max sequence length in batch, and generate the + corresponding position data and attention bias. + """ + return_list = [] + max_len = max(len(inst) for inst in insts) + inst_data = np.array( + [inst + [pad_idx] * (max_len - len(inst)) for inst in insts]) + return_list += [inst_data.astype("int64").reshape([-1, 1])] + if return_pos: + inst_pos = np.array([[ + pos_i + 1 if w_i != pad_idx else 0 + for pos_i, w_i in enumerate(inst) + ] for inst in inst_data]) + + return_list += [inst_pos.astype("int64").reshape([-1, 1])] + if return_attn_bias: + if is_target: + # This is used to avoid attention on paddings and subsequent + # words. + slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, + max_len)) + slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape( + [-1, 1, max_len, max_len]) + slf_attn_bias_data = np.tile(slf_attn_bias_data, + [1, n_head, 1, 1]) * [-1e9] + else: + # This is used to avoid attention on paddings. + slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] * + (max_len - len(inst)) + for inst in insts]) + slf_attn_bias_data = np.tile( + slf_attn_bias_data.reshape([-1, 1, 1, max_len]), + [1, n_head, max_len, 1]) + return_list += [slf_attn_bias_data.astype("float32")] + if return_max_len: + return_list += [max_len] + return return_list if len(return_list) > 1 else return_list[0] + + def data_to_tensor(data_list, name_list, input_dict, place): + assert len(data_list) == len(name_list) + for i in range(len(name_list)): + tensor = fluid.LoDTensor() + tensor.set(data_list[i], place) + input_dict[name_list[i]] = tensor + + src_word, src_pos, src_slf_attn_bias, src_max_len = __pad_batch_data( + [inst[0] for inst in insts], src_pad_idx, is_target=False) + trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = __pad_batch_data( + [inst[1] for inst in insts], trg_pad_idx, is_target=True) + trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], + [1, 1, trg_max_len, 1]).astype("float32") + lbl_word = __pad_batch_data([inst[2] for inst in insts], trg_pad_idx, False, + False, False, False) + + data_to_tensor([ + src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias, + trg_slf_attn_bias, trg_src_attn_bias, lbl_word + ], input_data_names, input_dict, place) + + return input_dict + + +def main(): + avg_cost = transformer( + ModelHyperParams.src_vocab_size + 1, + ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1, + ModelHyperParams.n_layer, ModelHyperParams.n_head, + ModelHyperParams.d_key, ModelHyperParams.d_value, + ModelHyperParams.d_model, ModelHyperParams.d_inner_hid, + ModelHyperParams.dropout, ModelHyperParams.src_pad_idx, + ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx) + + optimizer = fluid.optimizer.Adam( + learning_rate=TrainTaskConfig.learning_rate, + beta1=TrainTaskConfig.beta1, + beta2=TrainTaskConfig.beta2, + epsilon=TrainTaskConfig.eps) + optimizer.minimize(avg_cost) + + train_data = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.wmt16.train(ModelHyperParams.src_vocab_size, + ModelHyperParams.trg_vocab_size), + buf_size=51200), + batch_size=TrainTaskConfig.batch_size) + + place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + # Initialize the parameters. + exe.run(fluid.framework.default_startup_program()) + for pos_enc_param_name in pos_enc_param_names: + pos_enc_param = fluid.global_scope().find_var( + pos_enc_param_name).get_tensor() + pos_enc_param.set( + position_encoding_init(ModelHyperParams.max_length + 1, + ModelHyperParams.d_model), place) + + for pass_id in xrange(TrainTaskConfig.pass_num): + for batch_id, data in enumerate(train_data()): + data_input = prepare_batch_input( + data, input_data_names, ModelHyperParams.src_pad_idx, + ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, + ModelHyperParams.n_head, place) + outs = exe.run(fluid.framework.default_main_program(), + feed=data_input, + fetch_list=[avg_cost]) + avg_cost_val = np.array(outs[0]) + print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) + + " avg_cost = " + str(avg_cost_val)) + + +if __name__ == "__main__": + main()