Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add lamb optimizer #7389

Merged
merged 14 commits into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Optimizers
Optimizer,
RMSprop,
SGD,
LAMB,
lr_scheduler

.. automodule:: oneflow.optim.lr_scheduler
Expand Down
20 changes: 20 additions & 0 deletions oneflow/api/python/functional/dispatch_stateful_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,26 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok();
});
m.add_functor("DispatchLambUpdate",
[](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs,
float learning_rate, float bias_correction1, float bias_correction2,
double scale, float l1, float l2, float beta1, float beta2, float epsilon,
float weight_decay, bool do_bias_correction) -> Maybe<void> {
MutableAttrMap attrs;
JUST(attrs.SetAttr("learning_rate_val", learning_rate));
JUST(attrs.SetAttr("bias_correction1_val", bias_correction1));
JUST(attrs.SetAttr("bias_correction2_val", bias_correction2));
JUST(attrs.SetAttr("scale", scale));
JUST(attrs.SetAttr("l1", l1));
JUST(attrs.SetAttr("l2", l2));
JUST(attrs.SetAttr("beta1", beta1));
JUST(attrs.SetAttr("beta2", beta2));
JUST(attrs.SetAttr("epsilon", epsilon));
JUST(attrs.SetAttr("weight_decay", weight_decay));
JUST(attrs.SetAttr("do_bias_correction", do_bias_correction));
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok();
});
m.add_functor("DispatchEagerNcclAllReduce",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,
const std::string& parallel_conf, bool async_launch) -> Maybe<Tensor> {
Expand Down
4 changes: 4 additions & 0 deletions oneflow/api/python/functional/dispatch_stateful_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@
- name: "dispatch_sgd_update"
signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float weight_decay=0) => DispatchSgdUpdate"
bind_python: True

- name: "dispatch_lamb_update"
signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Float bias_correction1=1.0, Float bias_correction2=1.0, Double scale=1.0, Float l1=0, Float l2=0, Float beta1=0.9, Float beta2=0.999, Float epsilon=1e-8, Float weight_decay=0, Bool do_bias_correction=True) => DispatchLambUpdate"
bind_python: True

- name: "dispatch_eager_nccl_all_reduce"
signature: "Tensor (OpExpr op, Tensor input, String parallel_conf, Bool async_launch=False) => DispatchEagerNcclAllReduce"
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/job/job_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ message LazyAdamModelUpdateConf {
}

message LambModelUpdateConf {
required float beta1 = 1;
required float beta2 = 2;
required float epsilon = 3;
optional float beta1 = 1 [default = 0.9];
optional float beta2 = 2 [default = 0.999];
optional float epsilon = 3 [default = 1e-8];
optional bool do_bias_correction = 4 [default = true];
}

