Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lstm proj_size #4263

Merged
merged 11 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ jobs:
uses: actions/cache@v3
with:
path: lavapipe-install
key: lavapipe-linux-install-20211127-2
key: lavapipe-linux-install-20211127-3
- name: checkout-lavapipe
if: steps.cache-lavapipe.outputs.cache-hit != 'true'
uses: actions/checkout@v3
Expand Down
10 changes: 6 additions & 4 deletions docs/developer-guide/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1026,15 +1026,17 @@ y0, hidden y1, cell y2 = lstm(x0, hidden x1, cell x2)

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | num_output | int | 0 | hidden size of output |
| 0 | num_output | int | 0 | output size of output |
| 1 | weight_data_size| int | 0 | total size of IFOG weight matrix |
| 2 | direction | int | 0 | 0=forward, 1=reverse, 2=bidirectional |
| 3 | hidden_size | int | num_output| hidden size |

| weight | type | shape |
| ------------- | ----- | --------------------- |
| weight_xc_data| float/fp16/int8 | [input_size, num_output * 4, num_directions] |
| bias_c_data | float/fp16/int8 | [num_output, 4, num_directions] |
| weight_hc_data| float/fp16/int8 | [num_output, num_output * 4, num_directions] |
| weight_xc_data| float/fp16/int8 | [input_size, hidden_size * 4, num_directions] |
| bias_c_data | float/fp16/int8 | [hidden_size, 4, num_directions] |
| weight_hc_data| float/fp16/int8 | [num_output, hidden_size * 4, num_directions] |
| weight_hr_data| float/fp16/int8 | [hidden_size, num_output, num_directions] |

Direction flag:
- 0 = forward only
Expand Down
256 changes: 188 additions & 68 deletions src/layer/arm/lstm_arm.cpp

Large diffs are not rendered by default.

285 changes: 199 additions & 86 deletions src/layer/arm/lstm_arm_asimdhp.cpp

Large diffs are not rendered by default.

97 changes: 70 additions & 27 deletions src/layer/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,43 +29,60 @@ int LSTM::load_param(const ParamDict& pd)
num_output = pd.get(0, 0);
weight_data_size = pd.get(1, 0);
direction = pd.get(2, 0);
hidden_size = pd.get(3, num_output);
return 0;
}

int LSTM::load_model(const ModelBin& mb)
{
int num_directions = direction == 2 ? 2 : 1;

int size = weight_data_size / num_directions / num_output / 4;
int size = weight_data_size / num_directions / hidden_size / 4;

// raw weight data
weight_xc_data = mb.load(size, num_output * 4, num_directions, 0);
weight_xc_data = mb.load(size, hidden_size * 4, num_directions, 0);
if (weight_xc_data.empty())
return -100;

bias_c_data = mb.load(num_output, 4, num_directions, 0);
bias_c_data = mb.load(hidden_size, 4, num_directions, 0);
if (bias_c_data.empty())
return -100;

weight_hc_data = mb.load(num_output, num_output * 4, num_directions, 0);
weight_hc_data = mb.load(num_output, hidden_size * 4, num_directions, 0);
if (weight_hc_data.empty())
return -100;

if (num_output != hidden_size)
{
weight_hr_data = mb.load(hidden_size, num_output, num_directions, 0);
if (weight_hr_data.empty())
return -100;
}

return 0;
}

static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt)
{
int size = bottom_blob.w;
int T = bottom_blob.h;

int num_output = top_blob.w;
int hidden_size = cell_state.w;

// 4 x num_output
Mat gates(4, num_output, 4u, opt.workspace_allocator);
// 4 x hidden_size
Mat gates(4, hidden_size, 4u, opt.workspace_allocator);
if (gates.empty())
return -100;

Mat tmp_hidden_state;
if (num_output != hidden_size)
{
tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator);
if (tmp_hidden_state.empty())
return -100;
}

// unroll
for (int t = 0; t < T; t++)
{
Expand All @@ -80,7 +97,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w

const float* x = bottom_blob.row(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
for (int q = 0; q < hidden_size; q++)
{
const float* bias_c_I = bias_c.row(0);
const float* bias_c_F = bias_c.row(1);
Expand All @@ -90,15 +107,15 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
float* gates_data = gates.row(q);

// gate I F O G
const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q);
const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q);
const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q);
const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q);

const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q);
const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q);
const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q);
const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q);

float I = bias_c_I[q];
float F = bias_c_F[q];
Expand Down Expand Up @@ -140,7 +157,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
// h_t := o_t .* tanh[c_t]
float* output_data = top_blob.row(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
for (int q = 0; q < hidden_size; q++)
{
const float* gates_data = gates.row(q);

Expand All @@ -157,8 +174,34 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
float cell2 = F * cell_state[q] + I * G;
float H = O * tanh(cell2);
cell_state[q] = cell2;
hidden_state[q] = H;
output_data[q] = H;

if (num_output == hidden_size)
{
hidden_state[q] = H;
output_data[q] = H;
}
else
{
tmp_hidden_state[q] = H;
}
}

if (num_output != hidden_size)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
{
const float* hr = weight_hr.row(q);

float H = 0;
for (int i = 0; i < hidden_size; i++)
{
H += tmp_hidden_state[i] * hr[i];
}

hidden_state[q] = H;
output_data[q] = H;
}
}
}

Expand All @@ -177,7 +220,7 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
return -100;
hidden.fill(0.f);

Mat cell(num_output, 4u, opt.workspace_allocator);
Mat cell(hidden_size, 4u, opt.workspace_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
Expand All @@ -189,7 +232,7 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
Expand All @@ -204,14 +247,14 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
if (top_blob_reverse.empty())
return -100;

int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret0 != 0)
return ret0;

hidden.fill(0.0f);
cell.fill(0.0f);

int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt);
if (ret1 != 0)
return ret1;

Expand Down Expand Up @@ -251,7 +294,7 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
return -100;
hidden.fill(0.f);

cell.create(num_output, num_directions, 4u, hidden_cell_allocator);
cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
Expand All @@ -265,7 +308,7 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
Expand All @@ -282,13 +325,13 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl

Mat hidden0 = hidden.row_range(0, 1);
Mat cell0 = cell.row_range(0, 1);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, cell0, opt);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt);
if (ret0 != 0)
return ret0;

Mat hidden1 = hidden.row_range(1, 1);
Mat cell1 = cell.row_range(1, 1);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, cell1, opt);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt);
if (ret1 != 0)
return ret1;

Expand Down
2 changes: 2 additions & 0 deletions src/layer/lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ class LSTM : public Layer
int num_output;
int weight_data_size;
int direction; // 0=forward 1=reverse 2=bidirectional
int hidden_size;

Mat weight_hc_data;
Mat weight_xc_data;
Mat bias_c_data;
Mat weight_hr_data;
};

} // namespace ncnn
Expand Down
Loading