Skip to content

Commit

Permalink
Merge pull request #485 from balgillo/zerodeploy-close-timeout
Browse files Browse the repository at this point in the history
Add a configurable timeout on the zerodeploy close() method
  • Loading branch information
comrumino authored May 17, 2022
2 parents 7ea2d24 + 7726059 commit fddc19f
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 6 deletions.
7 changes: 6 additions & 1 deletion docs/docs/zerodeploy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,9 @@ under the user's permissions. You can connect as an unprivileged user to make su
``rm -rf /``. Second, it creates an SSH tunnel for the transport, so everything is kept encrypted on the wire.
And you get these features for free -- just configuring SSH accounts will do.


Timeouts
--------
You can pass a ``timeout`` argument, in seconds, to the ``close()`` method. A ``TimeoutExpired`` is raised if
any subprocess communication takes longer than the timeout, after the subprocess has been told to terminate. By
default, the timeout is ``None`` i.e. infinite. A timeout value prevents a ``close()`` call blocking
indefinitely.
18 changes: 14 additions & 4 deletions rpyc/utils/zerodeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Requires [plumbum](http://plumbum.readthedocs.org/)
"""
from __future__ import with_statement
from subprocess import TimeoutExpired
import sys
import socket # noqa: F401
from rpyc.lib.compat import BYTES_LITERAL
Expand Down Expand Up @@ -151,27 +152,36 @@ def __enter__(self):
def __exit__(self, t, v, tb):
self.close()

def close(self):
def close(self, timeout=None):
if self.proc is not None:
try:
self.proc.terminate()
self.proc.communicate()
self.proc.communicate(timeout=timeout)
except TimeoutExpired:
self.proc.kill()
raise
except Exception:
pass
self.proc = None
if self.tun is not None:
try:
self.tun._session.proc.terminate()
self.tun._session.proc.communicate()
self.tun._session.proc.communicate(timeout=timeout)
self.tun.close()
except TimeoutExpired:
self.tun._session.proc.kill()
raise
except Exception:
pass
self.tun = None
if self.remote_machine is not None:
try:
self.remote_machine._session.proc.terminate()
self.remote_machine._session.proc.communicate()
self.remote_machine._session.proc.communicate(timeout=timeout)
self.remote_machine.close()
except TimeoutExpired:
self.remote_machine._session.proc.kill()
raise
except Exception:
pass
self.remote_machine = None
Expand Down
49 changes: 48 additions & 1 deletion tests/test_deploy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import with_statement

import unittest
import subprocess
import sys

from plumbum import SshMachine
from plumbum.machines.paramiko_machine import ParamikoMachine
from rpyc.utils.zerodeploy import DeployedServer
Expand All @@ -11,7 +14,6 @@
_paramiko_import_failed = True


@unittest.skipIf(_paramiko_import_failed, "Paramiko is not available")
class TestDeploy(unittest.TestCase):
def test_deploy(self):
rem = SshMachine("localhost")
Expand All @@ -30,6 +32,51 @@ def test_deploy(self):
self.fail("expected an EOFError")
rem.close()

def test_close_timeout(self):
expected_timeout = 4
observed_timeouts = []
original_communicate = subprocess.Popen.communicate

def replacement_communicate(self, input=None, timeout=None):
observed_timeouts.append(timeout)
return original_communicate(self, input, timeout)

try:
subprocess.Popen.communicate = replacement_communicate
rem = SshMachine("localhost")
SshMachine.python = rem[sys.executable]
dep = DeployedServer(rem)
dep.classic_connect()
dep.close(timeout=expected_timeout)
rem.close()
finally:
subprocess.Popen.communicate = original_communicate
# The last three calls to communicate() happen during close(), so check they
# applied the timeout.
assert observed_timeouts[-3:] == [expected_timeout] * 3

def test_close_timeout_default_none(self):
observed_timeouts = []
original_communicate = subprocess.Popen.communicate

def replacement_communicate(self, input=None, timeout=None):
observed_timeouts.append(timeout)
return original_communicate(self, input, timeout)

try:
subprocess.Popen.communicate = replacement_communicate
rem = SshMachine("localhost")
SshMachine.python = rem[sys.executable]
dep = DeployedServer(rem)
dep.classic_connect()
dep.close()
rem.close()
finally:
subprocess.Popen.communicate = original_communicate
# No timeout specified, so Popen.communicate should have been called with timeout None.
assert observed_timeouts == [None] * len(observed_timeouts)

@unittest.skipIf(_paramiko_import_failed, "Paramiko is not available")
def test_deploy_paramiko(self):
rem = ParamikoMachine("localhost", missing_host_policy=paramiko.AutoAddPolicy())
with DeployedServer(rem) as dep:
Expand Down

0 comments on commit fddc19f

Please sign in to comment.