Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-72] Improve sparse.adam_update #10062

Merged
merged 19 commits into from
Mar 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions benchmark/python/sparse/updater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import time
import mxnet as mx
from mxnet.ndarray.sparse import adam_update
import numpy as np
import argparse

mx.random.seed(0)
np.random.seed(0)

parser = argparse.ArgumentParser(description='Benchmark adam updater')
parser.add_argument('--dim-in', type=int, default=240000, help='weight.shape[0]')
parser.add_argument('--dim-out', type=int, default=512, help='weight.shape[1]')
parser.add_argument('--nnr', type=int, default=5000, help='grad.indices.shape[0]')
parser.add_argument('--repeat', type=int, default=1000, help='num repeat')
parser.add_argument('--dense-grad', action='store_true',
help='if set to true, both gradient and weight are dense.')
parser.add_argument('--dense-state', action='store_true',
help='if set to true, states are dense, indicating standard update')
parser.add_argument('--cpu', action='store_true')


args = parser.parse_args()
dim_in = args.dim_in
dim_out = args.dim_out
nnr = args.nnr
ctx = mx.cpu() if args.cpu else mx.gpu()

ones = mx.nd.ones((dim_in, dim_out), ctx=ctx)

if not args.dense_grad:
weight = ones.tostype('row_sparse')
indices = np.arange(dim_in)
np.random.shuffle(indices)
indices = np.unique(indices[:nnr])
indices = mx.nd.array(indices, ctx=ctx)
grad = mx.nd.sparse.retain(weight, indices)
else:
weight = ones.copy()
grad = ones.copy()

if args.dense_state:
mean = ones.copy()
else:
mean = ones.tostype('row_sparse')

var = mean.copy()

# warmup
for i in range(10):
adam_update(weight, grad, mean, var, out=weight, lr=1, wd=0, beta1=0.9,
beta2=0.99, rescale_grad=0.5, epsilon=1e-8)
weight.wait_to_read()

# measure speed
a = time.time()
for i in range(args.repeat):
adam_update(weight, grad, mean, var, out=weight, lr=1, wd=0, beta1=0.9,
beta2=0.99, rescale_grad=0.5, epsilon=1e-8)
weight.wait_to_read()
b = time.time()
print(b - a)
78 changes: 39 additions & 39 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -749,14 +749,17 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
});
}

template<int req, typename xpu>
struct AdamDnsRspDnsKernel;

/*!
* Note: this kernel performs sparse adam update. For each row-slice in row_sparse
* gradient, it finds the corresponding elements in weight, mean and var and performs
* the update.
* The kernel assumes dense weight/mean/var, and row_sparse gradient
*/
template<int req>
struct AdamDnsRspDnsKernel {
struct AdamDnsRspDnsKernel<req, cpu> {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
Expand Down Expand Up @@ -788,6 +791,33 @@ struct AdamDnsRspDnsKernel {
};


template<int req>
struct AdamDnsRspDnsKernel<req, gpu> {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
const DType* grad_data, const DType clip_gradient, const DType beta1, const DType beta2,
const DType lr, const DType wd, const DType epsilon, const DType rescale_grad) {
using nnvm::dim_t;
using namespace mshadow_op;
const dim_t row_id = i / row_length;
const dim_t col_id = i % row_length;
const dim_t row_offset = grad_idx[row_id] * row_length;
// index in data/mean/var
const dim_t data_i = row_offset + col_id;
// index in grad
DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[data_i] * wd;
if (clip_gradient >= 0.0f) {
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
}
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
var_data[data_i] = beta2 * var_data[data_i] +
(1.f - beta2) * grad_rescaled * grad_rescaled;
KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
(square_root::Map(var_data[data_i]) + epsilon));
}
};

template<typename xpu>
inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param,
const OpContext& ctx,
Expand Down Expand Up @@ -817,8 +847,12 @@ inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param,
DType* out_data = out->dptr<DType>();
nnvm::dim_t num_rows = grad.aux_shape(kIdx)[0];
const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
Kernel<AdamDnsRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, row_length,
out_data, mean_data, var_data, weight_data, grad_idx, grad_val,
size_t num_threads = num_rows;
if (std::is_same<xpu, gpu>::value) {
num_threads = num_rows * row_length;
}
Kernel<AdamDnsRspDnsKernel<req_type, xpu>, xpu>::Launch(s, num_threads,
row_length, out_data, mean_data, var_data, weight_data, grad_idx, grad_val,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
Expand Down Expand Up @@ -858,42 +892,8 @@ inline void AdamUpdateRspRspRspImpl(const AdamParam& param,
var.data(), req, &out_blob);
}

