diff --git a/src/SSHLibrary/deco.py b/src/SSHLibrary/deco.py deleted file mode 100644 index b5e24782c..000000000 --- a/src/SSHLibrary/deco.py +++ /dev/null @@ -1,37 +0,0 @@ - -def keyword(types=()): - """Decorator to set custom argument types to keywords. - - This decorator creates ``robot_types`` attribute on the decorated - keyword method or function based on the provided arguments. - Robot Framework checks them to determine the keyword's - argument types. - - Types must be given as a dictionary mapping argument names to types or as a list - (or tuple) of types mapped to arguments based on position. It is OK to - specify types only to some arguments, and setting ``types`` to ``None`` - disables type conversion altogether. - - Examples:: - - @keyword(types={'length': int, 'case_insensitive': bool}) - def types_as_dict(length, case_insensitive=False): - # ... - - @keyword(types=[int, bool]) - def types_as_list(length, case_insensitive=False): - # ... - - @keyword(types=None]) - def no_conversion(length, case_insensitive=False): - # ... - - @keyword - def func(): - # ... - """ - - def decorator(func): - func.robot_types = types - return func - return decorator diff --git a/src/SSHLibrary/library.py b/src/SSHLibrary/library.py index 697cd26a1..aba4026de 100644 --- a/src/SSHLibrary/library.py +++ b/src/SSHLibrary/library.py @@ -16,25 +16,27 @@ from __future__ import print_function import re -from .deco import keyword -try: - from robot.api import logger -except ImportError: - logger = None +from .logger import logger from robot.utils import is_string, is_truthy, plural_or_not +from robot.api.deco import keyword, library from .sshconnectioncache import SSHConnectionCache from .client import SSHClientException from .client import SSHClient -from .config import (Configuration, IntegerEntry, LogLevelEntry, NewlineEntry, - StringEntry, TimeEntry) +from .config import ( + Configuration, + IntegerEntry, + LogLevelEntry, + NewlineEntry, + StringEntry, + TimeEntry, +) from .version import VERSION -__version__ = VERSION - -class SSHLibrary(object): +@library(version=VERSION, scope="GLOBAL") +class SSHLibrary: """SSHLibrary is a Robot Framework test library for SSH and SFTP. This document explains how to use keywords provided by SSHLibrary. @@ -435,33 +437,33 @@ class SSHLibrary(object): | `Run Keyword And Expect Error` | Non-existing index or alias 'conn'. | `Switch Connection` | conn | """ - ROBOT_LIBRARY_SCOPE = 'GLOBAL' - ROBOT_LIBRARY_VERSION = __version__ - DEFAULT_TIMEOUT = '3 seconds' - DEFAULT_NEWLINE = 'LF' + DEFAULT_TIMEOUT = "3 seconds" + DEFAULT_NEWLINE = "LF" DEFAULT_PROMPT = None - DEFAULT_LOGLEVEL = 'INFO' - DEFAULT_TERM_TYPE = 'vt100' + DEFAULT_LOGLEVEL = "INFO" + DEFAULT_TERM_TYPE = "vt100" DEFAULT_TERM_WIDTH = 80 DEFAULT_TERM_HEIGHT = 24 - DEFAULT_PATH_SEPARATOR = '/' - DEFAULT_ENCODING = 'UTF-8' + DEFAULT_PATH_SEPARATOR = "/" + DEFAULT_ENCODING = "UTF-8" DEFAULT_ESCAPE_ANSI = False - DEFAULT_ENCODING_ERRORS = 'strict' - - def __init__(self, - timeout=DEFAULT_TIMEOUT, - newline=DEFAULT_NEWLINE, - prompt=DEFAULT_PROMPT, - loglevel=DEFAULT_LOGLEVEL, - term_type=DEFAULT_TERM_TYPE, - width=DEFAULT_TERM_WIDTH, - height=DEFAULT_TERM_HEIGHT, - path_separator=DEFAULT_PATH_SEPARATOR, - encoding=DEFAULT_ENCODING, - escape_ansi=DEFAULT_ESCAPE_ANSI, - encoding_errors=DEFAULT_ENCODING_ERRORS): + DEFAULT_ENCODING_ERRORS = "strict" + + def __init__( + self, + timeout=DEFAULT_TIMEOUT, + newline=DEFAULT_NEWLINE, + prompt=DEFAULT_PROMPT, + loglevel=DEFAULT_LOGLEVEL, + term_type=DEFAULT_TERM_TYPE, + width=DEFAULT_TERM_WIDTH, + height=DEFAULT_TERM_HEIGHT, + path_separator=DEFAULT_PATH_SEPARATOR, + encoding=DEFAULT_ENCODING, + escape_ansi=DEFAULT_ESCAPE_ANSI, + encoding_errors=DEFAULT_ENCODING_ERRORS, + ): """SSHLibrary allows some import time `configuration`. If the library is imported without any arguments, the library @@ -499,7 +501,7 @@ def __init__(self, path_separator or self.DEFAULT_PATH_SEPARATOR, encoding or self.DEFAULT_ENCODING, escape_ansi or self.DEFAULT_ESCAPE_ANSI, - encoding_errors or self.DEFAULT_ENCODING_ERRORS + encoding_errors or self.DEFAULT_ENCODING_ERRORS, ) self._last_commands = dict() @@ -507,11 +509,21 @@ def __init__(self, def current(self): return self._connections.current - @keyword(types=None) - def set_default_configuration(self, timeout=None, newline=None, prompt=None, - loglevel=None, term_type=None, width=None, - height=None, path_separator=None, - encoding=None, escape_ansi=None, encoding_errors=None): + @keyword(tags=("configuration",)) + def set_default_configuration( + self, + timeout=None, + newline=None, + prompt=None, + loglevel=None, + term_type=None, + width=None, + height=None, + path_separator=None, + encoding=None, + escape_ansi=None, + encoding_errors=None, + ): """Update the default `configuration`. Please note that using this keyword does not affect the already @@ -542,14 +554,34 @@ def set_default_configuration(self, timeout=None, newline=None, prompt=None, | `Should Be Equal As Integers` | ${emea.timeout} | 20 | | `Should Be Equal As Integers` | ${apac.timeout} | 20 | """ - self._config.update(timeout=timeout, newline=newline, prompt=prompt, - loglevel=loglevel, term_type=term_type, width=width, - height=height, path_separator=path_separator, - encoding=encoding, escape_ansi=escape_ansi, encoding_errors=encoding_errors) - - def set_client_configuration(self, timeout=None, newline=None, prompt=None, - term_type=None, width=None, height=None, - path_separator=None, encoding=None, escape_ansi=None, encoding_errors=None): + self._config.update( + timeout=timeout, + newline=newline, + prompt=prompt, + loglevel=loglevel, + term_type=term_type, + width=width, + height=height, + path_separator=path_separator, + encoding=encoding, + escape_ansi=escape_ansi, + encoding_errors=encoding_errors, + ) + + @keyword(tags=("configuration",)) + def set_client_configuration( + self, + timeout=None, + newline=None, + prompt=None, + term_type=None, + width=None, + height=None, + path_separator=None, + encoding=None, + escape_ansi=None, + encoding_errors=None, + ): """Update the `configuration` of the current connection. Only parameters whose value is other than ``None`` are updated. @@ -578,13 +610,20 @@ def set_client_configuration(self, timeout=None, newline=None, prompt=None, | `Open Connection` | 192.168.1.1 | | `Set Client Configuration` | term_type=ansi | width=40 | """ - self.current.config.update(timeout=timeout, newline=newline, - prompt=prompt, term_type=term_type, - width=width, height=height, - path_separator=path_separator, - encoding=encoding, escape_ansi=escape_ansi, - encoding_errors=encoding_errors) + self.current.config.update( + timeout=timeout, + newline=newline, + prompt=prompt, + term_type=term_type, + width=width, + height=height, + path_separator=path_separator, + encoding=encoding, + escape_ansi=escape_ansi, + encoding_errors=encoding_errors, + ) + @keyword(tags=("configuration",)) def enable_ssh_logging(self, logfile): """Enables logging of SSH protocol output to given ``logfile``. @@ -603,12 +642,25 @@ def enable_ssh_logging(self, logfile): | # Check myserver.log for detailed debug information | """ if SSHClient.enable_logging(logfile): - self._log(f'SSH log is written to file.', - 'HTML') - - def open_connection(self, host, alias=None, port=22, timeout=None, - newline=None, prompt=None, term_type=None, width=None, - height=None, path_separator=None, encoding=None, escape_ansi=None, encoding_errors=None): + self._log(f'SSH log is written to file.', "HTML") + + @keyword(tags=("connection",)) + def open_connection( + self, + host, + alias=None, + port=22, + timeout=None, + newline=None, + prompt=None, + term_type=None, + width=None, + height=None, + path_separator=None, + encoding=None, + escape_ansi=None, + encoding_errors=None, + ): """Opens a new SSH connection to the given ``host`` and ``port``. The new connection is made active. Possible existing connections @@ -676,12 +728,26 @@ def open_connection(self, host, alias=None, port=22, timeout=None, encoding = encoding or self._config.encoding escape_ansi = escape_ansi or self._config.escape_ansi encoding_errors = encoding_errors or self._config.encoding_errors - client = SSHClient(host, alias, port, timeout, newline, prompt, - term_type, width, height, path_separator, encoding, escape_ansi, encoding_errors) + client = SSHClient( + host, + alias, + port, + timeout, + newline, + prompt, + term_type, + width, + height, + path_separator, + encoding, + escape_ansi, + encoding_errors, + ) connection_index = self._connections.register(client, alias) client.config.update(index=connection_index) return connection_index + @keyword(tags=("connection",)) def switch_connection(self, index_or_alias): """Switches the active connection by index or alias. @@ -714,6 +780,7 @@ def switch_connection(self, index_or_alias): self._connections.switch(index_or_alias) return old_index + @keyword(tags=("connection",)) def close_connection(self): """Closes the current connection. @@ -730,6 +797,7 @@ def close_connection(self): connections = self._connections connections.close_current() + @keyword(tags=("connection",)) def close_all_connections(self): """Closes all open connections. @@ -748,10 +816,23 @@ def close_all_connections(self): """ self._connections.close_all() - def get_connection(self, index_or_alias=None, index=False, host=False, - alias=False, port=False, timeout=False, newline=False, - prompt=False, term_type=False, width=False, height=False, - encoding=False, escape_ansi=False): + @keyword(tags=("connection",)) + def get_connection( + self, + index_or_alias=None, + index=False, + host=False, + alias=False, + port=False, + timeout=False, + newline=False, + prompt=False, + term_type=False, + width=False, + height=False, + encoding=False, + escape_ansi=False, + ): """Returns information about the connection. Connection is not changed by this keyword, use `Switch Connection` to @@ -840,38 +921,70 @@ def get_connection(self, index_or_alias=None, index=False, host=False, except AttributeError: config = SSHClient(None).config self._log(str(config), self._config.loglevel) - return_values = tuple(self._get_config_values(config, index, host, - alias, port, timeout, - newline, prompt, - term_type, width, height, - encoding, escape_ansi)) + return_values = tuple( + self._get_config_values( + config, + index, + host, + alias, + port, + timeout, + newline, + prompt, + term_type, + width, + height, + encoding, + escape_ansi, + ) + ) if not return_values: return config if len(return_values) == 1: return return_values[0] return return_values - def _log(self, msg, level='INFO'): + def _log(self, msg, level="INFO"): level = self._active_loglevel(level) - if level != 'NONE': + if level != "NONE": msg = msg.strip() if not msg: return if logger: logger.write(msg, level) else: - print(f'*{level}* {msg}') + print(f"*{level}* {msg}") def _active_loglevel(self, level): if level is None: return self._config.loglevel - if is_string(level) and \ - level.upper() in ['TRACE', 'DEBUG', 'INFO', 'WARN', 'HTML', 'NONE']: + if is_string(level) and level.upper() in [ + "TRACE", + "DEBUG", + "INFO", + "WARN", + "HTML", + "NONE", + ]: return level.upper() raise AssertionError(f"Invalid log level '{level}'.") - def _get_config_values(self, config, index, host, alias, port, timeout, - newline, prompt, term_type, width, height, encoding, escape_ansi): + def _get_config_values( + self, + config, + index, + host, + alias, + port, + timeout, + newline, + prompt, + term_type, + width, + height, + encoding, + escape_ansi, + ): if is_truthy(index): yield config.index if is_truthy(host): @@ -897,6 +1010,7 @@ def _get_config_values(self, config, index, host, alias, port, timeout, if is_truthy(escape_ansi): yield config.escape_ansi + @keyword(tags=("connection",)) def get_connections(self): """Returns information about all the open connections. @@ -920,8 +1034,19 @@ def get_connections(self): self._log(str(c), self._config.loglevel) return configs - def login(self, username=None, password=None, allow_agent=False, look_for_keys=False, delay='0.5 seconds', - proxy_cmd=None, read_config=False, jumphost_index_or_alias=None, keep_alive_interval='0 seconds'): + @keyword(tags=("login",)) + def login( + self, + username=None, + password=None, + allow_agent=False, + look_for_keys=False, + delay="0.5 seconds", + proxy_cmd=None, + read_config=False, + jumphost_index_or_alias=None, + keep_alive_interval="0 seconds", + ): """Logs into the SSH server with the given ``username`` and ``password``. Connection must be opened before using this keyword. @@ -980,20 +1105,44 @@ def login(self, username=None, password=None, allow_agent=False, look_for_keys=F | `Open Connection` | linux.server.com | | `Login` | johndoe | allow_agent=True | """ - jumphost_connection_conf = self.get_connection(index_or_alias=jumphost_index_or_alias) \ - if jumphost_index_or_alias else None - jumphost_connection = self._connections.connections[jumphost_connection_conf.index - 1] \ - if jumphost_connection_conf and jumphost_connection_conf.index else None - - return self._login(self.current.login, username, password, is_truthy(allow_agent), - is_truthy(look_for_keys), delay, proxy_cmd, is_truthy(read_config), - jumphost_connection, keep_alive_interval) - - def login_with_public_key(self, username=None, keyfile=None, password='', - allow_agent=False, look_for_keys=False, - delay='0.5 seconds', proxy_cmd=None, - jumphost_index_or_alias=None, - read_config=False, keep_alive_interval='0 seconds'): + jumphost_connection_conf = ( + self.get_connection(index_or_alias=jumphost_index_or_alias) + if jumphost_index_or_alias + else None + ) + jumphost_connection = ( + self._connections.connections[jumphost_connection_conf.index - 1] + if jumphost_connection_conf and jumphost_connection_conf.index + else None + ) + + return self._login( + self.current.login, + username, + password, + is_truthy(allow_agent), + is_truthy(look_for_keys), + delay, + proxy_cmd, + is_truthy(read_config), + jumphost_connection, + keep_alive_interval, + ) + + @keyword(tags=("login",)) + def login_with_public_key( + self, + username=None, + keyfile=None, + password="", + allow_agent=False, + look_for_keys=False, + delay="0.5 seconds", + proxy_cmd=None, + jumphost_index_or_alias=None, + read_config=False, + keep_alive_interval="0 seconds", + ): """Logs into the SSH server using key-based authentication. Connection must be opened before using this keyword. @@ -1049,29 +1198,48 @@ def login_with_public_key(self, username=None, keyfile=None, password='', ``keep_alive_interval`` is new in SSHLibrary 3.7.0. """ if proxy_cmd and jumphost_index_or_alias: - raise ValueError("`proxy_cmd` and `jumphost_connection` are mutually exclusive SSH features.") - jumphost_connection_conf = self.get_connection( - index_or_alias=jumphost_index_or_alias) if jumphost_index_or_alias else None - jumphost_connection = self._connections.connections[ - jumphost_connection_conf.index - 1] if jumphost_connection_conf and jumphost_connection_conf.index else None - return self._login(self.current.login_with_public_key, username, - keyfile, password, is_truthy(allow_agent), - is_truthy(look_for_keys), delay, proxy_cmd, - jumphost_connection, is_truthy(read_config), keep_alive_interval) + raise ValueError( + "`proxy_cmd` and `jumphost_connection` are mutually exclusive SSH features." + ) + jumphost_connection_conf = ( + self.get_connection(index_or_alias=jumphost_index_or_alias) + if jumphost_index_or_alias + else None + ) + jumphost_connection = ( + self._connections.connections[jumphost_connection_conf.index - 1] + if jumphost_connection_conf and jumphost_connection_conf.index + else None + ) + return self._login( + self.current.login_with_public_key, + username, + keyfile, + password, + is_truthy(allow_agent), + is_truthy(look_for_keys), + delay, + proxy_cmd, + jumphost_connection, + is_truthy(read_config), + keep_alive_interval, + ) def _login(self, login_method, username, *args): self._log( f"Logging into '{self.current.config.host}:{self.current.config.port}' as '{self.current.config.host}'.", - self._config.loglevel) + self._config.loglevel, + ) try: login_output = login_method(username, *args) if is_truthy(self.current.config.escape_ansi): login_output = self._escape_ansi_sequences(login_output) - self._log(f'Read output: {login_output}', self._config.loglevel) + self._log(f"Read output: {login_output}", self._config.loglevel) return login_output except SSHClientException as e: raise RuntimeError(e) + @keyword(tags=("login",)) def get_pre_login_banner(self, host=None, port=22): """Returns the banner supplied by the server upon connect. @@ -1099,12 +1267,26 @@ def get_pre_login_banner(self, host=None, port=22): elif self.current: banner = self.current.get_banner() else: - raise RuntimeError("'host' argument is mandatory if there is no open connection.") + raise RuntimeError( + "'host' argument is mandatory if there is no open connection." + ) return banner.decode(self.DEFAULT_ENCODING) - def execute_command(self, command, return_stdout=True, return_stderr=False, - return_rc=False, sudo=False, sudo_password=None, timeout=None, output_during_execution=False, - output_if_timeout=False, invoke_subsystem=False, forward_agent=False): + @keyword(tags=("command",)) + def execute_command( + self, + command, + return_stdout=True, + return_stderr=False, + return_rc=False, + sudo=False, + sudo_password=None, + timeout=None, + output_during_execution=False, + output_if_timeout=False, + invoke_subsystem=False, + forward_agent=False, + ): """Executes ``command`` on the remote machine and returns its outputs. This keyword executes the ``command`` and returns after the execution @@ -1175,14 +1357,28 @@ def execute_command(self, command, return_stdout=True, return_stderr=False, self._log(f"Executing command '{command}'.", self._config.loglevel) else: self._log(f"Executing command 'sudo {command}'.", self._config.loglevel) - opts = self._legacy_output_options(return_stdout, return_stderr, - return_rc) - stdout, stderr, rc = self.current.execute_command(command, sudo, sudo_password, - timeout, output_during_execution, output_if_timeout, - is_truthy(invoke_subsystem), forward_agent) + opts = self._legacy_output_options(return_stdout, return_stderr, return_rc) + stdout, stderr, rc = self.current.execute_command( + command, + sudo, + sudo_password, + timeout, + output_during_execution, + output_if_timeout, + is_truthy(invoke_subsystem), + forward_agent, + ) return self._return_command_output(stdout, stderr, rc, *opts) - def start_command(self, command, sudo=False, sudo_password=None, invoke_subsystem=False, forward_agent=False): + @keyword(tags=("command",)) + def start_command( + self, + command, + sudo=False, + sudo_password=None, + invoke_subsystem=False, + forward_agent=False, + ): """Starts execution of the ``command`` on the remote machine and returns immediately. This keyword returns nothing and does not wait for the ``command`` @@ -1234,10 +1430,18 @@ def start_command(self, command, sudo=False, sudo_password=None, invoke_subsyste else: temp_dict = {self.current.config.index: command} self._last_commands.update(temp_dict) - self.current.start_command(command, sudo, sudo_password, is_truthy(invoke_subsystem), is_truthy(forward_agent)) + self.current.start_command( + command, + sudo, + sudo_password, + is_truthy(invoke_subsystem), + is_truthy(forward_agent), + ) - def read_command_output(self, return_stdout=True, return_stderr=False, - return_rc=False, timeout=None): + @keyword(tags=("command",)) + def read_command_output( + self, return_stdout=True, return_stderr=False, return_rc=False, timeout=None + ): """Returns outputs of the most recent started command. At least one command must have been started using `Start Command` @@ -1285,17 +1489,21 @@ def read_command_output(self, return_stdout=True, return_stderr=False, This keyword logs the read command with log level ``INFO``. """ - self._log(f"Reading output of command '{self._last_commands.get(self.current.config.index)}'.", - self._config.loglevel) - opts = self._legacy_output_options(return_stdout, return_stderr, - return_rc) + self._log( + f"Reading output of command '{self._last_commands.get(self.current.config.index)}'.", + self._config.loglevel, + ) + opts = self._legacy_output_options(return_stdout, return_stderr, return_rc) try: stdout, stderr, rc = self.current.read_command_output(timeout=timeout) except SSHClientException as msg: raise RuntimeError(msg) return self._return_command_output(stdout, stderr, rc, *opts) - def create_local_ssh_tunnel(self, local_port, remote_host, remote_port=22, bind_address=None): + @keyword(tags=("connection",)) + def create_local_ssh_tunnel( + self, local_port, remote_host, remote_port=22, bind_address=None + ): """ The keyword uses the existing connection to set up local port forwarding (the openssh -L option) from a local port through a tunneled @@ -1324,32 +1532,36 @@ def create_local_ssh_tunnel(self, local_port, remote_host, remote_port=22, bind_ ``bind_address`` is new in SSHLibrary 3.3.0. """ - self.current.create_local_ssh_tunnel(local_port, remote_host, remote_port, bind_address) + self.current.create_local_ssh_tunnel( + local_port, remote_host, remote_port, bind_address + ) def _legacy_output_options(self, stdout, stderr, rc): if not is_string(stdout): return stdout, stderr, rc stdout = stdout.lower() - if stdout == 'stderr': + if stdout == "stderr": return False, True, rc - if stdout == 'both': + if stdout == "both": return True, True, rc return stdout, stderr, rc - def _return_command_output(self, stdout, stderr, rc, return_stdout, - return_stderr, return_rc): + def _return_command_output( + self, stdout, stderr, rc, return_stdout, return_stderr, return_rc + ): self._log(f"Command exited with return code {rc}.", self._config.loglevel) ret = [] if is_truthy(return_stdout): - ret.append(stdout.rstrip('\n')) + ret.append(stdout.rstrip("\n")) if is_truthy(return_stderr): - ret.append(stderr.rstrip('\n')) + ret.append(stderr.rstrip("\n")) if is_truthy(return_rc): ret.append(rc) if len(ret) == 1: return ret[0] return ret + @keyword(tags=("command",)) def write(self, text, loglevel=None): """Writes the given ``text`` on the remote machine and appends a newline. @@ -1377,6 +1589,7 @@ def write(self, text, loglevel=None): self._write(text, add_newline=True) return self._read_and_log(loglevel, self.current.read_until_newline) + @keyword(tags=("command",)) def write_bare(self, text): """Writes the given ``text`` on the remote machine without appending a newline. @@ -1402,8 +1615,10 @@ def _write(self, text, add_newline=False): except SSHClientException as e: raise RuntimeError(e) - def write_until_expected_output(self, text, expected, timeout, - retry_interval, loglevel=None): + @keyword(tags=("command",)) + def write_until_expected_output( + self, text, expected, timeout, retry_interval, loglevel=None + ): """Writes the given ``text`` repeatedly until ``expected`` appears in the server output. This keyword returns nothing. @@ -1427,9 +1642,16 @@ def write_until_expected_output(self, text, expected, timeout, | `Write Until Expected Output` | lsof -c python27\\n | expected=myscript.py | timeout=5s | retry_interval=0.5s | """ - self._read_and_log(loglevel, self.current.write_until_expected, text, - expected, timeout, retry_interval) + self._read_and_log( + loglevel, + self.current.write_until_expected, + text, + expected, + timeout, + retry_interval, + ) + @keyword(tags=("command",)) def read(self, loglevel=None, delay=None): """Consumes and returns everything available on the server output. @@ -1459,6 +1681,7 @@ def read(self, loglevel=None, delay=None): """ return self._read_and_log(loglevel, self.current.read, delay) + @keyword(tags=("command",)) def read_until(self, expected, loglevel=None): """Consumes and returns the server output until ``expected`` is encountered. @@ -1485,6 +1708,7 @@ def read_until(self, expected, loglevel=None): """ return self._read_and_log(loglevel, self.current.read_until, expected) + @keyword(tags=("command",)) def read_until_prompt(self, loglevel=None, strip_prompt=False): """Consumes and returns the server output until the prompt is found. @@ -1520,8 +1744,11 @@ def read_until_prompt(self, loglevel=None, strip_prompt=False): ``strip_prompt`` argument is new in SSHLibrary 3.2.0. """ - return self._read_and_log(loglevel, self.current.read_until_prompt, is_truthy(strip_prompt)) + return self._read_and_log( + loglevel, self.current.read_until_prompt, is_truthy(strip_prompt) + ) + @keyword(tags=("command",)) def read_until_regexp(self, regexp, loglevel=None): """Consumes and returns the server output until a match to ``regexp`` is found. @@ -1550,8 +1777,7 @@ def read_until_regexp(self, regexp, loglevel=None): details about reading and writing in general, see the `Interactive shells` section. """ - return self._read_and_log(loglevel, self.current.read_until_regexp, - regexp) + return self._read_and_log(loglevel, self.current.read_until_regexp, regexp) def _read_and_log(self, loglevel, reader, *args): try: @@ -1568,11 +1794,14 @@ def _read_and_log(self, loglevel, reader, *args): @staticmethod def _escape_ansi_sequences(output): - ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]', flags=re.IGNORECASE) - output = ansi_escape.sub('', output) - return (f"{output!r}")[1:-1].encode().decode('unicode-escape') + ansi_escape = re.compile( + r"(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]", flags=re.IGNORECASE + ) + output = ansi_escape.sub("", output) + return (f"{output!r}")[1:-1].encode().decode("unicode-escape") - def get_file(self, source, destination='.', scp='OFF', scp_preserve_times=False): + @keyword(tags=("file",)) + def get_file(self, source, destination=".", scp="OFF", scp_preserve_times=False): """Downloads file(s) from the remote machine to the local machine. ``source`` is a path on the remote machine. Both absolute paths and @@ -1625,11 +1854,19 @@ def get_file(self, source, destination='.', scp='OFF', scp_preserve_times=False) ``scp_preserve_times`` is new in SSHLibrary 3.6.0. """ - return self._run_command(self.current.get_file, source, - destination, scp, scp_preserve_times) + return self._run_command( + self.current.get_file, source, destination, scp, scp_preserve_times + ) - def get_directory(self, source, destination='.', recursive=False, - scp='OFF', scp_preserve_times=False): + @keyword(tags=("file",)) + def get_directory( + self, + source, + destination=".", + recursive=False, + scp="OFF", + scp_preserve_times=False, + ): """Downloads a directory, including its content, from the remote machine to the local machine. ``source`` is a path on the remote machine. Both absolute paths and @@ -1674,11 +1911,25 @@ def get_directory(self, source, destination='.', recursive=False, ``scp_preserve_times`` is new in SSHLibrary 3.6.0. """ - return self._run_command(self.current.get_directory, source, - destination, is_truthy(recursive), scp, scp_preserve_times) + return self._run_command( + self.current.get_directory, + source, + destination, + is_truthy(recursive), + scp, + scp_preserve_times, + ) - def put_file(self, source, destination='.', mode='0744', newline='', - scp='OFF', scp_preserve_times=False): + @keyword(tags=("file",)) + def put_file( + self, + source, + destination=".", + mode="0744", + newline="", + scp="OFF", + scp_preserve_times=False, + ): """Uploads file(s) from the local machine to the remote machine. ``source`` is the path on the local machine. Both absolute paths and @@ -1736,11 +1987,27 @@ def put_file(self, source, destination='.', mode='0744', newline='', ``scp_preserve_times`` is new in SSHLibrary 3.6.0. """ - return self._run_command(self.current.put_file, source, - destination, mode, newline, scp, scp_preserve_times) + return self._run_command( + self.current.put_file, + source, + destination, + mode, + newline, + scp, + scp_preserve_times, + ) - def put_directory(self, source, destination='.', mode='0744', newline='', - recursive=False, scp='OFF', scp_preserve_times=False): + @keyword(tags=("file",)) + def put_directory( + self, + source, + destination=".", + mode="0744", + newline="", + recursive=False, + scp="OFF", + scp_preserve_times=False, + ): """Uploads a directory, including its content, from the local machine to the remote machine. ``source`` is the path on the local machine. Both absolute paths and @@ -1792,9 +2059,16 @@ def put_directory(self, source, destination='.', mode='0744', newline='', ``scp_preserve_times`` is new in SSHLibrary 3.6.0. """ - return self._run_command(self.current.put_directory, source, - destination, mode, newline, - is_truthy(recursive), scp, scp_preserve_times) + return self._run_command( + self.current.put_directory, + source, + destination, + mode, + newline, + is_truthy(recursive), + scp, + scp_preserve_times, + ) def _run_command(self, command, *args): try: @@ -1805,6 +2079,7 @@ def _run_command(self, command, *args): for src, dst in files: self._log(f"'{src}' -> '{dst}'", self._config.loglevel) + @keyword(tags=("file",)) def file_should_exist(self, path): """Fails if the given ``path`` does NOT point to an existing file. @@ -1820,6 +2095,7 @@ def file_should_exist(self, path): if not self.current.is_file(path): raise AssertionError(f"File '{path}' does not exist.") + @keyword(tags=("file",)) def file_should_not_exist(self, path): """Fails if the given ``path`` points to an existing file. @@ -1835,6 +2111,7 @@ def file_should_not_exist(self, path): if self.current.is_file(path): raise AssertionError(f"File '{path}' exists.") + @keyword(tags=("file",)) def directory_should_exist(self, path): """Fails if the given ``path`` does not point to an existing directory. @@ -1850,6 +2127,7 @@ def directory_should_exist(self, path): if not self.current.is_dir(path): raise AssertionError(f"Directory '{path}' does not exist.") + @keyword(tags=("file",)) def directory_should_not_exist(self, path): """Fails if the given ``path`` points to an existing directory. @@ -1865,6 +2143,7 @@ def directory_should_not_exist(self, path): if self.current.is_dir(path): raise AssertionError(f"Directory '{path}' exists.") + @keyword(tags=("file",)) def list_directory(self, path, pattern=None, absolute=False): """Returns and logs items in the remote ``path``, optionally filtered with ``pattern``. @@ -1898,10 +2177,15 @@ def list_directory(self, path, pattern=None, absolute=False): items = self.current.list_dir(path, pattern, is_truthy(absolute)) except SSHClientException as msg: raise RuntimeError(msg) - self._log("{0} item{1}:\n{2}".format(len(items), plural_or_not(items), - '\n'.join(items)), self._config.loglevel) + self._log( + "{0} item{1}:\n{2}".format( + len(items), plural_or_not(items), "\n".join(items) + ), + self._config.loglevel, + ) return items + @keyword(tags=("file",)) def list_files_in_directory(self, path, pattern=None, absolute=False): """A wrapper for `List Directory` that returns only files.""" absolute = is_truthy(absolute) @@ -1910,26 +2194,46 @@ def list_files_in_directory(self, path, pattern=None, absolute=False): except SSHClientException as msg: raise RuntimeError(msg) files = self.current.list_files_in_dir(path, pattern, absolute) - self._log('{0} file{1}:\n{2}'.format(len(files), plural_or_not(files), - '\n'.join(files)), self._config.loglevel) + self._log( + "{0} file{1}:\n{2}".format( + len(files), plural_or_not(files), "\n".join(files) + ), + self._config.loglevel, + ) return files + @keyword(tags=("file",)) def list_directories_in_directory(self, path, pattern=None, absolute=False): """A wrapper for `List Directory` that returns only directories.""" try: dirs = self.current.list_dirs_in_dir(path, pattern, is_truthy(absolute)) except SSHClientException as msg: raise RuntimeError(msg) - self._log('{0} director{1}:\n{2}'.format(len(dirs), - 'y' if len(dirs) == 1 else 'ies', - '\n'.join(dirs)), self._config.loglevel) + self._log( + "{0} director{1}:\n{2}".format( + len(dirs), "y" if len(dirs) == 1 else "ies", "\n".join(dirs) + ), + self._config.loglevel, + ) return dirs class _DefaultConfiguration(Configuration): - def __init__(self, timeout, newline, prompt, loglevel, term_type, width, - height, path_separator, encoding, escape_ansi, encoding_errors): + def __init__( + self, + timeout, + newline, + prompt, + loglevel, + term_type, + width, + height, + path_separator, + encoding, + escape_ansi, + encoding_errors, + ): super(_DefaultConfiguration, self).__init__( timeout=TimeEntry(timeout), newline=NewlineEntry(newline), @@ -1941,5 +2245,5 @@ def __init__(self, timeout, newline, prompt, loglevel, term_type, width, path_separator=StringEntry(path_separator), encoding=StringEntry(encoding), escape_ansi=StringEntry(escape_ansi), - encoding_errors=StringEntry(encoding_errors) + encoding_errors=StringEntry(encoding_errors), )