Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Cpu lstm inference #9977

Merged
merged 26 commits into from
Mar 10, 2018
Merged

Cpu lstm inference #9977

merged 26 commits into from
Mar 10, 2018

Conversation

Jerryzcn
Copy link
Contributor

@Jerryzcn Jerryzcn commented Mar 3, 2018

Description

(Brief description on what this PR is about)
CPU LSTM inference kernel.
This is around 9.5x faster than gluon LSTM cell.

Verified on speech recognition task, as well as on unittest.

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@pengzhao-intel
Copy link
Contributor

@Jerryzcn It's very great to explore the full CPU power for RNN cell.

FYI, we have implemented the fused LSTM OP of CPU in local including inference and training.
And the new OP is registered with NNVM.

I think we can co-operate together to merge the code. @TaoLv @sherry-zhang
We can PR our code in your repo so that the whole LSTM solution can be ready in the same time.

What's your opinion?

@Jerryzcn
Copy link
Contributor Author

Jerryzcn commented Mar 3, 2018

@pengzhao-intel This is great, we can definitely collaborate. The reason I am sending this PR is for one of our own project. and we would like to have something to use ASAP. Do you have a timeline for the fused LSTM OP?

@reminisce
Copy link
Contributor

It's great to have cpu version implemented. We are deprecating operator implementation using the legacy interface. It's would be better if you can refactor code using the nnvm interface for operator implementation. One example is https://github.com/apache/incubator-mxnet/blob/master/src/operator/nn/convolution-inl.h#L155

@pengzhao-intel
Copy link
Contributor

@Jerryzcn Thanks for the info. So, I think it's better to merge this PR as-is.

Our LSTM/GRU will be ready in this month and we will submit the code separately for the review :)

@szha
Copy link
Member

szha commented Mar 4, 2018

@pengzhao-intel will the elman RNN be part of your PR? cudnn currently supports it, so we support it as part of the RNN layers.

@TaoLv
Copy link
Member

TaoLv commented Mar 5, 2018

@szha To make it easier to review, we would like to split the whole RNN implementation on CPU into several PRs. Firstly, We will submit code for single-layer and unidirectional LSTM/GRU. Then, multi-layer and bidirectional support will be added for LSTM/GRU. Vanilla RNN (maybe elman RNN in your words) will be supported after we finish LSTM/GRU. Actually, we have implemented fused vanilla RNN, but I think it should be a low priority to integrated it into mxnet, compared with LSTM/GRU.

@szha What about your opinion? We can set a detailed plan for this PRs if needed.
@pengzhao-intel Correct me if I missed anything.

@szha
Copy link
Member

szha commented Mar 5, 2018

@TaoLv sounds good. What timeline are we looking at for feature parity with cudnn?

@TaoLv
Copy link
Member

TaoLv commented Mar 5, 2018

@szha Team need take an internal discussion about it and will back to you with a detailed plan soon.
BTW, do you know anybody can help to refactor the existing RNN operator with nnvm interfaces? Seems that some of gpu code need be changed for that.

@szha
Copy link
Member

szha commented Mar 5, 2018

Pinging @piiswrong for coordination.

@piiswrong
Copy link
Contributor

We don't have to do the refactor now. CPU support is more important.

@@ -214,6 +214,7 @@ def __iter__(self):
worker.start()
workers.append(worker)

idx = -1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idx might be reference before assignment when i run pylint

@CodingCat
Copy link
Contributor