template<int req>
struct AdamStdDnsRspDnsKernel {
template<typename DType, typename IType, typename RType>
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
const DType beta1, const DType beta2, const DType lr, const DType wd,
const DType epsilon, const DType rescale_grad) {
using namespace mshadow_op;
const bool non_zero = (i == 0) ? prefix_sum[0] > 0
: prefix_sum[i] > prefix_sum[i-1];

const index_t row_i = i * row_length;
const RType grad_i = (prefix_sum[i]-1) * row_length;
for (index_t j = 0; j < row_length; j++) {
const index_t data_i = row_i + j;
const DType grad_rescaled = non_zero ? static_cast<DType>(
grad_data[grad_i + j] * rescale_grad +
weight_data[data_i] * wd)
: static_cast<DType>(weight_data[data_i] * wd);
if (clip_gradient >= 0.0f) {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) *
clip::Map(grad_rescaled, clip_gradient);
var_data[data_i] = beta2 * var_data[data_i] + (1.f - beta2) * square::Map(
clip::Map(grad_rescaled, clip_gradient));
} else {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
var_data[data_i] = beta2 * var_data[data_i] +
(1.f - beta2) * square::Map(grad_rescaled);
}
KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
(square_root::Map(var_data[data_i]) + epsilon));
}
}
};

template<int req, typename xpu>
struct AdamStdDnsRspDnsKernel;

template<typename xpu>
void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param,
Expand Down
39 changes: 38 additions & 1 deletion src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,43 @@ void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
});
}

template<int req>
struct AdamStdDnsRspDnsKernel<req, cpu> {
template<typename DType, typename IType, typename RType>
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
const DType beta1, const DType beta2, const DType lr, const DType wd,
const DType epsilon, const DType rescale_grad) {
using namespace mshadow_op;
const bool non_zero = (i == 0) ? prefix_sum[0] > 0
: prefix_sum[i] > prefix_sum[i-1];

const index_t row_i = i * row_length;
const RType grad_i = (prefix_sum[i]-1) * row_length;
for (index_t j = 0; j < row_length; j++) {
const index_t data_i = row_i + j;
const DType grad_rescaled = non_zero ? static_cast<DType>(
grad_data[grad_i + j] * rescale_grad +
weight_data[data_i] * wd)
: static_cast<DType>(weight_data[data_i] * wd);
if (clip_gradient >= 0.0f) {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) *
clip::Map(grad_rescaled, clip_gradient);
var_data[data_i] = beta2 * var_data[data_i] + (1.f - beta2) * square::Map(
clip::Map(grad_rescaled, clip_gradient));
} else {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
var_data[data_i] = beta2 * var_data[data_i] +
(1.f - beta2) * square::Map(grad_rescaled);
}
KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
(square_root::Map(var_data[data_i]) + epsilon));
}
}
};


template<>
void AdamStdUpdateDnsRspDnsImpl<cpu>(const AdamParam& param,
const OpContext& ctx,
Expand Down Expand Up @@ -194,7 +231,7 @@ void AdamStdUpdateDnsRspDnsImpl<cpu>(const AdamParam& param,
}
}

Kernel<AdamStdDnsRspDnsKernel<req_type>, cpu>::Launch(s, num_rows, row_length,
Kernel<AdamStdDnsRspDnsKernel<req_type, cpu>, cpu>::Launch(s, num_rows, row_length,
out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
Expand Down
37 changes: 33 additions & 4 deletions src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,35 @@ void SGDMomStdUpdateDnsRspDnsImpl<gpu>(const SGDMomParam& param,
});
}

template<int req>
struct AdamStdDnsRspDnsKernel<req, gpu> {
template<typename DType, typename IType, typename RType>
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
const DType beta1, const DType beta2, const DType lr, const DType wd,
const DType epsilon, const DType rescale_grad) {
using namespace mshadow_op;
using nnvm::dim_t;
const dim_t row_id = i / row_length;
const dim_t col_id = i % row_length;
const bool non_zero = (row_id == 0) ? prefix_sum[0] > 0
: prefix_sum[row_id] > prefix_sum[row_id - 1];
const RType grad_offset = (prefix_sum[row_id] - 1) * row_length + col_id;
DType grad_rescaled = non_zero ? static_cast<DType>(grad_data[grad_offset] * rescale_grad
+ weight_data[i] * wd)
: static_cast<DType>(weight_data[i] * wd);
if (clip_gradient >= 0.0f) {
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
}
mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
var_data[i] = beta2 * var_data[i] +
(1.f - beta2) * square::Map(grad_rescaled);
KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * mean_data[i] /
(square_root::Map(var_data[i]) + epsilon));
}
};

