Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SSHOperator so a subclass can run many commands (#10874) #17378

Merged
merged 2 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 109 additions & 95 deletions airflow/providers/ssh/operators/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import warnings
from base64 import b64encode
from select import select
from typing import Optional, Union
from typing import Optional, Tuple, Union

from paramiko.client import SSHClient

from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -107,103 +109,115 @@ def __init__(
stacklevel=1,
)

def execute(self, context) -> Union[bytes, str, bool]:
def get_hook(self) -> SSHHook:
if self.ssh_conn_id:
if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
else:
self.log.info("ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook.")
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, conn_timeout=self.conn_timeout)

if not self.ssh_hook:
raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")

if self.remote_host is not None:
self.log.info(
"remote_host is provided explicitly. "
"It will replace the remote_host which was defined "
"in ssh_hook or predefined in connection of ssh_conn_id."
)
self.ssh_hook.remote_host = self.remote_host

return self.ssh_hook

def get_ssh_client(self) -> SSHClient:
# Remember to use context manager or call .close() on this when done
self.log.info('Creating ssh_client')
return self.get_hook().get_conn()

def exec_ssh_client_command(self, ssh_client: SSHClient, command: str) -> Tuple[int, bytes, bytes]:
Copy link
Member

@uranusjr uranusjr Aug 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like a wrong abstraction for me since self here is only used for logging, and it’s entirely up to the caller to pass in the correct SSHClient instance, which the operator should be able to manage.

Would something like this make more sense?

@property
def client(self):
    if self._client is None:
        raise RuntimeError("Outside of a create_ssh_client() context")
    return self._client

def execute(self, context=None) -> Union[bytes, str]:
    with self.create_ssh_client():  # This sets self._client so it can be used by other methods.
        self.run_remote_command(command)
        # On exit, close self._client and set self._client to None.
    # Error handling and serialization etc. afterward omitted for brevity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback @uranusjr

I've pushed another commit along these lines, please take a look.
(I know this build is failing, can ignore & I'll work on it)

The thing is we don't want to call super.execute() from a subclass.
So I put the error handling etc. outside it so it can be re-used by a subclass when needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Build has passed. Was flaky CI before now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am ok to keep it as is (re: abstraction) - with passing client, we already have some hooks that do that stateless approach (and some that keep state of the connection in).

No strong opinions which is better. The statefull approach is better from OO perspective and gives more meaning to Hook as also being 'session'. But this is not necessary really. Hook (and it is a bad name) is more of a "nice API" for operator to (re-)use and to understand "connection" and read credentials from it.

I think we never agreed on whether Hook should be 1<->1 session/client and maybe it does not really matter. I think the most important capability of the Hook is ability of mapping connection into credentials and simple Python API so that you can easily use it from Operator.

But adding _client as a field is also OK for me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @potiuk, thanks for the feedback.
Hey @uranusjr, please review again and let me know :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @potiuk and @uranusjr , any feedback :)

self.log.info("Running command: %s", command)

# set timeout taken as params
stdin, stdout, stderr = ssh_client.exec_command(
command=command,
get_pty=self.get_pty,
timeout=self.timeout,
environment=self.environment,
)
# get channels
channel = stdout.channel

# closing stdin
stdin.close()
channel.shutdown_write()

agg_stdout = b''
agg_stderr = b''

# capture any initial output in case channel is closed already
stdout_buffer_length = len(stdout.channel.in_buffer)

if stdout_buffer_length > 0:
agg_stdout += stdout.channel.recv(stdout_buffer_length)

# read from both stdout and stderr
while not channel.closed or channel.recv_ready() or channel.recv_stderr_ready():
readq, _, _ = select([channel], [], [], self.cmd_timeout)
for recv in readq:
if recv.recv_ready():
line = stdout.channel.recv(len(recv.in_buffer))
agg_stdout += line
self.log.info(line.decode('utf-8', 'replace').strip('\n'))
if recv.recv_stderr_ready():
line = stderr.channel.recv_stderr(len(recv.in_stderr_buffer))
agg_stderr += line
self.log.warning(line.decode('utf-8', 'replace').strip('\n'))
if (
stdout.channel.exit_status_ready()
and not stderr.channel.recv_stderr_ready()
and not stdout.channel.recv_ready()
):
stdout.channel.shutdown_read()
try:
stdout.channel.close()
except Exception:
# there is a race that when shutdown_read has been called and when
# you try to close the connection, the socket is already closed
# We should ignore such errors (but we should log them with warning)
self.log.warning("Ignoring exception on close", exc_info=True)
break