Hi, the community has passed to vote about associating the code changes with JIRA (https://lists.apache.org/thread.html/ab22cf0e35f1bce2c3bf3bec2bc5b85a9583a3fe7fd56ba1bbade55f@%3Cdev.mxnet.apache.org%3E)

We have updated the guidelines for contributors in https://cwiki.apache.org/confluence/display/MXNET/Development+Process, please ensure that you have created a JIRA at https://issues.apache.org/jira/projects/MXNET/issues/ to describe your work in this pull request and include the JIRA title in your PR as [MXNET-xxxx] your title where MXNET-xxxx is the JIRA id

Thanks!

size *= 2;
} else {
size += (layerNum - 1) * rnn_single_param_size(hiddenSize, hiddenSize, mode);
size += (layerNum - 1) * rnn_single_param_size(hiddenSize, hiddenSize,
mode);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are just reformatting the code here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

CHECK_EQ(y.CheckContiguous(), true);

if (ctx.is_train)
LOG(FATAL) << "only inference mode is available for cpu at the moment.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can do CHECK(!ctx.is_train) << "..."

in_data[rnn_enum::kState], out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]};
std::vector<int> dep = {in_data[rnn_enum::kData],
in_data[rnn_enum::kParams], in_data[rnn_enum::kState],
out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure why you want to change the code in this function. it seems you just reorganize the code a little bit.

Copy link
Contributor Author

@Jerryzcn Jerryzcn Mar 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it exceeds 80 char per line limit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the coding style in mxnet allows up to 100 char per line.
so the original code is fine.

if (param_.mode == rnn_enum::kLstm)
param_.lstm_q_ = true;
else
param_.lstm_q_ = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems this check can be merged to the switch case statement above.

@szha szha self-assigned this Mar 8, 2018
@@ -114,7 +120,8 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {

DMLC_DECLARE_FIELD(p).set_default(0.)
.set_range(0, 1)
.describe("Dropout probability, fraction of the input that gets dropped out at training time");
.describe("Dropout probability, fraction of the input that gets dropped"
"out at training time");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this change. Length of this line is less than 100.
BTW, why there are still some parameters don't have their descriptions, like pkeep_, lstm_q_?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pkeep_, lstm_q_ are used in cudnn_rnn-inl.h

CHECK_EQ(y.CheckContiguous(), true);

CHECK(!ctx.is_train) << "only inference mode is available"
"for cpu at the moment.";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about check this at the front of this function?

model = mx.gluon.nn.Sequential()
with model.name_scope():
model.add(mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True))
model.initialize(mx.init.One())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also test the consistency between cpu and gpu, with same random weights and random inputs?

Copy link
Contributor Author

@Jerryzcn Jerryzcn Mar 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will it break CPU tests? It might be too much an effort

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under the current input and weight, your test would still pass even if the weights are iterated backwards. Unfortunately it's not in an acceptable state.

CHECK_EQ(x.CheckContiguous(), true);
CHECK_EQ(w.CheckContiguous(), true);
CHECK_EQ(hx.CheckContiguous(), true);
CHECK_EQ(y.CheckContiguous(), true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK(x.CheckContiguous());

private:
RNNParam param_;

virtual void LSTMFusedElementWiseCPUOps(const Tensor<cpu, 2, DType> &i2h_y,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why virtual?

int64_t f = i + h_channel;
int64_t c = i + h_channel * 2;
int64_t o = i + h_channel * 3;
h2h_y[j][i] += i2h_y[j][i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too many overloaded operator [] calls and temporary Tensor objects generated here. At least you can cache h2h_y[j], i2h_y[j], etc. for each loop.

Copy link
Contributor Author

@Jerryzcn Jerryzcn Mar 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried, but i did not notice any difference in runtime. i think the tensor object probably does not generate new tensor object here. I think multiple [] are probably implemented as a single dereference operation rather multiple one. I suspect that assigning it to a local variable will actually use one of the register for holding the pointer to the object, which may actually slow down the process

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you use data[i], where data is a 2D tensor, it returns a 1D temporary Tensor object for you, and then call the 1D tensor's operator[]. You would not be able to notice much runtime improvement after you make the change if the program didn't run for a long time only for this loop and the improvement could be dwarfed by other factors that are major bottlenecks. At least, it's not a good practice to write C++ code like this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay. but it seems inside mshadow, all the ops are implemented using multiple []
https://github.com/dmlc/mshadow/blob/master/mshadow/tensor_cpu-inl.h#L380
I will probably access the dptr_

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion on this. If you could use dptr_, that's the best for performance because it saves function calls and temp tensor object creation, but it could introduce the issue of code readability and defeat the purpose of OO.

I think the rule of thumb here is try to avoid temp tensor creation and destruction while keep the code readable. So it's okay to use operator[] for a 1D Tensor since it only return values and cache the temp tensor created by calling operator[] for a 2D tensor.

model = mx.gluon.nn.Sequential()
with model.name_scope():
model.add(mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True))
model.initialize(mx.init.One())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.initialize(mx.init.One())
y = model(x).asnumpy()

mx.test_utils.assert_almost_equal(y, np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there hardcoded numbers?

Copy link
Contributor Author

@Jerryzcn Jerryzcn Mar 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mv hardcoded number to constant
+ (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size));
Tensor<cpu, 2, DType> i2h_w(w.Slice(start, start + (layer < num_dir ?
(in_channel * fused_h_ch) : num_dir * h2h_w_size)).dptr_,
i2h_w_shape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why slice? i think w.dptr_ + start is same as w.Slice(start, start + (layer < num_dir ? in_channel * fused_h_ch) : num_dir * h2h_w_size)).dptr_

int64_t ji;
#pragma omp parallel for private(ji)
for (ji = 0; ji < batch_size * h_channel; ji++) {
int64_t j = ji / h_channel; // batch dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ok to write batch_size * h_channel in condition expression? It will calculate ji times.
And ++ji is better than ji++.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

u mean move it out of condition expression?

int64_t ji;
#pragma omp parallel for private(ji)
for (ji = 0; ji < batch_size * h_channel; ji++) {
int64_t j = ji / h_channel; // batch dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need to set ji private if define ji in for loop. like this:

#pragma omp parallel for
for(int64_t ji = 0; ... ; ....)

Copy link
Member

@szha szha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add consistency tests between CPU and GPU (cudnn) using random weights and random inputs, with dropouts off.

@Jerryzcn
Copy link
Contributor Author

Jerryzcn commented Mar 9, 2018

@szha I think https://github.com/apache/incubator-mxnet/blob/master/tests/python/gpu/test_operator_gpu.py#L1527 check for consistency? Although the inputs are ones.

@szha szha merged commit 13ae4d1 into apache:master Mar 10, 2018
@szha
Copy link
Member

szha commented Mar 10, 2018

Since this change is only useful for inference, RNN layer still needs to remain a Block. Once the backward is in place, we will be able to change it to a HybridBlock.

@Jerryzcn Jerryzcn deleted the cpu-lstm branch March 10, 2018 09:15
jinhuang415 pushed a commit to jinhuang415/incubator-mxnet that referenced this pull request Mar 30, 2018
* fix autograd import path

* cpu lstm working

* remove fatal log

* add simple unittest
remove redundant log
enable openmp

* fused input2hidden gemm

* fix lint

* fix pylint

* fix windows build error

* fix gluon rnn interface

* Update dataloader.py

* address cr

* address cr

* fix import

* revert some cosmetic change

* fix typo

* remove newline

* rm virtual
mv hardcoded number to constant

* address cr
add tests

* simplify test

* fix test

* fix tests

* change magic number scope
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* fix autograd import path

* cpu lstm working

* remove fatal log

* add simple unittest
remove redundant log
enable openmp

* fused input2hidden gemm

* fix lint

* fix pylint

* fix windows build error

* fix gluon rnn interface

* Update dataloader.py

* address cr

* address cr

* fix import

* revert some cosmetic change

* fix typo

* remove newline

* rm virtual
mv hardcoded number to constant

* address cr
add tests

* simplify test

* fix test

* fix tests

* change magic number scope
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* fix autograd import path

* cpu lstm working

* remove fatal log

* add simple unittest
remove redundant log
enable openmp

* fused input2hidden gemm

* fix lint

* fix pylint

* fix windows build error

* fix gluon rnn interface

* Update dataloader.py

* address cr

* address cr

* fix import

* revert some cosmetic change

* fix typo

* remove newline

* rm virtual
mv hardcoded number to constant

* address cr
add tests

* simplify test

* fix test

* fix tests

* change magic number scope
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants