-
-
Notifications
You must be signed in to change notification settings - Fork 675
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #48 from philipperemy/sequential
add sequential examples + keras layer
- Loading branch information
Showing
8 changed files
with
165 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from attention.attention import attention_3d_block # noqa | ||
from attention.attention import Attention # noqa | ||
|
||
VERSION = '3.0' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,32 @@ | ||
from tensorflow.keras.layers import Dense, Lambda, dot, Activation, concatenate | ||
from tensorflow.keras.layers import Layer | ||
|
||
|
||
def attention_3d_block(hidden_states): | ||
""" | ||
Many-to-one attention mechanism for Keras. | ||
@param hidden_states: 3D tensor with shape (batch_size, time_steps, input_dim). | ||
@return: 2D tensor with shape (batch_size, 128) | ||
@author: felixhao28. | ||
""" | ||
hidden_size = int(hidden_states.shape[2]) | ||
# Inside dense layer | ||
# hidden_states dot W => score_first_part | ||
# (batch_size, time_steps, hidden_size) dot (hidden_size, hidden_size) => (batch_size, time_steps, hidden_size) | ||
# W is the trainable weight matrix of attention Luong's multiplicative style score | ||
score_first_part = Dense(hidden_size, use_bias=False, name='attention_score_vec')(hidden_states) | ||
# score_first_part dot last_hidden_state => attention_weights | ||
# (batch_size, time_steps, hidden_size) dot (batch_size, hidden_size) => (batch_size, time_steps) | ||
h_t = Lambda(lambda x: x[:, -1, :], output_shape=(hidden_size,), name='last_hidden_state')(hidden_states) | ||
score = dot([score_first_part, h_t], [2, 1], name='attention_score') | ||
attention_weights = Activation('softmax', name='attention_weight')(score) | ||
# (batch_size, time_steps, hidden_size) dot (batch_size, time_steps) => (batch_size, hidden_size) | ||
context_vector = dot([hidden_states, attention_weights], [1, 1], name='context_vector') | ||
pre_activation = concatenate([context_vector, h_t], name='attention_output') | ||
attention_vector = Dense(128, use_bias=False, activation='tanh', name='attention_vector')(pre_activation) | ||
return attention_vector | ||
class Attention(Layer): | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def __call__(self, hidden_states): | ||
""" | ||
Many-to-one attention mechanism for Keras. | ||
@param hidden_states: 3D tensor with shape (batch_size, time_steps, input_dim). | ||
@return: 2D tensor with shape (batch_size, 128) | ||
@author: felixhao28. | ||
""" | ||
hidden_size = int(hidden_states.shape[2]) | ||
# Inside dense layer | ||
# hidden_states dot W => score_first_part | ||
# (batch_size, time_steps, hidden_size) dot (hidden_size, hidden_size) => (batch_size, time_steps, hidden_size) | ||
# W is the trainable weight matrix of attention Luong's multiplicative style score | ||
score_first_part = Dense(hidden_size, use_bias=False, name='attention_score_vec')(hidden_states) | ||
# score_first_part dot last_hidden_state => attention_weights | ||
# (batch_size, time_steps, hidden_size) dot (batch_size, hidden_size) => (batch_size, time_steps) | ||
h_t = Lambda(lambda x: x[:, -1, :], output_shape=(hidden_size,), name='last_hidden_state')(hidden_states) | ||
score = dot([score_first_part, h_t], [2, 1], name='attention_score') | ||
attention_weights = Activation('softmax', name='attention_weight')(score) | ||
# (batch_size, time_steps, hidden_size) dot (batch_size, time_steps) => (batch_size, hidden_size) | ||
context_vector = dot([hidden_states, attention_weights], [1, 1], name='context_vector') | ||
pre_activation = concatenate([context_vector, h_t], name='attention_output') | ||
attention_vector = Dense(128, use_bias=False, activation='tanh', name='attention_vector')(pre_activation) | ||
return attention_vector |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from keract import get_activations | ||
from tensorflow.keras import Sequential | ||
from tensorflow.keras.callbacks import Callback | ||
from tensorflow.keras.layers import Dense, LSTM | ||
|
||
from attention import Attention | ||
|
||
|
||
class VisualizeAttentionMap(Callback): | ||
|
||
def __init__(self, model, x): | ||
super().__init__() | ||
self.model = model | ||
self.x = x | ||
|
||
def on_epoch_begin(self, epoch, logs=None): | ||
attention_map = get_activations(self.model, self.x, layer_names='attention_weight')['attention_weight'] | ||
x = self.x[..., 0] | ||
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(5, 6)) | ||
maps = [attention_map, create_argmax_mask(attention_map), create_argmax_mask(x)] | ||
maps_names = ['attention layer', 'attention layer - argmax()', 'ground truth - argmax()'] | ||
for i, ax in enumerate(axes.flat): | ||
im = ax.imshow(maps[i], interpolation='none', cmap='jet') | ||
ax.set_ylabel(maps_names[i] + '\n#sample axis') | ||
ax.set_xlabel('sequence axis') | ||
ax.xaxis.set_ticks([]) | ||
ax.yaxis.set_ticks([]) | ||
cbar_ax = fig.add_axes([0.75, 0.15, 0.05, 0.7]) | ||
fig.colorbar(im, cax=cbar_ax) | ||
fig.suptitle(f'Epoch {epoch} - training') | ||
plt.show() | ||
|
||
|
||
def create_argmax_mask(x): | ||
mask = np.zeros_like(x) | ||
for i, m in enumerate(x.argmax(axis=1)): | ||
mask[i, m] = 1 | ||
return mask | ||
|
||
|
||
def main(): | ||
seq_length = 10 | ||
num_samples = 100000 | ||
# https://stats.stackexchange.com/questions/485784/which-distribution-has-its-maximum-uniformly-distributed | ||
# Choose beta(1/N,1) to have max(X_1,...,X_n) ~ U(0, 1) => minimizes amount of knowledge. | ||
# If all the max(s) are concentrated around 1, then it makes the task easy for the model. | ||
x_data = np.random.beta(a=1 / seq_length, b=1, size=(num_samples, seq_length, 1)) | ||
y_data = np.max(x_data, axis=1) | ||
model = Sequential([ | ||
LSTM(128, input_shape=(seq_length, 1), return_sequences=True), | ||
Attention(name='attention_weight'), | ||
Dense(1, activation='linear') | ||
]) | ||
model.compile(loss='mae') | ||
max_epoch = 100 | ||
# visualize the attention on the first samples. | ||
visualize = VisualizeAttentionMap(model, x_data[0:12]) | ||
model.fit(x_data, y_data, epochs=max_epoch, validation_split=0.2, callbacks=[visualize]) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,19 @@ | ||
from setuptools import setup | ||
|
||
from attention import VERSION | ||
|
||
setup( | ||
name='attention', | ||
version='2.2', | ||
description='Keras Attention Many to One', | ||
version=VERSION, | ||
description='Keras Simple Attention', | ||
author='Philippe Remy', | ||
license='Apache 2.0', | ||
long_description_content_type='text/markdown', | ||
long_description=open('README.md').read(), | ||
packages=['attention'], | ||
# manually install tensorflow or tensorflow-gpu | ||
install_requires=[ | ||
'numpy>=1.18.1', | ||
'keras>=2.3.1', | ||
'gast>=0.2.2' | ||
'tensorflow>=2.1' | ||
] | ||
) |