-
Notifications
You must be signed in to change notification settings - Fork 6.8k
standard update for sparse sgd_mom_update #9189
Conversation
python/mxnet/optimizer.py
Outdated
@@ -464,16 +465,19 @@ class SGD(Optimizer): | |||
---------- | |||
momentum : float, optional | |||
The momentum value. | |||
lazy_update : bool, optional | |||
If True, standard updates are applied. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarify the default value, too
@@ -96,11 +147,15 @@ only the row slices whose indices appear in grad.indices are updated (for both w | |||
.set_attr_parser(ParamParser<SGDMomParam>) | |||
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>) | |||
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>) | |||
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<3, 1, false, true, false>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also update the doc here, too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
src/operator/optimizer_op-inl.h
Outdated
} | ||
if (!dispatched && in_attrs->at(0) == kRowSparseStorage && | ||
in_attrs->at(1) == kRowSparseStorage && | ||
(in_attrs->at(2) == kRowSparseStorage || in_attrs->at(2) == kDefaultStorage)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Save in_attrs->at(2)
in a local var to improve readability?
src/operator/optimizer_op.cu
Outdated
DType* mom_data = mom.dptr<DType>(); | ||
DType* out_data = out->dptr<DType>(); | ||
nnvm::dim_t num_rows = weight.shape_[0]; | ||
auto row_length = weight.shape_.ProdShape(1, weight.ndim()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not use auto
src/operator/optimizer_op.cc
Outdated
|
||
nnvm::dim_t* prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_); | ||
// mark row flags | ||
Fill<false>(s, TBlob(prefix_sum, Shape1(num_rows), cpu::kDevMask), kWriteTo, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fill uses memset which is single-thread. It's slow for large number of elements. Let's use Kernel<set_zero, cpu>::Launch
python/mxnet/optimizer.py
Outdated
@@ -433,7 +433,8 @@ def _get_wd(self, index): | |||
class SGD(Optimizer): | |||
"""The SGD optimizer with momentum and weight decay. | |||
|
|||
The optimizer updates the weight by:: | |||
If any storage type of weight, state or grad is ``default``, \ | |||
**standard updates** are applied by:: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should reverse the order and mention lazy_update
, since the users of optimizer don't know how state stypes are created:
If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, **lazy updates** are applied by::
for row in grad.indices:
...
Otherwise, **standard updates** are applied by:
...
src/operator/optimizer_op-inl.h
Outdated
: prefix_sum[i] > prefix_sum[i-1]; | ||
|
||
for (index_t j = 0; j < row_length; j++) { | ||
const index_t data_i = i * row_length + j; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i * row_length
can be cached and computed only once.
Same for (prefix_sum[i]-1) * row_length
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM in general
src/operator/optimizer_op-inl.h
Outdated
@@ -460,6 +461,99 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param, | |||
mom.data(), req, &out_blob); | |||
} | |||
|
|||
template<int n_rsp, int n_rsp_dns> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add some description on what the template params mean
* standard sparse sgd mom update * update * update comments * address comments * revise * more general infer stype * fix * fix * add comments for stype inference func * update
* standard sparse sgd mom update * update * update comments * address comments * revise * more general infer stype * fix * fix * add comments for stype inference func * update
* standard sparse sgd mom update * update * update comments * address comments * revise * more general infer stype * fix * fix * add comments for stype inference func * update
Description
#9177
cc @eric-haibin-lin
Checklist
Essentials
make lint
)Changes
Comments