Skip to content

Commit

Permalink
feat: add convenience flag for reverse ssh tunnels (#562)
Browse files Browse the repository at this point in the history
* Add convenience flag for reverse ssh tunnels

* Add unit test for reverse SshMachine tunnel
  • Loading branch information
oakreid authored Nov 22, 2021
1 parent 27adf10 commit e5f1e3e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
5 changes: 4 additions & 1 deletion plumbum/machines/ssh_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def session(self, isatty=False, new_session=False):
)

def tunnel(
self, lport, dport, lhost="localhost", dhost="localhost", connect_timeout=5
self, lport, dport, lhost="localhost", dhost="localhost", connect_timeout=5, reverse=False
):
r"""Creates an SSH tunnel from the TCP port (``lport``) of the local machine
(``lhost``, defaults to ``"localhost"``, but it can be any IP you can ``bind()``)
Expand Down Expand Up @@ -276,6 +276,9 @@ def tunnel(
ssh_opts = [
"-L",
"{}{}:{}{}".format(formatted_lhost, lport, formatted_dhost, dport),
] if not reverse else [
"-R",
"{}{}:{}{}".format(formatted_dhost, dport, formatted_lhost, lport),
]
proc = self.popen((), ssh_opts=ssh_opts, new_session=True)
return SshTunnel(
Expand Down
41 changes: 41 additions & 0 deletions tests/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import time
from copy import deepcopy
from multiprocessing import Process, Queue

import env
import pytest
Expand Down Expand Up @@ -477,6 +478,46 @@ def test_tunnel(self):
print(p.communicate())
assert data == b"hello world"


def test_reverse_tunnel(self):

def serve_reverse_tunnel(queue):
s = socket.socket()
s.bind(("", 12222))
s.listen(1)
s2, _ = s.accept()
data = s2.recv(100).decode("ascii").strip()
queue.put(data)
s2.close()
s.close()

with self._connect() as rem:
get_unbound_socket_remote = """import sys, socket
s = socket.socket()
s.bind(("", 0))
s.listen(1)
sys.stdout.write(str(s.getsockname()[1]))
sys.stdout.flush()
s.close()
"""
p = (rem.python["-u"] << get_unbound_socket_remote).popen()
remote_socket = p.stdout.readline().decode("ascii").strip()
queue = Queue()
tunnel_server = Process(target=serve_reverse_tunnel, args=(queue,))
tunnel_server.start()
message = str(time.time_ns())
with rem.tunnel(12222, remote_socket, dhost="localhost", reverse=True):
remote_send_af_inet = """import sys, socket
s = socket.socket()
s.connect(("localhost", {}))
s.send("{}".encode("ascii"))
s.close()
""".format(remote_socket, message)
(rem.python["-u"] << remote_send_af_inet).popen()
tunnel_server.join()
assert queue.get() == message


def test_get(self):
with self._connect() as rem:
assert str(rem["ls"]) == str(rem.get("ls"))
Expand Down

0 comments on commit e5f1e3e

Please sign in to comment.