diff --git a/src/fast_grnn.c b/src/fast_grnn.c index 0c84bdf..da7e9b7 100644 --- a/src/fast_grnn.c +++ b/src/fast_grnn.c @@ -69,9 +69,9 @@ void sha_rnn_rnn0_process(const sha_rnn_input_t input, sha_rnn_rnn1_input_t outp for (size_t k = 0; k < SHARNN_BRICK_SIZE; k++) { - memset(output, 0, sizeof(float) * 64); + memset(output, 0, sizeof(sha_rnn_rnn1_input_t)); rnn0_process(input[k], hidden, output); - memcpy(hidden, output, sizeof(float) * 64); + memcpy(hidden, output, sizeof(hidden)); } } @@ -109,17 +109,17 @@ static void rnn1_process(const float input[64], const float hidden[32], float ou void sha_rnn_rnn1_process(const sha_rnn_rnn1_input_t input, sha_rnn_fc_input_t output) { - static float rnn1_input_hist[9][64]; - static size_t rnn1_hist_idx; + static sha_rnn_rnn1_input_t rnn1_input_hist[9]; + static size_t rnn1_hist_idx = 0; float rnn1_hidden[32] = {0.0}; memcpy(rnn1_input_hist[rnn1_hist_idx], input, sizeof(sha_rnn_rnn1_input_t)); - memset(output, 0, sizeof(sha_rnn_fc_input_t)); for (size_t i = 0; i < 9; i++) { size_t j = (rnn1_hist_idx + 1 + i) % 9; + memset(output, 0, sizeof(sha_rnn_fc_input_t)); rnn1_process(rnn1_input_hist[j], rnn1_hidden, output); memcpy(rnn1_hidden, output, sizeof(sha_rnn_fc_input_t)); }