From fb7cc5b0faf55f3ac8641d7ee5cc88e0b8acbe33 Mon Sep 17 00:00:00 2001 From: Mariusz Ryndzionek Date: Sat, 17 Aug 2024 19:24:14 +0200 Subject: [PATCH] Few small fixes --- src/fast_grnn.c | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/fast_grnn.c b/src/fast_grnn.c index 0c84bdf..da7e9b7 100644 --- a/src/fast_grnn.c +++ b/src/fast_grnn.c @@ -69,9 +69,9 @@ void sha_rnn_rnn0_process(const sha_rnn_input_t input, sha_rnn_rnn1_input_t outp for (size_t k = 0; k < SHARNN_BRICK_SIZE; k++) { - memset(output, 0, sizeof(float) * 64); + memset(output, 0, sizeof(sha_rnn_rnn1_input_t)); rnn0_process(input[k], hidden, output); - memcpy(hidden, output, sizeof(float) * 64); + memcpy(hidden, output, sizeof(hidden)); } } @@ -109,17 +109,17 @@ static void rnn1_process(const float input[64], const float hidden[32], float ou void sha_rnn_rnn1_process(const sha_rnn_rnn1_input_t input, sha_rnn_fc_input_t output) { - static float rnn1_input_hist[9][64]; - static size_t rnn1_hist_idx; + static sha_rnn_rnn1_input_t rnn1_input_hist[9]; + static size_t rnn1_hist_idx = 0; float rnn1_hidden[32] = {0.0}; memcpy(rnn1_input_hist[rnn1_hist_idx], input, sizeof(sha_rnn_rnn1_input_t)); - memset(output, 0, sizeof(sha_rnn_fc_input_t)); for (size_t i = 0; i < 9; i++) { size_t j = (rnn1_hist_idx + 1 + i) % 9; + memset(output, 0, sizeof(sha_rnn_fc_input_t)); rnn1_process(rnn1_input_hist[j], rnn1_hidden, output); memcpy(rnn1_hidden, output, sizeof(sha_rnn_fc_input_t)); }