Skip to content

Commit

Permalink
lstm x86
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Oct 12, 2022
1 parent 9e9a465 commit 4c4bf2e
Show file tree
Hide file tree
Showing 3 changed files with 431 additions and 331 deletions.
59 changes: 22 additions & 37 deletions src/layer/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,54 +155,39 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
// tanh(G)
// c_t := f_t .* c_{t-1} + i_t .* g_t
// h_t := o_t .* tanh[c_t]
if (num_output == hidden_size)
float* output_data = top_blob.row(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < hidden_size; q++)
{
float* output_data = top_blob.row(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < hidden_size; q++)
{
const float* gates_data = gates.row(q);
const float* gates_data = gates.row(q);

float I = gates_data[0];
float F = gates_data[1];
float O = gates_data[2];
float G = gates_data[3];

float I = gates_data[0];
float F = gates_data[1];
float O = gates_data[2];
float G = gates_data[3];
I = 1.f / (1.f + exp(-I));
F = 1.f / (1.f + exp(-F));
O = 1.f / (1.f + exp(-O));
G = tanh(G);

I = 1.f / (1.f + exp(-I));
F = 1.f / (1.f + exp(-F));
O = 1.f / (1.f + exp(-O));
G = tanh(G);
float cell2 = F * cell_state[q] + I * G;
float H = O * tanh(cell2);
cell_state[q] = cell2;

float cell2 = F * cell_state[q] + I * G;
float H = O * tanh(cell2);
cell_state[q] = cell2;
if (num_output == hidden_size)
{
hidden_state[q] = H;
output_data[q] = H;
}
}
else
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < hidden_size; q++)
else
{
const float* gates_data = gates.row(q);

float I = gates_data[0];
float F = gates_data[1];
float O = gates_data[2];
float G = gates_data[3];

I = 1.f / (1.f + exp(-I));
F = 1.f / (1.f + exp(-F));
O = 1.f / (1.f + exp(-O));
G = tanh(G);

float cell2 = F * cell_state[q] + I * G;
float H = O * tanh(cell2);
cell_state[q] = cell2;
tmp_hidden_state[q] = H;
}
}

if (num_output != hidden_size)
{
float* output_data = top_blob.row(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
Expand Down
Loading

0 comments on commit 4c4bf2e

Please sign in to comment.