From bcfb82d33e431d621317f97d3c0703d9b002a8ee Mon Sep 17 00:00:00 2001
From: typhoonzero <typhoonzero1986@gmail.com>
Date: Mon, 15 Jan 2018 20:55:48 +0800
Subject: [PATCH 1/5] dist train support split selectedrows

---
 .../paddle/v2/fluid/distribute_transpiler.py  | 45 +++++++++++++++++++
 1 file changed, 45 insertions(+)

diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py
index d17f9815cca5e3..00fe3e68c90086 100644
--- a/python/paddle/v2/fluid/distribute_transpiler.py
+++ b/python/paddle/v2/fluid/distribute_transpiler.py
@@ -59,6 +59,51 @@ def split_dense_variable(var_list,
     return blocks
 
 
+def split_selected_rows(var,
+                        pserver_count,
+                        min_block_size=1024,
+                        max_block_size=1048576):
+    assert ((len(var.shape)) <= 1)
+
+    split_count = pserver_count
+    indices = var.desc.selected_rows().dims()
+    var_width = reduce(lambda x, y: x * y, var.shape[1:])
+    row_count = len(indices)
+    rows_per_block = 1
+    if var_width < min_block_size:
+        rows_per_block = 1
+        split_count = row_count
+    else:
+        rows_per_block = row_count / pserver_count
+        if not rows_per_block % pserver_count:
+            rows_per_block += 1
+        split_count = row_count / rows_per_block
+        if not row_count % rows_per_block:
+            split_count += 1
+    blocks = []
+    for block_id in xrange(split_count):
+        curr_block_rows = min(rows_per_block,
+                              row_count - (block_id * rows_per_block))
+        block = VarBlock(var.name, block_id, curr_block_rows)
+        blocks.append(block)
+    return blocks
+
+
+def split_variable(var_list,
+                   pserver_count,
+                   min_block_size=1024,
+                   max_block_size=1048576):
+    for var in var_list:
+        if var.type == core.VarDesc.VarType.LOD_TENSOR:
+            split_dense_variable(var_list, pserver_count, min_block_size,
+                                 max_block_size)
+        elif var.type == core.VarDesc.VarType.SELECTED_ROWS:
+            split_selected_rows(var_list, pserver_count, min_block_size,
+                                max_block_size)
+        else:
+            raise TypeError("variable must be lodtensor or selected rows")
+
+
 class DistributeTranspiler:
     def transpile(self,
                   optimize_ops,

From 02ea349101662e5ad5199dac47b48f1835eda361 Mon Sep 17 00:00:00 2001
From: typhoonzero <typhoonzero1986@gmail.com>
Date: Wed, 17 Jan 2018 18:02:45 +0800
Subject: [PATCH 2/5] enhance dist train performance

---
 paddle/operators/detail/grpc_client.cc        |  5 +-
 paddle/operators/detail/grpc_client.h         |  2 +-
 paddle/operators/recv_op.cc                   | 66 ++++++++-----------
 paddle/operators/send_op.cc                   |  6 +-
 .../paddle/v2/fluid/distribute_transpiler.py  | 15 ++++-
 .../notest_recognize_digits_conv_dist.py      | 17 ++---
 6 files changed, 55 insertions(+), 56 deletions(-)

diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc
index 5a4db2d7e686ce..521760228b5d77 100644
--- a/paddle/operators/detail/grpc_client.cc
+++ b/paddle/operators/detail/grpc_client.cc
@@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
   sendrecv::VariableMessage req;
   req.set_varname(var_name);
 
-  auto* var = scope.FindVar(var_name);
-  SerializeToMessage(var_name, var, ctx, &req);
-
   // varhandle
   VarHandle var_h;
   var_h.ep = ep;
@@ -87,7 +84,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
   return true;
 }
 
-bool RPCClient::wait() {
+bool RPCClient::Wait() {
   bool ok = true;
 
   while (true) {
diff --git a/paddle/operators/detail/grpc_client.h b/paddle/operators/detail/grpc_client.h
index d27b5ced9ece67..a62e70a2533ae5 100644
--- a/paddle/operators/detail/grpc_client.h
+++ b/paddle/operators/detail/grpc_client.h
@@ -130,7 +130,7 @@ class RPCClient {
                         const framework::Scope& scope,
                         const std::string& var_name,
                         int64_t time_out = 600 * 1000);
-  bool wait();
+  bool Wait();
 
  private:
   bool Proceed();
diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc
index 55b33343af4380..dea7db391cf563 100644
--- a/paddle/operators/recv_op.cc
+++ b/paddle/operators/recv_op.cc
@@ -27,6 +27,7 @@ limitations under the License. */
 #include "paddle/operators/detail/grpc_server.h"
 #include "paddle/operators/detail/sendrecvop_utils.h"
 #include "paddle/operators/detail/simple_block_queue.h"
+#include "paddle/string/printf.h"
 
 #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
 
@@ -77,35 +78,37 @@ class RecvOp : public framework::OperatorBase {
     if (grads_counter_.find(varname) == grads_counter_.end()) {
       grads_counter_[varname] = 0;
     }
-    char ret[256];
-    snprintf(ret, sizeof(ret), "%s.trainer_%d", varname.c_str(),
-             grads_counter_[varname]++);
-    return std::string(ret);
+    return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
   }
 
   void Run(const framework::Scope &scope,
            const platform::Place &dev_place) const override {
-    // FIXME(typhoonzero): no new scopes for every run.
+    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
+    auto &dev_ctx = *pool.Get(dev_place);
     framework::Scope &recv_scope = scope.NewScope();
     rpc_service_->SetScope(&recv_scope);
     auto param_list = Attr<std::vector<std::string>>("ParamList");
     auto grad_list = Attr<std::vector<std::string>>("GradList");
-    auto trainer_count = Attr<int>("Trainers");
+    auto fan_in = Attr<int>("Fanin");
     size_t param_count = param_list.size();
 
+    std::string program_str = Attr<std::string>("OptimizeProgram");
+    framework::proto::ProgramDesc program_desc;
+    program_desc.ParseFromString(program_str);
+    framework::ProgramDesc program(program_desc);
+    framework::Executor executor(dev_place);
+
     rpc_service_->Reset();
     // TODO(typhoonzero): change this to a while_op for every cluster-batch.
     bool exit_flag = false;
     while (!exit_flag) {
-      // TODO(gognwb): simply this loop.
-      // Get from multiple trainers, we don't care about order in which
-      // the gradient arrives, just add suffix 0~n then average the gradient.
-      for (size_t i = 0; i < param_count * trainer_count; ++i) {
-        // blocking get one var from client.
+      // Get from multiple trainers, we don't care about the order in which
+      // the gradients arrives, just add suffix 0~n and merge the gradient.
+      for (size_t i = 0; i < param_count * fan_in; ++i) {
         const detail::MessageWithName &v = rpc_service_->Get();
         auto grad_var_name = v.first;
         if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
-          VLOG(4) << "received LISTEN_TERMINATE_MESSAGE and RunOp.Run() exit";
+          LOG(INFO) << "received terminate message and exit";
           exit_flag = true;
           break;
         }
@@ -114,44 +117,27 @@ class RecvOp : public framework::OperatorBase {
         if (it != grad_list.end()) {
           param_var_name = param_list[it - grad_list.begin()];
         } else {
-          LOG(ERROR) << "grad have no paired param found!\"" << grad_var_name
-                     << "\"";
+          LOG(ERROR) << "grad have no paired param:" << grad_var_name;
         }
         VLOG(3) << "recved grad: " << grad_var_name
                 << " updating param: " << param_var_name;
-
-        auto *merged_grad = recv_scope.FindVar(grad_var_name);
-        if (merged_grad == nullptr) {
-          auto *ptr = recv_scope.Var(grad_var_name);
-          CreateTensorFromMessageType(ptr, v.second.type());
-          VLOG(3) << "Create Variable " << grad_var_name
-                  << " on recv scope, which pointer is " << ptr << " type is "
-                  << v.second.type();
+        // Assume grad_var_name must appear in global scope.
+        std::string grad_var_name_trainer;
+        if (fan_in > 1) {
+          grad_var_name_trainer = this->GetGradVarNameForTrainer(grad_var_name);
         }
-
-        if (trainer_count > 1) {
-          grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
+        auto *var = recv_scope.FindVar(grad_var_name_trainer);
+        if (var == nullptr) {
+          LOG(ERROR) << "can not find server side var: "
+                     << grad_var_name_trainer;
+          PADDLE_THROW("can not find server side var");
         }
-
-        auto *var = recv_scope.Var(grad_var_name);
-        platform::DeviceContextPool &pool =
-            platform::DeviceContextPool::Instance();
-        auto &dev_ctx = *pool.Get(dev_place);
         detail::DeserializeFromMessage(v.second, dev_ctx, var);
       }
-
       if (exit_flag) {
         break;
       }
-
       rpc_service_->Reset();
-
-      std::string program_str = Attr<std::string>("OptimizeProgram");
-      framework::proto::ProgramDesc program_desc;
-      program_desc.ParseFromString(program_str);
-      framework::ProgramDesc program(program_desc);
-      framework::Executor executor(dev_place);
-      // Run sub graph to get optimized tensor
       try {
         executor.Run(program, &recv_scope, 0, /*global_block*/
                      false /*create_local_scope*/, false /*create_vars*/);
@@ -195,7 +181,7 @@ This operator will recv tensor from send_op
         "GradList", "type list of string",
         "grad->param name mapping to find which param to optimize.")
         .SetDefault({});
-    AddAttr<int>("Trainers", "type int",
+    AddAttr<int>("Fanin", "type int",
                  "Number of trainers in the current cluster job")
         .SetDefault(1);
   }
diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc
index 4d145250bdc736..d65153c1fdb5bd 100644
--- a/paddle/operators/send_op.cc
+++ b/paddle/operators/send_op.cc
@@ -41,14 +41,16 @@ class SendOp : public framework::OperatorBase {
     // FIXME(gongwb): DeviceContext?
     auto ctx = platform::CPUDeviceContext();
     for (size_t i = 0; i < ins.size(); i++) {
+      VLOG(3) << "sending " << ins[i];
       client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
     }
+    client_.Wait();
 
     for (size_t i = 0; i < outs.size(); i++) {
+      VLOG(3) << "getting " << outs[i];
       client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
     }
-
-    client_.wait();
+    client_.Wait();
   }
 
  private:
diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py
index 00fe3e68c90086..9876296a37ae1a 100644
--- a/python/paddle/v2/fluid/distribute_transpiler.py
+++ b/python/paddle/v2/fluid/distribute_transpiler.py
@@ -452,6 +452,19 @@ def get_pserver_program(self, endpoint, optimize_ops):
         pserver_program = Program()
         for v in self.param_grad_ep_mapping[endpoint]["params"]:
             self._clone_var(pserver_program.global_block(), v)
+        for v in self.param_grad_ep_mapping[endpoint]["grads"]:
+            # create vars for each trainer in global scope, so
+            # we don't need to create them when grad arrives.
+            pserver_program.global_block().create_var(
+                name=v.name, persistable=True, dtype=v.dtype, shape=v.shape)
+            for trainer_id in xrange(self.trainers):
+                print("create variable for program: %s.trainer_%d" %
+                      (v.name, trainer_id))
+                pserver_program.global_block().create_var(
+                    name="%s.trainer_%d" % (v.name, trainer_id),
+                    persistable=True,
+                    dtype=v.dtype,
+                    shape=v.shape)
         # step6
         optimize_sub_program = Program()
         for idx, opt_op in enumerate(optimize_ops):
@@ -481,7 +494,7 @@ def get_pserver_program(self, endpoint, optimize_ops):
                     p.name
                     for p in self.param_grad_ep_mapping[endpoint]["grads"]
                 ],
-                "Trainers": self.trainers
+                "Fanin": self.trainers
             })
         pserver_program.sync_with_cpp()
         return pserver_program
diff --git a/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py b/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
index 20b4a8b34cd085..e563e0ddc5d799 100644
--- a/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
+++ b/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
@@ -39,26 +39,27 @@
 place = fluid.CPUPlace()
 exe = fluid.Executor(place)
 
-t = fluid.DistributeTranspiler()
-# all parameter server endpoints list for spliting parameters
-pserver_endpoints = os.getenv("PSERVERS")
-# server endpoint for current node
-current_endpoint = os.getenv("SERVER_ENDPOINT")
-# run as trainer or parameter server
+pserver_endpoints = os.getenv("PSERVERS")  # all pserver endpoints
+trainers = int(os.getenv("TRAINERS"))  # total trainer count
+current_endpoint = os.getenv("SERVER_ENDPOINT")  # current pserver endpoint
 training_role = os.getenv("TRAINING_ROLE",
                           "TRAINER")  # get the training role: trainer/pserver
-t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
+t = fluid.DistributeTranspiler()
+t.transpile(
+    optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers)
 
 if training_role == "PSERVER":
     if not current_endpoint:
         print("need env SERVER_ENDPOINT")
         exit(1)
     pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
-    exe.run(fluid.default_startup_program())
+    pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
+    exe.run(pserver_startup)
     exe.run(pserver_prog)
 elif training_role == "TRAINER":
     trainer_prog = t.get_trainer_program()
     feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
+    # TODO(typhoonzero): change trainer startup program to fetch parameters from pserver
     exe.run(fluid.default_startup_program())
 
     for pass_id in range(PASS_NUM):

From ae19d2ea1ecd28db7f5704da4cb07c59e038e195 Mon Sep 17 00:00:00 2001
From: typhoonzero <typhoonzero1986@gmail.com>
Date: Thu, 18 Jan 2018 18:27:32 +0800
Subject: [PATCH 3/5] fix comm issues

---
 paddle/operators/detail/grpc_server.cc | 47 +++++++++++++++-----------
 paddle/operators/detail/grpc_server.h  | 15 ++++----
 paddle/operators/recv_op.cc            | 15 +++++---
 3 files changed, 48 insertions(+), 29 deletions(-)

diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc
index c0b94746a0b7f6..42d3cc57584d9d 100644
--- a/paddle/operators/detail/grpc_server.cc
+++ b/paddle/operators/detail/grpc_server.cc
@@ -36,7 +36,10 @@ class RequestBase {
 
   CallStatus Status() { return status_; }
   void SetStatus(CallStatus status) { status_ = status; }
-  virtual std::string GetReqName() { assert(false); }
+  virtual std::string GetReqName() {
+    assert(false);
+    return "";
+  }
 
  protected:
   grpc::ServerContext ctx_;
@@ -80,11 +83,13 @@ class RequestGet final : public RequestBase {
  public:
   explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
                       grpc::ServerCompletionQueue* cq, framework::Scope* scope,
-                      const platform::DeviceContext* dev_ctx)
+                      const platform::DeviceContext* dev_ctx,
+                      SimpleBlockQueue<char>* queue)
       : RequestBase(service, cq),
         responder_(&ctx_),
         scope_(scope),
-        dev_ctx_(dev_ctx) {
+        dev_ctx_(dev_ctx),
+        queue_(queue) {
     service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
   }
 
@@ -100,6 +105,7 @@ class RequestGet final : public RequestBase {
     // TODO(gongwb): check var's info.
     responder_.Finish(reply_, grpc::Status::OK, this);
     status_ = FINISH;
+    queue_->Push('c');
   }
 
  protected:
@@ -108,8 +114,15 @@ class RequestGet final : public RequestBase {
   ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
   framework::Scope* scope_;
   const platform::DeviceContext* dev_ctx_;
+  SimpleBlockQueue<char>* queue_;
 };
 
+void AsyncGRPCServer::WaitClientGet(int count) {
+  for (int i = 0; i < count; ++i) {
+    var_get_queue_.Pop();
+  }
+}
+
 void AsyncGRPCServer::RunSyncUpdate() {
   grpc::ServerBuilder builder;
   builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
@@ -170,7 +183,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
   if (is_shut_down_) {
     return;
   }
-  RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_);
+  RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
+                                   &var_get_queue_);
   VLOG(4) << "create Requestget status:" << get->Status();
 }
 
@@ -188,9 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
     }
 
     PADDLE_ENFORCE(tag);
-    if (wait && !done_) {
-      Wait();
-    }
+    if (cq_name == "cq_get") WaitCond(2);
+    if (cq_name == "cq_send") WaitCond(0);
 
     RequestBase* base = (RequestBase*)tag;
     // reference:
@@ -222,22 +235,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
   }
 }
 
-void AsyncGRPCServer::Wait() {
-  std::unique_lock<std::mutex> lock(this->mutex_);
-  condition_.wait(lock, [=] { return this->done_ == true; });
-}
-
-void AsyncGRPCServer::Reset() {
-  std::lock_guard<std::mutex> lock(this->mutex_);
-  done_ = false;
+void AsyncGRPCServer::WaitCond(int cond) {
+  std::unique_lock<std::mutex> lock(this->barrier_mutex_);
+  barrier_condition_.wait(lock,
+                          [=] { return this->barrier_cond_step_ == cond; });
 }
 
-void AsyncGRPCServer::Done() {
+void AsyncGRPCServer::SetCond(int cond) {
   {
-    std::lock_guard<std::mutex> lock(this->mutex_);
-    done_ = true;
+    std::lock_guard<std::mutex> lock(this->barrier_mutex_);
+    barrier_cond_step_ = cond;
   }
-  condition_.notify_all();
+  barrier_condition_.notify_all();
 }
 
 }  // namespace detail
diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h
index 2c078b77771656..5c7be5f5bd2560 100644
--- a/paddle/operators/detail/grpc_server.h
+++ b/paddle/operators/detail/grpc_server.h
@@ -41,9 +41,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
 
   void RunSyncUpdate();
 
-  void Reset();
-
+  // functions to sync server barrier status.
+  void WaitStart();
+  void WaitDone();
+  void Start();
   void Done();
+  void WaitClientGet(int count);
 
   void SetScope(framework::Scope *scope) { scope_ = scope; }
 
@@ -56,7 +59,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
   void ShutDown();
 
  protected:
-  void Wait();
   void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
                      std::string cq_name,
                      std::function<void()> TryToRegisterNewOne);
@@ -78,11 +80,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
   const platform::DeviceContext *dev_ctx_;
   // received variable from RPC, operators fetch variable from this queue.
   SimpleBlockQueue<MessageWithName> var_recv_queue_;
+  SimpleBlockQueue<char> var_get_queue_;
 
   // condition of the sub program
-  std::mutex mutex_;
-  volatile mutable bool done_;
-  std::condition_variable condition_;
+  std::mutex barrier_mutex_;
+  mutable int barrier_cond_step_;
+  std::condition_variable barrier_condition_;
 
   std::unique_ptr<std::thread> t_send_;
   std::unique_ptr<std::thread> t_get_;
diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc
index b77d150dccfbed..2ecd56671f1c40 100644
--- a/paddle/operators/recv_op.cc
+++ b/paddle/operators/recv_op.cc
@@ -34,6 +34,10 @@ limitations under the License. */
 namespace paddle {
 namespace operators {
 
+constexpr int kCondStart = 0;
+constexpr int kCondRunning = 1;
+constexpr int kCondDone = 2;
+
 void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
   service->RunSyncUpdate();
   VLOG(4) << "RunServer thread end";
@@ -101,12 +105,14 @@ class RecvOp : public framework::OperatorBase {
     framework::ProgramDesc program(program_desc);
     framework::Executor executor(dev_place);
 
-    rpc_service_->Reset();
+    // rpc_service_->Reset();
     // TODO(typhoonzero): change this to a while_op for every cluster-batch.
     bool exit_flag = false;
     while (!exit_flag) {
       // Get from multiple trainers, we don't care about the order in which
       // the gradients arrives, just add suffix 0~n and merge the gradient.
+      rpc_service_->SetCond(kCondStart);
+      VLOG(3) << "================ start get from service ===========";
       for (size_t i = 0; i < param_count * fan_in; ++i) {
         const detail::MessageWithName &v = rpc_service_->Get();
         auto grad_var_name = v.first;
@@ -139,15 +145,16 @@ class RecvOp : public framework::OperatorBase {
       if (exit_flag) {
         break;
       }
-      rpc_service_->Reset();
+      // rpc_service_->Reset();
       try {
         executor.Run(program, &recv_scope, 0, /*global_block*/
                      false /*create_local_scope*/, false /*create_vars*/);
       } catch (std::exception &e) {
         LOG(ERROR) << "run sub program error " << e.what();
       }
-
-      rpc_service_->Done();
+      VLOG(3) << "================ run sub program end ===========";
+      rpc_service_->SetCond(kCondDone);
+      rpc_service_->WaitClientGet(param_count * fan_in);
       grads_counter_.clear();
     }  // while(true)
   }

From 5f4d9130f01833dfef44dac2eadb7089accbe0ba Mon Sep 17 00:00:00 2001
From: typhoonzero <typhoonzero1986@gmail.com>
Date: Thu, 18 Jan 2018 19:27:20 +0800
Subject: [PATCH 4/5] merge codes

---
 paddle/operators/detail/grpc_server.cc |  5 +++--
 paddle/operators/detail/grpc_server.h  |  6 ++----
 paddle/operators/recv_op.cc            | 15 +++++----------
 3 files changed, 10 insertions(+), 16 deletions(-)

diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc
index 42d3cc57584d9d..3ddcd839bdd235 100644
--- a/paddle/operators/detail/grpc_server.cc
+++ b/paddle/operators/detail/grpc_server.cc
@@ -162,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() {
 }
 
 // This URL explains why shutdown is complicate:
-// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
 void AsyncGRPCServer::ShutDown() {
   server_->Shutdown();
   ShutdownQueue();
@@ -188,6 +187,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
   VLOG(4) << "create Requestget status:" << get->Status();
 }
 
+// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
 void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
                                     std::string cq_name,
                                     std::function<void()> TryToRegisterNewOne) {
@@ -202,7 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
     }
 
     PADDLE_ENFORCE(tag);
-    if (cq_name == "cq_get") WaitCond(2);
+    // FIXME(typhoonzero): de-couple the barriers with recv_op
+    if (cq_name == "cq_get") WaitCond(1);
     if (cq_name == "cq_send") WaitCond(0);
 
     RequestBase* base = (RequestBase*)tag;
diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h
index 5c7be5f5bd2560..1ca9086c744c55 100644
--- a/paddle/operators/detail/grpc_server.h
+++ b/paddle/operators/detail/grpc_server.h
@@ -42,10 +42,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
   void RunSyncUpdate();
 
   // functions to sync server barrier status.
-  void WaitStart();
-  void WaitDone();
-  void Start();
-  void Done();
+  void WaitCond(int cond);
+  void SetCond(int cond);
   void WaitClientGet(int count);
 
   void SetScope(framework::Scope *scope) { scope_ = scope; }
diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc
index 2ecd56671f1c40..8d1479bdd63117 100644
--- a/paddle/operators/recv_op.cc
+++ b/paddle/operators/recv_op.cc
@@ -105,15 +105,14 @@ class RecvOp : public framework::OperatorBase {
     framework::ProgramDesc program(program_desc);
     framework::Executor executor(dev_place);
 
-    // rpc_service_->Reset();
     // TODO(typhoonzero): change this to a while_op for every cluster-batch.
     bool exit_flag = false;
+    int64_t barrier_size = param_count * fan_in;
     while (!exit_flag) {
       // Get from multiple trainers, we don't care about the order in which
       // the gradients arrives, just add suffix 0~n and merge the gradient.
-      rpc_service_->SetCond(kCondStart);
-      VLOG(3) << "================ start get from service ===========";
-      for (size_t i = 0; i < param_count * fan_in; ++i) {
+      rpc_service_->SetCond(0);
+      for (size_t i = 0; i < barrier_size; ++i) {
         const detail::MessageWithName &v = rpc_service_->Get();
         auto grad_var_name = v.first;
         if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
@@ -130,8 +129,6 @@ class RecvOp : public framework::OperatorBase {
         }
         VLOG(3) << "recved grad: " << grad_var_name
                 << " updating param: " << param_var_name;
-        // Assume grad_var_name must appear in global scope.
-        std::string grad_var_name_trainer;
         if (fan_in > 1) {
           grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
         }
@@ -145,16 +142,14 @@ class RecvOp : public framework::OperatorBase {
       if (exit_flag) {
         break;
       }
-      // rpc_service_->Reset();
       try {
         executor.Run(program, &recv_scope, 0, /*global_block*/
                      false /*create_local_scope*/, false /*create_vars*/);
       } catch (std::exception &e) {
         LOG(ERROR) << "run sub program error " << e.what();
       }
-      VLOG(3) << "================ run sub program end ===========";
-      rpc_service_->SetCond(kCondDone);
-      rpc_service_->WaitClientGet(param_count * fan_in);
+      rpc_service_->SetCond(1);
+      rpc_service_->WaitClientGet(barrier_size);
       grads_counter_.clear();
     }  // while(true)
   }

From 30529e314e7e9bdce78aa0adf9667da3fe9977cb Mon Sep 17 00:00:00 2001
From: typhoonzero <typhoonzero1986@gmail.com>
Date: Thu, 18 Jan 2018 20:02:26 +0800
Subject: [PATCH 5/5] delete debug transpiler code

---
 .../paddle/v2/fluid/distribute_transpiler.py  | 45 -------------------
 1 file changed, 45 deletions(-)

diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py
index 3cba015fc5250c..13d2bb8325200b 100644
--- a/python/paddle/v2/fluid/distribute_transpiler.py
+++ b/python/paddle/v2/fluid/distribute_transpiler.py
@@ -72,51 +72,6 @@ def split_dense_variable(var_list,
     return blocks
 
 
-def split_selected_rows(var,
-                        pserver_count,
-                        min_block_size=1024,
-                        max_block_size=1048576):
-    assert ((len(var.shape)) <= 1)
-
-    split_count = pserver_count
-    indices = var.desc.selected_rows().dims()
-    var_width = reduce(lambda x, y: x * y, var.shape[1:])
-    row_count = len(indices)
-    rows_per_block = 1
-    if var_width < min_block_size:
-        rows_per_block = 1
-        split_count = row_count
-    else:
-        rows_per_block = row_count / pserver_count
-        if not rows_per_block % pserver_count:
-            rows_per_block += 1
-        split_count = row_count / rows_per_block
-        if not row_count % rows_per_block:
-            split_count += 1
-    blocks = []
-    for block_id in xrange(split_count):
-        curr_block_rows = min(rows_per_block,
-                              row_count - (block_id * rows_per_block))
-        block = VarBlock(var.name, block_id, curr_block_rows)
-        blocks.append(block)
-    return blocks
-
-
-def split_variable(var_list,
-                   pserver_count,
-                   min_block_size=1024,
-                   max_block_size=1048576):
-    for var in var_list:
-        if var.type == core.VarDesc.VarType.LOD_TENSOR:
-            split_dense_variable(var_list, pserver_count, min_block_size,
-                                 max_block_size)
-        elif var.type == core.VarDesc.VarType.SELECTED_ROWS:
-            split_selected_rows(var_list, pserver_count, min_block_size,
-                                max_block_size)
-        else:
-            raise TypeError("variable must be lodtensor or selected rows")
-
-
 class DistributeTranspiler:
     def transpile(self,
                   optimize_ops,