Skip to content

Commit e892d09

Browse files
committed
Fix the arg variable name
1 parent 6b4619e commit e892d09

File tree

1 file changed

+37
-44
lines changed

1 file changed

+37
-44
lines changed

testgres/operations/remote_ops.py

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import logging
21
import os
2+
import socket
33
import subprocess
44
import tempfile
55
import platform
@@ -45,47 +45,44 @@ def __init__(self, conn_params: ConnectionParams):
4545
self.conn_params = conn_params
4646
self.host = conn_params.host
4747
self.ssh_key = conn_params.ssh_key
48+
self.port = conn_params.port
49+
self.ssh_args = []
4850
if self.ssh_key:
49-
self.ssh_cmd = ["-i", self.ssh_key]
50-
else:
51-
self.ssh_cmd = []
51+
self.ssh_args += ["-i", self.ssh_key]
52+
if self.port:
53+
self.ssh_args += ["-p", self.port]
5254
self.remote = True
5355
self.username = conn_params.username
5456
self.ssh_dest = f"{self.username}@{self.host}" if self.username else self.host
5557
self.add_known_host(self.host)
5658
self.tunnel_process = None
59+
self.tunnel_port = None
5760

5861
def __enter__(self):
5962
return self
6063

6164
def __exit__(self, exc_type, exc_val, exc_tb):
6265
self.close_ssh_tunnel()
6366

64-
def establish_ssh_tunnel(self, local_port, remote_port):
65-
"""
66-
Establish an SSH tunnel from a local port to a remote PostgreSQL port.
67-
"""
68-
ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"]
69-
self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300)
67+
@staticmethod
68+
def is_port_open(host, port):
69+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
70+
sock.settimeout(1) # Таймаут для попытки соединения
71+
try:
72+
sock.connect((host, port))
73+
return True
74+
except socket.error:
75+
return False
7076

7177
def close_ssh_tunnel(self):
72-
if hasattr(self, 'tunnel_process'):
78+
if self.tunnel_process:
7379
self.tunnel_process.terminate()
7480
self.tunnel_process.wait()
81+
print("SSH tunnel closed.")
7582
del self.tunnel_process
7683
else:
7784
print("No active tunnel to close.")
7885

79-
def add_known_host(self, host):
80-
known_hosts_path = os.path.expanduser("~/.ssh/known_hosts")
81-
cmd = 'ssh-keyscan -H %s >> %s' % (host, known_hosts_path)
82-
83-
try:
84-
subprocess.check_call(cmd, shell=True)
85-
logging.info("Successfully added %s to known_hosts." % host)
86-
except subprocess.CalledProcessError as e:
87-
raise Exception("Failed to add %s to known_hosts. Error: %s" % (host, str(e)))
88-
8986
def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
9087
encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None,
9188
stderr=None, get_process=None, timeout=None):
@@ -96,9 +93,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
9693
"""
9794
ssh_cmd = []
9895
if isinstance(cmd, str):
99-
ssh_cmd = ['ssh', self.ssh_dest] + self.ssh_cmd + [cmd]
96+
ssh_cmd = ['ssh'] + self.ssh_args + [self.ssh_dest, cmd]
10097
elif isinstance(cmd, list):
101-
ssh_cmd = ['ssh', self.ssh_dest] + self.ssh_cmd + cmd
98+
ssh_cmd = ['ssh'] + self.ssh_args + [self.ssh_dest] + cmd
10299
process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
103100
if get_process:
104101
return process
@@ -243,9 +240,9 @@ def mkdtemp(self, prefix=None):
243240
- prefix (str): The prefix of the temporary directory name.
244241
"""
245242
if prefix:
246-
command = ["ssh"] + self.ssh_cmd + [self.ssh_dest, f"mktemp -d {prefix}XXXXX"]
243+
command = ["ssh"] + self.ssh_args + [self.ssh_dest, f"mktemp -d {prefix}XXXXX"]
247244
else:
248-
command = ["ssh"] + self.ssh_cmd + [self.ssh_dest, "mktemp -d"]
245+
command = ["ssh"] + self.ssh_args + [self.ssh_dest, "mktemp -d"]
249246

250247
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
251248

@@ -288,8 +285,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
288285
mode = "r+b" if binary else "r+"
289286

290287
with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file:
288+
# Because in scp we set up port using -P option
289+
scp_ssh_cmd = ['-P' if x == '-p' else x for x in self.ssh_cmd]
290+
291291
if not truncate:
292-
scp_cmd = ['scp'] + self.ssh_cmd + [f"{self.ssh_dest}:{filename}", tmp_file.name]
292+
scp_cmd = ['scp'] + self.ssh_args + [f"{self.ssh_dest}:{filename}", tmp_file.name]
293293
subprocess.run(scp_cmd, check=False) # The file might not exist yet
294294
tmp_file.seek(0, os.SEEK_END)
295295

@@ -305,11 +305,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
305305
tmp_file.write(data)
306306

307307
tmp_file.flush()
308-
scp_cmd = ['scp'] + self.ssh_cmd + [tmp_file.name, f"{self.ssh_dest}:{filename}"]
308+
scp_cmd = ['scp'] + self.ssh_args + [tmp_file.name, f"{self.ssh_dest}:{filename}"]
309309
subprocess.run(scp_cmd, check=True)
310310

311311
remote_directory = os.path.dirname(filename)
312-
mkdir_cmd = ['ssh'] + self.ssh_cmd + [self.ssh_dest, f"mkdir -p {remote_directory}"]
312+
mkdir_cmd = ['ssh'] + self.ssh_args + [self.ssh_dest, f"mkdir -p {remote_directory}"]
313313
subprocess.run(mkdir_cmd, check=True)
314314

315315
os.remove(tmp_file.name)
@@ -374,7 +374,7 @@ def get_pid(self):
374374
return int(self.exec_command("echo $$", encoding=get_default_encoding()))
375375

376376
def get_process_children(self, pid):
377-
command = ["ssh"] + self.ssh_cmd + [self.ssh_dest, f"pgrep -P {pid}"]
377+
command = ["ssh"] + self.ssh_args + [self.ssh_dest, f"pgrep -P {pid}"]
378378

379379
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
380380

@@ -386,18 +386,11 @@ def get_process_children(self, pid):
386386

387387
# Database control
388388
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
389-
"""
390-
Established SSH tunnel and Connects to a PostgreSQL
391-
"""
392-
self.establish_ssh_tunnel(local_port=port, remote_port=5432)
393-
try:
394-
conn = pglib.connect(
395-
host=host,
396-
port=port,
397-
database=dbname,
398-
user=user,
399-
password=password,
400-
)
401-
return conn
402-
except Exception as e:
403-
raise Exception(f"Could not connect to the database. Error: {e}")
389+
conn = pglib.connect(
390+
host=host,
391+
port=port,
392+
database=dbname,
393+
user=user,
394+
password=password,
395+
)
396+
return conn

0 commit comments

Comments
 (0)