Skip to content

Commit

Permalink
feat: implemented tensorflow backend for lstm on cpu and gpu, and upd…
Browse files Browse the repository at this point in the history
…ated torch frontend lstm to use ivy.lstm_update (ivy-llc#27762)
  • Loading branch information
mattbarrett98 authored Dec 18, 2023
1 parent 5abff06 commit e25f92c
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 58 deletions.
105 changes: 105 additions & 0 deletions ivy/functional/backends/tensorflow/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import tensorflow as tf
from tensorflow.python.types.core import Tensor
from keras.src.layers.rnn import gru_lstm_utils

# local
import ivy
Expand Down Expand Up @@ -700,6 +701,110 @@ def conv_general_transpose(
return res


def _cpu_lstm(
x, init_h, init_c, kernel, recurrent_kernel, bias, recurrent_bias, time_major
):
def step(cell_inputs, cell_states):
h_tm1 = cell_states[0] # previous memory state
c_tm1 = cell_states[1] # previous carry state

z = tf.keras.backend.dot(cell_inputs, kernel) + bias
z += tf.keras.backend.dot(h_tm1, recurrent_kernel) + recurrent_bias

z0, z1, z2, z3 = tf.split(z, 4, axis=-1)

i = tf.sigmoid(z0)
f = tf.sigmoid(z1)
c = f * c_tm1 + i * tf.tanh(z2)
o = tf.sigmoid(z3)

h = o * tf.tanh(c)
return h, [h, c]

_, outputs, new_states = tf.keras.backend.rnn(
step,
x,
[init_h, init_c],
time_major=time_major,
)
return outputs, new_states


def _gpu_lstm(
x, init_h, init_c, kernel, recurrent_kernel, bias, recurrent_bias, time_major
):
if not time_major:
x = tf.transpose(x, perm=(1, 0, 2))

init_h = tf.expand_dims(init_h, axis=0)
init_c = tf.expand_dims(init_c, axis=0)

weights = tf.split(kernel, 4, axis=1)
weights += tf.split(recurrent_kernel, 4, axis=1)
full_bias = tf.concat((recurrent_bias, bias), axis=0)
params = gru_lstm_utils.canonical_to_params(
weights=weights,
biases=tf.split(full_bias, 8),
shape=tf.constant([-1]),
transpose_weights=True,
)
outputs, h, c, _ = tf.raw_ops.CudnnRNN(
input=x,
input_h=init_h,
input_c=init_c,
params=params,
rnn_mode="lstm",
)
return outputs, (h, c)


def lstm_update(
x: Union[tf.Tensor, tf.Variable],
init_h: Union[tf.Tensor, tf.Variable],
init_c: Union[tf.Tensor, tf.Variable],
kernel: Union[tf.Tensor, tf.Variable],
recurrent_kernel: Union[tf.Tensor, tf.Variable],
/,
*,
bias: Optional[Union[tf.Tensor, tf.Variable]] = None,
recurrent_bias: Optional[Union[tf.Tensor, tf.Variable]] = None,
time_major: bool = False,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
dev = x.device
x = x.data
init_h = init_h.data
init_c = init_c.data
kernel = kernel.data
recurrent_kernel = recurrent_kernel.data
bias = bias.data if bias is not None else bias
recurrent_bias = (
recurrent_bias.data if recurrent_bias is not None else recurrent_bias
)
if "cpu" in dev:
outputs, new_states = _cpu_lstm(
x,
init_h,
init_c,
kernel,
recurrent_kernel,
bias,
recurrent_bias,
time_major,
)
else:
outputs, new_states = _gpu_lstm(
x,
init_h,
init_c,
kernel,
recurrent_kernel,
bias,
recurrent_bias,
time_major,
)
return outputs, new_states


def nms(
boxes,
scores=None,
Expand Down
89 changes: 36 additions & 53 deletions ivy/functional/frontends/torch/nn/functional/layer_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _generic_lstm(
(weight_ih, weight_hh),
(bias_i, bias_h),
bidirectional,
batch_first=batch_first,
batch_sizes=batch_sizes,
)
h_outs.append(h_out)
Expand All @@ -114,59 +115,32 @@ def _generic_lstm(


def _lstm_cell(
x, init_h, init_c, kernel, recurrent_kernel, bias, recurrent_bias, batch_sizes=None
x,
init_h,
init_c,
kernel,
recurrent_kernel,
bias,
recurrent_bias,
batch_first,
batch_sizes=None,
):
x_shape = x.shape
batch_shape = x_shape[1:-1]
timesteps = x_shape[0]
input_channels = x_shape[-1]

Wi = kernel
Wi_x = ivy.reshape(
ivy.matmul(ivy.reshape(x, (-1, input_channels)), Wi)
+ (bias if bias is not None else 0),
[timesteps, *batch_shape, -1],
init_h = ivy.squeeze(init_h, axis=0)
init_c = ivy.squeeze(init_c, axis=0)
out, states = ivy.lstm_update(
x,
init_h,
init_c,
kernel,
recurrent_kernel,
bias=bias,
recurrent_bias=recurrent_bias,
time_major=not batch_first,
)
Wii_x, Wif_x, Wig_x, Wio_x = ivy.split(Wi_x, num_or_size_splits=4, axis=-1)
Wh = recurrent_kernel
ht = init_h
ct = init_c
ht_list = []
ct_list = []

for Wii_xt, Wif_xt, Wig_xt, Wio_xt in zip(
ivy.unstack(Wii_x, axis=0),
ivy.unstack(Wif_x, axis=0),
ivy.unstack(Wig_x, axis=0),
ivy.unstack(Wio_x, axis=0),
):
htm1 = ht
ctm1 = ct
Wh_htm1 = ivy.matmul(htm1, Wh) + (
recurrent_bias if recurrent_bias is not None else 0
)
Whi_htm1, Whf_htm1, Whg_htm1, Who_htm1 = ivy.split(
Wh_htm1, num_or_size_splits=4, axis=-1
)
it = ivy.sigmoid(Wii_xt + Whi_htm1)
ft = ivy.sigmoid(Wif_xt + Whf_htm1)
gt = ivy.tanh(Wig_xt + Whg_htm1)
ot = ivy.sigmoid(Wio_xt + Who_htm1)
ct = ft * ctm1 + it * gt
ht = ot * ivy.tanh(ct)
ct_list.append(ct)
ht_list.append(ht)

if batch_sizes is None:
c = ct_list[-1]
h = ht_list[-1]
output = ivy.concat(ht_list, axis=0)
else:
ct_list = ivy.concat(ct_list, axis=0)
output = ht_list = ivy.concat(ht_list, axis=0)
c = _extract_states(ct_list, batch_sizes)
h = _extract_states(ht_list, batch_sizes)
return output, (h, c)
h, c = states
h = ivy.expand_dims(h) if len(h.shape) == 2 else h
c = ivy.expand_dims(c) if len(c.shape) == 2 else c
return out, (h, c)


def _lstm_full(
Expand All @@ -193,10 +167,17 @@ def _lstm_full(
)


def _lstm_layer(x, hidden, weights, biases, bidirectional, batch_sizes=None):
def _lstm_layer(
x, hidden, weights, biases, bidirectional, batch_first, batch_sizes=None
):
if not bidirectional:
result, (h, c) = _lstm_cell(
x, *hidden, *weights, *biases, batch_sizes=batch_sizes
x,
*hidden,
*weights,
*biases,
batch_first=batch_first,
batch_sizes=batch_sizes,
)
else:
result_fw, (h_fw, c_fw) = _lstm_cell(
Expand All @@ -207,6 +188,7 @@ def _lstm_layer(x, hidden, weights, biases, bidirectional, batch_sizes=None):
weights[1][0],
biases[0][0],
biases[1][0],
batch_first=batch_first,
batch_sizes=batch_sizes,
)
x_reversed = ivy.flip(x, axis=0)
Expand All @@ -218,6 +200,7 @@ def _lstm_layer(x, hidden, weights, biases, bidirectional, batch_sizes=None):
weights[1][1],
biases[0][1],
biases[1][1],
batch_first=batch_first,
batch_sizes=batch_sizes,
)
result_bw = ivy.flip(result_bw, axis=0)
Expand Down
22 changes: 17 additions & 5 deletions ivy/functional/ivy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2286,14 +2286,16 @@ def lstm_update(
*,
bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
recurrent_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
) -> Tuple[ivy.Array, ivy.Array]:
time_major: bool = False,
) -> Tuple[ivy.Array, Tuple[ivy.Array, ivy.Array]]:
"""Perform long-short term memory update by unrolling time dimension of
input array.
Parameters
----------
x
input tensor of LSTM layer *[batch_shape, t, in]*.
input tensor of LSTM layer *[batch_shape, t, in]* if time_major=False,
else *[t, batch_shape, in]*.
init_h
initial state tensor for the cell output *[batch_shape, out]*.
init_c
Expand All @@ -2306,13 +2308,19 @@ def lstm_update(
bias for cell kernel *[4 x out]*. (Default value = None)
recurrent_bias
bias for cell recurrent kernel *[4 x out]*. (Default value = None)
time_major
whether or not the input tensor `x` has the time dimension before batch dim.
Returns
-------
ret
hidden state for all timesteps *[batch_shape,t,out]* and cell state for last
timestep *[batch_shape,out]*
hidden state for all timesteps of shape *[batch_shape,t,out]* if time_major
is False, else *[t, batch_shape, out]*, and a tuple containing the final cell
states, both of shape *[batch_shape,out]*.
"""
# ToDo: test_lstm_update needs to be fixed
if time_major:
x = ivy.swapaxes(x, 0, 1)
# get shapes
x_shape = list(x.shape)
batch_shape = x_shape[:-2]
Expand Down Expand Up @@ -2364,7 +2372,11 @@ def lstm_update(

hts_list.append(ivy.expand_dims(ht, axis=-2))

return ivy.concat(hts_list, axis=-2), ct
ret = ivy.concat(hts_list, axis=-2)
if time_major:
ret = ivy.swapaxes(ret, 0, 1)

return ret, (ht, ct)


# Helpers #
Expand Down

0 comments on commit e25f92c

Please sign in to comment.