Skip to content

Commit 6c7dfa5

Browse files
ChaiBapchyaUbuntu
authored and
Ubuntu
committed
[Large Tensor] Add LT support for NN optimizers and 1 activation function (apache#17444)
* fix hard sigmoid * change int i to index_t i for all Kernel Map functions * fix lint * size t indext fix
1 parent 8a7e977 commit 6c7dfa5

File tree

2 files changed

+43
-40
lines changed

2 files changed

+43
-40
lines changed

src/operator/optimizer_op-inl.h

+41-38
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,10 @@ struct MultiSGDKernelParam {
225225
template <typename MPDType, bool has_momentum, bool has_mixed_precision>
226226
struct MultiSGDKernel {
227227
template<typename DType>
228-
MSHADOW_XINLINE static void Map(int i, const MultiSGDKernelParam<DType, MPDType>& param,
228+
MSHADOW_XINLINE static void Map(index_t i, const MultiSGDKernelParam<DType, MPDType>& param,
229229
const OpReqType req) {
230230
for (int index = 0; index < param.count; ++index) {
231-
if ((size_t)i < param.sizes[index]) {
231+
if (i < static_cast<index_t>(param.sizes[index])) {
232232
MPDType w = has_mixed_precision ? param.weights32[index][i] :
233233
MPDType(param.weights[index][i]);
234234
MPDType mom = has_momentum ? param.mom[index][i] : MPDType(0);
@@ -381,7 +381,7 @@ inline void MultiSGDMomUpdate(const nnvm::NodeAttrs& attrs,
381381

382382
struct SGDKernel {
383383
template<typename DType>
384-
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
384+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* weight_data,
385385
const DType* grad_data, const DType param_clip_gradient,
386386
const DType param_lr, const DType param_wd, const DType param_rescale_grad,
387387
const OpReqType req) {
@@ -429,9 +429,9 @@ struct SGDDnsRspKernel<req, gpu> {
429429
// IType is row sparse idx type
430430
// i is the ith element in row sparse gradient
431431
template<typename DType, typename IType>
432-
MSHADOW_XINLINE static void Map(int i, const index_t row_length, DType* out, const DType* weight,
433-
const IType* grad_idx, const DType *grad_val,
434-
const DType clip_gradient, const DType lr,
432+
MSHADOW_XINLINE static void Map(index_t i, const index_t row_length, DType* out,
433+
const DType* weight, const IType* grad_idx,
434+
const DType *grad_val, const DType clip_gradient, const DType lr,
435435
const DType wd, const DType rescale_grad) {
436436
using nnvm::dim_t;
437437
using namespace mshadow_op;
@@ -457,9 +457,9 @@ struct SGDDnsRspKernel<req, cpu> {
457457
// IType is row sparse idx type
458458
// i is the ith row in row sparse gradient
459459
template<typename DType, typename IType>
460-
MSHADOW_XINLINE static void Map(int i, const index_t row_length, DType* out, const DType* weight,
461-
const IType* grad_idx, const DType *grad_val,
462-
const DType clip_gradient, const DType lr,
460+
MSHADOW_XINLINE static void Map(index_t i, const index_t row_length, DType* out,
461+
const DType* weight, const IType* grad_idx,
462+
const DType *grad_val, const DType clip_gradient, const DType lr,
463463
const DType wd, const DType rescale_grad) {
464464
for (index_t j = 0; j < row_length; j++) {
465465
index_t data_i = grad_idx[i] * row_length + j;
@@ -600,10 +600,11 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
600600

601601
struct SGDMomKernel {
602602
template<typename DType>
603-
MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data,
604-
const DType* grad_data, const DType param_clip_gradient, const DType param_momentum,
605-
const DType param_lr, const DType param_wd, const DType param_rescale_grad,
606-
const OpReqType req) {
603+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, DType* mom_data,
604+
const DType* weight_data, const DType* grad_data,
605+
const DType param_clip_gradient, const DType param_momentum,
606+
const DType param_lr, const DType param_wd,
607+
const DType param_rescale_grad, const OpReqType req) {
607608
if (param_clip_gradient >= 0.0f) {
608609
mom_data[i] = param_momentum*mom_data[i]
609610
- param_lr*param_wd*weight_data[i]
@@ -654,7 +655,7 @@ inline bool MP_InferType(const nnvm::NodeAttrs& attrs,
654655

655656
struct MP_SGDKernel {
656657
template<typename DType>
657-
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
658+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* weight_data,
658659
const DType* grad_data, float* weight32, const float param_clip_gradient,
659660
const float param_lr, const float param_wd, const float param_rescale_grad,
660661
const OpReqType req) {
@@ -698,7 +699,7 @@ inline void MP_SGDUpdate(const nnvm::NodeAttrs& attrs,
698699

699700
struct MP_SGDMomKernel {
700701
template<typename DType>
701-
MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mom_data,
702+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, float* mom_data,
702703
const DType* weight_data, const DType* grad_data, float* weight32,
703704
const float param_clip_gradient, const float param_momentum, const float param_lr,
704705
const float param_wd, const float param_rescale_grad, const OpReqType req) {
@@ -749,7 +750,7 @@ struct SGDMomDnsRspDnsKernel;
749750
template<int req>
750751
struct SGDMomDnsRspDnsKernel<req, cpu> {
751752
template<typename DType, typename IType>
752-
MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data,
753+
MSHADOW_XINLINE static void Map(index_t i, index_t row_length, DType* out_data,
753754
DType* mom_data, const DType* weight_data, const IType* grad_idx,
754755
const DType* grad_data, const DType clip_gradient, const DType momentum,
755756
const DType lr, const DType wd, const DType rescale_grad) {
@@ -776,7 +777,7 @@ struct SGDMomDnsRspDnsKernel<req, cpu> {
776777
template<int req>
777778
struct SGDMomDnsRspDnsKernel<req, gpu> {
778779
template<typename DType, typename IType>
779-
MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data,
780+
MSHADOW_XINLINE static void Map(index_t i, index_t row_length, DType* out_data,
780781
DType* mom_data, const DType* weight_data, const IType* grad_idx,
781782
const DType* grad_data, const DType clip_gradient, const DType momentum,
782783
const DType lr, const DType wd, const DType rescale_grad) {
@@ -1060,7 +1061,7 @@ struct NAGMomParam : public dmlc::Parameter<NAGMomParam> {
10601061

10611062
struct NAGMomKernel {
10621063
template<typename DType>
1063-
MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data,
1064+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, DType* mom_data,
10641065
const DType* weight_data, const DType* grad_data,
10651066
const DType param_clip_gradient, const DType param_momentum,
10661067
const DType param_lr, const DType param_wd,
@@ -1107,7 +1108,7 @@ inline void NAGMomUpdate(const nnvm::NodeAttrs& attrs,
11071108

11081109
struct MP_NAGMomKernel {
11091110
template<typename DType>
1110-
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1111+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
11111112
float* mom_data, const DType* weight_data,
11121113
const DType* grad_data, float* weight32,
11131114
const float param_clip_gradient,
@@ -1204,7 +1205,7 @@ struct FTMLParam : public dmlc::Parameter<FTMLParam> {
12041205

12051206
struct FTMLKernel {
12061207
template<typename DType>
1207-
MSHADOW_XINLINE static void Map(int i, DType* out, DType* weight, DType* grad,
1208+
MSHADOW_XINLINE static void Map(index_t i, DType* out, DType* weight, DType* grad,
12081209
DType* d, DType* v, DType* z, const DType lr, const DType beta1,
12091210
const DType beta2, const DType epsilon, const DType t,
12101211
const DType wd, const DType rescale_grad, const DType clip_grad,
@@ -1291,7 +1292,7 @@ struct AdamParam : public dmlc::Parameter<AdamParam> {
12911292

12921293
struct AdamUpdateKernel {
12931294
template<typename DType>
1294-
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1295+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
12951296
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
12961297
const DType clip_gradient, const DType rescale_grad,
12971298
const DType beta1, const DType beta2,
@@ -1350,7 +1351,7 @@ struct AdamDnsRspDnsKernel;
13501351
template<int req>
13511352
struct AdamDnsRspDnsKernel<req, cpu> {
13521353
template<typename DType, typename IType>
1353-
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
1354+
MSHADOW_XINLINE static void Map(index_t i, const nnvm::dim_t row_length, DType* out_data,
13541355
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
13551356
const DType* grad_data, const DType clip_gradient, const DType beta1, const DType beta2,
13561357
const DType lr, const DType wd, const DType epsilon, const DType rescale_grad) {
@@ -1383,7 +1384,7 @@ struct AdamDnsRspDnsKernel<req, cpu> {
13831384
template<int req>
13841385
struct AdamDnsRspDnsKernel<req, gpu> {
13851386
template<typename DType, typename IType>
1386-
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
1387+
MSHADOW_XINLINE static void Map(index_t i, const nnvm::dim_t row_length, DType* out_data,
13871388
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
13881389
const DType* grad_data, const DType clip_gradient, const DType beta1, const DType beta2,
13891390
const DType lr, const DType wd, const DType epsilon, const DType rescale_grad) {
@@ -1620,7 +1621,7 @@ struct LambUpdatePhaseTwoParam : public dmlc::Parameter<LambUpdatePhaseTwoParam>
16201621

16211622
struct LambUpdatePhaseOneKernel {
16221623
template<typename DType>
1623-
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1624+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
16241625
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
16251626
const DType clip_gradient, const DType rescale_grad,
16261627
const DType beta1, const DType beta1_t, const DType beta2, const DType beta2_t,
@@ -1704,7 +1705,7 @@ inline bool LambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs,
17041705

17051706
struct LambUpdatePhaseTwoKernel {
17061707
template<typename DType>
1707-
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1708+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
17081709
const DType* weight_data, const DType* g,
17091710
const DType* r1, const DType* r2,
17101711
DType lr, const DType lower_bound,
@@ -1771,7 +1772,7 @@ inline bool MPLambPhaseOneType(const nnvm::NodeAttrs& attrs,
17711772

17721773
struct MPLambUpdatePhaseOneKernel {
17731774
template<typename DType>
1774-
MSHADOW_XINLINE static void Map(int i, float* out_data,
1775+
MSHADOW_XINLINE static void Map(index_t i, float* out_data,
17751776
float* mean_data, float* var_data, const DType* weight_data,
17761777
const DType* grad_data, const float* weight32_data,
17771778
const float clip_gradient, const float rescale_grad,
@@ -1861,7 +1862,7 @@ inline bool MPLambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs,
18611862

18621863
struct MPLambUpdatePhaseTwoKernel {
18631864
template<typename DType>
1864-
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1865+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
18651866
const DType* weight_data, const float* g,
18661867
const float* r1, const float* r2, const float* weight32_data,
18671868
float lr, const float lower_bound,
@@ -1952,7 +1953,7 @@ struct RMSPropAlexParam : public dmlc::Parameter<RMSPropAlexParam> {
19521953

19531954
struct RMSPropAlexUpdateKernel {
19541955
template<typename DType>
1955-
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1956+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
19561957
DType* state_n_data, DType* state_g_data, DType* delta_data,
19571958
const DType* weight_data, const DType* grad_data,
19581959
const DType clip_gradient, const DType rescale_grad,
@@ -2051,7 +2052,7 @@ struct RMSPropParam : public dmlc::Parameter<RMSPropParam> {
20512052

20522053
struct RMSPropUpdateKernel {
20532054
template<typename DType>
2054-
MSHADOW_XINLINE static void Map(int i,
2055+
MSHADOW_XINLINE static void Map(index_t i,
20552056
DType* out_data, DType* state_n_data,
20562057
const DType* weight_data, const DType* grad_data,
20572058
const DType clip_gradient, const DType rescale_grad,
@@ -2132,7 +2133,7 @@ struct FtrlParam : public dmlc::Parameter<FtrlParam> {
21322133

21332134
struct FtrlUpdateKernel {
21342135
template<typename DType>
2135-
MSHADOW_XINLINE static void Map(int i, DType* out_data,
2136+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
21362137
DType* n_data, DType* z_data, const DType* weight_data, const DType* grad_data,
21372138
const DType clip_gradient, const DType rescale_grad,
21382139
const DType beta, const DType lamda1,
@@ -2185,7 +2186,7 @@ inline void FtrlUpdate(const nnvm::NodeAttrs& attrs,
21852186
template<int req>
21862187
struct FtrlDnsRspDnsKernel {
21872188
template<typename DType, typename IType>
2188-
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
2189+
MSHADOW_XINLINE static void Map(index_t i, const nnvm::dim_t row_length, DType* out_data,
21892190
DType* z_data, DType* n_data, const DType* weight_data, const IType* grad_idx,
21902191
const DType* grad_data, const DType clip_gradient, const DType lamda1, const DType beta,
21912192
const DType lr, const DType wd, const DType rescale_grad) {
@@ -2343,7 +2344,7 @@ struct SignSGDParam : public dmlc::Parameter<SignSGDParam> {
23432344

23442345
struct SignSGDKernel {
23452346
template<typename DType>
2346-
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
2347+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* weight_data,
23472348
const DType* grad_data, const DType param_clip_gradient,
23482349
const DType param_lr, const DType param_wd, const DType param_rescale_grad,
23492350
const OpReqType req) {
@@ -2411,10 +2412,12 @@ struct SignumParam : public dmlc::Parameter<SignumParam> {
24112412

24122413
struct SignumKernel {
24132414
template<typename DType>
2414-
MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data,
2415-
const DType* grad_data, const DType param_clip_gradient, const DType param_momentum,
2416-
const DType param_lr, const DType param_wd, const DType param_rescale_grad,
2417-
const DType param_wd_lh, const OpReqType req) {
2415+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, DType* mom_data,
2416+
const DType* weight_data, const DType* grad_data,
2417+
const DType param_clip_gradient, const DType param_momentum,
2418+
const DType param_lr, const DType param_wd,
2419+
const DType param_rescale_grad, const DType param_wd_lh,
2420+
const OpReqType req) {
24182421
if (param_clip_gradient >= 0.0f) {
24192422
mom_data[i] = param_momentum*mom_data[i]
24202423
- (1-param_momentum)*param_wd*weight_data[i]
@@ -2506,7 +2509,7 @@ struct AdagradDnsRspDnsKernel;
25062509
template<>
25072510
struct AdagradDnsRspDnsKernel<cpu> {
25082511
template<typename DType, typename IType>
2509-
MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data,
2512+
MSHADOW_XINLINE static void Map(index_t i, index_t row_length, DType* out_data,
25102513
DType* state_data, const DType* weight_data, const IType* grad_idx,
25112514
const DType* grad_data, const DType clip_gradient, const DType epsilon,
25122515
const DType lr, const DType rescale_grad) {
@@ -2533,7 +2536,7 @@ struct AdagradDnsRspDnsKernel<cpu> {
25332536
template<>
25342537
struct AdagradDnsRspDnsKernel<gpu> {
25352538
template<typename DType, typename IType>
2536-
MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data,
2539+
MSHADOW_XINLINE static void Map(index_t i, index_t row_length, DType* out_data,
25372540
DType* state_data, const DType* weight_data, const IType* grad_idx,
25382541
const DType* grad_data, const DType clip_gradient, const DType epsilon,
25392542
const DType lr, const DType rescale_grad) {

src/operator/tensor/elemwise_unary_op.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ struct HardSigmoidParam : public dmlc::Parameter<HardSigmoidParam> {
495495
template<int req>
496496
struct hard_sigmoid_forward {
497497
template<typename DType>
498-
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data,
498+
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data,
499499
const real_t alpha, const real_t beta) {
500500
DType result = DType(alpha * in_data[i] + beta);
501501
result = (DType(1) < result) ? DType(1) : result;
@@ -507,7 +507,7 @@ struct hard_sigmoid_forward {
507507
template<int req>
508508
struct hard_sigmoid_backward {
509509
template<typename DType>
510-
MSHADOW_XINLINE static void Map(int i, DType* in_grad, const DType* in_data,
510+
MSHADOW_XINLINE static void Map(index_t i, DType* in_grad, const DType* in_data,
511511
const DType* out_grad, const real_t alpha, const real_t beta) {
512512
DType out_val = DType(alpha) * in_data[i] + DType(beta);
513513
DType grad = (out_val > DType(0) && out_val < DType(1)) ?

0 commit comments

Comments
 (0)