message AdagradModelUpdateConf {
Expand Down
38 changes: 19 additions & 19 deletions oneflow/core/job_rewriter/adam_optm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,6 @@ struct hash<oneflow::BiasCorrectionFactorCacheKey> {

namespace oneflow {

namespace {

std::string GenVariableOutputLbn(const OperatorConf& op_conf) {
CHECK(op_conf.has_variable_conf());
return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out());
}

OperatorConf GenerateAdamHelperVariableOpConf(const VariableOp& op, const std::string& name,
const float initial_value) {
OperatorConf helper_variable_op(op.op_conf());
helper_variable_op.set_name(op.op_name() + "-" + name);
helper_variable_op.mutable_variable_conf()->set_out("out");
InitializerConf constant_initializer;
constant_initializer.mutable_constant_conf()->set_value(initial_value);
*(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer;
helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id());
return helper_variable_op;
}

class BiasCorrectionFactorState final : public JobPassState {
public:
BiasCorrectionFactorState() {}
Expand All @@ -88,6 +69,25 @@ class BiasCorrectionFactorState final : public JobPassState {
HashMap<BiasCorrectionFactorCacheKey, std::string> key2lbn_;
};

namespace {

std::string GenVariableOutputLbn(const OperatorConf& op_conf) {
CHECK(op_conf.has_variable_conf());
return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out());
}

OperatorConf GenerateAdamHelperVariableOpConf(const VariableOp& op, const std::string& name,
const float initial_value) {
OperatorConf helper_variable_op(op.op_conf());
helper_variable_op.set_name(op.op_name() + "-" + name);
helper_variable_op.mutable_variable_conf()->set_out("out");
InitializerConf constant_initializer;
constant_initializer.mutable_constant_conf()->set_value(initial_value);
*(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer;
helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id());
return helper_variable_op;
}

void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,
const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf,
JobBuilder* job_builder) {
Expand Down
143 changes: 121 additions & 22 deletions oneflow/core/job_rewriter/lamb_optm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,46 @@ limitations under the License.
*/
#include "oneflow/core/job_rewriter/optimizer.h"
#include "oneflow/core/framework/framework.h"

namespace oneflow {

struct BiasCorrectionFactorCacheKey {
float beta = 1.0;
ParallelConf parallel_conf;
};

bool operator==(const BiasCorrectionFactorCacheKey& lhs, const BiasCorrectionFactorCacheKey& rhs);

} // namespace oneflow

namespace std {
template<>
struct hash<oneflow::BiasCorrectionFactorCacheKey> {
size_t operator()(const oneflow::BiasCorrectionFactorCacheKey& key) const {
const auto& float_hash = std::hash<float>();
const auto& parallel_conf_hash = std::hash<oneflow::ParallelConf>();
return float_hash(key.beta) ^ parallel_conf_hash(key.parallel_conf);
}
};

} // namespace std

namespace oneflow {

// Forward declaration for bias correction factor
class BiasCorrectionFactorState final : public JobPassState {
public:
BiasCorrectionFactorState() {}
~BiasCorrectionFactorState() override = default;

std::string GetLbn(float beta, std::string bias_correction_name, ParallelConf parallel_conf,
const std::function<std::string(float beta_val, std::string op_name)>&
BiasCorrectionFactorStateOp);

private:
HashMap<BiasCorrectionFactorCacheKey, std::string> key2lbn_;
};

namespace {

std::string GenVariableOutputLbn(const OperatorConf& op_conf) {
Expand Down Expand Up @@ -47,37 +85,98 @@ void SetScalarShapeAndNdSbpConf(const ParallelDesc& parallel_desc, OperatorConf*
}

void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,
const std::string& model_diff_lbn, const OptimizerConf optimizer_conf,
const std::string& model_diff_lbn, const OptimizerConf& optimizer_conf,
JobBuilder* job_builder) {
const VariableOp* var_op = dynamic_cast<const VariableOp*>(&var_op_node.op());
CHECK_NOTNULL(var_op);

OperatorConf m_var = GenerateLAMBHelperVariableOpConf(*var_op, "m", 0.f);
OperatorConf v_var = GenerateLAMBHelperVariableOpConf(*var_op, "v", 0.f);

job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {m_var, v_var});

OperatorConf beta1_t_var;
OperatorConf beta2_t_var;
user_op::UserOpConfWrapperBuilder lamb_update_op_builder(var_op->op_name() + "_optimizer");

const LambModelUpdateConf& lamb_conf = optimizer_conf.lamb_conf();
beta1_t_var = GenerateLAMBHelperVariableOpConf(*var_op, "beta1_t", lamb_conf.beta1());
SetScalarShapeAndNdSbpConf(var_op_node.parallel_desc(), &beta1_t_var);
beta2_t_var = GenerateLAMBHelperVariableOpConf(*var_op, "beta2_t", lamb_conf.beta2());
SetScalarShapeAndNdSbpConf(var_op_node.parallel_desc(), &beta2_t_var);
job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {beta1_t_var, beta2_t_var});
float beta1 = lamb_conf.beta1();
float beta2 = lamb_conf.beta2();
float epsilon = lamb_conf.epsilon();
bool do_bias_correction = lamb_conf.do_bias_correction();

const std::string& train_step_lbn = job_builder->job().job_conf().train_conf().train_step_lbn();
const std::string& learning_rate_lbn = optimizer_conf.learning_rate_lbn();

if (do_bias_correction) {
// Reuse adam bias_correction job pass
const std::string& job_pass_state_key = "adam_bias_correction_factor";
const bool has_state = CHECK_JUST(ctx->HasState<BiasCorrectionFactorState>(job_pass_state_key));
if (!has_state) {
CHECK_JUST(
ctx->ResetState(job_pass_state_key, std::make_unique<BiasCorrectionFactorState>()));
}
auto* state = CHECK_JUST(ctx->MutableState<BiasCorrectionFactorState>(job_pass_state_key));
ParallelConf bias_correction_parallel_conf;
const auto& lr_parallel_conf =
CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(learning_rate_lbn)));
const auto& train_step_parallel_conf =
CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(train_step_lbn)));
if (lr_parallel_conf == train_step_parallel_conf) {
bias_correction_parallel_conf = lr_parallel_conf;
} else {
bias_correction_parallel_conf = var_op_node.parallel_desc().parallel_conf();
}
auto AddLambBiasCorrectionFactorOp = [&](float beta_val,
const std::string& op_name) -> std::string {
user_op::UserOpConfWrapperBuilder op_builder(var_op->op_name() + op_name);
const auto lamb_bias_correction_factor_op =
op_builder.OpTypeName("adam_bias_correction_factor")
.Input("train_step", train_step_lbn)
.Attr<float>("beta", beta_val)
.Output("out")
.ScopeSymbolId(var_op->op_conf().scope_symbol_id())
.Build();

job_builder->AddOps(bias_correction_parallel_conf,
{lamb_bias_correction_factor_op.op_conf()});
return lamb_bias_correction_factor_op.output("out", 0);
};

const std::string bias_correction1_lbn =
state->GetLbn(beta1, "lamb_bias_correction_factor1", bias_correction_parallel_conf,
AddLambBiasCorrectionFactorOp);
const std::string bias_correction2_lbn =
state->GetLbn(beta2, "lamb_bias_correction_factor2", bias_correction_parallel_conf,
AddLambBiasCorrectionFactorOp);

lamb_update_op_builder.OpTypeName("lamb_update")
.Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out")))
.Input("model_diff", model_diff_lbn)
.Input("m", GenVariableOutputLbn(m_var))
.Input("v", GenVariableOutputLbn(v_var))
.Input("learning_rate", learning_rate_lbn)
.Input("bias_correction1", bias_correction1_lbn)
.Input("bias_correction2", bias_correction2_lbn)
.Attr<float>("beta1", beta1)
.Attr<float>("beta2", beta2)
.Attr<float>("epsilon", epsilon)
.Attr<float>("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))
.Attr<bool>("do_bias_correction", true)
.ScopeSymbolId(var_op->op_conf().scope_symbol_id());
} else {
lamb_update_op_builder.OpTypeName("lamb_update")
.Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out")))
.Input("model_diff", model_diff_lbn)
.Input("m", GenVariableOutputLbn(m_var))
.Input("v", GenVariableOutputLbn(v_var))
.Input("learning_rate", learning_rate_lbn)
.Attr<float>("beta1", beta1)
.Attr<float>("beta2", beta2)
.Attr<float>("epsilon", epsilon)
.Attr<float>("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))
.Attr<bool>("do_bias_correction", false)
.ScopeSymbolId(var_op->op_conf().scope_symbol_id());
}

user_op::UserOpConfWrapperBuilder lamb_update_op_builder(var_op->op_name() + "_optimizer");
lamb_update_op_builder.OpTypeName("lamb_update")
.Input("m", GenVariableOutputLbn(m_var))
.Input("v", GenVariableOutputLbn(v_var))
.Input("beta1_t", GenVariableOutputLbn(beta1_t_var))
.Input("beta2_t", GenVariableOutputLbn(beta2_t_var))
.Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out")))
.Input("model_diff", model_diff_lbn)
.Input("learning_rate", optimizer_conf.learning_rate_lbn())
.Attr<float>("beta1", lamb_conf.beta1())
.Attr<float>("beta2", lamb_conf.beta2())
.Attr<float>("epsilon", lamb_conf.epsilon())
.Attr<float>("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))
.ScopeSymbolId(var_op->op_conf().scope_symbol_id());
SetDynamicLossScaleSkipIf(ctx, &lamb_update_op_builder);
const auto lamb_update_op = lamb_update_op_builder.Build();
job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {lamb_update_op.op_conf()});
Expand Down
24 changes: 14 additions & 10 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5406,24 +5406,28 @@ def OneFlow_IndexedSlicesSgdUpdateOp : OneFlow_BaseOp<"indexed_slices_sgd_update

def OneFlow_LambUpdateOp : OneFlow_BaseOp<"lamb_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$m,
OneFlow_Tensor:$v,
OneFlow_Tensor:$beta1_t,
OneFlow_Tensor:$beta2_t,
OneFlow_Tensor:$model,
OneFlow_Tensor:$model_diff,
OneFlow_Tensor:$learning_rate,
Optional<OneFlow_Tensor>:$learning_rate,
Optional<OneFlow_Tensor>:$scale_by_tensor,
Optional<OneFlow_Tensor>:$skip_if
Optional<OneFlow_Tensor>:$skip_if,
Optional<OneFlow_Tensor>:$bias_correction1,
Optional<OneFlow_Tensor>:$bias_correction2,
OneFlow_Tensor:$m,
OneFlow_Tensor:$v
);
let attrs = (ins
DefaultValuedAttr<F32Attr, "0.">:$beta1,
DefaultValuedAttr<F32Attr, "0.">:$beta2,
DefaultValuedAttr<F32Attr, "0.">:$epsilon,
DefaultValuedAttr<F32Attr, "0.">:$learning_rate_val,
DefaultValuedAttr<F32Attr, "1.">:$bias_correction1_val,
DefaultValuedAttr<F32Attr, "1.">:$bias_correction2_val,
DefaultValuedAttr<F64Attr, "1.">:$scale,
DefaultValuedAttr<F32Attr, "0.">:$l1,
DefaultValuedAttr<F32Attr, "0.">:$l2,
DefaultValuedAttr<F32Attr, "0.">:$weight_decay
DefaultValuedAttr<F32Attr, "0.9">:$beta1,
DefaultValuedAttr<F32Attr, "0.999">:$beta2,
DefaultValuedAttr<F32Attr, "0.">:$epsilon,
DefaultValuedAttr<F32Attr, "0.">:$weight_decay,
DefaultValuedAttr<BoolAttr, "true">:$do_bias_correction
);
let trait_attrs = (ins
I32ElementsAttr:$operand_segment_sizes
Expand Down
28 changes: 18 additions & 10 deletions oneflow/user/kernels/model_update_kernel_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,30 +272,38 @@ template struct AdagradUpdateKernelUtil<DeviceType::kCPU, double, double>;
template<typename T, typename G>
struct LambUpdateKernelUtil<DeviceType::kCPU, T, G> {
static void Update(ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1,
float beta2, float epsilon, float weight_decay, const float* learning_rate,
float beta2, float epsilon, float weight_decay, float learning_rate_val,
bool do_bias_correction, float bias_correction1_val,
float bias_correction2_val, const float* learning_rate_ptr,
const float* bias_correction1_ptr, const float* bias_correction2_ptr,
const T* scale_by_ptr, const int64_t* skip_if, const G* model_diff,
T* adam_diff, T* model, T* m, T* v, T* norm_buffer, T* beta1_t, T* beta2_t);
T* adam_diff, T* model, T* m, T* v, T* norm_buffer);
};

template<typename T, typename G>
void LambUpdateKernelUtil<DeviceType::kCPU, T, G>::Update(
ep::Stream* stream, int64_t n, float scale, float l1, float l2, float beta1, float beta2,
float epsilon, float weight_decay, const float* learning_rate, const T* scale_by_ptr,
const int64_t* skip_if, const G* model_diff, T* adam_diff, T* model, T* m, T* v, T* norm_buffer,
T* beta1_t, T* beta2_t) {
float epsilon, float weight_decay, float learning_rate_val, bool do_bias_correction,
float bias_correction1_val, float bias_correction2_val, const float* learning_rate_ptr,
const float* bias_correction1_ptr, const float* bias_correction2_ptr, const T* scale_by_ptr,
const int64_t* skip_if, const G* model_diff, T* adam_diff, T* model, T* m, T* v,
T* norm_buffer) {
if (skip_if != nullptr && *skip_if != 0) { return; }
*beta1_t *= beta1;
*beta2_t *= beta2;
if (learning_rate_ptr != nullptr) { learning_rate_val = *learning_rate_ptr; }
if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }
if (bias_correction1_ptr != nullptr) { bias_correction1_val = *bias_correction1_ptr; }
if (bias_correction2_ptr != nullptr) { bias_correction2_val = *bias_correction2_ptr; }

FOR_RANGE(int64_t, i, 0, n) {
LambGradFunctor<T, G>()(beta1_t, beta2_t, model_diff + i, adam_diff + i, model + i, m + i,
v + i, scale, l1, l2, beta1, beta2, epsilon);
LambGradFunctor<T, G>()(model_diff + i, adam_diff + i, model + i, m + i, v + i, scale, l1, l2,
beta1, beta2, epsilon, do_bias_correction, bias_correction1_val,
bias_correction2_val);
}
T* w_norm_2 = norm_buffer;
T* g_norm_2 = norm_buffer + 1;
Memset<DeviceType::kCPU>(stream, norm_buffer, 0, 2 * sizeof(T));
SumSquares2(n, model, w_norm_2, adam_diff, g_norm_2);
const float lr = LambLRFunctor<T>()(*learning_rate, w_norm_2, g_norm_2);
const float lr = LambLRFunctor<T>()(learning_rate_val, w_norm_2, g_norm_2);
FOR_RANGE(int64_t, i, 0, n) {
LambUpdateFunctor<T>()(lr, weight_decay, adam_diff + i, model + i);
}
Expand Down
Loading