diff --git a/docs/source/optim.rst b/docs/source/optim.rst index f9fcbcb3dd5..3ebb75161b6 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -10,6 +10,7 @@ Optimizers Optimizer, RMSprop, SGD, + LAMB, lr_scheduler .. automodule:: oneflow.optim.lr_scheduler diff --git a/oneflow/api/python/functional/dispatch_stateful_ops.cpp b/oneflow/api/python/functional/dispatch_stateful_ops.cpp index a97a6526d11..240ebb0ed4f 100644 --- a/oneflow/api/python/functional/dispatch_stateful_ops.cpp +++ b/oneflow/api/python/functional/dispatch_stateful_ops.cpp @@ -480,6 +480,26 @@ ONEFLOW_FUNCTION_LIBRARY(m) { JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); return Maybe::Ok(); }); + m.add_functor("DispatchLambUpdate", + [](const std::shared_ptr& op, const TensorTuple& inputs, + float learning_rate, float bias_correction1, float bias_correction2, + double scale, float l1, float l2, float beta1, float beta2, float epsilon, + float weight_decay, bool do_bias_correction) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("learning_rate_val", learning_rate)); + JUST(attrs.SetAttr("bias_correction1_val", bias_correction1)); + JUST(attrs.SetAttr("bias_correction2_val", bias_correction2)); + JUST(attrs.SetAttr("scale", scale)); + JUST(attrs.SetAttr("l1", l1)); + JUST(attrs.SetAttr("l2", l2)); + JUST(attrs.SetAttr("beta1", beta1)); + JUST(attrs.SetAttr("beta2", beta2)); + JUST(attrs.SetAttr("epsilon", epsilon)); + JUST(attrs.SetAttr("weight_decay", weight_decay)); + JUST(attrs.SetAttr("do_bias_correction", do_bias_correction)); + JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); + return Maybe::Ok(); + }); m.add_functor("DispatchEagerNcclAllReduce", [](const std::shared_ptr& op, const std::shared_ptr& input, const std::string& parallel_conf, bool async_launch) -> Maybe { diff --git a/oneflow/api/python/functional/dispatch_stateful_ops.yaml b/oneflow/api/python/functional/dispatch_stateful_ops.yaml index 87fa7647059..713bdc08d38 100644 --- a/oneflow/api/python/functional/dispatch_stateful_ops.yaml +++ b/oneflow/api/python/functional/dispatch_stateful_ops.yaml @@ -138,6 +138,10 @@ - name: "dispatch_sgd_update" signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float weight_decay=0) => DispatchSgdUpdate" bind_python: True + +- name: "dispatch_lamb_update" + signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Float bias_correction1=1.0, Float bias_correction2=1.0, Double scale=1.0, Float l1=0, Float l2=0, Float beta1=0.9, Float beta2=0.999, Float epsilon=1e-8, Float weight_decay=0, Bool do_bias_correction=True) => DispatchLambUpdate" + bind_python: True - name: "dispatch_eager_nccl_all_reduce" signature: "Tensor (OpExpr op, Tensor input, String parallel_conf, Bool async_launch=False) => DispatchEagerNcclAllReduce" diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 8b1d2153a9d..88432aa7504 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -47,9 +47,10 @@ message LazyAdamModelUpdateConf { } message LambModelUpdateConf { - required float beta1 = 1; - required float beta2 = 2; - required float epsilon = 3; + optional float beta1 = 1 [default = 0.9]; + optional float beta2 = 2 [default = 0.999]; + optional float epsilon = 3 [default = 1e-8]; + optional bool do_bias_correction = 4 [default = true]; } message AdagradModelUpdateConf { diff --git a/oneflow/core/job_rewriter/adam_optm.cpp b/oneflow/core/job_rewriter/adam_optm.cpp index b325a992317..382307f2f2d 100644 --- a/oneflow/core/job_rewriter/adam_optm.cpp +++ b/oneflow/core/job_rewriter/adam_optm.cpp @@ -44,25 +44,6 @@ struct hash { namespace oneflow { -namespace { - -std::string GenVariableOutputLbn(const OperatorConf& op_conf) { - CHECK(op_conf.has_variable_conf()); - return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out()); -} - -OperatorConf GenerateAdamHelperVariableOpConf(const VariableOp& op, const std::string& name, - const float initial_value) { - OperatorConf helper_variable_op(op.op_conf()); - helper_variable_op.set_name(op.op_name() + "-" + name); - helper_variable_op.mutable_variable_conf()->set_out("out"); - InitializerConf constant_initializer; - constant_initializer.mutable_constant_conf()->set_value(initial_value); - *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer; - helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id()); - return helper_variable_op; -} - class BiasCorrectionFactorState final : public JobPassState { public: BiasCorrectionFactorState() {} @@ -88,6 +69,25 @@ class BiasCorrectionFactorState final : public JobPassState { HashMap key2lbn_; }; +namespace { + +std::string GenVariableOutputLbn(const OperatorConf& op_conf) { + CHECK(op_conf.has_variable_conf()); + return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out()); +} + +OperatorConf GenerateAdamHelperVariableOpConf(const VariableOp& op, const std::string& name, + const float initial_value) { + OperatorConf helper_variable_op(op.op_conf()); + helper_variable_op.set_name(op.op_name() + "-" + name); + helper_variable_op.mutable_variable_conf()->set_out("out"); + InitializerConf constant_initializer; + constant_initializer.mutable_constant_conf()->set_value(initial_value); + *(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer; + helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id()); + return helper_variable_op; +} + void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node, const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder) { diff --git a/oneflow/core/job_rewriter/lamb_optm.cpp b/oneflow/core/job_rewriter/lamb_optm.cpp index d764637b17a..999f6d34d10 100644 --- a/oneflow/core/job_rewriter/lamb_optm.cpp +++ b/oneflow/core/job_rewriter/lamb_optm.cpp @@ -15,8 +15,46 @@ limitations under the License. */ #include "oneflow/core/job_rewriter/optimizer.h" #include "oneflow/core/framework/framework.h" + +namespace oneflow { + +struct BiasCorrectionFactorCacheKey { + float beta = 1.0; + ParallelConf parallel_conf; +}; + +bool operator==(const BiasCorrectionFactorCacheKey& lhs, const BiasCorrectionFactorCacheKey& rhs); + +} // namespace oneflow + +namespace std { +template<> +struct hash { + size_t operator()(const oneflow::BiasCorrectionFactorCacheKey& key) const { + const auto& float_hash = std::hash(); + const auto& parallel_conf_hash = std::hash(); + return float_hash(key.beta) ^ parallel_conf_hash(key.parallel_conf); + } +}; + +} // namespace std + namespace oneflow { +// Forward declaration for bias correction factor +class BiasCorrectionFactorState final : public JobPassState { + public: + BiasCorrectionFactorState() {} + ~BiasCorrectionFactorState() override = default; + + std::string GetLbn(float beta, std::string bias_correction_name, ParallelConf parallel_conf, + const std::function& + BiasCorrectionFactorStateOp); + + private: + HashMap key2lbn_; +}; + namespace { std::string GenVariableOutputLbn(const OperatorConf& op_conf) { @@ -47,37 +85,98 @@ void SetScalarShapeAndNdSbpConf(const ParallelDesc& parallel_desc, OperatorConf* } void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node, - const std::string& model_diff_lbn, const OptimizerConf optimizer_conf, + const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf, JobBuilder* job_builder) { const VariableOp* var_op = dynamic_cast(&var_op_node.op()); CHECK_NOTNULL(var_op); + OperatorConf m_var = GenerateLAMBHelperVariableOpConf(*var_op, "m", 0.f); OperatorConf v_var = GenerateLAMBHelperVariableOpConf(*var_op, "v", 0.f); + job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {m_var, v_var}); - OperatorConf beta1_t_var; - OperatorConf beta2_t_var; + user_op::UserOpConfWrapperBuilder lamb_update_op_builder(var_op->op_name() + "_optimizer"); + const LambModelUpdateConf& lamb_conf = optimizer_conf.lamb_conf(); - beta1_t_var = GenerateLAMBHelperVariableOpConf(*var_op, "beta1_t", lamb_conf.beta1()); - SetScalarShapeAndNdSbpConf(var_op_node.parallel_desc(), &beta1_t_var); - beta2_t_var = GenerateLAMBHelperVariableOpConf(*var_op, "beta2_t", lamb_conf.beta2()); - SetScalarShapeAndNdSbpConf(var_op_node.parallel_desc(), &beta2_t_var); - job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {beta1_t_var, beta2_t_var}); + float beta1 = lamb_conf.beta1(); + float beta2 = lamb_conf.beta2(); + float epsilon = lamb_conf.epsilon(); + bool do_bias_correction = lamb_conf.do_bias_correction(); + + const std::string& train_step_lbn = job_builder->job().job_conf().train_conf().train_step_lbn(); + const std::string& learning_rate_lbn = optimizer_conf.learning_rate_lbn(); + + if (do_bias_correction) { + // Reuse adam bias_correction job pass + const std::string& job_pass_state_key = "adam_bias_correction_factor"; + const bool has_state = CHECK_JUST(ctx->HasState(job_pass_state_key)); + if (!has_state) { + CHECK_JUST( + ctx->ResetState(job_pass_state_key, std::make_unique())); + } + auto* state = CHECK_JUST(ctx->MutableState(job_pass_state_key)); + ParallelConf bias_correction_parallel_conf; + const auto& lr_parallel_conf = + CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(learning_rate_lbn))); + const auto& train_step_parallel_conf = + CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(train_step_lbn))); + if (lr_parallel_conf == train_step_parallel_conf) { + bias_correction_parallel_conf = lr_parallel_conf; + } else { + bias_correction_parallel_conf = var_op_node.parallel_desc().parallel_conf(); + } + auto AddLambBiasCorrectionFactorOp = [&](float beta_val, + const std::string& op_name) -> std::string { + user_op::UserOpConfWrapperBuilder op_builder(var_op->op_name() + op_name); + const auto lamb_bias_correction_factor_op = + op_builder.OpTypeName("adam_bias_correction_factor") + .Input("train_step", train_step_lbn) + .Attr("beta", beta_val) + .Output("out") + .ScopeSymbolId(var_op->op_conf().scope_symbol_id()) + .Build(); + + job_builder->AddOps(bias_correction_parallel_conf, + {lamb_bias_correction_factor_op.op_conf()}); + return lamb_bias_correction_factor_op.output("out", 0); + }; + + const std::string bias_correction1_lbn = + state->GetLbn(beta1, "lamb_bias_correction_factor1", bias_correction_parallel_conf, + AddLambBiasCorrectionFactorOp); + const std::string bias_correction2_lbn = + state->GetLbn(beta2, "lamb_bias_correction_factor2", bias_correction_parallel_conf, + AddLambBiasCorrectionFactorOp); + + lamb_update_op_builder.OpTypeName("lamb_update") + .Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out"))) + .Input("model_diff", model_diff_lbn) + .Input("m", GenVariableOutputLbn(m_var)) + .Input("v", GenVariableOutputLbn(v_var)) + .Input("learning_rate", learning_rate_lbn) + .Input("bias_correction1", bias_correction1_lbn) + .Input("bias_correction2", bias_correction2_lbn) + .Attr("beta1", beta1) + .Attr("beta2", beta2) + .Attr("epsilon", epsilon) + .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) + .Attr("do_bias_correction", true) + .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); + } else { + lamb_update_op_builder.OpTypeName("lamb_update") + .Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out"))) + .Input("model_diff", model_diff_lbn) + .Input("m", GenVariableOutputLbn(m_var)) + .Input("v", GenVariableOutputLbn(v_var)) + .Input("learning_rate", learning_rate_lbn) + .Attr("beta1", beta1) + .Attr("beta2", beta2) + .Attr("epsilon", epsilon) + .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) + .Attr("do_bias_correction", false) + .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); + } - user_op::UserOpConfWrapperBuilder lamb_update_op_builder(var_op->op_name() + "_optimizer"); - lamb_update_op_builder.OpTypeName("lamb_update") - .Input("m", GenVariableOutputLbn(m_var)) - .Input("v", GenVariableOutputLbn(v_var)) - .Input("beta1_t", GenVariableOutputLbn(beta1_t_var)) - .Input("beta2_t", GenVariableOutputLbn(beta2_t_var)) - .Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out"))) - .Input("model_diff", model_diff_lbn) - .Input("learning_rate", optimizer_conf.learning_rate_lbn()) - .Attr("beta1", lamb_conf.beta1()) - .Attr("beta2", lamb_conf.beta2()) - .Attr("epsilon", lamb_conf.epsilon()) - .Attr("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op)) - .ScopeSymbolId(var_op->op_conf().scope_symbol_id()); SetDynamicLossScaleSkipIf(ctx, &lamb_update_op_builder); const auto lamb_update_op = lamb_update_op_builder.Build(); job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {lamb_update_op.op_conf()}); diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 926d6d594db..d52e2b829b5 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -5406,24 +5406,28 @@ def OneFlow_IndexedSlicesSgdUpdateOp : OneFlow_BaseOp<"indexed_slices_sgd_update def OneFlow_LambUpdateOp : OneFlow_BaseOp<"lamb_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins - OneFlow_Tensor:$m, - OneFlow_Tensor:$v, - OneFlow_Tensor:$beta1_t, - OneFlow_Tensor:$beta2_t, OneFlow_Tensor:$model, OneFlow_Tensor:$model_diff, - OneFlow_Tensor:$learning_rate, + Optional:$learning_rate, Optional:$scale_by_tensor, - Optional:$skip_if + Optional:$skip_if, + Optional:$bias_correction1, + Optional:$bias_correction2, + OneFlow_Tensor:$m, + OneFlow_Tensor:$v ); let attrs = (ins - DefaultValuedAttr:$beta1, - DefaultValuedAttr:$beta2, - DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$learning_rate_val, + DefaultValuedAttr:$bias_correction1_val, + DefaultValuedAttr:$bias_correction2_val, DefaultValuedAttr:$scale, DefaultValuedAttr:$l1, DefaultValuedAttr:$l2, - DefaultValuedAttr:$weight_decay + DefaultValuedAttr:$beta1, + DefaultValuedAttr:$beta2, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$weight_decay, + DefaultValuedAttr:$do_bias_correction ); let trait_attrs = (ins I32ElementsAttr:$operand_segment_sizes diff --git a/oneflow/user/kernels/model_update_kernel_util.cpp b/oneflow/user/kernels/model_update_kernel_util.cpp index 9a85aebfcf8..36e4ad9a136 100644 --- a/oneflow/user/kernels/model_update_kernel_util.cpp +++ b/oneflow/user/kernels/model_update_kernel_util.cpp @@ -272,30 +272,38 @@ template struct AdagradUpdateKernelUtil; template struct LambUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, - float beta2, float epsilon, float weight_decay, const float* learning_rate, + float beta2, float epsilon, float weight_decay, float learning_rate_val, + bool do_bias_correction, float bias_correction1_val, + float bias_correction2_val, const float* learning_rate_ptr, + const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, - T* adam_diff, T* model, T* m, T* v, T* norm_buffer, T* beta1_t, T* beta2_t); + T* adam_diff, T* model, T* m, T* v, T* norm_buffer); }; template void LambUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2, - float epsilon, float weight_decay, const float* learning_rate, const T* scale_by_ptr, - const int64_t* skip_if, const G* model_diff, T* adam_diff, T* model, T* m, T* v, T* norm_buffer, - T* beta1_t, T* beta2_t) { + float epsilon, float weight_decay, float learning_rate_val, bool do_bias_correction, + float bias_correction1_val, float bias_correction2_val, const float* learning_rate_ptr, + const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, + const int64_t* skip_if, const G* model_diff, T* adam_diff, T* model, T* m, T* v, + T* norm_buffer) { if (skip_if != nullptr && *skip_if != 0) { return; } - *beta1_t *= beta1; - *beta2_t *= beta2; + if (learning_rate_ptr != nullptr) { learning_rate_val = *learning_rate_ptr; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } + if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; } + if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; } + FOR_RANGE(int64_t, i, 0, n) { - LambGradFunctor()(beta1_t, beta2_t, model_diff + i, adam_diff + i, model + i, m + i, - v + i, scale, l1, l2, beta1, beta2, epsilon); + LambGradFunctor()(model_diff + i, adam_diff + i, model + i, m + i, v + i, scale, l1, l2, + beta1, beta2, epsilon, do_bias_correction, bias_correction1_val, + bias_correction2_val); } T* w_norm_2 = norm_buffer; T* g_norm_2 = norm_buffer + 1; Memset(stream, norm_buffer, 0, 2 * sizeof(T)); SumSquares2(n, model, w_norm_2, adam_diff, g_norm_2); - const float lr = LambLRFunctor()(*learning_rate, w_norm_2, g_norm_2); + const float lr = LambLRFunctor()(learning_rate_val, w_norm_2, g_norm_2); FOR_RANGE(int64_t, i, 0, n) { LambUpdateFunctor()(lr, weight_decay, adam_diff + i, model + i); } diff --git a/oneflow/user/kernels/model_update_kernel_util.cu b/oneflow/user/kernels/model_update_kernel_util.cu index d858fe444a4..4dd53735a10 100644 --- a/oneflow/user/kernels/model_update_kernel_util.cu +++ b/oneflow/user/kernels/model_update_kernel_util.cu @@ -322,23 +322,29 @@ __global__ void IndexedSlicesAdamUpdateGpu( template __global__ void LambGradGpu(int64_t n, T scale, float l1, float l2, float beta1, float beta2, - float epsilon, const T* beta1_t, const T* beta2_t, - const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, - T* adam_diff, T* model, T* m, T* v) { + float epsilon, const T* scale_by_ptr, const int64_t* skip_if, + const G* model_diff, T* adam_diff, T* model, T* m, T* v, + bool do_bias_correction, float bias_correction1_val, + float bias_correction2_val, const float* bias_correction1_ptr, + const float* bias_correction2_ptr) { if (skip_if != nullptr && *skip_if != 0) { return; } if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; } + if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; } + if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; } CUDA_1D_KERNEL_LOOP(i, n) { - LambGradFunctor()(beta1_t, beta2_t, model_diff + i, adam_diff + i, model + i, m + i, - v + i, scale, l1, l2, beta1, beta2, epsilon); + LambGradFunctor()(model_diff + i, adam_diff + i, model + i, m + i, v + i, scale, l1, l2, + beta1, beta2, epsilon, do_bias_correction, bias_correction1_val, + bias_correction2_val); } } template -__global__ void LambUpdateGpu(int64_t n, float weight_decay, const float* learning_rate, - const int64_t* skip_if, const T* w_norm_2, const T* g_norm_2, - const T* beta1_t, const T* beta2_t, const T* adam_diff, T* model) { +__global__ void LambUpdateGpu(int64_t n, float weight_decay, float learning_rate_val, + const float* learning_rate_ptr, const int64_t* skip_if, + const T* w_norm_2, const T* g_norm_2, const T* adam_diff, T* model) { if (skip_if != nullptr && *skip_if != 0) { return; } - const float lr = LambLRFunctor()(*learning_rate, w_norm_2, g_norm_2); + if (learning_rate_ptr != nullptr) { learning_rate_val = *learning_rate_ptr; } + const float lr = LambLRFunctor()(learning_rate_val, w_norm_2, g_norm_2); CUDA_1D_KERNEL_LOOP(i, n) { LambUpdateFunctor()(lr, weight_decay, adam_diff + i, model + i); } } @@ -447,23 +453,27 @@ template struct AdagradUpdateKernelUtil; template struct LambUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, - float beta2, float epsilon, float weight_decay, const float* learning_rate, + float beta2, float epsilon, float weight_decay, float learning_rate_val, + bool do_bias_correction, float bias_correction1_val, + float bias_correction2_val, const float* learning_rate_ptr, + const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, - T* adam_diff, T* model, T* m, T* v, T* norm_buffer, T* beta1_t, T* beta2_t); + T* adam_diff, T* model, T* m, T* v, T* norm_buffer); }; template void LambUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2, - float epsilon, float weight_decay, const float* learning_rate, const T* scale_by_ptr, - const int64_t* skip_if, const G* model_diff, T* adam_diff, T* model, T* m, T* v, T* norm_buffer, - T* beta1_t, T* beta2_t) { - AdamUpdateBetaTGpu<<<1, 1, 0, stream->As()->cuda_stream()>>>( - beta1, beta2, skip_if, beta1_t, beta2_t); + float epsilon, float weight_decay, float learning_rate_val, bool do_bias_correction, + float bias_correction1_val, float bias_correction2_val, const float* learning_rate_ptr, + const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, + const int64_t* skip_if, const G* model_diff, T* adam_diff, T* model, T* m, T* v, + T* norm_buffer) { LambGradGpu<<As()->cuda_stream()>>>( - n, scale, l1, l2, beta1, beta2, epsilon, beta1_t, beta2_t, scale_by_ptr, skip_if, model_diff, - adam_diff, model, m, v); + n, scale, l1, l2, beta1, beta2, epsilon, scale_by_ptr, skip_if, model_diff, adam_diff, model, + m, v, do_bias_correction, bias_correction1_val, bias_correction2_val, bias_correction1_ptr, + bias_correction2_ptr); T* w_norm_2 = norm_buffer; T* g_norm_2 = norm_buffer + 1; Memset(stream, norm_buffer, 0, 2 * sizeof(T)); @@ -472,28 +482,34 @@ void LambUpdateKernelUtil::Update( stream->As()->cuda_stream()>>>(n, model, w_norm_2, adam_diff, g_norm_2); LambUpdateGpu<<As()->cuda_stream()>>>( - n, weight_decay, learning_rate, skip_if, w_norm_2, g_norm_2, beta1_t, beta2_t, adam_diff, + n, weight_decay, learning_rate_val, learning_rate_ptr, skip_if, w_norm_2, g_norm_2, adam_diff, model); } template struct LambUpdateKernelUtil { static void Update(ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, - float beta2, float epsilon, float weight_decay, const float* learning_rate, + float beta2, float epsilon, float weight_decay, float learning_rate_val, + bool do_bias_correction, float bias_correction1_val, + float bias_correction2_val, const float* learning_rate_ptr, + const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, - T* adam_diff, T* model, T* m, T* v, T* norm_buffer, T* beta1_t, T* beta2_t); + T* adam_diff, T* model, T* m, T* v, T* norm_buffer); }; template void LambUpdateKernelUtil::Update( ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2, - float epsilon, float weight_decay, const float* learning_rate, const T* scale_by_ptr, + float epsilon, float weight_decay, float learning_rate_val, bool do_bias_correction, + float bias_correction1_val, float bias_correction2_val, const float* learning_rate_ptr, + const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if, const float16* model_diff, T* adam_diff, T* model, T* m, T* v, - T* norm_buffer, T* beta1_t, T* beta2_t) { + T* norm_buffer) { LambUpdateKernelUtil::Update( - stream, n, scale, l1, l2, beta1, beta2, epsilon, weight_decay, learning_rate, scale_by_ptr, - skip_if, reinterpret_cast(model_diff), adam_diff, model, m, v, norm_buffer, - beta1_t, beta2_t); + stream, n, scale, l1, l2, beta1, beta2, epsilon, weight_decay, learning_rate_val, + do_bias_correction, bias_correction1_val, bias_correction2_val, learning_rate_ptr, + bias_correction1_ptr, bias_correction2_ptr, scale_by_ptr, skip_if, + reinterpret_cast(model_diff), adam_diff, model, m, v, norm_buffer); } template struct LambUpdateKernelUtil; diff --git a/oneflow/user/kernels/model_update_kernel_util.h b/oneflow/user/kernels/model_update_kernel_util.h index b81428209e1..03ae9b819c0 100644 --- a/oneflow/user/kernels/model_update_kernel_util.h +++ b/oneflow/user/kernels/model_update_kernel_util.h @@ -120,27 +120,36 @@ struct AdagradUpdateFunctor { template struct LambGradFunctor { OF_DEVICE_FUNC - void operator()(const T* beta1_t, const T* beta2_t, const G* model_diff, T* adam_diff, T* model, - T* m, T* v, float scale, float l1, float l2, float beta1, float beta2, - float epsilon) const { + void operator()(const G* model_diff, T* adam_diff, T* model, T* m, T* v, float scale, float l1, + float l2, float beta1, float beta2, float epsilon, bool do_bias_correction, + float bias_correction1, float bias_correction2) const { const T model_val = *model; T model_diff_t = CastScaleRegularizeGradientFunctor()(*model_diff, model_val, scale, l1, l2); const T next_m = beta1 * *m + (1 - beta1) * model_diff_t; const T next_v = beta2 * *v + (1 - beta2) * model_diff_t * model_diff_t; - *adam_diff = (next_m / (1 - *beta1_t)) / (std::sqrt(next_v / (1 - *beta2_t)) + epsilon); *m = next_m; *v = next_v; + T numerator = 0; + T denominator = 0; + if (do_bias_correction) { + numerator = next_m / bias_correction1; + denominator = (sqrt(next_v) / sqrt(bias_correction2)) + epsilon; + } else { + numerator = next_m; + denominator = sqrt(next_v) + epsilon; + } + *adam_diff = numerator / denominator; } }; template struct LambLRFunctor { OF_DEVICE_FUNC - float operator()(const float learning_rate, const T* w_norm_2, const T* g_norm_2) const { - float lr = learning_rate; - const T w_norm_val = std::sqrt(*w_norm_2); - const T g_norm_val = std::sqrt(*g_norm_2); + float operator()(const float learning_rate_val, const T* w_norm_2, const T* g_norm_2) const { + float lr = learning_rate_val; + const T w_norm_val = sqrt(*w_norm_2); + const T g_norm_val = sqrt(*g_norm_2); T trust_ratio = 1; if (w_norm_val > 0 && g_norm_val > 0) { trust_ratio = w_norm_val / g_norm_val; } lr *= trust_ratio; @@ -216,9 +225,12 @@ template struct LambUpdateKernelUtil { public: static void Update(ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, - float beta2, float epsilon, float weight_decay, const float* learning_rate, + float beta2, float epsilon, float weight_decay, float learning_rate_val, + bool do_bias_correction, float bias_correction1_val, + float bias_correction2_val, const float* learning_rate_ptr, + const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff, - T* adam_diff, T* model, T* m, T* v, T* norm_buffer, T* beta1_t, T* beta2_t); + T* adam_diff, T* model, T* m, T* v, T* norm_buffer); }; template diff --git a/oneflow/user/kernels/model_update_kernels.cpp b/oneflow/user/kernels/model_update_kernels.cpp index b233b7d42f5..5a8b29326ce 100644 --- a/oneflow/user/kernels/model_update_kernels.cpp +++ b/oneflow/user/kernels/model_update_kernels.cpp @@ -676,15 +676,13 @@ class LambUpdateKernel final : public user_op::OpKernel, public user_op::CudaGra private: void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); const user_op::Tensor* model_diff = ctx->Tensor4ArgNameAndIndex("model_diff", 0); user_op::Tensor* model = ctx->Tensor4ArgNameAndIndex("model", 0); user_op::Tensor* m = ctx->Tensor4ArgNameAndIndex("m", 0); user_op::Tensor* v = ctx->Tensor4ArgNameAndIndex("v", 0); - user_op::Tensor* beta1_t = ctx->Tensor4ArgNameAndIndex("beta1_t", 0); - user_op::Tensor* beta2_t = ctx->Tensor4ArgNameAndIndex("beta2_t", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); LambTmpBufferManager tbm(tmp_buffer->mut_dptr(), model->shape().elem_cnt()); + const auto scale = ctx->Attr("scale"); const auto l1 = ctx->Attr("l1"); const auto l2 = ctx->Attr("l2"); @@ -692,6 +690,31 @@ class LambUpdateKernel final : public user_op::OpKernel, public user_op::CudaGra const auto beta2 = ctx->Attr("beta2"); const auto epsilon = ctx->Attr("epsilon"); const auto weight_decay = ctx->Attr("weight_decay"); + + const bool do_bias_correction = ctx->Attr("do_bias_correction"); + const float bias_correction1_val = ctx->Attr("bias_correction1_val"); + const float* bias_correction1_ptr = nullptr; + if (ctx->has_input("bias_correction1", 0)) { + const user_op::Tensor* bias_correction1 = ctx->Tensor4ArgNameAndIndex("bias_correction1", 0); + // Just for Lazy optional input check. + CHECK_EQ(bias_correction1->shape().elem_cnt(), 1); + bias_correction1_ptr = bias_correction1->dptr(); + } + const float bias_correction2_val = ctx->Attr("bias_correction2_val"); + const float* bias_correction2_ptr = nullptr; + if (ctx->has_input("bias_correction2", 0)) { + const user_op::Tensor* bias_correction2 = ctx->Tensor4ArgNameAndIndex("bias_correction2", 0); + CHECK_EQ(bias_correction2->shape().elem_cnt(), 1); + bias_correction2_ptr = bias_correction2->dptr(); + } + + const float learning_rate_val = ctx->Attr("learning_rate_val"); + const float* learning_rate_ptr = nullptr; + if (ctx->has_input("learning_rate", 0)) { + const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); + learning_rate_ptr = learning_rate->dptr(); + } + const T* scale_by_ptr = nullptr; if (ctx->has_input("scale_by_tensor", 0)) { const user_op::Tensor* scale_by_tensor = ctx->Tensor4ArgNameAndIndex("scale_by_tensor", 0); @@ -699,17 +722,20 @@ class LambUpdateKernel final : public user_op::OpKernel, public user_op::CudaGra CHECK_EQ(scale_by_tensor->shape().elem_cnt(), 1); scale_by_ptr = scale_by_tensor->dptr(); } + const int64_t* skip_if_ptr = nullptr; if (ctx->has_input("skip_if", 0)) { const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); CHECK_EQ(skip_if->shape().elem_cnt(), 1); skip_if_ptr = skip_if->dptr(); } + LambUpdateKernelUtil::Update( ctx->stream(), m->shape().elem_cnt(), scale, l1, l2, beta1, beta2, epsilon, weight_decay, - learning_rate->dptr(), scale_by_ptr, skip_if_ptr, model_diff->dptr(), - tbm.AdamDiffPtr(), model->mut_dptr(), m->mut_dptr(), v->mut_dptr(), - tbm.NormBufferPtr(), beta1_t->mut_dptr(), beta2_t->mut_dptr()); + learning_rate_val, do_bias_correction, bias_correction1_val, bias_correction2_val, + learning_rate_ptr, bias_correction1_ptr, bias_correction2_ptr, scale_by_ptr, skip_if_ptr, + model_diff->dptr(), tbm.AdamDiffPtr(), model->mut_dptr(), m->mut_dptr(), + v->mut_dptr(), tbm.NormBufferPtr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; } }; diff --git a/oneflow/user/ops/model_update_ops.cpp b/oneflow/user/ops/model_update_ops.cpp index d0da22056f9..7a18d87289c 100644 --- a/oneflow/user/ops/model_update_ops.cpp +++ b/oneflow/user/ops/model_update_ops.cpp @@ -253,10 +253,6 @@ Maybe InferLambUpdateTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& v = ctx->InputTensorDesc("v", 0); JUST(CheckShapeLike(&v, &model)); JUST(CheckLearningRateShape(ctx)); - const user_op::TensorDesc& beta1_t = ctx->InputTensorDesc("beta1_t", 0); - const user_op::TensorDesc& beta2_t = ctx->InputTensorDesc("beta2_t", 0); - JUST(CheckScalarShape(&beta1_t)); - JUST(CheckScalarShape(&beta2_t)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); JUST(CheckScalarShape(&scale_by_tensor)); @@ -270,11 +266,6 @@ Maybe InferLambUpdateDataType(user_op::InferContext* ctx) { JUST(CheckDataTypeLike(&m, &model)); const user_op::TensorDesc& v = ctx->InputTensorDesc("v", 0); JUST(CheckDataTypeLike(&v, &model)); - const DataType data_type = model.data_type(); - const user_op::TensorDesc& beta1_t = ctx->InputTensorDesc("beta1_t", 0); - const user_op::TensorDesc& beta2_t = ctx->InputTensorDesc("beta2_t", 0); - JUST(CheckScalarDataType(&beta1_t, data_type)); - JUST(CheckScalarDataType(&beta2_t, data_type)); JUST(CheckLearningRateDataType(ctx)); if (ctx->has_input("scale_by_tensor", 0)) { const auto& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); @@ -282,6 +273,7 @@ Maybe InferLambUpdateDataType(user_op::InferContext* ctx) { } return Maybe::Ok(); } + Maybe SetInputArgModifierMutable(const user_op::GetInputArgModifier& GetInputArgModifierFn, const std::string& arg_name, int32_t arg_index) { user_op::InputArgModifier* arg_modifier = GetInputArgModifierFn(arg_name, arg_index); @@ -311,8 +303,6 @@ Maybe LambInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArg JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "m", 0)); JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "v", 0)); - JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "beta1_t", 0)); - JUST(SetInputArgModifierMutable(GetInputArgModifierFn, "beta2_t", 0)); return Maybe::Ok(); } diff --git a/python/oneflow/nn/optimizer/lamb.py b/python/oneflow/nn/optimizer/lamb.py new file mode 100644 index 00000000000..f4f7fdf5b99 --- /dev/null +++ b/python/oneflow/nn/optimizer/lamb.py @@ -0,0 +1,254 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from typing import Callable, Dict, Iterator, List, Union, Tuple + +import math +import oneflow as flow +from oneflow.nn.optimizer.optimizer import Optimizer +from oneflow.nn.parameter import Parameter + + +class LAMB(Optimizer): + """Implements LAMB algorithm. + + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + The equation of parameters updating is: + + .. math:: + + & V_t = \\beta_1*V_{t-1} + (1-\\beta_1)*grad + + & S_t = \\beta_2*S_{t-1} + (1-\\beta_2)*{grad} \\odot {grad} + + & \\hat{u} = \\frac{{V_t}}{\\sqrt{{S_t}}+\\epsilon} + + & \\hat{r} = learning\\_rate * \\frac{||param_{old}||_2}{||\\hat{u}||_2} + + & param_{new} = param_{old} - \\hat{r} * \\hat{u} + + Args: + parameters (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + adam_w_mode (bool, optional): apply L2 regularization or weight decay True for + decoupled weight decay (also known as AdamW) (default: True) + do_bias_correction (bool, optional): whether to do bias correction (default: True) + amsgrad (bool, optional): whether to use the AMSGrad variant of this algorithm. + NOT SUPPORTED now! (default: False) + + .. _Large Batch Optimization for Deep Learning\\: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + + For example: + + Example 1: + + .. code-block:: python + + # Assume net is a custom model. + lamb = flow.optim.LAMB(net.parameters(), lr=1e-3) + + for epoch in range(epochs): + # Read data, Compute the loss and so on. + # ... + loss.backward() + lamb.step() + lamb.zero_grad() + + Example 2: + + .. code-block:: python + + # Assume net is a custom model. + lamb = flow.optim.LAMB( + [ + { + "params": net.parameters(), + "lr": learning_rate, + "clip_grad_max_norm": 0.5, + "clip_grad_norm_type": 2.0, + } + ], + ) + + for epoch in range(epochs): + # Read data, Compute the loss and so on. + # ... + loss.backward() + lamb.clip_grad() + lamb.step() + lamb.zero_grad() + + If you want to use clip_grad, you can refer this example. + + For more details of `clip_grad_max_norm` and `clip_grad_norm_type`, you can refer to :func:`oneflow.nn.utils.clip_grad_norm_`. + """ + + def __init__( + self, + parameters: Union[Iterator[Parameter], List[Dict]], + lr: float = 0.001, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-08, + weight_decay: float = 0, + adam_w_mode: bool = True, + do_bias_correction: bool = True, + amsgrad: bool = False, + ): + if amsgrad: + # TODO: supported amsgrad in Lamb + raise RuntimeError("LAMB does not support AMSGrad variant.") + assert lr >= 0.0, f"Invalid learning rate: {lr}" + assert eps >= 0.0, f"Invalid epsilon value: {eps}" + assert ( + betas[0] >= 0.0 and betas[0] < 1.0 + ), f"Invalid beta parameter at index 0: {betas[0]}" + assert ( + betas[1] >= 0.0 and betas[1] < 1.0 + ), f"Invalid beta parameter at index 1: {betas[1]}" + assert weight_decay >= 0.0, f"Invalid weight_decay value: {weight_decay}" + + options = dict() + options["lr"] = lr + options["eps"] = eps + options["betas"] = betas + options["weight_decay"] = weight_decay + options["amsgrad"] = amsgrad + options["adam_w_mode"] = adam_w_mode + options["bias_correction1"] = 1.0 + options["bias_correction2"] = 1.0 + options["do_bias_correction"] = do_bias_correction + + super().__init__(parameters, options) + + for param_group in self.param_groups: + for param in param_group.parameters: + assert param.is_leaf, "parameters must be leaf tensor" + self._state[param] = dict() + + self._op = ( + flow.stateful_op("lamb_update") + .Input("model") + .Input("model_diff") + .Input("m") + .Input("v") + .Build() + ) + + def step(self, closure: Callable = None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + with flow.no_grad(): + loss = None + if closure is not None: + loss = closure() + + for param_group in self.param_groups: + if param_group["do_bias_correction"]: + param_group["bias_correction1"] = 1.0 - math.pow( + param_group["betas"][0], self._state["step"] + 1 + ) + param_group["bias_correction2"] = 1.0 - math.pow( + param_group["betas"][1], self._state["step"] + 1 + ) + + kwargs = { + "learning_rate": param_group["lr"], + "bias_correction1": param_group["bias_correction1"], + "bias_correction2": param_group["bias_correction2"], + "beta1": param_group["betas"][0], + "beta2": param_group["betas"][1], + "epsilon": param_group["eps"], + "do_bias_correction": param_group["do_bias_correction"], + } + if param_group["adam_w_mode"]: + kwargs["weight_decay"] = param_group["weight_decay"] + kwargs["l2"] = 0.0 + else: + kwargs["l2"] = param_group["weight_decay"] + kwargs["weight_decay"] = 0.0 + for param in param_group.parameters: + if param.grad is None: + continue + if "exp_avg" not in self._state[param]: + self._state[param]["exp_avg"] = flow.zeros_like(param) + if "exp_avg_sq" not in self._state[param]: + self._state[param]["exp_avg_sq"] = flow.zeros_like(param) + m_tensor = self._state[param]["exp_avg"] + v_tensor = self._state[param]["exp_avg_sq"] + + flow._C.dispatch_lamb_update( + self._op, (param, param.grad, m_tensor, v_tensor), **kwargs + ) + + self._state["step"] += 1 + + return loss + + def _generate_conf_for_graph(self, train_conf, vars_conf): + new_opt_confs = [] + for param_group in self.param_groups: + optimizer_conf = train_conf.mutable_optimizer_conf().Add() + + lr = ( + param_group["initial_lr"] + if "initial_lr" in param_group + else param_group["lr"] + ) + adam_w_mode = param_group["adam_w_mode"] + weight_decay = param_group["weight_decay"] + beta1 = param_group["betas"][0] + beta2 = param_group["betas"][1] + do_bias_correction = param_group["do_bias_correction"] + epsilon = param_group["eps"] + + optimizer_conf.set_base_learning_rate(lr) + + optimizer_conf.mutable_lamb_conf().set_beta1(beta1) + optimizer_conf.mutable_lamb_conf().set_beta2(beta2) + optimizer_conf.mutable_lamb_conf().set_epsilon(epsilon) + optimizer_conf.mutable_lamb_conf().set_do_bias_correction( + do_bias_correction + ) + + self._generate_grad_clip_conf_for_optim_conf(param_group, optimizer_conf) + + if adam_w_mode: + optimizer_conf.mutable_weight_decay_conf().set_weight_decay_rate( + weight_decay + ) + else: + optimizer_conf.mutable_weight_decay_conf().set_weight_decay_rate(0.0) + + for param in param_group.parameters: + if not adam_w_mode: + # Set l2 penalty as weight decay if **NOT** using adam_w_mode + vars_conf[param].l2 = weight_decay + if param.requires_grad: + optimizer_conf.add_variable_op_names(vars_conf[param].name) + + new_opt_confs.append(optimizer_conf) + return new_opt_confs diff --git a/python/oneflow/optim/__init__.py b/python/oneflow/optim/__init__.py index baf0061095c..0484175bafa 100644 --- a/python/oneflow/optim/__init__.py +++ b/python/oneflow/optim/__init__.py @@ -19,6 +19,7 @@ from oneflow.nn.optimizer.rmsprop import RMSprop from oneflow.nn.optimizer.sgd import SGD from oneflow.nn.optimizer.adagrad import Adagrad +from oneflow.nn.optimizer.lamb import LAMB from . import lr_scheduler from . import utils diff --git a/python/oneflow/test/graph/test_graph_optim_lamb.py b/python/oneflow/test/graph/test_graph_optim_lamb.py new file mode 100644 index 00000000000..791b4fd98af --- /dev/null +++ b/python/oneflow/test/graph/test_graph_optim_lamb.py @@ -0,0 +1,181 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from collections import OrderedDict +import numpy as np + +from test_util import GenArgList +from optimizer_test_util import clip_grad_norm_np + +import oneflow as flow + + +def compare_with_numpy_lamb( + test_case, + device, + x_shape, + learning_rate, + train_iters, + betas, + weight_decay, + eps, + do_bias_correction, + adam_w_mode, + clip_grad_max_norm, + clip_grad_norm_type, +): + + np.random.seed(1000) + + random_grad_seq = [] + for _ in range(train_iters): + random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) + init_value = np.random.uniform(size=x_shape).astype(np.float32) + + class CustomModule(flow.nn.Module): + def __init__(self): + super().__init__() + self.param = flow.nn.Parameter( + flow.Tensor(init_value, device=flow.device(device)) + ) + + def forward(self, mask): + return self.param * mask + + simp_module = CustomModule() + simp_module.to(device) + simp_module.train() + + optim_kwargs = { + "params": simp_module.parameters(), + "lr": learning_rate, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "adam_w_mode": adam_w_mode, + "do_bias_correction": do_bias_correction, + } + + if clip_grad_max_norm != -1: + optim_kwargs["clip_grad_max_norm"] = clip_grad_max_norm + optim_kwargs["clip_grad_norm_type"] = clip_grad_norm_type + + lamb_optim = flow.optim.LAMB([optim_kwargs]) + + class CustomLambGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.m = simp_module + self.add_optimizer(lamb_optim) + + def build(self, mask_tensor): + loss = flow.sum(self.m(mask_tensor)) + loss.backward() + return loss + + lamb_graph = CustomLambGraph() + + for i in range(train_iters): + mask_tensor = flow.tensor( + random_grad_seq[i], + dtype=flow.float32, + requires_grad=False, + device=flow.device(device), + ) + lamb_graph(mask_tensor) + + of_res = simp_module.param.numpy() + + def train_by_numpy(): + x = init_value + mt = np.zeros_like(x) + vt = np.zeros_like(x) + beta1 = betas[0] + beta2 = betas[1] + if adam_w_mode: + l2 = 0 + wd = weight_decay + else: + l2 = weight_decay + wd = 0 + + def np_train_one_iter(step, grad): + if clip_grad_max_norm != -1: + _, grad = clip_grad_norm_np( + grad, clip_grad_max_norm, clip_grad_norm_type + ) + + grad = grad + l2 * x + + bias_correction1 = 1.0 + bias_correction2 = 1.0 + + if do_bias_correction: + bias_correction1 = 1.0 - np.power(beta1, step + 1) + bias_correction2 = 1.0 - np.power(beta2, step + 1) + + m = beta1 * mt + (1 - beta1) * grad + v = beta2 * vt + (1 - beta2) * grad * grad + + denom = np.sqrt(v) / np.sqrt(bias_correction2) + eps + + adam_diff = m / bias_correction1 / denom + + w_norm = np.linalg.norm(x, ord=2) + g_norm = np.linalg.norm(adam_diff, ord=2) + if w_norm > 0 and g_norm > 0: + trust_ratio = w_norm / g_norm + else: + trust_ratio = 1.0 + + param = x - learning_rate * trust_ratio * (adam_diff + wd * x) + return (param, m, v) + + for i in range(train_iters): + (x, mt, vt) = np_train_one_iter(i, random_grad_seq[i]) + return x + + np_res = train_by_numpy() + + test_case.assertTrue( + np.allclose(of_res.flatten(), np_res.flatten(), rtol=1e-3, atol=1e-3) + ) + + +@flow.unittest.skip_unless_1n1d() +class TestLamb(flow.unittest.TestCase): + def test_lamb(test_case): + arg_dict = OrderedDict() + arg_dict["device"] = ["cpu", "cuda"] + arg_dict["x_shape"] = [(10,)] + arg_dict["learning_rate"] = [0.1, 1e-3] + arg_dict["train_iters"] = [10] + arg_dict["betas"] = [(0.99, 0.9)] + arg_dict["weight_decay"] = [0.001, 0.1] + arg_dict["eps"] = [1e-8, 1e-6] + arg_dict["do_bias_correction"] = [True, False] + arg_dict["adam_w_mode"] = [True, False] + # NOTE(l1aoxingyu): max_norm = -1 means no clip grad + # nn.Graph only support `clip_grad_max_norm == 1.0` and `clip_grad_norm_type == 2.0` + arg_dict["clip_grad_max_norm"] = [-1, 1.0] + arg_dict["clip_grad_norm_type"] = [2.0] + + for arg in GenArgList(arg_dict): + compare_with_numpy_lamb(test_case, *arg) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/modules/test_optim_lamb.py b/python/oneflow/test/modules/test_optim_lamb.py new file mode 100644 index 00000000000..3c7a38337c9 --- /dev/null +++ b/python/oneflow/test/modules/test_optim_lamb.py @@ -0,0 +1,178 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import tempfile +import unittest +from collections import OrderedDict + +import numpy as np +from optimizer_test_util import clip_grad_norm_np +from test_util import GenArgList + +import oneflow as flow + + +def compare_with_numpy_lamb( + test_case, + device, + x_shape, + learning_rate, + train_iters, + betas, + weight_decay, + eps, + do_bias_correction, + adam_w_mode, + clip_grad_max_norm, + clip_grad_norm_type, + reload_state_step, + save_load_by_pickle, +): + + np.random.seed(1000) + + random_grad_seq = [] + for _ in range(train_iters): + random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32)) + init_value = np.random.uniform(size=x_shape).astype(np.float32) + + def train_by_oneflow(): + x = flow.nn.Parameter(flow.Tensor(init_value, device=flow.device(device))) + + optim_kwargs = { + "params": [x], + "lr": learning_rate, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "adam_w_mode": adam_w_mode, + "do_bias_correction": do_bias_correction, + } + + if clip_grad_max_norm != -1: + optim_kwargs["clip_grad_max_norm"] = clip_grad_max_norm + optim_kwargs["clip_grad_norm_type"] = clip_grad_norm_type + + lamb = flow.optim.LAMB([optim_kwargs]) + + def train_one_iter(grad): + grad_tensor = flow.tensor( + grad, + dtype=flow.float32, + requires_grad=False, + device=flow.device(device), + ) + + loss = flow.sum(x * grad_tensor) + loss.backward() + if clip_grad_max_norm != -1: + lamb.clip_grad() + lamb.step() + lamb.zero_grad() + + for i in range(train_iters): + train_one_iter(random_grad_seq[i]) + if i == reload_state_step: + state_dict = lamb.state_dict() + lamb = flow.optim.LAMB([optim_kwargs]) + if save_load_by_pickle: + with tempfile.TemporaryDirectory() as save_dir: + flow.save(state_dict, save_dir) + state_dict = flow.load(save_dir) + lamb.load_state_dict(state_dict) + return x + + def train_by_numpy(): + x = init_value + mt = np.zeros_like(x) + vt = np.zeros_like(x) + beta1 = betas[0] + beta2 = betas[1] + if adam_w_mode: + l2 = 0 + wd = weight_decay + else: + l2 = weight_decay + wd = 0 + + def np_train_one_iter(step, grad): + if clip_grad_max_norm != -1: + _, grad = clip_grad_norm_np( + grad, clip_grad_max_norm, clip_grad_norm_type + ) + + grad = grad + l2 * x + + bias_correction1 = 1.0 + bias_correction2 = 1.0 + + if do_bias_correction: + bias_correction1 = 1.0 - np.power(beta1, step + 1) + bias_correction2 = 1.0 - np.power(beta2, step + 1) + + m = beta1 * mt + (1 - beta1) * grad + v = beta2 * vt + (1 - beta2) * grad * grad + + denom = np.sqrt(v) / np.sqrt(bias_correction2) + eps + + adam_diff = m / bias_correction1 / denom + + w_norm = np.linalg.norm(x, ord=2) + g_norm = np.linalg.norm(adam_diff, ord=2) + if w_norm > 0 and g_norm > 0: + trust_ratio = w_norm / g_norm + else: + trust_ratio = 1.0 + + param = x - learning_rate * trust_ratio * (adam_diff + wd * x) + return (param, m, v) + + for i in range(train_iters): + (x, mt, vt) = np_train_one_iter(i, random_grad_seq[i]) + return x + + of_res = train_by_oneflow().numpy() + np_res = train_by_numpy() + + test_case.assertTrue( + np.allclose(of_res.flatten(), np_res.flatten(), rtol=1e-3, atol=1e-3) + ) + + +@flow.unittest.skip_unless_1n1d() +class TestLamb(flow.unittest.TestCase): + def test_lamb(test_case): + arg_dict = OrderedDict() + arg_dict["device"] = ["cpu", "cuda"] + arg_dict["x_shape"] = [(1,)] + arg_dict["learning_rate"] = [0.1, 1e-3] + arg_dict["train_iters"] = [10] + arg_dict["betas"] = [(0.99, 0.9)] + arg_dict["weight_decay"] = [0.001, 0.1] + arg_dict["eps"] = [1e-8, 1e-6] + arg_dict["do_bias_correction"] = [True, False] + arg_dict["adam_w_mode"] = [True, False] + # NOTE(l1aoxingyu): max_norm = -1 means no clip grad + arg_dict["clip_grad_max_norm"] = [-1, 0.0, 0.5, 1.0] + arg_dict["clip_grad_norm_type"] = ["inf", "-inf", 0.0, 1.0, 2.0, 3.5] + arg_dict["reload_state_step"] = [5] + arg_dict["save_load_by_pickle"] = [False, True] + + for arg in GenArgList(arg_dict): + compare_with_numpy_lamb(test_case, *arg) + + +if __name__ == "__main__": + unittest.main()