Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align Momentum Optimizer #8549

Merged
merged 13 commits into from
Jul 6, 2022
6 changes: 5 additions & 1 deletion oneflow/api/python/functional/dispatch_stateful_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,13 +464,17 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor(
"DispatchMomentumUpdate",
[](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, float learning_rate,
double scale, float l1, float l2, float beta, float weight_decay) -> Maybe<void> {
double scale, float l1, float l2, float beta, float dampening, bool nesterov,
bool maximize, float weight_decay) -> Maybe<void> {
MutableAttrMap attrs;
JUST(attrs.SetAttr("learning_rate_val", learning_rate));
JUST(attrs.SetAttr("scale", scale));
JUST(attrs.SetAttr("l1", l1));
JUST(attrs.SetAttr("l2", l2));
JUST(attrs.SetAttr("beta", beta));
JUST(attrs.SetAttr("dampening", dampening));
JUST(attrs.SetAttr("nesterov", nesterov));
JUST(attrs.SetAttr("maximize", maximize));
JUST(attrs.SetAttr("weight_decay", weight_decay));
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok();
Expand Down
2 changes: 1 addition & 1 deletion oneflow/api/python/functional/dispatch_stateful_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
bind_python: True

- name: "dispatch_momentum_update"
signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float beta=0.9, Float weight_decay=0) => DispatchMomentumUpdate"
signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float beta=0.9, Float dampening=0.0, Bool nesterov=False, Bool maximize=False, Float weight_decay=0) => DispatchMomentumUpdate"
bind_python: True

