Skip to content

Commit

Permalink
Align Momentum Optimizer (#8549)
Browse files Browse the repository at this point in the history
* fix moemntum update

* align momentum

* fix bug and finish eager unittest

* Support Graph optimizer

* fix momentum bug

* refine beta

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
MARD1NO and mergify[bot] authored Jul 6, 2022
1 parent 28690a2 commit 1531b06
Show file tree
Hide file tree
Showing 16 changed files with 292 additions and 102 deletions.
6 changes: 5 additions & 1 deletion oneflow/api/python/functional/dispatch_stateful_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,13 +464,17 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor(
"DispatchMomentumUpdate",
[](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs, float learning_rate,
double scale, float l1, float l2, float beta, float weight_decay) -> Maybe<void> {
double scale, float l1, float l2, float beta, float dampening, bool nesterov,
bool maximize, float weight_decay) -> Maybe<void> {
MutableAttrMap attrs;
JUST(attrs.SetAttr("learning_rate_val", learning_rate));
JUST(attrs.SetAttr("scale", scale));
JUST(attrs.SetAttr("l1", l1));
JUST(attrs.SetAttr("l2", l2));
JUST(attrs.SetAttr("beta", beta));
JUST(attrs.SetAttr("dampening", dampening));
JUST(attrs.SetAttr("nesterov", nesterov));
JUST(attrs.SetAttr("maximize", maximize));
JUST(attrs.SetAttr("weight_decay", weight_decay));
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok();
Expand Down
2 changes: 1 addition & 1 deletion oneflow/api/python/functional/dispatch_stateful_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
bind_python: True

- name: "dispatch_momentum_update"
signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float beta=0.9, Float weight_decay=0) => DispatchMomentumUpdate"
signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float beta=0.9, Float dampening=0.0, Bool nesterov=False, Bool maximize=False, Float weight_decay=0) => DispatchMomentumUpdate"
bind_python: True

- name: "dispatch_sgd_update"
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/job/job_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ message NaiveModelUpdateConf {

message MomentumModelUpdateConf {
optional float beta = 1 [default = 0.9];
optional float dampening = 2 [default = 0.0];
optional bool nesterov = 3 [default = false];
optional bool maximize = 4 [default = false];
}

message RMSPropModelUpdateConf {
Expand Down
5 changes: 4 additions & 1 deletion oneflow/core/job_rewriter/fuse_update_ops_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ Maybe<void> FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
// do nothing
} else if (user_op_conf.op_type_name() == "momentum_update") {
fused_op_builder.Input("momentum", user_op_conf.input("momentum", 0))
.Attr<float>("beta", user_op_conf.attr<float>("beta"));
.Attr<float>("beta", user_op_conf.attr<float>("beta"))
.Attr<float>("dampening", user_op_conf.attr<float>("dampening"))
.Attr<bool>("nesterov", user_op_conf.attr<bool>("nesterov"))
.Attr<bool>("maximize", user_op_conf.attr<bool>("maximize"));
} else if (user_op_conf.op_type_name() == "adam_update") {
fused_op_builder.Input("m", user_op_conf.input("m", 0))
.Input("v", user_op_conf.input("v", 0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ Maybe<void> IndexedSlicesOptimizerRewritePass::Apply(const OpGraph& op_graph,
// do nothing
} else if (user_op_conf.op_type_name() == "momentum_update") {
indexed_slices_op_builder.Input("momentum", user_op_conf.input("momentum", 0))
.Attr<float>("beta", user_op_conf.attr<float>("beta"));
.Attr<float>("beta", user_op_conf.attr<float>("beta"))
.Attr<float>("dampening", user_op_conf.attr<float>("dampening"))
.Attr<bool>("nesterov", user_op_conf.attr<bool>("nesterov"))
.Attr<bool>("maximize", user_op_conf.attr<bool>("maximize"));
} else if (user_op_conf.op_type_name() == "adam_update") {
indexed_slices_op_builder.Input("m", user_op_conf.input("m", 0))
.Input("v", user_op_conf.input("v", 0))
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/job_rewriter/momentum_optm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,
.Input("learning_rate", optimizer_conf.learning_rate_lbn())
.Input("momentum", GenLogicalBlobName(op_name, momentum_var.variable_conf().out()))
.Attr<float>("beta", optimizer_conf.momentum_conf().beta())
.Attr<float>("dampening", optimizer_conf.momentum_conf().dampening())
.Attr<bool>("nesterov", optimizer_conf.momentum_conf().nesterov())
.Attr<bool>("maximize", optimizer_conf.momentum_conf().maximize())
.Attr<float>("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))
.ScopeSymbolId(var_op->op_conf().scope_symbol_id());
SetDynamicLossScaleSkipIf(ctx, &momentum_update_op_builder);
Expand Down
6 changes: 6 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5889,6 +5889,9 @@ def OneFlow_IndexedSlicesMomentumUpdateOp : OneFlow_BaseOp<"indexed_slices_momen
);
let attrs = (ins
DefaultValuedAttr<F32Attr, "0.9">:$beta,
DefaultValuedAttr<F32Attr, "0.0">:$dampening,
DefaultValuedAttr<BoolAttr, "false">:$nesterov,
DefaultValuedAttr<BoolAttr, "false">:$maximize,
DefaultValuedAttr<F32Attr, "0.">:$weight_decay
);
let has_logical_tensor_desc_infer_fn = 1;
Expand Down Expand Up @@ -5993,6 +5996,9 @@ def OneFlow_MomentumUpdateOp : OneFlow_BaseOp<"momentum_update", [NoGrad, AttrSi
DefaultValuedAttr<F32Attr, "0.">:$l1,
DefaultValuedAttr<F32Attr, "0.">:$l2,
DefaultValuedAttr<F32Attr, "0.9">:$beta,
DefaultValuedAttr<F32Attr, "0.0">:$dampening,
DefaultValuedAttr<BoolAttr, "false">:$nesterov,
DefaultValuedAttr<BoolAttr, "false">:$maximize,
DefaultValuedAttr<F32Attr, "0.">:$weight_decay
);
let trait_attrs = (ins
Expand Down
33 changes: 18 additions & 15 deletions oneflow/user/kernels/model_update_kernel_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,23 @@ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_
template<typename T, typename G>
struct MomentumUpdateKernelUtil<DeviceType::kCPU, T, G> {
static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta,
float weight_decay, float learning_rate_val, const float* learning_rate,
const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,
T* momentum);
float dampening, bool nesterov, bool maximize, float weight_decay,
float learning_rate_val, const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const G* model_diff, T* model, T* momentum);
};

template<typename T, typename G>
void MomentumUpdateKernelUtil<DeviceType::kCPU, T, G>::Update(
ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float weight_decay,
float learning_rate_val, const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const G* model_diff, T* model, T* momentum) {
ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening,
bool nesterov, bool maximize, float weight_decay, float learning_rate_val,
const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,
T* model, T* momentum) {
if (skip_if != nullptr && *skip_if != 0) { return; }
if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }
if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }
for (int64_t i = 0; i != n; ++i) {
MomentumUpdateFunctor<T, G>()(model_diff + i, model + i, momentum + i, scale, l1, l2, beta,
weight_decay, learning_rate_val);
dampening, nesterov, maximize, weight_decay, learning_rate_val);
}
}

Expand All @@ -132,17 +133,19 @@ template struct MomentumUpdateKernelUtil<DeviceType::kCPU, double, double>;

template<typename T, typename K, typename IDX>
struct IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCPU, T, K, IDX> {
static void Update(ep::Stream* stream, T beta, float weight_decay, int64_t num_instance,
int64_t feature_size, int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance, const float* learning_rate, const K* indices,
const T* values, T* model, T* momentum);
static void Update(ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize,
float weight_decay, int64_t num_instance, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices, const T* values, T* model,
T* momentum);
};

template<typename T, typename K, typename IDX>
void IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCPU, T, K, IDX>::Update(
ep::Stream* stream, T beta, float weight_decay, int64_t num_instance, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices, const T* values, T* model, T* momentum) {
ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, float weight_decay,
int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values,
T* model, T* momentum) {
const int64_t n = *num_unique_instance * feature_size;
const T lr = *learning_rate;
for (int64_t i = 0; i != n; ++i) {
Expand All @@ -152,7 +155,7 @@ void IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCPU, T, K, IDX>::Updat
if (instance_id >= lower_bound && instance_id < upper_bound) {
const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx;
MomentumUpdateFunctor<T, T>()(values + i, model + model_idx, momentum + model_idx, 1.0, 0.0,
0.0, beta, weight_decay, lr);
0.0, beta, dampening, nesterov, maximize, weight_decay, lr);
}
}
}
Expand Down
75 changes: 41 additions & 34 deletions oneflow/user/kernels/model_update_kernel_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,24 @@ namespace {
template<typename T, typename G>
__global__ void MomentumUpdateGpu(int64_t n, T scale, float l1, float l2, float beta,
float weight_decay, float learning_rate_val,
const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const G* model_diff, T* model,
T* momentum) {
float dampening, bool nesterov, bool maximize, float weight_decay,
float learning_rate_val, const float* learning_rate,
const T* scale_by_ptr, const int64_t* skip_if,
const G* model_diff, T* model, T* momentum) {
if (skip_if != nullptr && *skip_if != 0) { return; }
if (learning_rate != nullptr) { learning_rate_val = *learning_rate; }
if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }
CUDA_1D_KERNEL_LOOP(i, n) {
MomentumUpdateFunctor<T, G>()(model_diff + i, model + i, momentum + i, scale, l1, l2, beta,
weight_decay, learning_rate_val);
dampening, nesterov, maximize, weight_decay, learning_rate_val);
}
}
template<typename T, typename K, typename IDX>
__global__ void IndexedSlicesMomentumUpdateGpu(T beta, float weight_decay, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance,
__global__ void IndexedSlicesMomentumUpdateGpu(T beta, float dampening, bool nesterov,
bool maximize, float weight_decay,
int64_t feature_size, int64_t lower_bound,
int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices,
const T* values, T* model, T* momentum) {
const int64_t n = *num_unique_instance * feature_size;
Expand All @@ -202,7 +203,8 @@ __global__ void IndexedSlicesMomentumUpdateGpu(T beta, float weight_decay, int64
if (instance_id >= lower_bound && instance_id < upper_bound) {
const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx;
MomentumUpdateFunctor<T, T>()(values + i, model + model_idx, momentum + model_idx,
static_cast<T>(1), 0.0, 0.0, beta, weight_decay, lr);
static_cast<T>(1), 0.0, 0.0, beta, dampening, nesterov,
maximize, weight_decay, lr);
}
}
}
Expand All @@ -211,38 +213,41 @@ __global__ void IndexedSlicesMomentumUpdateGpu(T beta, float weight_decay, int64
template<typename T, typename G>
struct MomentumUpdateKernelUtil<DeviceType::kCUDA, T, G> {
static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta,
float weight_decay, float learning_rate_val, const float* learning_rate,
const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model,
T* momentum);
float dampening, bool nesterov, bool maximize, float weight_decay,
float learning_rate_val, const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const G* model_diff, T* model, T* momentum);
};
template<typename T, typename G>
void MomentumUpdateKernelUtil<DeviceType::kCUDA, T, G>::Update(
ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float weight_decay,
float learning_rate_val, const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const G* model_diff, T* model, T* momentum) {
ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening,
bool nesterov, bool maximize, float weight_decay, float learning_rate_val,
const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,
T* model, T* momentum) {
MomentumUpdateGpu<T, G><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(
n, scale, l1, l2, beta, weight_decay, learning_rate_val, learning_rate, scale_by_ptr, skip_if,
model_diff, model, momentum);
n, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay, learning_rate_val,
learning_rate, scale_by_ptr, skip_if, model_diff, model, momentum);
}
template<typename T>
struct MomentumUpdateKernelUtil<DeviceType::kCUDA, T, float16> {
static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta,
float weight_decay, float learning_rate_val, const float* learning_rate,
const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff,
T* model, T* momentum);
float dampening, bool nesterov, bool maximize, float weight_decay,
float learning_rate_val, const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const float16* model_diff, T* model, T* momentum);
};
template<typename T>
void MomentumUpdateKernelUtil<DeviceType::kCUDA, T, float16>::Update(
ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float weight_decay,
float learning_rate_val, const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const float16* model_diff, T* model, T* momentum) {
ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening,
bool nesterov, bool maximize, float weight_decay, float learning_rate_val,
const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if,
const float16* model_diff, T* model, T* momentum) {
MomentumUpdateKernelUtil<DeviceType::kCUDA, T, half>::Update(
stream, n, scale, l1, l2, beta, weight_decay, learning_rate_val, learning_rate, scale_by_ptr,
skip_if, reinterpret_cast<const half*>(model_diff), model, momentum);
stream, n, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay,
learning_rate_val, learning_rate, scale_by_ptr, skip_if,
reinterpret_cast<const half*>(model_diff), model, momentum);
}
template struct MomentumUpdateKernelUtil<DeviceType::kCUDA, double, double>;
Expand All @@ -251,22 +256,24 @@ template struct MomentumUpdateKernelUtil<DeviceType::kCUDA, float, float16>;
template<typename T, typename K, typename IDX>
struct IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCUDA, T, K, IDX> {
static void Update(ep::Stream* stream, T beta, float weight_decay, int64_t num_instance,
int64_t feature_size, int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance, const float* learning_rate, const K* indices,
const T* values, T* model, T* momentum);
static void Update(ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize,
float weight_decay, int64_t num_instance, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices, const T* values, T* model,
T* momentum);
};
template<typename T, typename K, typename IDX>
void IndexedSlicesMomentumMdUpdateKernelUtil<DeviceType::kCUDA, T, K, IDX>::Update(
ep::Stream* stream, T beta, float weight_decay, int64_t num_instance, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices, const T* values, T* model, T* momentum) {
ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, float weight_decay,
int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values,
T* model, T* momentum) {
IndexedSlicesMomentumUpdateGpu<T, K, IDX>
<<<BlocksNum4ThreadsNum(num_instance * feature_size), kCudaThreadsNumPerBlock, 0,
stream->As<ep::CudaStream>()->cuda_stream()>>>(
beta, weight_decay, feature_size, lower_bound, upper_bound, num_unique_instance,
learning_rate, indices, values, model, momentum);
beta, dampening, nesterov, maximize, weight_decay, feature_size, lower_bound, upper_bound,
num_unique_instance, learning_rate, indices, values, model, momentum);
}
#define INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CUDA( \
Expand Down
Loading

0 comments on commit 1531b06

Please sign in to comment.