Skip to content

Commit

Permalink
Merge pull request #11126 from Yancey1989/polish_test_listen_and_serv_op
Browse files Browse the repository at this point in the history
speedup test_listen_and_serv_op
  • Loading branch information
Yancey authored Jun 4, 2018
2 parents f7a6001 + 7f5eb9f commit 524c81e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 36 deletions.
6 changes: 4 additions & 2 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,7 @@ foreach(TEST_OP ${TEST_OPS})
endforeach(TEST_OP)
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
# tests that need to be done in fixed timeout
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
# FIXME(Yancey1989): this test would cost much more time on CUDAPlace
# since load cudnn libraries, so we use a longer timeout to make this
# unit test stability.
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 30)
65 changes: 31 additions & 34 deletions python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from op_test import OpTest


def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id):
def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
Expand All @@ -39,15 +39,8 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)

port = os.getenv("PADDLE_INIT_PORT", port)
pserver_ips = os.getenv("PADDLE_INIT_PSERVERS", ip) # ip,ip...
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
trainers = int(os.getenv("TRAINERS", trainer_count))
current_endpoint = os.getenv("POD_IP", ip) + ":" + port
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID", trainer_id))
pserver_endpoints = ip + ":" + port
current_endpoint = ip + ":" + port
t = fluid.DistributeTranspiler()
t.transpile(
trainer_id,
Expand All @@ -62,47 +55,51 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id):

class TestListenAndServOp(OpTest):
def setUp(self):
self.sleep_time = 5
self.ps_timeout = 5
self.ip = "127.0.0.1"
self.port = "6173"
self.trainer_count = 1
self.trainers = 1
self.trainer_id = 1

def _raise_signal(self, parent_pid, raised_signal):
time.sleep(self.sleep_time)
ps_command = subprocess.Popen(
"ps -o pid --ppid %d --noheaders" % parent_pid,
shell=True,
stdout=subprocess.PIPE)
ps_output = ps_command.stdout.read()
retcode = ps_command.wait()
assert retcode == 0, "ps command returned %d" % retcode

for pid_str in ps_output.split("\n")[:-1]:
try:
os.kill(int(pid_str), raised_signal)
except Exception:
continue

def _start_pserver(self, use_cuda, sync_mode):
p = Process(
target=run_pserver,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainer_count,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers,
self.trainer_id))
p.start()
return p.pid

def _wait_ps_ready(self, pid):
retry_times = self.ps_timeout
while True:
assert retry_times >= 0, "wait ps ready failed"
time.sleep(0.5)
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
retry_times -= 1

def test_rpc_interfaces(self):
# TODO(Yancey1989): need to make sure the rpc interface correctly.
pass

def test_handle_signal_in_serv_op(self):
# run pserver on CPU in sync mode
self._start_pserver(False, True)
pid = self._start_pserver(False, True)
self._wait_ps_ready(pid)

# raise SIGINT to pserver
self._raise_signal(os.getpid(), signal.SIGINT)
# raise SIGTERM to pserver
os.kill(pid, signal.SIGTERM)

# run pserver on CPU in async mode
self._start_pserver(False, False)
pid = self._start_pserver(False, False)
self._wait_ps_ready(pid)

# raise SIGTERM to pserver
self._raise_signal(os.getpid(), signal.SIGTERM)
os.kill(pid, signal.SIGTERM)


if __name__ == '__main__':
Expand Down

0 comments on commit 524c81e

Please sign in to comment.