Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multimodal emo and transformer #202

Merged
merged 9 commits into from
Apr 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions delta/data/task/speech_cls_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
''' emotion speech task '''
import re
import ast
import os
import copy
Expand Down Expand Up @@ -42,6 +43,10 @@ def _load_text(text_path):
return text


def _process_text(text):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make it a class static method.

text = re.findall(r"[\w']+|[.,!?;]", text.lower())
return text

#pylint: disable=too-many-public-methods,too-many-instance-attributes
@registers.task.register
class SpeechClsTask(SpeechTask):
Expand Down Expand Up @@ -117,8 +122,8 @@ def __init__(self, config, mode):
self.generate_meta(mode)

# load text vocab table
use_text = self.taskconf['text']['enable']
if use_text:
self.use_text = self.taskconf['text']['enable']
if self.use_text:
self.load_text_vocab_table()

# use distilling
Expand Down Expand Up @@ -683,6 +688,7 @@ def _word_table_lookup(self, text):
''' convert text to id'''
max_text_len = self._max_text_len
text2id = np.zeros(shape=[max_text_len])
text = _process_text(text)
pad_len = min(max_text_len, len(text))
for char_num in range(pad_len):
## handle unk
Expand All @@ -695,7 +701,6 @@ def _word_table_lookup(self, text):
#pylint: disable=too-many-statements,too-many-locals,too-many-branches
def generate_data(self):
''' generate one example'''
use_text = self.taskconf['text']['enable']

# total files
total = len(self._train_by_filename.values())
Expand All @@ -707,7 +712,7 @@ def generate_data(self):
#logging.info("example info", filename, examples)

# convert txt to ids
if use_text:
if self.use_text:
text = _load_text('.'.join(filename.split('.')[:-1]))
text2id = self._word_table_lookup(text)
else:
Expand All @@ -732,7 +737,7 @@ def generate_data(self):
class_num = self.taskconf['classes']['num']
soft_label = [0] * class_num

if use_text:
if self.use_text:
if clip_id == 0:
# only add into batch when meet the first clip
batch.append(
Expand Down Expand Up @@ -780,7 +785,7 @@ def generate_data(self):
# convert string label to int label
labelid = self.class_id(label)

if use_text:
if self.use_text:
if clip_id == 0:
# only add into batch when meet the first clip
batch.append(
Expand Down Expand Up @@ -882,6 +887,7 @@ def __init__(self, config, mode):
assert self.subset in ('impro', 'script', 'all')
logging.info(f"using subset data: {self.subset}, shuffle: {self.shuffle}")


self.examples_meta = []
for _, (filename, examples) in enumerate(self.data_items):
for label, seg, clip_id in examples:
Expand Down Expand Up @@ -932,7 +938,8 @@ def __getitem__(self, batch_index):
feats = []
labels = []
filenames = []
for _, (filename, label, seg) in enumerate(batch_meta):
texts = []
for i, (filename, label, seg) in enumerate(batch_meta):
feat = np.load(filename)

# shape : [nframe, feat_size, 3]
Expand All @@ -953,15 +960,25 @@ def __getitem__(self, batch_index):

# convert string label to int label
labelid = self.class_id(label)

if self.use_text:
text = _load_text('.'.join(filename.split('.')[:-1]))
text2id = self._word_table_lookup(text)
texts.append(text2id)
feats.append(feat)
filenames.append(filename)
labels.append(labelid)

features = {
'inputs': np.array(feats, dtype=np.float64),
'labels': np.array(labels, dtype=np.int32),
}
if self.use_text:
features = {
'inputs': np.array(feats, dtype=np.float32),
'labels': np.array(labels, dtype=np.int32),
'texts': np.array(texts, dtype=np.int32),
}
else:
features = {
'inputs': np.array(feats, dtype=np.float32),
'labels': np.array(labels, dtype=np.int32),
}

one_hot_label = np.array(labels, dtype=np.int32)
one_hot_label = tf.keras.utils.to_categorical(
Expand Down
12 changes: 6 additions & 6 deletions delta/data/task/text_seq2seq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def generate_data(self):
lambda x: compute_sen_lens(x, padding_token=utils.PAD_IDX),
num_parallel_calls=self.num_parallel_calls)

src_ds = src_ds.map(
self.exclude_padding, num_parallel_calls=self.num_parallel_calls)
# src_ds = src_ds.map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not remove it?

# self.exclude_padding, num_parallel_calls=self.num_parallel_calls)

if self.infer_without_label:
data_set = tf.data.Dataset.zip((src_ds, src_size_ds))
Expand All @@ -130,8 +130,8 @@ def generate_data(self):
lambda x: compute_sen_lens(x, padding_token=utils.PAD_IDX),
num_parallel_calls=self.num_parallel_calls)

tgt_in_ds = tgt_in_ds.map(
self.exclude_padding, num_parallel_calls=self.num_parallel_calls)
# tgt_in_ds = tgt_in_ds.map(
# self.exclude_padding, num_parallel_calls=self.num_parallel_calls)

inp_ds = tf.data.Dataset.zip(
(src_ds, src_size_ds, tgt_in_ds, tgt_in_size_ds))
Expand All @@ -145,8 +145,8 @@ def generate_data(self):
target_vocab_file_path),
num_parallel_calls=self.num_parallel_calls)