- name: "dispatch_sgd_update"
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/job/job_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ message NaiveModelUpdateConf {

message MomentumModelUpdateConf {
optional float beta = 1 [default = 0.9];
optional float dampening = 2 [default = 0.0];
optional bool nesterov = 3 [default = false];
optional bool maximize = 4 [default = false];
}

message RMSPropModelUpdateConf {
Expand Down
5 changes: 4 additions & 1 deletion oneflow/core/job_rewriter/fuse_update_ops_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ Maybe<void> FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
// do nothing
} else if (user_op_conf.op_type_name() == "momentum_update") {
fused_op_builder.Input("momentum", user_op_conf.input("momentum", 0))
.Attr<float>("beta", user_op_conf.attr<float>("beta"));
.Attr<float>("beta", user_op_conf.attr<float>("beta"))
.Attr<float>("dampening", user_op_conf.attr<float>("dampening"))
.Attr<bool>("nesterov", user_op_conf.attr<bool>("nesterov"))
.Attr<bool>("maximize", user_op_conf.attr<bool>("maximize"));
} else if (user_op_conf.op_type_name() == "adam_update") {
fused_op_builder.Input("m", user_op_conf.input("m", 0))
.Input("v", user_op_conf.input("v", 0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ Maybe<void> IndexedSlicesOptimizerRewritePass::Apply(const OpGraph& op_graph,
// do nothing
} else if (user_op_conf.op_type_name() == "momentum_update") {
indexed_slices_op_builder.Input("momentum", user_op_conf.input("momentum", 0))
.Attr<float>("beta", user_op_conf.attr<float>("beta"));
.Attr<float>("beta", user_op_conf.attr<float>("beta"))
.Attr<float>("dampening", user_op_conf.attr<float>("dampening"))
.Attr<bool>("nesterov", user_op_conf.attr<bool>("nesterov"))
.Attr<bool>("maximize", user_op_conf.attr<bool>("maximize"));
} else if (user_op_conf.op_type_name() == "adam_update") {
indexed_slices_op_builder.Input("m", user_op_conf.input("m", 0))
.Input("v", user_op_conf.input("v", 0))
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/job_rewriter/momentum_optm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,
.Input("learning_rate", optimizer_conf.learning_rate_lbn())
.Input("momentum", GenLogicalBlobName(op_name, momentum_var.variable_conf().out()))
.Attr<float>("beta", optimizer_conf.momentum_conf().beta())
.Attr<float>("dampening", optimizer_conf.momentum_conf().dampening())
.Attr<bool>("nesterov", optimizer_conf.momentum_conf().nesterov())
.Attr<bool>("maximize", optimizer_conf.momentum_conf().maximize())
.Attr<float>("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))
.ScopeSymbolId(var_op->op_conf().scope_symbol_id());
SetDynamicLossScaleSkipIf(ctx, &momentum_update_op_builder);
Expand Down
6 changes: 6 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5889,6 +5889,9 @@ def OneFlow_IndexedSlicesMomentumUpdateOp : OneFlow_BaseOp<"indexed_slices_momen
);
let attrs = (ins
DefaultValuedAttr<F32Attr, "0.9">:$beta,
DefaultValuedAttr<F32Attr, "0.0">:$dampening,
DefaultValuedAttr<BoolAttr, "false">:$nesterov,
DefaultValuedAttr<BoolAttr, "false">:$maximize,
DefaultValuedAttr<F32Attr, "0.">:$weight_decay
);
let has_logical_tensor_desc_infer_fn = 1;
Expand Down Expand Up @@ -5993,6 +5996,9 @@ def OneFlow_MomentumUpdateOp : OneFlow_BaseOp<"momentum_update", [NoGrad, AttrSi
DefaultValuedAttr<F32Attr, "0.">:$l1,
DefaultValuedAttr<F32Attr, "0.">:$l2,
DefaultValuedAttr<F32Attr, "0.9">:$beta,
DefaultValuedAttr<F32Attr, "0.0">:$dampening,
DefaultValuedAttr<BoolAttr, "false">:$nesterov,
DefaultValuedAttr<BoolAttr, "false">:$maximize,
DefaultValuedAttr<F32Attr, "0.">:$weight_decay
);
let trait_attrs = (ins
Expand Down
33 changes: 18 additions & 15 deletions oneflow/user/kernels/model_update_kernel_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,23 @@ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_
template<typename T, typename G>
struct MomentumUpdateKernelUtil<DeviceType::kCPU, T, G> {
static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta,
float weight_decay, float learning_rate_val, const float* learning_rate,
const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,
T* momentum);
float dampening, bool nesterov, bool maximize, float weight_decay,
float learning_rate_val, const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const G* model_diff, T* model, T* momentum);
};

template<typename T, typename G>
void MomentumUpdateKernelUtil<DeviceType::kCPU, T, G>::Update(
ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float weight_decay,
float learning_rate_val, const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const G* model_diff, T* model, T* momentum) {
ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening,
bool nesterov, bool maximize, float weight_decay, float learning_rate_val,
const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,
T* model, T* momentum) {
if (skip_if != nullptr && *skip_if != 0) { return; }
if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }
if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }
for (int64_t i = 0; i != n; ++i) {
MomentumUpdateFunctor<T, G>()(model_diff + i, model + i, momentum + i, scale, l1, l2, beta,
weight_decay, learning_rate_val);
dampening, nesterov, maximize, weight_decay, learning_rate_val);
}
}

Expand All @@ -132,17 +133,19 @@ template struct MomentumUpdateKernelUtil<DeviceType::kCPU, double, double>;

template<typename T, typename K, typename IDX>
struct IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCPU, T, K, IDX> {
static void Update(ep::Stream* stream, T beta, float weight_decay, int64_t num_instance,
int64_t feature_size, int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance, const float* learning_rate, const K* indices,
const T* values, T* model, T* momentum);
static void Update(ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize,
float weight_decay, int64_t num_instance, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices, const T* values, T* model,
T* momentum);
};

template<typename T, typename K, typename IDX>
void IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCPU, T, K, IDX>::Update(
ep::Stream* stream, T beta, float weight_decay, int64_t num_instance, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices, const T* values, T* model, T* momentum) {
ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, float weight_decay,
int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values,
T* model, T* momentum) {
const int64_t n = *num_unique_instance * feature_size;
const T lr = *learning_rate;
for (int64_t i = 0; i != n; ++i) {
Expand All @@ -152,7 +155,7 @@ void IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCPU, T, K, IDX>::Updat
if (instance_id >= lower_bound && instance_id < upper_bound) {
const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx;
MomentumUpdateFunctor<T, T>()(values + i, model + model_idx, momentum + model_idx, 1.0, 0.0,
0.0, beta, weight_decay, lr);
0.0, beta, dampening, nesterov, maximize, weight_decay, lr);
}
}
}
Expand Down
75 changes: 41 additions & 34 deletions oneflow/user/kernels/model_update_kernel_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,24 @@ namespace {

template<typename T, typename G>
__global__ void MomentumUpdateGpu(int64_t n, T scale, float l1, float l2, float beta,
float weight_decay, float learning_rate_val,
const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const G* model_diff, T* model,
T* momentum) {
float dampening, bool nesterov, bool maximize, float weight_decay,
float learning_rate_val, const float* learning_rate,
const T* scale_by_ptr, const int64_t* skip_if,
const G* model_diff, T* model, T* momentum) {
if (skip_if != nullptr && *skip_if != 0) { return; }
if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }
if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }
CUDA_1D_KERNEL_LOOP(i, n) {
MomentumUpdateFunctor<T, G>()(model_diff + i, model + i, momentum + i, scale, l1, l2, beta,
weight_decay, learning_rate_val);
dampening, nesterov, maximize, weight_decay, learning_rate_val);
}
}

