From ce75fae442e205800f337de553a2f0392411f6a1 Mon Sep 17 00:00:00 2001 From: Guillaume Infantes Date: Thu, 22 Aug 2019 10:32:30 +0200 Subject: [PATCH] dd api for caffe warmup --- src/backends/caffe/caffelib.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/backends/caffe/caffelib.cc b/src/backends/caffe/caffelib.cc index 12dc426ee..7737d3477 100644 --- a/src/backends/caffe/caffelib.cc +++ b/src/backends/caffe/caffelib.cc @@ -1133,6 +1133,10 @@ namespace dd solver_param.set_lr_policy(ad_solver.get("lr_policy").get()); if (ad_solver.has("base_lr")) solver_param.set_base_lr(ad_solver.get("base_lr").get()); + if (ad_solver.has("warmup_lr")) + solver_param.set_warmup_start_lr(ad_solver.get("warmup_lr").get()); + if (ad_solver.has("warmup_iter")) + solver_param.set_warmup_iter(ad_solver.get("warmup_iter").get()); if (ad_solver.has("gamma")) solver_param.set_gamma(ad_solver.get("gamma").get()); if (ad_solver.has("stepsize")) @@ -1409,6 +1413,8 @@ namespace dd { caffe::SGDSolver *sgd_solver = static_cast*>(solver.get()); this->_logger->info("Iteration {}, lr = {}, smoothed_loss={}",solver->iter_,sgd_solver->GetLearningRate(),this->get_meas("train_loss")); + if (sgd_solver->param_.warmup_iter() > 0) + this->_logger->info("[doing warmup (start_lr = {}, iter = {})]",solver->param_.warmup_start_lr(),solver->param_.warmup_iter()); } try {