tgt_out_ds = tgt_out_ds.map(
self.exclude_padding, num_parallel_calls=self.num_parallel_calls)
# tgt_out_ds = tgt_out_ds.map(
# self.exclude_padding, num_parallel_calls=self.num_parallel_calls)
data_set = tf.data.Dataset.zip((inp_ds, tgt_out_ds))

vocab_dict = load_vocab_dict(self.text_vocab_file_path)
Expand Down
4 changes: 3 additions & 1 deletion delta/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from delta.layers.attention import MatchAttention
from delta.layers.recurrent import RnnEncoder
from delta.layers.recurrent import RnnDecoder
from delta.layers.transformer import PositionEmbedding
from delta.layers.sub_tf import MultiHeadAttention
from delta.layers.sub_tf import PositionEmbedding
from delta.layers.sub_tf import PositionwiseFeedForward
from delta.layers.transformer import TransformerEncoder
from delta.layers.transformer import TransformerDecoder

Expand Down
209 changes: 209 additions & 0 deletions delta/layers/sub_tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer sub layers."""
from absl import logging
import delta.compat as tf
import numpy as np
from delta.layers.base_layer import Layer
from delta.layers.utils_tf import shape_list

#pylint: disable=invalid-name, too-many-instance-attributes, too-many-arguments


class PositionEmbedding(Layer):
"""
PositionEmbedding represents the positional information of tokens
consisting of two optional types: constant(untrainable) and trainable.
"""
def __init__(self, max_len, embed_dim, use_const, name, **kwargs):
super().__init__(**kwargs)
self.max_len = max_len
self.embed_dim = embed_dim
self.use_const = use_const
self.pos_name = name
self.pos_embed = self.get_pos_embedding_matrix(self.max_len,
self.embed_dim,
self.use_const,
self.pos_name)
@staticmethod
def get_pos_embedding_matrix(max_len, embed_dim, use_const, name):
"""
generate position embedding matrix, two optional types:
constant(untrainable) and trainable.
Args:
max_len, embed_dim, use_const