template<typename T, typename K, typename IDX>
__global__ void IndexedSlicesMomentumUpdateGpu(T beta, float weight_decay, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance,
__global__ void IndexedSlicesMomentumUpdateGpu(T beta, float dampening, bool nesterov,
bool maximize, float weight_decay,
int64_t feature_size, int64_t lower_bound,
int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices,
const T* values, T* model, T* momentum) {
const int64_t n = *num_unique_instance * feature_size;
Expand All @@ -202,7 +203,8 @@ __global__ void IndexedSlicesMomentumUpdateGpu(T beta, float weight_decay, int64
if (instance_id >= lower_bound && instance_id < upper_bound) {
const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx;
MomentumUpdateFunctor<T, T>()(values + i, model + model_idx, momentum + model_idx,
static_cast<T>(1), 0.0, 0.0, beta, weight_decay, lr);
static_cast<T>(1), 0.0, 0.0, beta, dampening, nesterov,
maximize, weight_decay, lr);
}
}
}
Expand All @@ -211,38 +213,41 @@ __global__ void IndexedSlicesMomentumUpdateGpu(T beta, float weight_decay, int64
template<typename T, typename G>
struct MomentumUpdateKernelUtil<DeviceType::kCUDA, T, G> {
static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta,
float weight_decay, float learning_rate_val, const float* learning_rate,
const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,
T* momentum);
float dampening, bool nesterov, bool maximize, float weight_decay,
float learning_rate_val, const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const G* model_diff, T* model, T* momentum);
};

template<typename T, typename G>
void MomentumUpdateKernelUtil<DeviceType::kCUDA, T, G>::Update(
ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float weight_decay,
float learning_rate_val, const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const G* model_diff, T* model, T* momentum) {
ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening,
bool nesterov, bool maximize, float weight_decay, float learning_rate_val,
const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,
T* model, T* momentum) {
MomentumUpdateGpu<T, G><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(
n, scale, l1, l2, beta, weight_decay, learning_rate_val, learning_rate, scale_by_ptr, skip_if,
model_diff, model, momentum);
n, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay, learning_rate_val,
learning_rate, scale_by_ptr, skip_if, model_diff, model, momentum);
}

template<typename T>
struct MomentumUpdateKernelUtil<DeviceType::kCUDA, T, float16> {
static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta,
float weight_decay, float learning_rate_val, const float* learning_rate,
const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff,
T* model, T* momentum);
float dampening, bool nesterov, bool maximize, float weight_decay,
float learning_rate_val, const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const float16* model_diff, T* model, T* momentum);
};

template<typename T>
void MomentumUpdateKernelUtil<DeviceType::kCUDA, T, float16>::Update(
ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float weight_decay,
float learning_rate_val, const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const float16* model_diff, T* model, T* momentum) {
ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening,
bool nesterov, bool maximize, float weight_decay, float learning_rate_val,
const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,
const float16* model_diff, T* model, T* momentum) {
MomentumUpdateKernelUtil<DeviceType::kCUDA, T, half>::Update(
stream, n, scale, l1, l2, beta, weight_decay, learning_rate_val, learning_rate, scale_by_ptr,
skip_if, reinterpret_cast<const half*>(model_diff), model, momentum);
stream, n, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay,
learning_rate_val, learning_rate, scale_by_ptr, skip_if,
reinterpret_cast<const half*>(model_diff), model, momentum);
}

template struct MomentumUpdateKernelUtil<DeviceType::kCUDA, double, double>;
Expand All @@ -251,22 +256,24 @@ template struct MomentumUpdateKernelUtil<DeviceType::kCUDA, float, float16>;

template<typename T, typename K, typename IDX>
struct IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCUDA, T, K, IDX> {
static void Update(ep::Stream* stream, T beta, float weight_decay, int64_t num_instance,
int64_t feature_size, int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance, const float* learning_rate, const K* indices,
const T* values, T* model, T* momentum);
static void Update(ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize,
float weight_decay, int64_t num_instance, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices, const T* values, T* model,
T* momentum);
};

template<typename T, typename K, typename IDX>
void IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCUDA, T, K, IDX>::Update(
ep::Stream* stream, T beta, float weight_decay, int64_t num_instance, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices, const T* values, T* model, T* momentum) {
ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, float weight_decay,
int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values,
T* model, T* momentum) {
IndexedSlicesMomentumUpdateGpu<T, K, IDX>
<<<BlocksNum4ThreadsNum(num_instance * feature_size), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(
beta, weight_decay, feature_size, lower_bound, upper_bound, num_unique_instance,
learning_rate, indices, values, model, momentum);
beta, dampening, nesterov, maximize, weight_decay, feature_size, lower_bound, upper_bound,
num_unique_instance, learning_rate, indices, values, model, momentum);
}

#define INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CUDA( \
Expand Down
Loading