diff --git a/src/SSHLibrary/deco.py b/src/SSHLibrary/deco.py
deleted file mode 100644
index b5e24782c..000000000
--- a/src/SSHLibrary/deco.py
+++ /dev/null
@@ -1,37 +0,0 @@
-
-def keyword(types=()):
- """Decorator to set custom argument types to keywords.
-
- This decorator creates ``robot_types`` attribute on the decorated
- keyword method or function based on the provided arguments.
- Robot Framework checks them to determine the keyword's
- argument types.
-
- Types must be given as a dictionary mapping argument names to types or as a list
- (or tuple) of types mapped to arguments based on position. It is OK to
- specify types only to some arguments, and setting ``types`` to ``None``
- disables type conversion altogether.
-
- Examples::
-
- @keyword(types={'length': int, 'case_insensitive': bool})
- def types_as_dict(length, case_insensitive=False):
- # ...
-
- @keyword(types=[int, bool])
- def types_as_list(length, case_insensitive=False):
- # ...
-
- @keyword(types=None])
- def no_conversion(length, case_insensitive=False):
- # ...
-
- @keyword
- def func():
- # ...
- """
-
- def decorator(func):
- func.robot_types = types
- return func
- return decorator
diff --git a/src/SSHLibrary/library.py b/src/SSHLibrary/library.py
index 697cd26a1..aba4026de 100644
--- a/src/SSHLibrary/library.py
+++ b/src/SSHLibrary/library.py
@@ -16,25 +16,27 @@
from __future__ import print_function
import re
-from .deco import keyword
-try:
- from robot.api import logger
-except ImportError:
- logger = None
+from .logger import logger
from robot.utils import is_string, is_truthy, plural_or_not
+from robot.api.deco import keyword, library
from .sshconnectioncache import SSHConnectionCache
from .client import SSHClientException
from .client import SSHClient
-from .config import (Configuration, IntegerEntry, LogLevelEntry, NewlineEntry,
- StringEntry, TimeEntry)
+from .config import (
+ Configuration,
+ IntegerEntry,
+ LogLevelEntry,
+ NewlineEntry,
+ StringEntry,
+ TimeEntry,
+)
from .version import VERSION
-__version__ = VERSION
-
-class SSHLibrary(object):
+@library(version=VERSION, scope="GLOBAL")
+class SSHLibrary:
"""SSHLibrary is a Robot Framework test library for SSH and SFTP.
This document explains how to use keywords provided by SSHLibrary.
@@ -435,33 +437,33 @@ class SSHLibrary(object):
| `Run Keyword And Expect Error` | Non-existing index or alias 'conn'. | `Switch Connection` | conn |
"""
- ROBOT_LIBRARY_SCOPE = 'GLOBAL'
- ROBOT_LIBRARY_VERSION = __version__
- DEFAULT_TIMEOUT = '3 seconds'
- DEFAULT_NEWLINE = 'LF'
+ DEFAULT_TIMEOUT = "3 seconds"
+ DEFAULT_NEWLINE = "LF"
DEFAULT_PROMPT = None
- DEFAULT_LOGLEVEL = 'INFO'
- DEFAULT_TERM_TYPE = 'vt100'
+ DEFAULT_LOGLEVEL = "INFO"
+ DEFAULT_TERM_TYPE = "vt100"
DEFAULT_TERM_WIDTH = 80
DEFAULT_TERM_HEIGHT = 24
- DEFAULT_PATH_SEPARATOR = '/'
- DEFAULT_ENCODING = 'UTF-8'
+ DEFAULT_PATH_SEPARATOR = "/"
+ DEFAULT_ENCODING = "UTF-8"
DEFAULT_ESCAPE_ANSI = False
- DEFAULT_ENCODING_ERRORS = 'strict'
-
- def __init__(self,
- timeout=DEFAULT_TIMEOUT,
- newline=DEFAULT_NEWLINE,
- prompt=DEFAULT_PROMPT,
- loglevel=DEFAULT_LOGLEVEL,
- term_type=DEFAULT_TERM_TYPE,
- width=DEFAULT_TERM_WIDTH,
- height=DEFAULT_TERM_HEIGHT,
- path_separator=DEFAULT_PATH_SEPARATOR,
- encoding=DEFAULT_ENCODING,
- escape_ansi=DEFAULT_ESCAPE_ANSI,
- encoding_errors=DEFAULT_ENCODING_ERRORS):
+ DEFAULT_ENCODING_ERRORS = "strict"
+
+ def __init__(
+ self,
+ timeout=DEFAULT_TIMEOUT,
+ newline=DEFAULT_NEWLINE,
+ prompt=DEFAULT_PROMPT,
+ loglevel=DEFAULT_LOGLEVEL,
+ term_type=DEFAULT_TERM_TYPE,
+ width=DEFAULT_TERM_WIDTH,
+ height=DEFAULT_TERM_HEIGHT,
+ path_separator=DEFAULT_PATH_SEPARATOR,
+ encoding=DEFAULT_ENCODING,
+ escape_ansi=DEFAULT_ESCAPE_ANSI,
+ encoding_errors=DEFAULT_ENCODING_ERRORS,
+ ):
"""SSHLibrary allows some import time `configuration`.
If the library is imported without any arguments, the library
@@ -499,7 +501,7 @@ def __init__(self,
path_separator or self.DEFAULT_PATH_SEPARATOR,
encoding or self.DEFAULT_ENCODING,
escape_ansi or self.DEFAULT_ESCAPE_ANSI,
- encoding_errors or self.DEFAULT_ENCODING_ERRORS
+ encoding_errors or self.DEFAULT_ENCODING_ERRORS,
)
self._last_commands = dict()
@@ -507,11 +509,21 @@ def __init__(self,
def current(self):
return self._connections.current
- @keyword(types=None)
- def set_default_configuration(self, timeout=None, newline=None, prompt=None,
- loglevel=None, term_type=None, width=None,
- height=None, path_separator=None,
- encoding=None, escape_ansi=None, encoding_errors=None):
+ @keyword(tags=("configuration",))
+ def set_default_configuration(
+ self,
+ timeout=None,
+ newline=None,
+ prompt=None,
+ loglevel=None,
+ term_type=None,
+ width=None,
+ height=None,
+ path_separator=None,
+ encoding=None,
+ escape_ansi=None,
+ encoding_errors=None,
+ ):
"""Update the default `configuration`.
Please note that using this keyword does not affect the already
@@ -542,14 +554,34 @@ def set_default_configuration(self, timeout=None, newline=None, prompt=None,
| `Should Be Equal As Integers` | ${emea.timeout} | 20 |
| `Should Be Equal As Integers` | ${apac.timeout} | 20 |
"""
- self._config.update(timeout=timeout, newline=newline, prompt=prompt,
- loglevel=loglevel, term_type=term_type, width=width,
- height=height, path_separator=path_separator,
- encoding=encoding, escape_ansi=escape_ansi, encoding_errors=encoding_errors)
-
- def set_client_configuration(self, timeout=None, newline=None, prompt=None,
- term_type=None, width=None, height=None,
- path_separator=None, encoding=None, escape_ansi=None, encoding_errors=None):
+ self._config.update(
+ timeout=timeout,
+ newline=newline,
+ prompt=prompt,
+ loglevel=loglevel,
+ term_type=term_type,
+ width=width,
+ height=height,
+ path_separator=path_separator,
+ encoding=encoding,
+ escape_ansi=escape_ansi,
+ encoding_errors=encoding_errors,
+ )
+
+ @keyword(tags=("configuration",))
+ def set_client_configuration(
+ self,
+ timeout=None,
+ newline=None,
+ prompt=None,
+ term_type=None,
+ width=None,
+ height=None,
+ path_separator=None,
+ encoding=None,
+ escape_ansi=None,
+ encoding_errors=None,
+ ):
"""Update the `configuration` of the current connection.
Only parameters whose value is other than ``None`` are updated.
@@ -578,13 +610,20 @@ def set_client_configuration(self, timeout=None, newline=None, prompt=None,
| `Open Connection` | 192.168.1.1 |
| `Set Client Configuration` | term_type=ansi | width=40 |
"""
- self.current.config.update(timeout=timeout, newline=newline,
- prompt=prompt, term_type=term_type,
- width=width, height=height,
- path_separator=path_separator,
- encoding=encoding, escape_ansi=escape_ansi,
- encoding_errors=encoding_errors)
+ self.current.config.update(
+ timeout=timeout,
+ newline=newline,
+ prompt=prompt,
+ term_type=term_type,
+ width=width,
+ height=height,
+ path_separator=path_separator,
+ encoding=encoding,
+ escape_ansi=escape_ansi,
+ encoding_errors=encoding_errors,
+ )
+ @keyword(tags=("configuration",))
def enable_ssh_logging(self, logfile):
"""Enables logging of SSH protocol output to given ``logfile``.
@@ -603,12 +642,25 @@ def enable_ssh_logging(self, logfile):
| # Check myserver.log for detailed debug information |
"""
if SSHClient.enable_logging(logfile):
- self._log(f'SSH log is written to file.',
- 'HTML')
-
- def open_connection(self, host, alias=None, port=22, timeout=None,
- newline=None, prompt=None, term_type=None, width=None,
- height=None, path_separator=None, encoding=None, escape_ansi=None, encoding_errors=None):
+ self._log(f'SSH log is written to file.', "HTML")
+
+ @keyword(tags=("connection",))
+ def open_connection(
+ self,
+ host,
+ alias=None,
+ port=22,
+ timeout=None,
+ newline=None,
+ prompt=None,
+ term_type=None,
+ width=None,
+ height=None,
+ path_separator=None,
+ encoding=None,
+ escape_ansi=None,
+ encoding_errors=None,
+ ):
"""Opens a new SSH connection to the given ``host`` and ``port``.
The new connection is made active. Possible existing connections
@@ -676,12 +728,26 @@ def open_connection(self, host, alias=None, port=22, timeout=None,
encoding = encoding or self._config.encoding
escape_ansi = escape_ansi or self._config.escape_ansi
encoding_errors = encoding_errors or self._config.encoding_errors
- client = SSHClient(host, alias, port, timeout, newline, prompt,
- term_type, width, height, path_separator, encoding, escape_ansi, encoding_errors)
+ client = SSHClient(
+ host,
+ alias,
+ port,
+ timeout,
+ newline,
+ prompt,
+ term_type,
+ width,
+ height,
+ path_separator,
+ encoding,
+ escape_ansi,
+ encoding_errors,
+ )
connection_index = self._connections.register(client, alias)
client.config.update(index=connection_index)
return connection_index
+ @keyword(tags=("connection",))
def switch_connection(self, index_or_alias):
"""Switches the active connection by index or alias.
@@ -714,6 +780,7 @@ def switch_connection(self, index_or_alias):
self._connections.switch(index_or_alias)
return old_index
+ @keyword(tags=("connection",))
def close_connection(self):
"""Closes the current connection.
@@ -730,6 +797,7 @@ def close_connection(self):
connections = self._connections
connections.close_current()
+ @keyword(tags=("connection",))
def close_all_connections(self):
"""Closes all open connections.
@@ -748,10 +816,23 @@ def close_all_connections(self):
"""
self._connections.close_all()
- def get_connection(self, index_or_alias=None, index=False, host=False,
- alias=False, port=False, timeout=False, newline=False,
- prompt=False, term_type=False, width=False, height=False,
- encoding=False, escape_ansi=False):
+ @keyword(tags=("connection",))
+ def get_connection(
+ self,
+ index_or_alias=None,
+ index=False,
+ host=False,
+ alias=False,
+ port=False,
+ timeout=False,
+ newline=False,
+ prompt=False,
+ term_type=False,
+ width=False,
+ height=False,
+ encoding=False,
+ escape_ansi=False,
+ ):
"""Returns information about the connection.
Connection is not changed by this keyword, use `Switch Connection` to
@@ -840,38 +921,70 @@ def get_connection(self, index_or_alias=None, index=False, host=False,
except AttributeError:
config = SSHClient(None).config
self._log(str(config), self._config.loglevel)
- return_values = tuple(self._get_config_values(config, index, host,
- alias, port, timeout,
- newline, prompt,
- term_type, width, height,
- encoding, escape_ansi))
+ return_values = tuple(
+ self._get_config_values(
+ config,
+ index,
+ host,
+ alias,
+ port,
+ timeout,
+ newline,
+ prompt,
+ term_type,
+ width,
+ height,
+ encoding,
+ escape_ansi,
+ )
+ )
if not return_values:
return config
if len(return_values) == 1:
return return_values[0]
return return_values
- def _log(self, msg, level='INFO'):
+ def _log(self, msg, level="INFO"):
level = self._active_loglevel(level)
- if level != 'NONE':
+ if level != "NONE":
msg = msg.strip()
if not msg:
return
if logger:
logger.write(msg, level)
else:
- print(f'*{level}* {msg}')
+ print(f"*{level}* {msg}")
def _active_loglevel(self, level):
if level is None:
return self._config.loglevel
- if is_string(level) and \
- level.upper() in ['TRACE', 'DEBUG', 'INFO', 'WARN', 'HTML', 'NONE']:
+ if is_string(level) and level.upper() in [
+ "TRACE",
+ "DEBUG",
+ "INFO",
+ "WARN",
+ "HTML",
+ "NONE",
+ ]:
return level.upper()
raise AssertionError(f"Invalid log level '{level}'.")
- def _get_config_values(self, config, index, host, alias, port, timeout,
- newline, prompt, term_type, width, height, encoding, escape_ansi):
+ def _get_config_values(
+ self,
+ config,
+ index,
+ host,
+ alias,
+ port,
+ timeout,
+ newline,
+ prompt,
+ term_type,
+ width,
+ height,
+ encoding,
+ escape_ansi,
+ ):
if is_truthy(index):
yield config.index
if is_truthy(host):
@@ -897,6 +1010,7 @@ def _get_config_values(self, config, index, host, alias, port, timeout,
if is_truthy(escape_ansi):
yield config.escape_ansi
+ @keyword(tags=("connection",))
def get_connections(self):
"""Returns information about all the open connections.
@@ -920,8 +1034,19 @@ def get_connections(self):
self._log(str(c), self._config.loglevel)
return configs
- def login(self, username=None, password=None, allow_agent=False, look_for_keys=False, delay='0.5 seconds',
- proxy_cmd=None, read_config=False, jumphost_index_or_alias=None, keep_alive_interval='0 seconds'):
+ @keyword(tags=("login",))
+ def login(
+ self,
+ username=None,
+ password=None,
+ allow_agent=False,
+ look_for_keys=False,
+ delay="0.5 seconds",
+ proxy_cmd=None,
+ read_config=False,
+ jumphost_index_or_alias=None,
+ keep_alive_interval="0 seconds",
+ ):
"""Logs into the SSH server with the given ``username`` and ``password``.
Connection must be opened before using this keyword.
@@ -980,20 +1105,44 @@ def login(self, username=None, password=None, allow_agent=False, look_for_keys=F
| `Open Connection` | linux.server.com |
| `Login` | johndoe | allow_agent=True |
"""
- jumphost_connection_conf = self.get_connection(index_or_alias=jumphost_index_or_alias) \
- if jumphost_index_or_alias else None
- jumphost_connection = self._connections.connections[jumphost_connection_conf.index - 1] \
- if jumphost_connection_conf and jumphost_connection_conf.index else None
-
- return self._login(self.current.login, username, password, is_truthy(allow_agent),
- is_truthy(look_for_keys), delay, proxy_cmd, is_truthy(read_config),
- jumphost_connection, keep_alive_interval)
-
- def login_with_public_key(self, username=None, keyfile=None, password='',
- allow_agent=False, look_for_keys=False,
- delay='0.5 seconds', proxy_cmd=None,
- jumphost_index_or_alias=None,
- read_config=False, keep_alive_interval='0 seconds'):
+ jumphost_connection_conf = (
+ self.get_connection(index_or_alias=jumphost_index_or_alias)
+ if jumphost_index_or_alias
+ else None
+ )
+ jumphost_connection = (
+ self._connections.connections[jumphost_connection_conf.index - 1]
+ if jumphost_connection_conf and jumphost_connection_conf.index
+ else None
+ )
+
+ return self._login(
+ self.current.login,
+ username,
+ password,
+ is_truthy(allow_agent),
+ is_truthy(look_for_keys),
+ delay,
+ proxy_cmd,
+ is_truthy(read_config),
+ jumphost_connection,
+ keep_alive_interval,
+ )
+
+ @keyword(tags=("login",))
+ def login_with_public_key(
+ self,
+ username=None,
+ keyfile=None,
+ password="",
+ allow_agent=False,
+ look_for_keys=False,
+ delay="0.5 seconds",
+ proxy_cmd=None,
+ jumphost_index_or_alias=None,
+ read_config=False,
+ keep_alive_interval="0 seconds",
+ ):
"""Logs into the SSH server using key-based authentication.
Connection must be opened before using this keyword.
@@ -1049,29 +1198,48 @@ def login_with_public_key(self, username=None, keyfile=None, password='',
``keep_alive_interval`` is new in SSHLibrary 3.7.0.
"""
if proxy_cmd and jumphost_index_or_alias:
- raise ValueError("`proxy_cmd` and `jumphost_connection` are mutually exclusive SSH features.")
- jumphost_connection_conf = self.get_connection(
- index_or_alias=jumphost_index_or_alias) if jumphost_index_or_alias else None
- jumphost_connection = self._connections.connections[
- jumphost_connection_conf.index - 1] if jumphost_connection_conf and jumphost_connection_conf.index else None
- return self._login(self.current.login_with_public_key, username,
- keyfile, password, is_truthy(allow_agent),
- is_truthy(look_for_keys), delay, proxy_cmd,
- jumphost_connection, is_truthy(read_config), keep_alive_interval)
+ raise ValueError(
+ "`proxy_cmd` and `jumphost_connection` are mutually exclusive SSH features."
+ )
+ jumphost_connection_conf = (
+ self.get_connection(index_or_alias=jumphost_index_or_alias)
+ if jumphost_index_or_alias
+ else None
+ )
+ jumphost_connection = (
+ self._connections.connections[jumphost_connection_conf.index - 1]
+ if jumphost_connection_conf and jumphost_connection_conf.index
+ else None
+ )
+ return self._login(
+ self.current.login_with_public_key,
+ username,
+ keyfile,
+ password,
+ is_truthy(allow_agent),
+ is_truthy(look_for_keys),
+ delay,
+ proxy_cmd,
+ jumphost_connection,
+ is_truthy(read_config),
+ keep_alive_interval,
+ )
def _login(self, login_method, username, *args):
self._log(
f"Logging into '{self.current.config.host}:{self.current.config.port}' as '{self.current.config.host}'.",
- self._config.loglevel)
+ self._config.loglevel,
+ )
try:
login_output = login_method(username, *args)
if is_truthy(self.current.config.escape_ansi):
login_output = self._escape_ansi_sequences(login_output)
- self._log(f'Read output: {login_output}', self._config.loglevel)
+ self._log(f"Read output: {login_output}", self._config.loglevel)
return login_output
except SSHClientException as e:
raise RuntimeError(e)
+ @keyword(tags=("login",))
def get_pre_login_banner(self, host=None, port=22):
"""Returns the banner supplied by the server upon connect.
@@ -1099,12 +1267,26 @@ def get_pre_login_banner(self, host=None, port=22):
elif self.current:
banner = self.current.get_banner()
else:
- raise RuntimeError("'host' argument is mandatory if there is no open connection.")
+ raise RuntimeError(
+ "'host' argument is mandatory if there is no open connection."
+ )
return banner.decode(self.DEFAULT_ENCODING)
- def execute_command(self, command, return_stdout=True, return_stderr=False,
- return_rc=False, sudo=False, sudo_password=None, timeout=None, output_during_execution=False,
- output_if_timeout=False, invoke_subsystem=False, forward_agent=False):
+ @keyword(tags=("command",))
+ def execute_command(
+ self,
+ command,
+ return_stdout=True,
+ return_stderr=False,
+ return_rc=False,
+ sudo=False,
+ sudo_password=None,
+ timeout=None,
+ output_during_execution=False,
+ output_if_timeout=False,
+ invoke_subsystem=False,
+ forward_agent=False,
+ ):
"""Executes ``command`` on the remote machine and returns its outputs.
This keyword executes the ``command`` and returns after the execution
@@ -1175,14 +1357,28 @@ def execute_command(self, command, return_stdout=True, return_stderr=False,
self._log(f"Executing command '{command}'.", self._config.loglevel)
else:
self._log(f"Executing command 'sudo {command}'.", self._config.loglevel)
- opts = self._legacy_output_options(return_stdout, return_stderr,
- return_rc)
- stdout, stderr, rc = self.current.execute_command(command, sudo, sudo_password,
- timeout, output_during_execution, output_if_timeout,
- is_truthy(invoke_subsystem), forward_agent)
+ opts = self._legacy_output_options(return_stdout, return_stderr, return_rc)
+ stdout, stderr, rc = self.current.execute_command(
+ command,
+ sudo,
+ sudo_password,
+ timeout,
+ output_during_execution,
+ output_if_timeout,
+ is_truthy(invoke_subsystem),
+ forward_agent,
+ )
return self._return_command_output(stdout, stderr, rc, *opts)
- def start_command(self, command, sudo=False, sudo_password=None, invoke_subsystem=False, forward_agent=False):
+ @keyword(tags=("command",))
+ def start_command(
+ self,
+ command,
+ sudo=False,
+ sudo_password=None,
+ invoke_subsystem=False,
+ forward_agent=False,
+ ):
"""Starts execution of the ``command`` on the remote machine and returns immediately.
This keyword returns nothing and does not wait for the ``command``
@@ -1234,10 +1430,18 @@ def start_command(self, command, sudo=False, sudo_password=None, invoke_subsyste
else:
temp_dict = {self.current.config.index: command}
self._last_commands.update(temp_dict)
- self.current.start_command(command, sudo, sudo_password, is_truthy(invoke_subsystem), is_truthy(forward_agent))
+ self.current.start_command(
+ command,
+ sudo,
+ sudo_password,
+ is_truthy(invoke_subsystem),
+ is_truthy(forward_agent),
+ )
- def read_command_output(self, return_stdout=True, return_stderr=False,
- return_rc=False, timeout=None):
+ @keyword(tags=("command",))
+ def read_command_output(
+ self, return_stdout=True, return_stderr=False, return_rc=False, timeout=None
+ ):
"""Returns outputs of the most recent started command.
At least one command must have been started using `Start Command`
@@ -1285,17 +1489,21 @@ def read_command_output(self, return_stdout=True, return_stderr=False,
This keyword logs the read command with log level ``INFO``.
"""
- self._log(f"Reading output of command '{self._last_commands.get(self.current.config.index)}'.",
- self._config.loglevel)
- opts = self._legacy_output_options(return_stdout, return_stderr,
- return_rc)
+ self._log(
+ f"Reading output of command '{self._last_commands.get(self.current.config.index)}'.",
+ self._config.loglevel,
+ )
+ opts = self._legacy_output_options(return_stdout, return_stderr, return_rc)
try:
stdout, stderr, rc = self.current.read_command_output(timeout=timeout)
except SSHClientException as msg:
raise RuntimeError(msg)
return self._return_command_output(stdout, stderr, rc, *opts)
- def create_local_ssh_tunnel(self, local_port, remote_host, remote_port=22, bind_address=None):
+ @keyword(tags=("connection",))
+ def create_local_ssh_tunnel(
+ self, local_port, remote_host, remote_port=22, bind_address=None
+ ):
"""
The keyword uses the existing connection to set up local port forwarding
(the openssh -L option) from a local port through a tunneled
@@ -1324,32 +1532,36 @@ def create_local_ssh_tunnel(self, local_port, remote_host, remote_port=22, bind_
``bind_address`` is new in SSHLibrary 3.3.0.
"""
- self.current.create_local_ssh_tunnel(local_port, remote_host, remote_port, bind_address)
+ self.current.create_local_ssh_tunnel(
+ local_port, remote_host, remote_port, bind_address
+ )
def _legacy_output_options(self, stdout, stderr, rc):
if not is_string(stdout):
return stdout, stderr, rc
stdout = stdout.lower()
- if stdout == 'stderr':
+ if stdout == "stderr":
return False, True, rc
- if stdout == 'both':
+ if stdout == "both":
return True, True, rc
return stdout, stderr, rc
- def _return_command_output(self, stdout, stderr, rc, return_stdout,
- return_stderr, return_rc):
+ def _return_command_output(
+ self, stdout, stderr, rc, return_stdout, return_stderr, return_rc
+ ):
self._log(f"Command exited with return code {rc}.", self._config.loglevel)
ret = []
if is_truthy(return_stdout):
- ret.append(stdout.rstrip('\n'))
+ ret.append(stdout.rstrip("\n"))
if is_truthy(return_stderr):
- ret.append(stderr.rstrip('\n'))
+ ret.append(stderr.rstrip("\n"))
if is_truthy(return_rc):
ret.append(rc)
if len(ret) == 1:
return ret[0]
return ret
+ @keyword(tags=("command",))
def write(self, text, loglevel=None):
"""Writes the given ``text`` on the remote machine and appends a newline.
@@ -1377,6 +1589,7 @@ def write(self, text, loglevel=None):
self._write(text, add_newline=True)
return self._read_and_log(loglevel, self.current.read_until_newline)
+ @keyword(tags=("command",))
def write_bare(self, text):
"""Writes the given ``text`` on the remote machine without appending a newline.
@@ -1402,8 +1615,10 @@ def _write(self, text, add_newline=False):
except SSHClientException as e:
raise RuntimeError(e)
- def write_until_expected_output(self, text, expected, timeout,
- retry_interval, loglevel=None):
+ @keyword(tags=("command",))
+ def write_until_expected_output(
+ self, text, expected, timeout, retry_interval, loglevel=None
+ ):
"""Writes the given ``text`` repeatedly until ``expected`` appears in the server output.
This keyword returns nothing.
@@ -1427,9 +1642,16 @@ def write_until_expected_output(self, text, expected, timeout,
| `Write Until Expected Output` | lsof -c python27\\n | expected=myscript.py | timeout=5s | retry_interval=0.5s |
"""
- self._read_and_log(loglevel, self.current.write_until_expected, text,
- expected, timeout, retry_interval)
+ self._read_and_log(
+ loglevel,
+ self.current.write_until_expected,
+ text,
+ expected,
+ timeout,
+ retry_interval,
+ )
+ @keyword(tags=("command",))
def read(self, loglevel=None, delay=None):
"""Consumes and returns everything available on the server output.
@@ -1459,6 +1681,7 @@ def read(self, loglevel=None, delay=None):
"""
return self._read_and_log(loglevel, self.current.read, delay)
+ @keyword(tags=("command",))
def read_until(self, expected, loglevel=None):
"""Consumes and returns the server output until ``expected`` is encountered.
@@ -1485,6 +1708,7 @@ def read_until(self, expected, loglevel=None):
"""
return self._read_and_log(loglevel, self.current.read_until, expected)
+ @keyword(tags=("command",))
def read_until_prompt(self, loglevel=None, strip_prompt=False):
"""Consumes and returns the server output until the prompt is found.
@@ -1520,8 +1744,11 @@ def read_until_prompt(self, loglevel=None, strip_prompt=False):
``strip_prompt`` argument is new in SSHLibrary 3.2.0.
"""
- return self._read_and_log(loglevel, self.current.read_until_prompt, is_truthy(strip_prompt))
+ return self._read_and_log(
+ loglevel, self.current.read_until_prompt, is_truthy(strip_prompt)
+ )
+ @keyword(tags=("command",))
def read_until_regexp(self, regexp, loglevel=None):
"""Consumes and returns the server output until a match to ``regexp`` is found.
@@ -1550,8 +1777,7 @@ def read_until_regexp(self, regexp, loglevel=None):
details about reading and writing in general, see the `Interactive
shells` section.
"""
- return self._read_and_log(loglevel, self.current.read_until_regexp,
- regexp)
+ return self._read_and_log(loglevel, self.current.read_until_regexp, regexp)
def _read_and_log(self, loglevel, reader, *args):
try:
@@ -1568,11 +1794,14 @@ def _read_and_log(self, loglevel, reader, *args):
@staticmethod
def _escape_ansi_sequences(output):
- ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]', flags=re.IGNORECASE)
- output = ansi_escape.sub('', output)
- return (f"{output!r}")[1:-1].encode().decode('unicode-escape')
+ ansi_escape = re.compile(
+ r"(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]", flags=re.IGNORECASE
+ )
+ output = ansi_escape.sub("", output)
+ return (f"{output!r}")[1:-1].encode().decode("unicode-escape")
- def get_file(self, source, destination='.', scp='OFF', scp_preserve_times=False):
+ @keyword(tags=("file",))
+ def get_file(self, source, destination=".", scp="OFF", scp_preserve_times=False):
"""Downloads file(s) from the remote machine to the local machine.
``source`` is a path on the remote machine. Both absolute paths and
@@ -1625,11 +1854,19 @@ def get_file(self, source, destination='.', scp='OFF', scp_preserve_times=False)
``scp_preserve_times`` is new in SSHLibrary 3.6.0.
"""
- return self._run_command(self.current.get_file, source,
- destination, scp, scp_preserve_times)
+ return self._run_command(
+ self.current.get_file, source, destination, scp, scp_preserve_times
+ )
- def get_directory(self, source, destination='.', recursive=False,
- scp='OFF', scp_preserve_times=False):
+ @keyword(tags=("file",))
+ def get_directory(
+ self,
+ source,
+ destination=".",
+ recursive=False,
+ scp="OFF",
+ scp_preserve_times=False,
+ ):
"""Downloads a directory, including its content, from the remote machine to the local machine.
``source`` is a path on the remote machine. Both absolute paths and
@@ -1674,11 +1911,25 @@ def get_directory(self, source, destination='.', recursive=False,
``scp_preserve_times`` is new in SSHLibrary 3.6.0.
"""
- return self._run_command(self.current.get_directory, source,
- destination, is_truthy(recursive), scp, scp_preserve_times)
+ return self._run_command(
+ self.current.get_directory,
+ source,
+ destination,
+ is_truthy(recursive),
+ scp,
+ scp_preserve_times,
+ )
- def put_file(self, source, destination='.', mode='0744', newline='',
- scp='OFF', scp_preserve_times=False):
+ @keyword(tags=("file",))
+ def put_file(
+ self,
+ source,
+ destination=".",
+ mode="0744",
+ newline="",
+ scp="OFF",
+ scp_preserve_times=False,
+ ):
"""Uploads file(s) from the local machine to the remote machine.
``source`` is the path on the local machine. Both absolute paths and
@@ -1736,11 +1987,27 @@ def put_file(self, source, destination='.', mode='0744', newline='',
``scp_preserve_times`` is new in SSHLibrary 3.6.0.
"""
- return self._run_command(self.current.put_file, source,
- destination, mode, newline, scp, scp_preserve_times)
+ return self._run_command(
+ self.current.put_file,
+ source,
+ destination,
+ mode,
+ newline,
+ scp,
+ scp_preserve_times,
+ )
- def put_directory(self, source, destination='.', mode='0744', newline='',
- recursive=False, scp='OFF', scp_preserve_times=False):
+ @keyword(tags=("file",))
+ def put_directory(
+ self,
+ source,
+ destination=".",
+ mode="0744",
+ newline="",
+ recursive=False,
+ scp="OFF",
+ scp_preserve_times=False,
+ ):
"""Uploads a directory, including its content, from the local machine to the remote machine.
``source`` is the path on the local machine. Both absolute paths and
@@ -1792,9 +2059,16 @@ def put_directory(self, source, destination='.', mode='0744', newline='',
``scp_preserve_times`` is new in SSHLibrary 3.6.0.
"""
- return self._run_command(self.current.put_directory, source,
- destination, mode, newline,
- is_truthy(recursive), scp, scp_preserve_times)
+ return self._run_command(
+ self.current.put_directory,
+ source,
+ destination,
+ mode,
+ newline,
+ is_truthy(recursive),
+ scp,
+ scp_preserve_times,
+ )
def _run_command(self, command, *args):
try:
@@ -1805,6 +2079,7 @@ def _run_command(self, command, *args):
for src, dst in files:
self._log(f"'{src}' -> '{dst}'", self._config.loglevel)
+ @keyword(tags=("file",))
def file_should_exist(self, path):
"""Fails if the given ``path`` does NOT point to an existing file.
@@ -1820,6 +2095,7 @@ def file_should_exist(self, path):
if not self.current.is_file(path):
raise AssertionError(f"File '{path}' does not exist.")
+ @keyword(tags=("file",))
def file_should_not_exist(self, path):
"""Fails if the given ``path`` points to an existing file.
@@ -1835,6 +2111,7 @@ def file_should_not_exist(self, path):
if self.current.is_file(path):
raise AssertionError(f"File '{path}' exists.")
+ @keyword(tags=("file",))
def directory_should_exist(self, path):
"""Fails if the given ``path`` does not point to an existing directory.
@@ -1850,6 +2127,7 @@ def directory_should_exist(self, path):
if not self.current.is_dir(path):
raise AssertionError(f"Directory '{path}' does not exist.")
+ @keyword(tags=("file",))
def directory_should_not_exist(self, path):
"""Fails if the given ``path`` points to an existing directory.
@@ -1865,6 +2143,7 @@ def directory_should_not_exist(self, path):
if self.current.is_dir(path):
raise AssertionError(f"Directory '{path}' exists.")
+ @keyword(tags=("file",))
def list_directory(self, path, pattern=None, absolute=False):
"""Returns and logs items in the remote ``path``, optionally filtered with ``pattern``.
@@ -1898,10 +2177,15 @@ def list_directory(self, path, pattern=None, absolute=False):
items = self.current.list_dir(path, pattern, is_truthy(absolute))
except SSHClientException as msg:
raise RuntimeError(msg)
- self._log("{0} item{1}:\n{2}".format(len(items), plural_or_not(items),
- '\n'.join(items)), self._config.loglevel)
+ self._log(
+ "{0} item{1}:\n{2}".format(
+ len(items), plural_or_not(items), "\n".join(items)
+ ),
+ self._config.loglevel,
+ )
return items
+ @keyword(tags=("file",))
def list_files_in_directory(self, path, pattern=None, absolute=False):
"""A wrapper for `List Directory` that returns only files."""
absolute = is_truthy(absolute)
@@ -1910,26 +2194,46 @@ def list_files_in_directory(self, path, pattern=None, absolute=False):
except SSHClientException as msg:
raise RuntimeError(msg)
files = self.current.list_files_in_dir(path, pattern, absolute)
- self._log('{0} file{1}:\n{2}'.format(len(files), plural_or_not(files),
- '\n'.join(files)), self._config.loglevel)
+ self._log(
+ "{0} file{1}:\n{2}".format(
+ len(files), plural_or_not(files), "\n".join(files)
+ ),
+ self._config.loglevel,
+ )
return files
+ @keyword(tags=("file",))
def list_directories_in_directory(self, path, pattern=None, absolute=False):
"""A wrapper for `List Directory` that returns only directories."""
try:
dirs = self.current.list_dirs_in_dir(path, pattern, is_truthy(absolute))
except SSHClientException as msg:
raise RuntimeError(msg)
- self._log('{0} director{1}:\n{2}'.format(len(dirs),
- 'y' if len(dirs) == 1 else 'ies',
- '\n'.join(dirs)), self._config.loglevel)
+ self._log(
+ "{0} director{1}:\n{2}".format(
+ len(dirs), "y" if len(dirs) == 1 else "ies", "\n".join(dirs)
+ ),
+ self._config.loglevel,
+ )
return dirs
class _DefaultConfiguration(Configuration):
- def __init__(self, timeout, newline, prompt, loglevel, term_type, width,
- height, path_separator, encoding, escape_ansi, encoding_errors):
+ def __init__(
+ self,
+ timeout,
+ newline,
+ prompt,
+ loglevel,
+ term_type,
+ width,
+ height,
+ path_separator,
+ encoding,
+ escape_ansi,
+ encoding_errors,
+ ):
super(_DefaultConfiguration, self).__init__(
timeout=TimeEntry(timeout),
newline=NewlineEntry(newline),
@@ -1941,5 +2245,5 @@ def __init__(self, timeout, newline, prompt, loglevel, term_type, width,
path_separator=StringEntry(path_separator),
encoding=StringEntry(encoding),
escape_ansi=StringEntry(escape_ansi),
- encoding_errors=StringEntry(encoding_errors)
+ encoding_errors=StringEntry(encoding_errors),
)