Return:
pos_embed: [max_len, embed_dim]
"""
# First part of the PE function: sin and cos argument
if use_const:
pos_embed = np.array([[
pos / np.power(10000, (i - i % 2) / embed_dim)
for i in range(embed_dim)
] for pos in range(max_len)])

# Second part, apply the cosine to even columns and sin to odds.
pos_embed[:, 0::2] = np.sin(pos_embed[:, 0::2]) # dim 2i
pos_embed[:, 1::2] = np.cos(pos_embed[:, 1::2]) # dim 2i+1
pos_embed = pos_embed[np.newaxis, ...]
pos_embed = tf.cast(pos_embed, dtype=tf.float32)
else:
pos_embed = tf.get_variable(
name=name,
shape=[max_len, embed_dim],
initializer=tf.random_uniform_initializer(-0.1, 0.1))
pos_embed = tf.expand_dims(pos_embed, 0)

return pos_embed

def call(self, inputs, training=None, mask=None):
"""
Args:
inputs: [batch_size, seq_x_len, embed_dim]
Return:
pos_embed: [batch_size, seq_x_len, embed_dim]
"""
seq_len = shape_list(inputs)[1]
pos_embed = self.pos_embed[:, :seq_len, :]
return pos_embed


class PositionwiseFeedForward(Layer):
"""
A two-layer Feed-Forward-Network.
"""
def __init__(self, d_model, dff, act_func, **kwargs):
super().__init__(**kwargs)
self.dense1 = tf.keras.layers.Dense(dff, activation=act_func)
self.dense2 = tf.keras.layers.Dense(d_model)

def call(self, inputs, training=None, mask=None):
"""
The implementation of PositionwiseFeedForward.
Args:
inputs: [batch_size, seq_x_len, d_model]
Return:
ffn: [batch_size, seq_x_len, d_model]
"""
ffn = self.dense2(self.dense1(inputs))
return ffn


class MultiHeadAttention(Layer):
"""
Multi-headed attention is based on "Attention
is all you Need" (https://arxiv.org/pdf/1706.03762.pdf).
"""
def __init__(self, hidden_size, num_heads, **kwargs):
super().__init__(**kwargs)
self.hidden_size, self.num_heads = hidden_size, num_heads
assert self.hidden_size % self.num_heads == 0

self.depth = self.hidden_size // self.num_heads

self.wq = tf.keras.layers.Dense(self.hidden_size)
self.wk = tf.keras.layers.Dense(self.hidden_size)
self.wv = tf.keras.layers.Dense(self.hidden_size)

self.dense = tf.keras.layers.Dense(self.hidden_size)

def split_heads(self, x, batch_size):
"""
Split hidden_size into depth(hidden_size // num_heads) for
multi-head attention.
Args:
x: (batch_size, seq_len_x, hidden_size)
batch_size

Returns:
split_x: (batch_size, num_heads, seq_len_x, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
split_x = tf.transpose(x, perm=[0, 2, 1, 3])
return split_x

def call(self, inputs, training=None, mask=None):
"""
The implementation of Multi-headed attention.
Args:
inputs = (v, k, q)
q: (batch_size, seq_len_q, hidden_size)
k: (batch_size, seq_len_k, hidden_size)
v: (batch_size, seq_len_v, hidden_size)
mask: (batch_size, seq_len_q, seq_len_k)

Returns:
output: (batch_size, seq_len_q, hidden_size)
attention_weights: (batch_size, num_heads, seq_len_q, seq_len_k)
"""
q, k, v = inputs
batch_size = tf.shape(q)[0]

q = self.wq(q) # (batch_size, seq_len_q, hidden_size)
k = self.wk(k) # (batch_size, seq_len_k, hidden_size)
v = self.wv(v) # (batch_size, seq_len_v, hidden_size)

q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)

# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = self.scaled_dot_product_attention(
q, k, v, mask)

scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)

concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.hidden_size)) # (batch_size, seq_len_q, hidden_size)

output = self.dense(concat_attention) # (batch_size, seq_len_q, hidden_size)

return output, attention_weights

@staticmethod
def scaled_dot_product_attention(q, k, v, mask):
"""
The implementation of scaled attention.
Args:
v: (batch_size, seq_len_v, hidden_size)
k: (batch_size, seq_len_k, hidden_size)
q: (batch_size, seq_len_q, hidden_size)
mask: (batch_size, seq_len_q, seq_len_k)

Returns:
output: (batch_size, seq_len_q, hidden_size)
attention_weights: (batch_size, num_heads, seq_len_q, seq_len_k)
"""

matmul_qk = tf.matmul(q, k, transpose_b=True) # (batch_size, seq_len_q, seq_len_k)

# Scaled
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

# Masked
if mask is not None:
scaled_attention_logits += (mask * -1e9)

# Normalized
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (batch_size, seq_len_q, seq_len_k)

# Weighted sum
output = tf.matmul(attention_weights, v) # (batch_size, seq_len_q, depth_v)

return output, attention_weights
Loading