diff --git a/oneflow/api/python/functional/dispatch_stateful_ops.cpp b/oneflow/api/python/functional/dispatch_stateful_ops.cpp index f123ba39f43..eeff32a711e 100644 --- a/oneflow/api/python/functional/dispatch_stateful_ops.cpp +++ b/oneflow/api/python/functional/dispatch_stateful_ops.cpp @@ -464,13 +464,17 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor( "DispatchMomentumUpdate", [](const std::shared_ptr& op, const TensorTuple& inputs, float learning_rate, - double scale, float l1, float l2, float beta, float weight_decay) -> Maybe { + double scale, float l1, float l2, float beta, float dampening, bool nesterov, + bool maximize, float weight_decay) -> Maybe { MutableAttrMap attrs; JUST(attrs.SetAttr("learning_rate_val", learning_rate)); JUST(attrs.SetAttr("scale", scale)); JUST(attrs.SetAttr("l1", l1)); JUST(attrs.SetAttr("l2", l2)); JUST(attrs.SetAttr("beta", beta)); + JUST(attrs.SetAttr("dampening", dampening)); + JUST(attrs.SetAttr("nesterov", nesterov)); + JUST(attrs.SetAttr("maximize", maximize)); JUST(attrs.SetAttr("weight_decay", weight_decay)); JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); return Maybe::Ok(); diff --git a/oneflow/api/python/functional/dispatch_stateful_ops.yaml b/oneflow/api/python/functional/dispatch_stateful_ops.yaml index bcd2848b6cd..c26ba19d735 100644 --- a/oneflow/api/python/functional/dispatch_stateful_ops.yaml +++ b/oneflow/api/python/functional/dispatch_stateful_ops.yaml @@ -137,7 +137,7 @@ bind_python: True - name: "dispatch_momentum_update" - signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float beta=0.9, Float weight_decay=0) => DispatchMomentumUpdate" + signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float beta=0.9, Float dampening=0.0, Bool nesterov=False, Bool maximize=False, Float weight_decay=0) => DispatchMomentumUpdate" bind_python: True - name: "dispatch_sgd_update" diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 2ebe5dfbb49..0626109a8ee 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -16,6 +16,9 @@ message NaiveModelUpdateConf { message MomentumModelUpdateConf { optional float beta = 1 [default = 0.9]; + optional float dampening = 2 [default = 0.0]; + optional bool nesterov = 3 [default = false]; + optional bool maximize = 4 [default = false]; } message RMSPropModelUpdateConf { diff --git a/oneflow/core/job_rewriter/fuse_update_ops_pass.cpp b/oneflow/core/job_rewriter/fuse_update_ops_pass.cpp index cd03447ad68..176ad1f70de 100644 --- a/oneflow/core/job_rewriter/fuse_update_ops_pass.cpp +++ b/oneflow/core/job_rewriter/fuse_update_ops_pass.cpp @@ -170,7 +170,10 @@ Maybe FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_bu // do nothing } else if (user_op_conf.op_type_name() == "momentum_update") { fused_op_builder.Input("momentum", user_op_conf.input("momentum", 0)) - .Attr("beta", user_op_conf.attr("beta")); + .Attr("beta", user_op_conf.attr("beta")) + .Attr("dampening", user_op_conf.attr("dampening")) + .Attr("nesterov", user_op_conf.attr("nesterov")) + .Attr("maximize", user_op_conf.attr("maximize")); } else if (user_op_conf.op_type_name() == "adam_update") { fused_op_builder.Input("m", user_op_conf.input("m", 0)) .Input("v", user_op_conf.input("v", 0)) diff --git a/oneflow/core/job_rewriter/indexed_slices_optimizer_rewrite_pass.cpp b/oneflow/core/job_rewriter/indexed_slices_optimizer_rewrite_pass.cpp index b092f1f8f22..4dfc51cdd58 100644 --- a/oneflow/core/job_rewriter/indexed_slices_optimizer_rewrite_pass.cpp +++ b/oneflow/core/job_rewriter/indexed_slices_optimizer_rewrite_pass.cpp @@ -113,7 +113,10 @@ Maybe IndexedSlicesOptimizerRewritePass::Apply(const OpGraph& op_graph, // do nothing } else if (user_op_conf.op_type_name() == "momentum_update") { indexed_slices_op_builder.Input("momentum", user_op_conf.input("momentum", 0)) - .Attr("beta", user_op_conf.attr("beta")); + .Attr("beta", user_op_conf.attr("beta")) + .Attr("dampening", user_op_conf.attr("dampening")) + .Attr("nesterov", user_op_conf.attr("nesterov")) + .Attr("maximize", user_op_conf.attr("maximize")); } else if (user_op_conf.op_type_name() == "adam_update") { indexed_slices_op_builder.Input("m", user_op_conf.input("m", 0)) .Input("v", user_op_conf.input("v", 0)) diff --git a/oneflow/core/job_rewriter/momentum_optm.cpp b/oneflow/core/job_rewriter/momentum_optm.cpp index 8d5f264241b..b718f220da0 100644 --- a/oneflow/core/job_rewriter/momentum_optm.cpp +++ b/oneflow/core/job_rewriter/momentum_optm.cpp @@ -58,6 +58,9 @@ void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node, .Input("learning_rate", optimizer_conf.learning_rate_lbn()) .Input("momentum", GenLogicalBlobName(op_name, momentum_var.variable_conf().out())) .Attr("beta", optimizer_conf.momentum_conf().beta()) + .Attr("dampening", optimizer_conf.momentum_conf().dampening()) + .Attr("nesterov", optimizer_conf.momentum_conf().nesterov()) + .Attr("maximize", optimizer_conf.momentum_conf().maximize()) .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); SetDynamicLossScaleSkipIf(ctx, &momentum_update_op_builder); diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 13535802a01..a59a924c824 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -5889,6 +5889,9 @@ def OneFlow_IndexedSlicesMomentumUpdateOp : OneFlow_BaseOp<"indexed_slices_momen ); let attrs = (ins DefaultValuedAttr:$beta, + DefaultValuedAttr:$dampening, + DefaultValuedAttr:$nesterov, + DefaultValuedAttr:$maximize, DefaultValuedAttr:$weight_decay ); let has_logical_tensor_desc_infer_fn = 1; @@ -5993,6 +5996,9 @@ def OneFlow_MomentumUpdateOp : OneFlow_BaseOp<"momentum_update", [NoGrad, AttrSi DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, DefaultValuedAttr:$beta, + DefaultValuedAttr:$dampening, + DefaultValuedAttr:$nesterov, + DefaultValuedAttr:$maximize, DefaultValuedAttr:$weight_decay ); let trait_attrs = (ins diff --git a/oneflow/user/kernels/model_update_kernel_util.cpp b/oneflow/user/kernels/model_update_kernel_util.cpp index fc76c6aff67..7368e104ff5 100644 --- a/oneflow/user/kernels/model_update_kernel_util.cpp +++ b/oneflow/user/kernels/model_update_kernel_util.cpp @@ -108,22 +108,23 @@ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INITIATE_INDEXED_SLICES_SGD_UPDATE_KERNEL_UTIL_ template struct MomentumUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, - float weight_decay, float learning_rate_val, const float* learning_rate, - const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, - T* momentum); + float dampening, bool nesterov, bool maximize, float weight_decay, + float learning_rate_val, const float* learning_rate, const T* scale_by_ptr, + const int64_t* skip_if, const G* model_diff, T* model, T* momentum); }; template void MomentumUpdateKernelUtil::Update( - ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float weight_decay, - float learning_rate_val, const float* learning_rate, const T* scale_by_ptr, - const int64_t* skip_if, const G* model_diff, T* model, T* momentum) { + ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening, + bool nesterov, bool maximize, float weight_decay, float learning_rate_val, + const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, + T* model, T* momentum) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } for (int64_t i = 0; i != n; ++i) { MomentumUpdateFunctor()(model_diff + i, model + i, momentum + i, scale, l1, l2, beta, - weight_decay, learning_rate_val); + dampening, nesterov, maximize, weight_decay, learning_rate_val); } } @@ -132,17 +133,19 @@ template struct MomentumUpdateKernelUtil; template struct IndexedSlicesMomentumMdUpdateKernelUtil { - static void Update(ep::Stream* stream, T beta, float weight_decay, int64_t num_instance, - int64_t feature_size, int64_t lower_bound, int64_t upper_bound, - const IDX* num_unique_instance, const float* learning_rate, const K* indices, - const T* values, T* model, T* momentum); + static void Update(ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, + float weight_decay, int64_t num_instance, int64_t feature_size, + int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, + const float* learning_rate, const K* indices, const T* values, T* model, + T* momentum); }; template void IndexedSlicesMomentumMdUpdateKernelUtil::Update( - ep::Stream* stream, T beta, float weight_decay, int64_t num_instance, int64_t feature_size, - int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, - const float* learning_rate, const K* indices, const T* values, T* model, T* momentum) { + ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, float weight_decay, + int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, + const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, + T* model, T* momentum) { const int64_t n = *num_unique_instance * feature_size; const T lr = *learning_rate; for (int64_t i = 0; i != n; ++i) { @@ -152,7 +155,7 @@ void IndexedSlicesMomentumMdUpdateKernelUtil::Updat if (instance_id >= lower_bound && instance_id < upper_bound) { const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx; MomentumUpdateFunctor()(values + i, model + model_idx, momentum + model_idx, 1.0, 0.0, - 0.0, beta, weight_decay, lr); + 0.0, beta, dampening, nesterov, maximize, weight_decay, lr); } } } diff --git a/oneflow/user/kernels/model_update_kernel_util.cu b/oneflow/user/kernels/model_update_kernel_util.cu index 9c9efd1048f..299f16976e8 100644 --- a/oneflow/user/kernels/model_update_kernel_util.cu +++ b/oneflow/user/kernels/model_update_kernel_util.cu @@ -174,23 +174,24 @@ namespace { template __global__ void MomentumUpdateGpu(int64_t n, T scale, float l1, float l2, float beta, - float weight_decay, float learning_rate_val, - const float* learning_rate, const T* scale_by_ptr, - const int64_t* skip_if, const G* model_diff, T* model, - T* momentum) { + float dampening, bool nesterov, bool maximize, float weight_decay, + float learning_rate_val, const float* learning_rate, + const T* scale_by_ptr, const int64_t* skip_if, + const G* model_diff, T* model, T* momentum) { if (skip_if != nullptr && *skip_if != 0) { return; } if (learning_rate != nullptr) { learning_rate_val = *learning_rate; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } CUDA_1D_KERNEL_LOOP(i, n) { MomentumUpdateFunctor()(model_diff + i, model + i, momentum + i, scale, l1, l2, beta, - weight_decay, learning_rate_val); + dampening, nesterov, maximize, weight_decay, learning_rate_val); } } template -__global__ void IndexedSlicesMomentumUpdateGpu(T beta, float weight_decay, int64_t feature_size, - int64_t lower_bound, int64_t upper_bound, - const IDX* num_unique_instance, +__global__ void IndexedSlicesMomentumUpdateGpu(T beta, float dampening, bool nesterov, + bool maximize, float weight_decay, + int64_t feature_size, int64_t lower_bound, + int64_t upper_bound, const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, T* model, T* momentum) { const int64_t n = *num_unique_instance * feature_size; @@ -202,7 +203,8 @@ __global__ void IndexedSlicesMomentumUpdateGpu(T beta, float weight_decay, int64 if (instance_id >= lower_bound && instance_id < upper_bound) { const IDX model_idx = (instance_id - lower_bound) * feature_size + inner_idx; MomentumUpdateFunctor()(values + i, model + model_idx, momentum + model_idx, - static_cast(1), 0.0, 0.0, beta, weight_decay, lr); + static_cast(1), 0.0, 0.0, beta, dampening, nesterov, + maximize, weight_decay, lr); } } } @@ -211,38 +213,41 @@ __global__ void IndexedSlicesMomentumUpdateGpu(T beta, float weight_decay, int64 template struct MomentumUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, - float weight_decay, float learning_rate_val, const float* learning_rate, - const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, - T* momentum); + float dampening, bool nesterov, bool maximize, float weight_decay, + float learning_rate_val, const float* learning_rate, const T* scale_by_ptr, + const int64_t* skip_if, const G* model_diff, T* model, T* momentum); }; template void MomentumUpdateKernelUtil::Update( - ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float weight_decay, - float learning_rate_val, const float* learning_rate, const T* scale_by_ptr, - const int64_t* skip_if, const G* model_diff, T* model, T* momentum) { + ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening, + bool nesterov, bool maximize, float weight_decay, float learning_rate_val, + const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, + T* model, T* momentum) { MomentumUpdateGpu<<As()->cuda_stream()>>>( - n, scale, l1, l2, beta, weight_decay, learning_rate_val, learning_rate, scale_by_ptr, skip_if, - model_diff, model, momentum); + n, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay, learning_rate_val, + learning_rate, scale_by_ptr, skip_if, model_diff, model, momentum); } template struct MomentumUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, - float weight_decay, float learning_rate_val, const float* learning_rate, - const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, - T* model, T* momentum); + float dampening, bool nesterov, bool maximize, float weight_decay, + float learning_rate_val, const float* learning_rate, const T* scale_by_ptr, + const int64_t* skip_if, const float16* model_diff, T* model, T* momentum); }; template void MomentumUpdateKernelUtil::Update( - ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float weight_decay, - float learning_rate_val, const float* learning_rate, const T* scale_by_ptr, - const int64_t* skip_if, const float16* model_diff, T* model, T* momentum) { + ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, float dampening, + bool nesterov, bool maximize, float weight_decay, float learning_rate_val, + const float* learning_rate, const T* scale_by_ptr, const int64_t* skip_if, + const float16* model_diff, T* model, T* momentum) { MomentumUpdateKernelUtil::Update( - stream, n, scale, l1, l2, beta, weight_decay, learning_rate_val, learning_rate, scale_by_ptr, - skip_if, reinterpret_cast(model_diff), model, momentum); + stream, n, scale, l1, l2, beta, dampening, nesterov, maximize, weight_decay, + learning_rate_val, learning_rate, scale_by_ptr, skip_if, + reinterpret_cast(model_diff), model, momentum); } template struct MomentumUpdateKernelUtil; @@ -251,22 +256,24 @@ template struct MomentumUpdateKernelUtil; template struct IndexedSlicesMomentumMdUpdateKernelUtil { - static void Update(ep::Stream* stream, T beta, float weight_decay, int64_t num_instance, - int64_t feature_size, int64_t lower_bound, int64_t upper_bound, - const IDX* num_unique_instance, const float* learning_rate, const K* indices, - const T* values, T* model, T* momentum); + static void Update(ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, + float weight_decay, int64_t num_instance, int64_t feature_size, + int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, + const float* learning_rate, const K* indices, const T* values, T* model, + T* momentum); }; template void IndexedSlicesMomentumMdUpdateKernelUtil::Update( - ep::Stream* stream, T beta, float weight_decay, int64_t num_instance, int64_t feature_size, - int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, - const float* learning_rate, const K* indices, const T* values, T* model, T* momentum) { + ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, float weight_decay, + int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound, + const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values, + T* model, T* momentum) { IndexedSlicesMomentumUpdateGpu <<As()->cuda_stream()>>>( - beta, weight_decay, feature_size, lower_bound, upper_bound, num_unique_instance, - learning_rate, indices, values, model, momentum); + beta, dampening, nesterov, maximize, weight_decay, feature_size, lower_bound, upper_bound, + num_unique_instance, learning_rate, indices, values, model, momentum); } #define INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDATE_KERNEL_UTIL_CUDA( \ diff --git a/oneflow/user/kernels/model_update_kernel_util.h b/oneflow/user/kernels/model_update_kernel_util.h index 9cd43cb591f..bca934ffcbf 100644 --- a/oneflow/user/kernels/model_update_kernel_util.h +++ b/oneflow/user/kernels/model_update_kernel_util.h @@ -76,13 +76,25 @@ template struct MomentumUpdateFunctor { OF_DEVICE_FUNC void operator()(const G* model_diff, T* model, T* momentum, T scale, float l1, float l2, - float beta, float weight_decay, float learning_rate) const { + float beta, float dampening, bool nesterov, bool maximize, float weight_decay, + float learning_rate) const { const T model_val = *model; T model_diff_t = CastScaleRegularizeGradientFunctor()(*model_diff, model_val, scale, l1, l2); - const T next_momentum = beta * *momentum - learning_rate * model_diff_t; + + T next_momentum = beta * *momentum + (1.0f - dampening) * model_diff_t; *momentum = next_momentum; - const T next_model = model_val + next_momentum - learning_rate * weight_decay * model_val; + + if (!nesterov) { + model_diff_t = next_momentum; + } else { + model_diff_t += beta * next_momentum; + } + + T alpha = -learning_rate; + if (maximize) { alpha = learning_rate; } + const T next_model = + model_val + alpha * model_diff_t - learning_rate * weight_decay * model_val; *model = next_model; } }; @@ -254,17 +266,18 @@ struct BiasCorrectionFactorKernelUtil { template struct MomentumUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, T scale, float l1, float l2, float beta, - float weight_decay, float learning_rate_val, const float* learning_rate, - const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, T* model, - T* momentum); + float dampening, bool nesterov, bool maximize, float weight_decay, + float learning_rate_val, const float* learning_rate, const T* scale_by_ptr, + const int64_t* skip_if, const G* model_diff, T* model, T* momentum); }; template struct IndexedSlicesMomentumMdUpdateKernelUtil { - static void Update(ep::Stream* stream, T beta, float weight_decay, int64_t num_instance, - int64_t feature_size, int64_t lower_bound, int64_t upper_bound, - const IDX* num_unique_instance, const float* learning_rate, const K* indices, - const T* values, T* model, T* momentum); + static void Update(ep::Stream* stream, T beta, float dampening, bool nesterov, bool maximize, + float weight_decay, int64_t num_instance, int64_t feature_size, + int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance, + const float* learning_rate, const K* indices, const T* values, T* model, + T* momentum); }; template diff --git a/oneflow/user/kernels/model_update_kernels.cpp b/oneflow/user/kernels/model_update_kernels.cpp index 82aa869dac1..dc6b0aeb4a8 100644 --- a/oneflow/user/kernels/model_update_kernels.cpp +++ b/oneflow/user/kernels/model_update_kernels.cpp @@ -265,6 +265,9 @@ class MomentumUpdateKernel final : public user_op::OpKernel, public user_op::Cud float l1 = ctx->Attr("l1"); float l2 = ctx->Attr("l2"); float beta = ctx->Attr("beta"); + const float dampening = ctx->Attr("dampening"); + const bool nesterov = ctx->Attr("nesterov"); + const bool maximize = ctx->Attr("maximize"); float weight_decay = ctx->Attr("weight_decay"); const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex("model_diff", 0); @@ -290,8 +293,9 @@ class MomentumUpdateKernel final : public user_op::OpKernel, public user_op::Cud } MomentumUpdateKernelUtil::Update( ctx->stream(), model->shape_view().elem_cnt(), static_cast(scale), l1, l2, beta, - weight_decay, learning_rate_val, learning_rate_ptr, scale_by_ptr, skip_if_ptr, - model_diff->dptr(), model->mut_dptr(), momentum->mut_dptr()); + dampening, nesterov, maximize, weight_decay, learning_rate_val, learning_rate_ptr, + scale_by_ptr, skip_if_ptr, model_diff->dptr(), model->mut_dptr(), + momentum->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; @@ -334,6 +338,9 @@ class IndexedSlicesMomentumUpdateKernel final : public user_op::OpKernel { user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); user_op::Tensor* momentum = ctx->Tensor4ArgNameAndIndex("momentum", 0); const auto beta = ctx->Attr("beta"); + const float dampening = ctx->Attr("dampening"); + const bool nesterov = ctx->Attr("nesterov"); + const bool maximize = ctx->Attr("maximize"); const auto weight_decay = ctx->Attr("weight_decay"); const int64_t num_indices = model_diff_indices->shape_view().elem_cnt(); const int64_t num_values = model_diff_values->shape_view().elem_cnt(); @@ -359,8 +366,8 @@ class IndexedSlicesMomentumUpdateKernel final : public user_op::OpKernel { buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(), buffer_manager.UniqueWorkspacePtr(), buffer_manager.UniqueWorkspaceBytes()); MdUpdateUtilT::Update( - ctx->stream(), beta, weight_decay, num_indices, feature_size, kernel_cache->lower(), - kernel_cache->upper(), buffer_manager.NumUniqueDiffIndicesPtr(), + ctx->stream(), beta, dampening, nesterov, maximize, weight_decay, num_indices, feature_size, + kernel_cache->lower(), kernel_cache->upper(), buffer_manager.NumUniqueDiffIndicesPtr(), learning_rate->dptr(), buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(), model->mut_dptr(), momentum->mut_dptr()); } diff --git a/oneflow/user/kernels/one_embedding_update_kernels.cu b/oneflow/user/kernels/one_embedding_update_kernels.cu index 1a4483234fe..3e274467731 100644 --- a/oneflow/user/kernels/one_embedding_update_kernels.cu +++ b/oneflow/user/kernels/one_embedding_update_kernels.cu @@ -56,6 +56,7 @@ __device__ void GetMomentumOffset(const int32_t line_size, const int32_t embeddi template __global__ void MomentumUpdateKernel(const int64_t line_size, const int64_t embedding_size, T scale, float l1, float l2, float weight_decay, float beta, + float dampening, bool nesterov, bool maximize, const IDX* num_unique_ids, const float* learning_rate, const T* scale_by_ptr, const T* down_scale_by_ptr, const int64_t* skip_if, const G* model_diff, @@ -76,7 +77,7 @@ __global__ void MomentumUpdateKernel(const int64_t line_size, const int64_t embe updated_unique_values[momentum_offset] = unique_values[momentum_offset]; MomentumUpdateFunctor()(model_diff + i, updated_unique_values + model_offset, updated_unique_values + momentum_offset, scale, l1, l2, beta, - weight_decay, learning_rate_val); + dampening, nesterov, maximize, weight_decay, learning_rate_val); } } } @@ -342,6 +343,10 @@ class MomentumEmbeddingUpdateKernel final : public user_op::OpKernel { const float l2 = ctx->Attr("l2"); const auto weight_decay = ctx->Attr("weight_decay"); const auto beta = ctx->Attr("beta"); + // TODO: Suppoprt dampening, nesterov, maximize in OneEmbeddingMomentumUpdate(zhengzekang). + const float dampening = 0.0; + const bool nesterov = false; + const bool maximize = false; const auto scale = ctx->Attr("scale"); const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { @@ -376,10 +381,10 @@ class MomentumEmbeddingUpdateKernel final : public user_op::OpKernel { MomentumUpdateKernel <<stream()->As()->cuda_stream()>>>( - line_size, embedding_size, scale, l1, l2, weight_decay, beta, - reinterpret_cast(num_unique_ids->dptr()), learning_rate_ptr, scale_by_ptr, - down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr(), unique_embeddings_ptr, - updated_unique_embeddings_ptr); + line_size, embedding_size, scale, l1, l2, weight_decay, beta, dampening, nesterov, + maximize, reinterpret_cast(num_unique_ids->dptr()), learning_rate_ptr, + scale_by_ptr, down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr(), + unique_embeddings_ptr, updated_unique_embeddings_ptr); embedding_state->OnEmbeddingUpdateEnd(ctx, current_iter_); current_iter_++; } diff --git a/python/oneflow/nn/optimizer/sgd.py b/python/oneflow/nn/optimizer/sgd.py index 44ac3f9b18e..551301630c7 100644 --- a/python/oneflow/nn/optimizer/sgd.py +++ b/python/oneflow/nn/optimizer/sgd.py @@ -14,6 +14,7 @@ limitations under the License. """ import collections +import warnings from typing import Callable, Dict, Iterator, List, Union import oneflow as flow @@ -100,15 +101,25 @@ def __init__( params: Union[Iterator[Parameter], List[Dict]], lr: float = 0.001, momentum: float = 0.0, + dampening: float = 0.0, weight_decay: float = 0.0, + nesterov: bool = False, + maximize: bool = False, ): assert lr >= 0.0, f"Invalid learning rate: {lr}" assert momentum >= 0.0, f"Invalid momentum: {momentum}" assert weight_decay >= 0.0, f"Invalid weight_decay: {weight_decay}" + if maximize: + warnings.warn( + "Only Momentum > 0.0, param `maximize` takes effect. ", FutureWarning, + ) options = dict() options["lr"] = lr options["momentum"] = momentum + options["dampening"] = dampening options["weight_decay"] = weight_decay + options["nesterov"] = nesterov + options["maximize"] = maximize super().__init__(params, options) for param_group in self.param_groups: @@ -145,6 +156,7 @@ def step(self, closure: Callable = None): if param.grad is None: continue if param_group["momentum"] == 0.0: + # TODO: Support param `maximize` in Naive SGD Optimizer. (zhengzekang) flow._C.dispatch_sgd_update( self._sgd, (param, param.grad), learning_rate=lr, l2=l2 ) @@ -153,12 +165,18 @@ def step(self, closure: Callable = None): self._state[param]["momentum_buf"] = flow.zeros_like(param) momentum_buf = self._state[param]["momentum_buf"] beta = param_group["momentum"] + dampening = param_group["dampening"] + nesterov = param_group["nesterov"] + maximize = param_group["maximize"] flow._C.dispatch_momentum_update( self._momentum_sgd, (param, param.grad, momentum_buf), learning_rate=lr, l2=l2, beta=beta, + dampening=dampening, + nesterov=nesterov, + maximize=maximize, ) self._state["step"] = self._state["step"] + 1 return loss @@ -174,12 +192,19 @@ def _generate_conf_for_graph(self, train_conf, vars_conf): ) beta = param_group["momentum"] l2 = param_group["weight_decay"] + dampening = param_group["dampening"] + nesterov = param_group["nesterov"] + maximize = param_group["maximize"] optimizer_conf.base_learning_rate = lr if beta == 0: optimizer_conf.naive_conf.SetInParent() else: optimizer_conf.momentum_conf.beta = beta + # Only Momentum Optimizer support these params. + optimizer_conf.momentum_conf.dampening = dampening + optimizer_conf.momentum_conf.nesterov = nesterov + optimizer_conf.momentum_conf.maximize = maximize self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf) diff --git a/python/oneflow/test/graph/test_graph_optim_sgd.py b/python/oneflow/test/graph/test_graph_optim_sgd.py index 2d2adc240f9..ef20b17c397 100644 --- a/python/oneflow/test/graph/test_graph_optim_sgd.py +++ b/python/oneflow/test/graph/test_graph_optim_sgd.py @@ -25,7 +25,16 @@ def compare_with_numpy_sgd( - test_case, device, x_shape, learning_rate, train_iters, momentum, weight_decay + test_case, + device, + x_shape, + learning_rate, + train_iters, + momentum, + dampening, + nesterov, + maximize, + weight_decay, ): random_grad_seq = [] for _ in range(train_iters): @@ -51,10 +60,13 @@ def forward(self, mask): { "params": simp_module.parameters(), "lr": learning_rate, - "momentum": momentum, "weight_decay": weight_decay, } ], + momentum=momentum, + dampening=dampening, + nesterov=nesterov, + maximize=maximize, ) class CustomSGDGraph(flow.nn.Graph): @@ -85,8 +97,23 @@ def train_by_numpy(): def np_train_one_iter(grad): grad = grad + weight_decay * x - v = momentum * vt - learning_rate * grad - param = x + v + if momentum > 0.0: + next_momentum = momentum * vt + (1 - dampening) * grad + v = next_momentum + + if nesterov: + grad += momentum * next_momentum + else: + grad = next_momentum + + alpha = -learning_rate + if maximize: + alpha = learning_rate + next_model = x + alpha * grad + param = next_model + else: + v = learning_rate * grad + param = x - v return (param, v) for i in range(train_iters): @@ -103,6 +130,9 @@ def compare_with_numpy_sgd_clip_grad( x_shape, learning_rate, momentum, + dampening, + nesterov, + maximize, weight_decay, clip_grad_max_norm, clip_grad_norm_type, @@ -132,12 +162,15 @@ def forward(self, mask): { "params": simp_module.parameters(), "lr": learning_rate, - "momentum": momentum, "weight_decay": weight_decay, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } - ] + ], + momentum=momentum, + dampening=dampening, + nesterov=nesterov, + maximize=maximize, ) class CustomSGDGraph(flow.nn.Graph): @@ -171,8 +204,23 @@ def np_train_one_iter(grad): grad, clip_grad_max_norm, clip_grad_norm_type ) grad = grad + weight_decay * x - v = momentum * vt - learning_rate * grad - param = x + v + if momentum > 0.0: + next_momentum = momentum * vt + (1 - dampening) * grad + v = next_momentum + + if nesterov: + grad += momentum * next_momentum + else: + grad = next_momentum + + alpha = -learning_rate + if maximize: + alpha = learning_rate + next_model = x + alpha * grad + param = next_model + else: + v = learning_rate * grad + param = x - v return (param, v) for i in range(train_iters): @@ -185,7 +233,7 @@ def np_train_one_iter(grad): @flow.unittest.skip_unless_1n1d() -class TestCpuSGD(flow.unittest.TestCase): +class TestGraphSGD(flow.unittest.TestCase): def test_sgd(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["cpu", "cuda"] @@ -193,6 +241,9 @@ def test_sgd(test_case): arg_dict["learning_rate"] = [1, 1e-3] arg_dict["train_iters"] = [10] arg_dict["momentum"] = [0.9, 0.8] + arg_dict["dampening"] = [0.0, 0.9] + arg_dict["nesterov"] = [True, False] + arg_dict["maximize"] = [True, False] arg_dict["weight_decay"] = [0.001, 0.0] for arg in GenArgList(arg_dict): compare_with_numpy_sgd(test_case, *arg) @@ -203,6 +254,9 @@ def test_sgd_with_clip_grad(test_case): arg_dict["x_shape"] = [(10,)] arg_dict["learning_rate"] = [1, 0.1] arg_dict["momentum"] = [0.0, 0.9] + arg_dict["dampening"] = [0.0, 0.9] + arg_dict["nesterov"] = [True, False] + arg_dict["maximize"] = [True, False] arg_dict["weight_decay"] = [0.0, 0.9] arg_dict["clip_grad_max_norm"] = [1.0] arg_dict["clip_grad_norm_type"] = [2.0] diff --git a/python/oneflow/test/modules/test_one_embedding_sgd.py b/python/oneflow/test/modules/test_one_embedding_sgd.py index 50170646e3a..53ec1d6d7f6 100644 --- a/python/oneflow/test/modules/test_one_embedding_sgd.py +++ b/python/oneflow/test/modules/test_one_embedding_sgd.py @@ -140,15 +140,24 @@ def sgd_by_numpy(): def train_one_iter(num_valid, grad, model, state): grad[0:num_valid] = grad[0:num_valid] * (scale / down_scale_by) next_state = ( - momentum * state[0:num_valid] if momentum > 0 else 0 - ) - learning_rate * grad[0:num_valid] + (momentum * state[0:num_valid] + grad[0:num_valid]) + if momentum > 0 + else 0 + ) if momentum > 0: state[0:num_valid] = next_state - model[0:num_valid] = ( - model[0:num_valid] - + next_state - - learning_rate * weight_decay * model[0:num_valid] - ) + model[0:num_valid] = ( + model[0:num_valid] + - learning_rate * next_state + - learning_rate * weight_decay * model[0:num_valid] + ) + else: + state[0:num_valid] = 0 + model[0:num_valid] = ( + model[0:num_valid] + - learning_rate * grad[0:num_valid] + - learning_rate * weight_decay * model[0:num_valid] + ) return (model, state) for i in range(train_iters): diff --git a/python/oneflow/test/modules/test_optim_sgd.py b/python/oneflow/test/modules/test_optim_sgd.py index 6480c21804d..61b51feca4d 100644 --- a/python/oneflow/test/modules/test_optim_sgd.py +++ b/python/oneflow/test/modules/test_optim_sgd.py @@ -31,6 +31,9 @@ def compare_with_numpy_sgd( device, x_shape, momentum, + dampening, + nesterov, + maximize, weight_decay, learning_rate, train_iters, @@ -45,14 +48,11 @@ def compare_with_numpy_sgd( def train_by_oneflow(): x = Parameter(flow.Tensor(init_value, device=flow.device(device))) sgd = flow.optim.SGD( - [ - { - "params": [x], - "lr": learning_rate, - "momentum": momentum, - "weight_decay": weight_decay, - } - ] + [{"params": [x], "lr": learning_rate, "weight_decay": weight_decay,}], + momentum=momentum, + dampening=dampening, + nesterov=nesterov, + maximize=maximize, ) def train_one_iter(grad): @@ -86,8 +86,23 @@ def train_by_numpy(): def train_one_iter(grad): grad = grad + weight_decay * x - v = momentum * vt - learning_rate * grad - param = x + v + if momentum > 0.0: + next_momentum = momentum * vt + (1 - dampening) * grad + v = next_momentum + + if nesterov: + grad += momentum * next_momentum + else: + grad = next_momentum + + alpha = -learning_rate + if maximize: + alpha = learning_rate + next_model = x + alpha * grad + param = next_model + else: + v = learning_rate * grad + param = x - v return (param, v) for i in range(train_iters): @@ -108,6 +123,9 @@ def compare_with_numpy_sgd_clip_grad( device, x_shape, momentum, + dampening, + nesterov, + maximize, weight_decay, learning_rate, clip_grad_max_norm, @@ -128,12 +146,18 @@ def train_by_oneflow(): { "params": [x], "lr": learning_rate, - "momentum": momentum, + "dampening": dampening, + "nesterov": nesterov, + "maximize": maximize, "weight_decay": weight_decay, "clip_grad_max_norm": clip_grad_max_norm, "clip_grad_norm_type": clip_grad_norm_type, } - ] + ], + momentum=momentum, + dampening=dampening, + nesterov=nesterov, + maximize=maximize, ) def train_one_iter(grad): @@ -171,8 +195,23 @@ def train_one_iter(grad): grad, clip_grad_max_norm, clip_grad_norm_type ) grad = grad + weight_decay * x - v = momentum * vt - learning_rate * grad - param = x + v + if momentum > 0.0: + next_momentum = momentum * vt + (1 - dampening) * grad + v = next_momentum + + if nesterov: + grad += momentum * next_momentum + else: + grad = next_momentum + + alpha = -learning_rate + if maximize: + alpha = learning_rate + next_model = x + alpha * grad + param = next_model + else: + v = learning_rate * grad + param = x - v return (param, v) for i in range(train_iters): @@ -196,6 +235,9 @@ def test_sgd(test_case): arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["momentum"] = [0.0, 0.9] + arg_dict["dampening"] = [0.0, 0.9] + arg_dict["nesterov"] = [True, False] + arg_dict["maximize"] = [True, False] arg_dict["weight_decay"] = [0.0, 0.9] arg_dict["learning_rate"] = [1, 0.1] arg_dict["train_iters"] = [10] @@ -209,6 +251,9 @@ def test_sgd_clip_grad(test_case): arg_dict["device"] = ["cpu", "cuda"] arg_dict["x_shape"] = [(10,)] arg_dict["momentum"] = [0.0, 0.9] + arg_dict["dampening"] = [0.0, 0.9] + arg_dict["nesterov"] = [True, False] + arg_dict["maximize"] = [True, False] arg_dict["weight_decay"] = [0.0, 0.9] arg_dict["learning_rate"] = [1, 0.1] arg_dict["clip_grad_max_norm"] = [0, 0.5, 1.0]