diff --git a/src/layer/lstm.cpp b/src/layer/lstm.cpp index 744bd75a57a..de1fe54f677 100644 --- a/src/layer/lstm.cpp +++ b/src/layer/lstm.cpp @@ -155,54 +155,39 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w // tanh(G) // c_t := f_t .* c_{t-1} + i_t .* g_t // h_t := o_t .* tanh[c_t] - if (num_output == hidden_size) + float* output_data = top_blob.row(ti); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) { - float* output_data = top_blob.row(ti); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < hidden_size; q++) - { - const float* gates_data = gates.row(q); + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; - float I = gates_data[0]; - float F = gates_data[1]; - float O = gates_data[2]; - float G = gates_data[3]; + I = 1.f / (1.f + exp(-I)); + F = 1.f / (1.f + exp(-F)); + O = 1.f / (1.f + exp(-O)); + G = tanh(G); - I = 1.f / (1.f + exp(-I)); - F = 1.f / (1.f + exp(-F)); - O = 1.f / (1.f + exp(-O)); - G = tanh(G); + float cell2 = F * cell_state[q] + I * G; + float H = O * tanh(cell2); + cell_state[q] = cell2; - float cell2 = F * cell_state[q] + I * G; - float H = O * tanh(cell2); - cell_state[q] = cell2; + if (num_output == hidden_size) + { hidden_state[q] = H; output_data[q] = H; } - } - else - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < hidden_size; q++) + else { - const float* gates_data = gates.row(q); - - float I = gates_data[0]; - float F = gates_data[1]; - float O = gates_data[2]; - float G = gates_data[3]; - - I = 1.f / (1.f + exp(-I)); - F = 1.f / (1.f + exp(-F)); - O = 1.f / (1.f + exp(-O)); - G = tanh(G); - - float cell2 = F * cell_state[q] + I * G; - float H = O * tanh(cell2); - cell_state[q] = cell2; tmp_hidden_state[q] = H; } + } + if (num_output != hidden_size) + { float* output_data = top_blob.row(ti); #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < num_output; q++) diff --git a/src/layer/x86/lstm_x86.cpp b/src/layer/x86/lstm_x86.cpp index 9544138cff2..78f7b3c5c30 100644 --- a/src/layer/x86/lstm_x86.cpp +++ b/src/layer/x86/lstm_x86.cpp @@ -30,12 +30,156 @@ LSTM_x86::LSTM_x86() int LSTM_x86::create_pipeline(const Option& opt) { - (void)(opt); + // pack IFOG + int num_directions = direction == 2 ? 2 : 1; + int size = weight_data_size / num_directions / hidden_size / 4; + +#if __AVX__ + weight_xc_data_packed.create(size, hidden_size / 2 + hidden_size % 2, num_directions, 32u, 8); + bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4); + weight_hc_data_packed.create(num_output, hidden_size / 2 + hidden_size % 2, num_directions, 32u, 8); +#else + weight_xc_data_packed.create(size, hidden_size, num_directions, 16u, 4); + bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4); + weight_hc_data_packed.create(num_output, hidden_size, num_directions, 16u, 4); +#endif + + #pragma omp parallel for num_threads(opt.num_threads) + for (int dr = 0; dr < num_directions; dr++) + { + const Mat weight_xc = weight_xc_data.channel(dr); + const Mat bias_c = bias_c_data.channel(dr); + const Mat weight_hc = weight_hc_data.channel(dr); + + Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr); + Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr); + Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr); + + const float* bias_c_I = bias_c.row(0); + const float* bias_c_F = bias_c.row(1); + const float* bias_c_O = bias_c.row(2); + const float* bias_c_G = bias_c.row(3); + + float* bias_c_IFOG = bias_c_data_packed_dr.row(0); + + int q = 0; +#if __AVX__ + for (; q + 1 < hidden_size; q += 2) + { + bias_c_IFOG[0] = bias_c_I[q]; + bias_c_IFOG[1] = bias_c_F[q]; + bias_c_IFOG[2] = bias_c_O[q]; + bias_c_IFOG[3] = bias_c_G[q]; + bias_c_IFOG[4] = bias_c_I[q + 1]; + bias_c_IFOG[5] = bias_c_F[q + 1]; + bias_c_IFOG[6] = bias_c_O[q + 1]; + bias_c_IFOG[7] = bias_c_G[q + 1]; + + bias_c_IFOG += 8; + + const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); + const float* weight_xc_I_1 = weight_xc.row(hidden_size * 0 + q + 1); + const float* weight_xc_F_1 = weight_xc.row(hidden_size * 1 + q + 1); + const float* weight_xc_O_1 = weight_xc.row(hidden_size * 2 + q + 1); + const float* weight_xc_G_1 = weight_xc.row(hidden_size * 3 + q + 1); + + const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); + const float* weight_hc_I_1 = weight_hc.row(hidden_size * 0 + q + 1); + const float* weight_hc_F_1 = weight_hc.row(hidden_size * 1 + q + 1); + const float* weight_hc_O_1 = weight_hc.row(hidden_size * 2 + q + 1); + const float* weight_hc_G_1 = weight_hc.row(hidden_size * 3 + q + 1); + + float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q / 2); + float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q / 2); + + for (int i = 0; i < size; i++) + { + weight_xc_IFOG[0] = weight_xc_I[i]; + weight_xc_IFOG[1] = weight_xc_F[i]; + weight_xc_IFOG[2] = weight_xc_O[i]; + weight_xc_IFOG[3] = weight_xc_G[i]; + weight_xc_IFOG[4] = weight_xc_I_1[i]; + weight_xc_IFOG[5] = weight_xc_F_1[i]; + weight_xc_IFOG[6] = weight_xc_O_1[i]; + weight_xc_IFOG[7] = weight_xc_G_1[i]; + + weight_xc_IFOG += 8; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_IFOG[0] = weight_hc_I[i]; + weight_hc_IFOG[1] = weight_hc_F[i]; + weight_hc_IFOG[2] = weight_hc_O[i]; + weight_hc_IFOG[3] = weight_hc_G[i]; + weight_hc_IFOG[4] = weight_hc_I_1[i]; + weight_hc_IFOG[5] = weight_hc_F_1[i]; + weight_hc_IFOG[6] = weight_hc_O_1[i]; + weight_hc_IFOG[7] = weight_hc_G_1[i]; + + weight_hc_IFOG += 8; + } + } +#endif // __AVX__ + for (; q < hidden_size; q++) + { + bias_c_IFOG[0] = bias_c_I[q]; + bias_c_IFOG[1] = bias_c_F[q]; + bias_c_IFOG[2] = bias_c_O[q]; + bias_c_IFOG[3] = bias_c_G[q]; + + bias_c_IFOG += 4; + + const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); + + const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); + +#if __AVX__ + float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q / 2 + q % 2); + float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q / 2 + q % 2); +#else + float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q); + float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q); +#endif + + for (int i = 0; i < size; i++) + { + weight_xc_IFOG[0] = weight_xc_I[i]; + weight_xc_IFOG[1] = weight_xc_F[i]; + weight_xc_IFOG[2] = weight_xc_O[i]; + weight_xc_IFOG[3] = weight_xc_G[i]; + + weight_xc_IFOG += 4; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_IFOG[0] = weight_hc_I[i]; + weight_hc_IFOG[1] = weight_hc_F[i]; + weight_hc_IFOG[2] = weight_hc_O[i]; + weight_hc_IFOG[3] = weight_hc_G[i]; + + weight_hc_IFOG += 4; + } + } + } return 0; } -#ifdef __AVX__ -static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt) + +static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) { int size = bottom_blob.w; int T = bottom_blob.h; @@ -44,10 +188,18 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w int hidden_size = cell_state.w; // 4 x hidden_size - Mat gates(hidden_size, 4, 4u, opt.workspace_allocator); + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); if (gates.empty()) return -100; + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + // unroll for (int t = 0; t < T; t++) { @@ -60,6 +212,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w int ti = reverse ? T - 1 - t : t; +#if __AVX__ int nn_hidden_size = hidden_size >> 1; int remain_hidden_size_start = nn_hidden_size << 1; #pragma omp parallel for num_threads(opt.num_threads) @@ -67,260 +220,180 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w { int q = qq * 2; - const float* x = bottom_blob.row(ti); - const float* hidden_ptr_r = hidden_state; - const float* bias_c_I = bias_c.row(0); - const float* bias_c_F = bias_c.row(1); - const float* bias_c_O = bias_c.row(2); - const float* bias_c_G = bias_c.row(3); - - float* gates_data_I = gates.row(0); - float* gates_data_F = gates.row(1); - float* gates_data_O = gates.row(2); - float* gates_data_G = gates.row(3); + const float* bias_c_IFOG = (const float*)bias_c + q * 4; + // gate I F O G - const float* weight_xc_I_0 = weight_xc.row(num_output * 0 + q); - const float* weight_xc_F_0 = weight_xc.row(num_output * 1 + q); - const float* weight_xc_O_0 = weight_xc.row(num_output * 2 + q); - const float* weight_xc_G_0 = weight_xc.row(num_output * 3 + q); - const float* weight_xc_I_1 = weight_xc.row(num_output * 0 + (q + 1)); - const float* weight_xc_F_1 = weight_xc.row(num_output * 1 + (q + 1)); - const float* weight_xc_O_1 = weight_xc.row(num_output * 2 + (q + 1)); - const float* weight_xc_G_1 = weight_xc.row(num_output * 3 + (q + 1)); - - const float* weight_hc_I_0 = weight_hc.row(num_output * 0 + q); - const float* weight_hc_F_0 = weight_hc.row(num_output * 1 + q); - const float* weight_hc_O_0 = weight_hc.row(num_output * 2 + q); - const float* weight_hc_G_0 = weight_hc.row(num_output * 3 + q); - const float* weight_hc_I_1 = weight_hc.row(num_output * 0 + (q + 1)); - const float* weight_hc_F_1 = weight_hc.row(num_output * 1 + (q + 1)); - const float* weight_hc_O_1 = weight_hc.row(num_output * 2 + (q + 1)); - const float* weight_hc_G_1 = weight_hc.row(num_output * 3 + (q + 1)); - - // float I = bias_c_I[q]; - // float F = bias_c_F[q]; - // float O = bias_c_O[q]; - // float G = bias_c_G[q]; - __m256 _sumI_0 = _mm256_setzero_ps(); - __m256 _sumF_0 = _mm256_setzero_ps(); - __m256 _sumO_0 = _mm256_setzero_ps(); - __m256 _sumG_0 = _mm256_setzero_ps(); - __m256 _sumI_1 = _mm256_setzero_ps(); - __m256 _sumF_1 = _mm256_setzero_ps(); - __m256 _sumO_1 = _mm256_setzero_ps(); - __m256 _sumG_1 = _mm256_setzero_ps(); - int nn_num_size = size >> 3; - int remain_size = size & 7; - for (; nn_num_size > 0; nn_num_size--) + const float* weight_xc_IFOG = weight_xc.row(q / 2); + const float* weight_hc_IFOG = weight_hc.row(q / 2); + + __m256 _IFOG = _mm256_loadu_ps(bias_c_IFOG); + __m256 _sum1 = _mm256_setzero_ps(); + __m256 _sum2 = _mm256_setzero_ps(); + __m256 _sum3 = _mm256_setzero_ps(); + + const float* x = bottom_blob.row(ti); + + int i = 0; + for (; i + 3 < size; i += 4) { - __m256 xi = _mm256_loadu_ps(x); - _sumI_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_I_0), xi, _sumI_0); - _sumF_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_F_0), xi, _sumF_0); - _sumO_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_O_0), xi, _sumO_0); - _sumG_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_G_0), xi, _sumG_0); - _sumI_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_I_1), xi, _sumI_1); - _sumF_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_F_1), xi, _sumF_1); - _sumO_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_O_1), xi, _sumO_1); - _sumG_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_G_1), xi, _sumG_1); - x += 8; - weight_xc_I_0 += 8; - weight_xc_F_0 += 8; - weight_xc_O_0 += 8; - weight_xc_G_0 += 8; - weight_xc_I_1 += 8; - weight_xc_F_1 += 8; - weight_xc_O_1 += 8; - weight_xc_G_1 += 8; + __m256 _xi0 = _mm256_broadcast_ss(x); + __m256 _xi1 = _mm256_broadcast_ss(x + 1); + __m256 _xi2 = _mm256_broadcast_ss(x + 2); + __m256 _xi3 = _mm256_broadcast_ss(x + 3); + __m256 _weight_xc_IFOG0 = _mm256_loadu_ps(weight_xc_IFOG); + __m256 _weight_xc_IFOG1 = _mm256_loadu_ps(weight_xc_IFOG + 8); + __m256 _weight_xc_IFOG2 = _mm256_loadu_ps(weight_xc_IFOG + 16); + __m256 _weight_xc_IFOG3 = _mm256_loadu_ps(weight_xc_IFOG + 24); + _IFOG = _mm256_comp_fmadd_ps(_weight_xc_IFOG0, _xi0, _IFOG); + _sum1 = _mm256_comp_fmadd_ps(_weight_xc_IFOG1, _xi1, _sum1); + _sum2 = _mm256_comp_fmadd_ps(_weight_xc_IFOG2, _xi2, _sum2); + _sum3 = _mm256_comp_fmadd_ps(_weight_xc_IFOG3, _xi3, _sum3); + + x += 4; + weight_xc_IFOG += 32; } - int nn_num_output = num_output >> 3; - int remain_num_output = num_output & 7; - for (; nn_num_output > 0; nn_num_output--) + for (; i < size; i++) { - __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r); - - _sumI_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_I_0), h_cont, _sumI_0); - _sumF_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_F_0), h_cont, _sumF_0); - _sumO_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_O_0), h_cont, _sumO_0); - _sumG_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_G_0), h_cont, _sumG_0); - _sumI_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_I_1), h_cont, _sumI_1); - _sumF_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_F_1), h_cont, _sumF_1); - _sumO_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_O_1), h_cont, _sumO_1); - _sumG_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_G_1), h_cont, _sumG_1); - hidden_ptr_r += 8; - weight_hc_I_0 += 8; - weight_hc_F_0 += 8; - weight_hc_O_0 += 8; - weight_hc_G_0 += 8; - weight_hc_I_1 += 8; - weight_hc_F_1 += 8; - weight_hc_O_1 += 8; - weight_hc_G_1 += 8; + __m256 _xi = _mm256_broadcast_ss(x); + __m256 _weight_xc_IFOG = _mm256_loadu_ps(weight_xc_IFOG); + _IFOG = _mm256_comp_fmadd_ps(_weight_xc_IFOG, _xi, _IFOG); + + x += 1; + weight_xc_IFOG += 8; } - float sums[8]; - _mm256_storeu_ps(sums, HorizontalSums(_sumI_0, _sumF_0, _sumO_0, _sumG_0, _sumI_1, _sumF_1, _sumO_1, _sumG_1)); - sums[0] += bias_c_I[q]; - sums[1] += bias_c_F[q]; - sums[2] += bias_c_O[q]; - sums[3] += bias_c_G[q]; - sums[4] += bias_c_I[q + 1]; - sums[5] += bias_c_F[q + 1]; - sums[6] += bias_c_O[q + 1]; - sums[7] += bias_c_G[q + 1]; - - for (; remain_size > 0; remain_size--) + + const float* hidden_ptr = hidden_state; + + i = 0; + for (; i + 3 < num_output; i += 4) { - float xi = *x; - sums[0] += *weight_xc_I_0 * xi; - sums[1] += *weight_xc_F_0 * xi; - sums[2] += *weight_xc_O_0 * xi; - sums[3] += *weight_xc_G_0 * xi; - sums[4] += *weight_xc_I_1 * xi; - sums[5] += *weight_xc_F_1 * xi; - sums[6] += *weight_xc_O_1 * xi; - sums[7] += *weight_xc_G_1 * xi; - x++; - weight_xc_I_0++; - weight_xc_F_0++; - weight_xc_O_0++; - weight_xc_G_0++; - weight_xc_I_1++; - weight_xc_F_1++; - weight_xc_O_1++; - weight_xc_G_1++; + __m256 _h_cont0 = _mm256_broadcast_ss(hidden_ptr); + __m256 _h_cont1 = _mm256_broadcast_ss(hidden_ptr + 1); + __m256 _h_cont2 = _mm256_broadcast_ss(hidden_ptr + 2); + __m256 _h_cont3 = _mm256_broadcast_ss(hidden_ptr + 3); + __m256 _weight_hc_IFOG0 = _mm256_loadu_ps(weight_hc_IFOG); + __m256 _weight_hc_IFOG1 = _mm256_loadu_ps(weight_hc_IFOG + 8); + __m256 _weight_hc_IFOG2 = _mm256_loadu_ps(weight_hc_IFOG + 16); + __m256 _weight_hc_IFOG3 = _mm256_loadu_ps(weight_hc_IFOG + 24); + _IFOG = _mm256_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont0, _IFOG); + _sum1 = _mm256_comp_fmadd_ps(_weight_hc_IFOG1, _h_cont1, _sum1); + _sum2 = _mm256_comp_fmadd_ps(_weight_hc_IFOG2, _h_cont2, _sum2); + _sum3 = _mm256_comp_fmadd_ps(_weight_hc_IFOG3, _h_cont3, _sum3); + + hidden_ptr += 4; + weight_hc_IFOG += 32; } - - for (; remain_num_output > 0; remain_num_output--) + for (; i < num_output; i++) { - float h_cont = *hidden_ptr_r; - sums[0] += *weight_hc_I_0 * h_cont; - sums[1] += *weight_hc_F_0 * h_cont; - sums[2] += *weight_hc_O_0 * h_cont; - sums[3] += *weight_hc_G_0 * h_cont; - sums[4] += *weight_hc_I_1 * h_cont; - sums[5] += *weight_hc_F_1 * h_cont; - sums[6] += *weight_hc_O_1 * h_cont; - sums[7] += *weight_hc_G_1 * h_cont; - hidden_ptr_r++; - weight_hc_I_0++; - weight_hc_F_0++; - weight_hc_O_0++; - weight_hc_G_0++; - weight_hc_I_1++; - weight_hc_F_1++; - weight_hc_O_1++; - weight_hc_G_1++; + __m256 _h_cont = _mm256_broadcast_ss(hidden_ptr); + __m256 _weight_hc_IFOG = _mm256_loadu_ps(weight_hc_IFOG); + _IFOG = _mm256_comp_fmadd_ps(_weight_hc_IFOG, _h_cont, _IFOG); + + hidden_ptr += 1; + weight_hc_IFOG += 8; } - gates_data_I[q] = sums[0]; - gates_data_F[q] = sums[1]; - gates_data_O[q] = sums[2]; - gates_data_G[q] = sums[3]; - gates_data_I[q + 1] = sums[4]; - gates_data_F[q + 1] = sums[5]; - gates_data_O[q + 1] = sums[6]; - gates_data_G[q + 1] = sums[7]; + + float* gates_data = gates.row(q); + + _IFOG = _mm256_add_ps(_IFOG, _sum1); + _sum2 = _mm256_add_ps(_sum2, _sum3); + _IFOG = _mm256_add_ps(_IFOG, _sum2); + + _mm256_storeu_ps(gates_data, _IFOG); } +#else + int nn_hidden_size = 0; + int remain_hidden_size_start = 0; +#endif // __AVX__ + #pragma omp parallel for num_threads(opt.num_threads) for (int q = remain_hidden_size_start; q < hidden_size; q++) { - const float* x = bottom_blob.row(ti); - const float* hidden_ptr_r = hidden_state; - const float* bias_c_I = bias_c.row(0); - const float* bias_c_F = bias_c.row(1); - const float* bias_c_O = bias_c.row(2); - const float* bias_c_G = bias_c.row(3); - - float* gates_data_I = gates.row(0); - float* gates_data_F = gates.row(1); - float* gates_data_O = gates.row(2); - float* gates_data_G = gates.row(3); + const float* bias_c_IFOG = (const float*)bias_c + q * 4; + // gate I F O G - const float* weight_xc_I = weight_xc.row(num_output * 0 + q); - const float* weight_xc_F = weight_xc.row(num_output * 1 + q); - const float* weight_xc_O = weight_xc.row(num_output * 2 + q); - const float* weight_xc_G = weight_xc.row(num_output * 3 + q); - - const float* weight_hc_I = weight_hc.row(num_output * 0 + q); - const float* weight_hc_F = weight_hc.row(num_output * 1 + q); - const float* weight_hc_O = weight_hc.row(num_output * 2 + q); - const float* weight_hc_G = weight_hc.row(num_output * 3 + q); - - // float I = bias_c_I[q]; - // float F = bias_c_F[q]; - // float O = bias_c_O[q]; - // float G = bias_c_G[q]; - __m256 _sumI = _mm256_setzero_ps(); - __m256 _sumF = _mm256_setzero_ps(); - __m256 _sumO = _mm256_setzero_ps(); - __m256 _sumG = _mm256_setzero_ps(); - int nn_num_size = size >> 3; - int remain_size = size & 7; - for (; nn_num_size > 0; nn_num_size--) +#if __AVX__ + const float* weight_xc_IFOG = weight_xc.row(q / 2 + q % 2); + const float* weight_hc_IFOG = weight_hc.row(q / 2 + q % 2); +#else // __AVX__ + const float* weight_xc_IFOG = weight_xc.row(q); + const float* weight_hc_IFOG = weight_hc.row(q); +#endif // __AVX__ + + __m128 _IFOG = _mm_loadu_ps(bias_c_IFOG); + __m128 _sum1 = _mm_setzero_ps(); + __m128 _sum2 = _mm_setzero_ps(); + __m128 _sum3 = _mm_setzero_ps(); + + const float* x = bottom_blob.row(ti); + + int i = 0; + for (; i + 3 < size; i += 4) { - __m256 xi = _mm256_loadu_ps(x); - _sumI = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_I), xi, _sumI); - _sumF = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_F), xi, _sumF); - _sumO = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_O), xi, _sumO); - _sumG = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_G), xi, _sumG); - x += 8; - weight_xc_I += 8; - weight_xc_F += 8; - weight_xc_O += 8; - weight_xc_G += 8; + __m128 _xi0 = _mm_load1_ps(x); + __m128 _xi1 = _mm_load1_ps(x + 1); + __m128 _xi2 = _mm_load1_ps(x + 2); + __m128 _xi3 = _mm_load1_ps(x + 3); + __m128 _weight_xc_IFOG0 = _mm_loadu_ps(weight_xc_IFOG); + __m128 _weight_xc_IFOG1 = _mm_loadu_ps(weight_xc_IFOG + 4); + __m128 _weight_xc_IFOG2 = _mm_loadu_ps(weight_xc_IFOG + 8); + __m128 _weight_xc_IFOG3 = _mm_loadu_ps(weight_xc_IFOG + 12); + _IFOG = _mm_comp_fmadd_ps(_weight_xc_IFOG0, _xi0, _IFOG); + _sum1 = _mm_comp_fmadd_ps(_weight_xc_IFOG1, _xi1, _sum1); + _sum2 = _mm_comp_fmadd_ps(_weight_xc_IFOG2, _xi2, _sum2); + _sum3 = _mm_comp_fmadd_ps(_weight_xc_IFOG3, _xi3, _sum3); + + x += 4; + weight_xc_IFOG += 16; } - int nn_num_output = num_output >> 3; - int remain_num_output = num_output & 7; - for (; nn_num_output > 0; nn_num_output--) + for (; i < size; i++) { - __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r); - - _sumI = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_I), h_cont, _sumI); - _sumF = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_F), h_cont, _sumF); - _sumO = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_O), h_cont, _sumO); - _sumG = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_G), h_cont, _sumG); - hidden_ptr_r += 8; - weight_hc_I += 8; - weight_hc_F += 8; - weight_hc_O += 8; - weight_hc_G += 8; + __m128 _xi = _mm_load1_ps(x); + __m128 _weight_xc_IFOG = _mm_loadu_ps(weight_xc_IFOG); + _IFOG = _mm_comp_fmadd_ps(_weight_xc_IFOG, _xi, _IFOG); + + x += 1; + weight_xc_IFOG += 4; } - float sums[4]; - _mm_storeu_ps(sums, HorizontalSums(_sumI, _sumF, _sumO, _sumG)); - sums[0] += bias_c_I[q]; - sums[1] += bias_c_F[q]; - sums[2] += bias_c_O[q]; - sums[3] += bias_c_G[q]; - - for (; remain_size > 0; remain_size--) + + const float* hidden_ptr = hidden_state; + + i = 0; + for (; i + 3 < num_output; i += 4) { - float xi = *x; - sums[0] += *weight_xc_I * xi; - sums[1] += *weight_xc_F * xi; - sums[2] += *weight_xc_O * xi; - sums[3] += *weight_xc_G * xi; - x++; - weight_xc_I++; - weight_xc_F++; - weight_xc_O++; - weight_xc_G++; + __m128 _h_cont0 = _mm_load1_ps(hidden_ptr); + __m128 _h_cont1 = _mm_load1_ps(hidden_ptr + 1); + __m128 _h_cont2 = _mm_load1_ps(hidden_ptr + 2); + __m128 _h_cont3 = _mm_load1_ps(hidden_ptr + 3); + __m128 _weight_hc_IFOG0 = _mm_loadu_ps(weight_hc_IFOG); + __m128 _weight_hc_IFOG1 = _mm_loadu_ps(weight_hc_IFOG + 4); + __m128 _weight_hc_IFOG2 = _mm_loadu_ps(weight_hc_IFOG + 8); + __m128 _weight_hc_IFOG3 = _mm_loadu_ps(weight_hc_IFOG + 12); + _IFOG = _mm_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont0, _IFOG); + _sum1 = _mm_comp_fmadd_ps(_weight_hc_IFOG1, _h_cont1, _sum1); + _sum2 = _mm_comp_fmadd_ps(_weight_hc_IFOG2, _h_cont2, _sum2); + _sum3 = _mm_comp_fmadd_ps(_weight_hc_IFOG3, _h_cont3, _sum3); + + hidden_ptr += 4; + weight_hc_IFOG += 16; } - - for (; remain_num_output > 0; remain_num_output--) + for (; i < num_output; i++) { - float h_cont = *hidden_ptr_r; - sums[0] += *weight_hc_I * h_cont; - sums[1] += *weight_hc_F * h_cont; - sums[2] += *weight_hc_O * h_cont; - sums[3] += *weight_hc_G * h_cont; - hidden_ptr_r++; - weight_hc_I++; - weight_hc_F++; - weight_hc_O++; - weight_hc_G++; + __m128 _h_cont = _mm_load1_ps(hidden_ptr); + __m128 _weight_hc_IFOG = _mm_loadu_ps(weight_hc_IFOG); + _IFOG = _mm_comp_fmadd_ps(_weight_hc_IFOG, _h_cont, _IFOG); + + hidden_ptr += 1; + weight_hc_IFOG += 4; } - gates_data_I[q] = sums[0]; - gates_data_F[q] = sums[1]; - gates_data_O[q] = sums[2]; - gates_data_G[q] = sums[3]; + + float* gates_data = gates.row(q); + + _IFOG = _mm_add_ps(_IFOG, _sum1); + _sum2 = _mm_add_ps(_sum2, _sum3); + _IFOG = _mm_add_ps(_IFOG, _sum2); + + _mm_storeu_ps(gates_data, _IFOG); } // lstm unit @@ -331,69 +404,117 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w // c_t := f_t .* c_{t-1} + i_t .* g_t // h_t := o_t .* tanh[c_t] float* output_data = top_blob.row(ti); + float* cell_ptr = cell_state; float* hidden_ptr = hidden_state; - const float* gates_data_I = gates.row(0); - const float* gates_data_F = gates.row(1); - const float* gates_data_O = gates.row(2); - const float* gates_data_G = gates.row(3); - int nn_activation = hidden_size >> 3; - int remain_activations = hidden_size & 7; - for (; nn_activation > 0; nn_activation--) + float* tmp_hidden_ptr = tmp_hidden_state; + + nn_hidden_size = hidden_size >> 2; + remain_hidden_size_start = nn_hidden_size << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) { - __m256 I = sigmoid_avx(_mm256_loadu_ps(gates_data_I)); - __m256 F = sigmoid_avx(_mm256_loadu_ps(gates_data_F)); - __m256 O = sigmoid_avx(_mm256_loadu_ps(gates_data_O)); - __m256 G = tanh_avx(_mm256_loadu_ps(gates_data_G)); - __m256 cell2 = _mm256_add_ps(_mm256_mul_ps(F, _mm256_loadu_ps(cell_ptr)), _mm256_mul_ps(I, G)); - __m256 H = _mm256_mul_ps(O, tanh_avx(cell2)); - _mm256_storeu_ps(cell_ptr, cell2); - _mm256_storeu_ps(hidden_ptr, H); - _mm256_storeu_ps(output_data, H); - cell_ptr += 8; - output_data += 8; - hidden_ptr += 8; - gates_data_I += 8; - gates_data_F += 8; - gates_data_O += 8; - gates_data_G += 8; + int q = qq * 4; + + const float* gates_data = gates.row(q); + + __m128 _IFOG_4x4_0 = _mm_loadu_ps(gates_data); + __m128 _IFOG_4x4_1 = _mm_loadu_ps(gates_data + 4); + __m128 _IFOG_4x4_2 = _mm_loadu_ps(gates_data + 8); + __m128 _IFOG_4x4_3 = _mm_loadu_ps(gates_data + 12); + + _MM_TRANSPOSE4_PS(_IFOG_4x4_0, _IFOG_4x4_1, _IFOG_4x4_2, _IFOG_4x4_3); + + __m128 _I = sigmoid_sse(_IFOG_4x4_0); + __m128 _F = sigmoid_sse(_IFOG_4x4_1); + __m128 _O = sigmoid_sse(_IFOG_4x4_2); + __m128 _G = tanh_sse(_IFOG_4x4_3); + + __m128 _cell2 = _mm_add_ps(_mm_mul_ps(_F, _mm_loadu_ps(cell_ptr + q)), _mm_mul_ps(_I, _G)); + __m128 _H = _mm_mul_ps(_O, tanh_sse(_cell2)); + + _mm_storeu_ps(cell_ptr + q, _cell2); + + if (num_output == hidden_size) + { + _mm_storeu_ps(hidden_ptr + q, _H); + _mm_storeu_ps(output_data + q, _H); + } + else + { + _mm_storeu_ps(tmp_hidden_ptr + q, _H); + } } - for (; remain_activations > 0; remain_activations--) + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) { - float I = *gates_data_I; - float F = *gates_data_F; - float O = *gates_data_O; - float G = *gates_data_G; + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; I = 1.f / (1.f + exp(-I)); F = 1.f / (1.f + exp(-F)); O = 1.f / (1.f + exp(-O)); G = tanh(G); - float cell2 = F * *cell_ptr + I * G; + + float cell2 = F * cell_ptr[q] + I * G; float H = O * tanh(cell2); - *cell_ptr = cell2; - *hidden_ptr = H; - *output_data = H; - cell_ptr++; - output_data++; - hidden_ptr++; - gates_data_I++; - gates_data_F++; - gates_data_O++; - gates_data_G++; + + cell_ptr[q] = cell2; + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = H; + } + else + { + tmp_hidden_ptr[q] = H; + } } - // no cell output here + if (num_output != hidden_size) + { + float* output_data = top_blob.row(ti); + + float* hidden_ptr = hidden_state; + + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float s = 0; + for (int i = 0; i < hidden_size; i++) + { + s += tmp_hidden_ptr[i] * hr[i]; + } + + output_data[q] = s; + hidden_ptr[q] = s; + } + } } return 0; } -#endif int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { -#if 0//__AVX__ int T = bottom_blob.h; + int num_directions = direction == 2 ? 2 : 1; // initial hidden state @@ -401,7 +522,7 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (hidden.empty()) return -100; hidden.fill(0.f); - // internal cell state + Mat cell(hidden_size, 4u, opt.workspace_allocator); if (cell.empty()) return -100; @@ -414,7 +535,7 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret != 0) return ret; } @@ -429,14 +550,14 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob_reverse.empty()) return -100; - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); + int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret0 != 0) return ret0; hidden.fill(0.0f); cell.fill(0.0f); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt); + int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); if (ret1 != 0) return ret1; @@ -453,14 +574,10 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) } return 0; -#else - return LSTM::forward(bottom_blob, top_blob, opt); -#endif } int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { -#if 0//__AVX__ const Mat& bottom_blob = bottom_blobs[0]; int T = bottom_blob.h; int num_directions = direction == 2 ? 2 : 1; @@ -494,7 +611,7 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret != 0) return ret; } @@ -511,15 +628,13 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat hidden0 = hidden.row_range(0, 1); Mat cell0 = cell.row_range(0, 1); - - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, cell0, opt); + int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); if (ret0 != 0) return ret0; Mat hidden1 = hidden.row_range(1, 1); Mat cell1 = cell.row_range(1, 1); - - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, cell1, opt); + int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); if (ret1 != 0) return ret1; @@ -542,9 +657,6 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to } return 0; -#else - return LSTM::forward(bottom_blobs, top_blobs, opt); -#endif } } // namespace ncnn diff --git a/src/layer/x86/lstm_x86.h b/src/layer/x86/lstm_x86.h index 51ffb413916..cab7d7e32fa 100644 --- a/src/layer/x86/lstm_x86.h +++ b/src/layer/x86/lstm_x86.h @@ -31,6 +31,9 @@ class LSTM_x86 : virtual public LSTM virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; public: + Mat weight_xc_data_packed; + Mat bias_c_data_packed; + Mat weight_hc_data_packed; }; } // namespace ncnn