-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
speedup test_listen_and_serv_op #11126
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ | |
from op_test import OpTest | ||
|
||
|
||
def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id): | ||
def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id): | ||
x = fluid.layers.data(name='x', shape=[1], dtype='float32') | ||
y_predict = fluid.layers.fc(input=x, size=1, act=None) | ||
y = fluid.layers.data(name='y', shape=[1], dtype='float32') | ||
|
@@ -39,15 +39,8 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id): | |
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() | ||
exe = fluid.Executor(place) | ||
|
||
port = os.getenv("PADDLE_INIT_PORT", port) | ||
pserver_ips = os.getenv("PADDLE_INIT_PSERVERS", ip) # ip,ip... | ||
eplist = [] | ||
for ip in pserver_ips.split(","): | ||
eplist.append(':'.join([ip, port])) | ||
pserver_endpoints = ",".join(eplist) # ip:port,ip:port... | ||
trainers = int(os.getenv("TRAINERS", trainer_count)) | ||
current_endpoint = os.getenv("POD_IP", ip) + ":" + port | ||
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID", trainer_id)) | ||
pserver_endpoints = ip + ":" + port | ||
current_endpoint = ip + ":" + port | ||
t = fluid.DistributeTranspiler() | ||
t.transpile( | ||
trainer_id, | ||
|
@@ -62,47 +55,47 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id): | |
|
||
class TestListenAndServOp(OpTest): | ||
def setUp(self): | ||
self.sleep_time = 5 | ||
self.ps_timeout = 5 | ||
self.ip = "127.0.0.1" | ||
self.port = "6173" | ||
self.trainer_count = 1 | ||
self.trainers = 1 | ||
self.trainer_id = 1 | ||
|
||
def _raise_signal(self, parent_pid, raised_signal): | ||
time.sleep(self.sleep_time) | ||
ps_command = subprocess.Popen( | ||
"ps -o pid --ppid %d --noheaders" % parent_pid, | ||
shell=True, | ||
stdout=subprocess.PIPE) | ||
ps_output = ps_command.stdout.read() | ||
retcode = ps_command.wait() | ||
assert retcode == 0, "ps command returned %d" % retcode | ||
|
||
for pid_str in ps_output.split("\n")[:-1]: | ||
try: | ||
os.kill(int(pid_str), raised_signal) | ||
except Exception: | ||
continue | ||
|
||
def _start_pserver(self, use_cuda, sync_mode): | ||
p = Process( | ||
target=run_pserver, | ||
args=(use_cuda, sync_mode, self.ip, self.port, self.trainer_count, | ||
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers, | ||
self.trainer_id)) | ||
p.start() | ||
return p.pid | ||
|
||
def _wait_ps_ready(self, pid): | ||
retry_times = self.ps_timeout | ||
while True: | ||
time.sleep(1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can sleep time < 1s and the test will be faster. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, but the pserver would cost about 1 seconds until become ready. BTW, the most time-consuming place is to load and initialize the cuda libraries. |
||
assert retry_times >= 0, "wait ps ready failed" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
try: | ||
# the listen_and_serv_op would touch a file which contains the listen port | ||
# on the /tmp directory until it was ready to process all the RPC call. | ||
os.stat("/tmp/paddle.%d.port" % pid) | ||
return | ||
except os.error: | ||
retry_times -= 1 | ||
|
||
def test_handle_signal_in_serv_op(self): | ||
# run pserver on CPU in sync mode | ||
self._start_pserver(False, True) | ||
pid = self._start_pserver(False, True) | ||
self._wait_ps_ready(pid) | ||
|
||
# raise SIGINT to pserver | ||
self._raise_signal(os.getpid(), signal.SIGINT) | ||
os.kill(pid, signal.SIGINT) | ||
|
||
# run pserver on CPU in async mode | ||
self._start_pserver(False, False) | ||
pid = self._start_pserver(False, False) | ||
self._wait_ps_ready(pid) | ||
|
||
# raise SIGTERM to pserver | ||
self._raise_signal(os.getpid(), signal.SIGTERM) | ||
os.kill(pid, signal.SIGINT) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Signal shoule be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
|
||
if __name__ == '__main__': | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
30 seconds is enough for distributed tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in this unit test, pserver would be killed when pserver is ready.