template<>
void AdamStdUpdateDnsRspDnsImpl<gpu>(const AdamParam& param,
const OpContext& ctx,
Expand Down Expand Up @@ -122,8 +151,8 @@ void AdamStdUpdateDnsRspDnsImpl<gpu>(const AdamParam& param,
DType* mean_data = mean.dptr<DType>();
DType* var_data = var.dptr<DType>();
DType* out_data = out->dptr<DType>();
nnvm::dim_t num_rows = weight.shape_[0];
nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());
const nnvm::dim_t num_rows = weight.shape_[0];
const nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());
nnvm::dim_t* prefix_sum = NULL;
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
Expand Down Expand Up @@ -152,8 +181,8 @@ void AdamStdUpdateDnsRspDnsImpl<gpu>(const AdamParam& param,
Stream<gpu>::GetStream(s));
}

Kernel<AdamStdDnsRspDnsKernel<req_type>, gpu>::Launch(s, num_rows, row_length,
out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum,
Kernel<AdamStdDnsRspDnsKernel<req_type, gpu>, gpu>::Launch(s, weight.shape_.Size(),
row_length, out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
Expand Down
16 changes: 8 additions & 8 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,13 +543,13 @@ def test_ftml():
class PyAdam(mx.optimizer.Optimizer):
"""python reference implemenation of adam"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
decay_factor=(1 - 1e-8), sparse_update=False, **kwargs):
decay_factor=(1 - 1e-8), lazy_update=False, **kwargs):
super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.decay_factor = decay_factor
self.sparse_update = sparse_update
self.lazy_update = lazy_update

def create_state(self, index, weight):
"""Create additional optimizer state: mean, variance
Expand Down Expand Up @@ -595,7 +595,7 @@ def update(self, index, weight, grad, state):
# check row slices of all zeros
all_zeros = mx.test_utils.almost_equal(grad[row].asnumpy(), np.zeros_like(grad[row].asnumpy()))
# skip zeros during sparse update
if all_zeros and self.sparse_update:
if all_zeros and self.lazy_update:
continue
grad[row] = grad[row] * self.rescale_grad + wd * weight[row]
# clip gradients
Expand Down Expand Up @@ -638,7 +638,7 @@ def test_adam():
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
rtol=1e-4, atol=2e-5)
# atol 2e-5 needed to pass with seed 781809840
compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape,
compare_optimizer(opt1(lazy_update=True, **kwarg), opt2(**kwarg), shape,
dtype, w_stype='row_sparse', g_stype='row_sparse',
rtol=1e-4, atol=2e-5)
compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape,
Expand Down Expand Up @@ -883,12 +883,12 @@ class PyFtrl(mx.optimizer.Optimizer):
\\eta_{t,i} = \\frac{learningrate}{\\beta+\\sqrt{\\sum_{s=1}^tg_{s,i}^t}}
"""

def __init__(self, lamda1=0.01, learning_rate=0.1, beta=1, sparse_update=False, **kwargs):
def __init__(self, lamda1=0.01, learning_rate=0.1, beta=1, lazy_update=False, **kwargs):
super(PyFtrl, self).__init__(**kwargs)
self.lamda1 = lamda1
self.beta = beta
self.lr = learning_rate
self.sparse_update = sparse_update
self.lazy_update = lazy_update

def create_state(self, index, weight):
return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), # dn
Expand All @@ -903,7 +903,7 @@ def update(self, index, weight, grad, state):
dn, n = state
for row in range(num_rows):
all_zeros = mx.test_utils.almost_equal(grad[row].asnumpy(), np.zeros_like(grad[row].asnumpy()))
if all_zeros and self.sparse_update:
if all_zeros and self.lazy_update:
continue
grad[row] = grad[row] * self.rescale_grad
if self.clip_gradient is not None:
Expand Down Expand Up @@ -933,7 +933,7 @@ def test_ftrl():
{'clip_gradient': 0.5, 'wd': 0.07, 'lamda1': 1.0}]
for kwarg in kwargs:
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32)
compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape,
compare_optimizer(opt1(lazy_update=True, **kwarg), opt2(**kwarg), shape,
np.float32, w_stype='row_sparse', g_stype='row_sparse')

@with_seed(1234)
Expand Down