Skip to content

Commit

Permalink
Added Skip GRU
Browse files Browse the repository at this point in the history
Moved GRU and Skip GRU to skip_gru.py.
Implemented SkipGRU.
Added SkipGRU and ResNet to LSTNet.
  • Loading branch information
Chase-Grajeda committed Jun 11, 2024
1 parent 67b6726 commit d5f9674
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 15 deletions.
23 changes: 12 additions & 11 deletions bayesflow/experimental/networks/lstnet/lstnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from bayesflow.experimental.utils import keras_kwargs
from keras import layers, Sequential, regularizers
from keras.saving import (register_keras_serializable)
from .skip_gru import SkipGRU
from ...networks.resnet import ResNet

@register_keras_serializable(package="bayesflow.networks.lstnet")
class LSTNet(keras.Model):
Expand Down Expand Up @@ -37,7 +39,7 @@ def __init__(

# TODO: Tidy code and condense comments

# Define model sequencer
# Define model
self.model = Sequential()

# 1D convolution layer with custom activation
Expand All @@ -50,23 +52,22 @@ def __init__(
)

# Batch normalization layer
self.bnorm = layers.BatchNormalization() # TODO: any custom args here?
self.bnorm = layers.BatchNormalization()

# GRU layers
self.keep_gru = layers.GRU(gru_out) # temp for gru.py
self.skip_gru = layers.GRU(gru_out) # temp for skip_gru.py
self.gru_add = layers.Add()
# self.gru_concat = layers.Concatenate(axis=-1)
# Skip GRU
self.skip_gru = SkipGRU(gru_out, [2])

# Final dense layer
self.final_dense = layers.Dense(dense_out, activation="relu") # TODO: upgrade to ResNet

# Aggregate layers
# In: (batch, time steps, num series)
self.resnet = ResNet(width=dense_out)

# Aggregate layers In: (batch, time steps, num series)
self.model.add(self.conv1) # -> (batch, reduced time steps, cnn_out)
self.model.add(self.bnorm) # -> (batch, reduced time steps, cnn_out)
self.model.add(self.keep_gru) # -> (batch, gru_out)
self.model.add(self.final_dense) # -> (batch, gru_out)
self.model.add(self.skip_gru) # -> (batch, gru_out) ...ideally
# self.model.add(self.final_dense) # -> (batch, gru_out)
self.model.add(self.resnet)


def call(self, x: Tensor) -> Tensor:
Expand Down
45 changes: 41 additions & 4 deletions bayesflow/experimental/networks/lstnet/skip_gru.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,45 @@
import keras
from keras.saving import register_keras_serializable
from bayesflow.experimental.types import Tensor
from keras import layers, Sequential
# from bayesflow.experimental.types import Tensor
from tensorflow import Tensor

@register_keras_serializable(package="bayesflow.networks")
@register_keras_serializable(package="bayesflow.networks.skip_gru")
class SkipGRU(keras.Model):
# TODO
pass
def __init__(self, gru_out: int, skip_steps: list[int], **kwargs):
super().__init__(**kwargs)
self.gru_out = gru_out
self.skip_steps = skip_steps
self.gru = layers.GRU(gru_out)
self.skip_grus = [layers.GRU(gru_out) for _ in range(len(self.skip_steps))]

def call(self, x: Tensor) -> Tensor:
# Standard GRU
# In: (batch, reduced time steps, cnn_out)
gru = self.gru(x) # -> (batch, gru_out)

# x = C, gru = R

# Skip GRU
batch_size = x.shape[0]
reduced_steps = x.shape[1]
for i, skip_step in enumerate(self.skip_steps):
# Reshape, remove skipped time points
skip_length = reduced_steps // skip_step
s = x[:, -skip_length * skip_step:, :] # -> (batch, shrinked time steps, cnn_out)
s1 = keras.ops.reshape(s, (s.shape[0], s.shape[2], skip_length, skip_step)) # -> (batch, cnn_out, skip_length, skip_step)
s2 = keras.ops.transpose(s1, [0, 3, 2, 1]) # -> (batch, skip step, skip_length, cnn_out)
s3 = keras.ops.reshape(s2, (s2.shape[0] * s2.shape[1], s2.shape[2], s2.shape[3])) # -> (batch * skip step, skip_length, cnn_out)

# GRU on remaining data
s4 = self.skip_grus[i](s3) # -> (batch * skip step, gru_out)
s5 = keras.ops.reshape(s4, (batch_size, skip_step * s4.shape[1])) # -> (batch, skip step * gru_out)

# Concat
gru = keras.ops.concatenate([gru, s5], axis=1) # -> (batch, gru_out * skip step * 2)

return gru

def build(self, input_shape):
super().build(input_shape)
self(keras.KerasTensor(input_shape))

0 comments on commit d5f9674

Please sign in to comment.