stdout.close()
stderr.close()

exit_status = stdout.channel.recv_exit_status()

return exit_status, agg_stdout, agg_stderr

def raise_for_status(self, exit_status: int, stderr: bytes) -> None:
if exit_status != 0:
error_msg = stderr.decode('utf-8')
raise AirflowException(f"error running cmd: {self.command}, error: {error_msg}")

def run_ssh_client_command(self, ssh_client: SSHClient, command: str) -> bytes:
exit_status, agg_stdout, agg_stderr = self.exec_ssh_client_command(ssh_client, command)
self.raise_for_status(exit_status, agg_stderr)
return agg_stdout

def execute(self, context=None) -> Union[bytes, str]:
result = None
if self.command is None:
raise AirflowException("SSH operator error: SSH command not specified. Aborting.")
try:
if self.ssh_conn_id:
if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
else:
self.log.info(
"ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook."
)
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, conn_timeout=self.conn_timeout)

if not self.ssh_hook:
raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")

if self.remote_host is not None:
self.log.info(
"remote_host is provided explicitly. "
"It will replace the remote_host which was defined "
"in ssh_hook or predefined in connection of ssh_conn_id."
)
self.ssh_hook.remote_host = self.remote_host

if not self.command:
raise AirflowException("SSH command not specified. Aborting.")

with self.ssh_hook.get_conn() as ssh_client:
self.log.info("Running command: %s", self.command)

# set timeout taken as params
stdin, stdout, stderr = ssh_client.exec_command(
command=self.command,
get_pty=self.get_pty,
timeout=self.cmd_timeout,
environment=self.environment,
)
# get channels
channel = stdout.channel

# closing stdin
stdin.close()
channel.shutdown_write()

agg_stdout = b''
agg_stderr = b''

# capture any initial output in case channel is closed already
stdout_buffer_length = len(stdout.channel.in_buffer)

if stdout_buffer_length > 0:
agg_stdout += stdout.channel.recv(stdout_buffer_length)

# read from both stdout and stderr
while not channel.closed or channel.recv_ready() or channel.recv_stderr_ready():
readq, _, _ = select([channel], [], [], self.cmd_timeout)
for recv in readq:
if recv.recv_ready():
line = stdout.channel.recv(len(recv.in_buffer))
agg_stdout += line
self.log.info(line.decode('utf-8', 'replace').strip('\n'))
if recv.recv_stderr_ready():
line = stderr.channel.recv_stderr(len(recv.in_stderr_buffer))
agg_stderr += line
self.log.warning(line.decode('utf-8', 'replace').strip('\n'))
if (
stdout.channel.exit_status_ready()
and not stderr.channel.recv_stderr_ready()
and not stdout.channel.recv_ready()
):
stdout.channel.shutdown_read()
try:
stdout.channel.close()
except Exception:
# there is a race that when shutdown_read has been called and when
# you try to close the connection, the socket is already closed
# We should ignore such errors (but we should log them with warning)
self.log.warning("Ignoring exception on close", exc_info=True)
break

stdout.close()
stderr.close()

exit_status = stdout.channel.recv_exit_status()
if exit_status == 0:
enable_pickling = conf.getboolean('core', 'enable_xcom_pickling')
if enable_pickling:
return agg_stdout
else:
return b64encode(agg_stdout).decode('utf-8')

else:
error_msg = agg_stderr.decode('utf-8')
raise AirflowException(f"error running cmd: {self.command}, error: {error_msg}")

with self.get_ssh_client() as ssh_client:
result = self.run_ssh_client_command(ssh_client, self.command)
except Exception as e:
raise AirflowException(f"SSH operator error: {str(e)}")

return True
enable_pickling = conf.getboolean('core', 'enable_xcom_pickling')
if not enable_pickling:
result = b64encode(result).decode('utf-8')
return result

def tunnel(self) -> None:
"""Get ssh tunnel"""
Expand Down
Loading