Skip to content

Commit

Permalink
[AIRFLOW-1762] Implement key_file support in ssh_hook create_tunnel
Browse files Browse the repository at this point in the history
Switched to using sshtunnel package instead of
popen approach

Closes #3473 from NielsZeilemaker/ssh_hook
  • Loading branch information
NielsZeilemaker authored and Fokko Driesprong committed Jul 24, 2018
1 parent b889522 commit 53933c0
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 261 deletions.
290 changes: 148 additions & 142 deletions airflow/contrib/hooks/ssh_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@

import getpass
import os
import warnings

import paramiko
from paramiko.config import SSH_PORT
from sshtunnel import SSHTunnelForwarder

from contextlib import contextmanager
from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(self,
username=None,
password=None,
key_file=None,
port=SSH_PORT,
port=None,
timeout=10,
keepalive_interval=30
):
Expand All @@ -75,162 +76,167 @@ def __init__(self,
self.username = username
self.password = password
self.key_file = key_file
self.port = port
self.timeout = timeout
self.keepalive_interval = keepalive_interval

# Default values, overridable from Connection
self.compress = True
self.no_host_key_check = True
self.host_proxy = None

# Placeholder for deprecated __enter__
self.client = None
self.port = port

# Use connection to override defaults
if self.ssh_conn_id is not None:
conn = self.get_connection(self.ssh_conn_id)
if self.username is None:
self.username = conn.login
if self.password is None:
self.password = conn.password
if self.remote_host is None:
self.remote_host = conn.host
if self.port is None:
self.port = conn.port
if conn.extra is not None:
extra_options = conn.extra_dejson
self.key_file = extra_options.get("key_file")

if "timeout" in extra_options:
self.timeout = int(extra_options["timeout"], 10)

if "compress" in extra_options\
and str(extra_options["compress"]).lower() == 'false':
self.compress = False
if "no_host_key_check" in extra_options\
and\
str(extra_options["no_host_key_check"]).lower() == 'false':
self.no_host_key_check = False

if not self.remote_host:
raise AirflowException("Missing required param: remote_host")

# Auto detecting username values from system
if not self.username:
self.log.debug(
"username to ssh to host: %s is not specified for connection id"
" %s. Using system's default provided by getpass.getuser()",
self.remote_host, self.ssh_conn_id
)
self.username = getpass.getuser()

user_ssh_config_filename = os.path.expanduser('~/.ssh/config')
if os.path.isfile(user_ssh_config_filename):
ssh_conf = paramiko.SSHConfig()
ssh_conf.parse(open(user_ssh_config_filename))
host_info = ssh_conf.lookup(self.remote_host)
if host_info and host_info.get('proxycommand'):
self.host_proxy = paramiko.ProxyCommand(host_info.get('proxycommand'))

if not (self.password or self.key_file):
if host_info and host_info.get('identityfile'):
self.key_file = host_info.get('identityfile')[0]

self.port = self.port or SSH_PORT

def get_conn(self):
if not self.client:
self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id)
if self.ssh_conn_id is not None:
conn = self.get_connection(self.ssh_conn_id)
if self.username is None:
self.username = conn.login
if self.password is None:
self.password = conn.password
if self.remote_host is None:
self.remote_host = conn.host
if conn.port is not None:
self.port = conn.port
if conn.extra is not None:
extra_options = conn.extra_dejson
self.key_file = extra_options.get("key_file")

if "timeout" in extra_options:
self.timeout = int(extra_options["timeout"], 10)

if "compress" in extra_options \
and str(extra_options["compress"]).lower() == 'false':
self.compress = False
if "no_host_key_check" in extra_options \
and \
str(extra_options["no_host_key_check"]).lower() == 'false':
self.no_host_key_check = False

if not self.remote_host:
raise AirflowException("Missing required param: remote_host")

# Auto detecting username values from system
if not self.username:
self.log.debug(
"username to ssh to host: %s is not specified for connection id"
" %s. Using system's default provided by getpass.getuser()",
self.remote_host, self.ssh_conn_id
)
self.username = getpass.getuser()

host_proxy = None
user_ssh_config_filename = os.path.expanduser('~/.ssh/config')
if os.path.isfile(user_ssh_config_filename):
ssh_conf = paramiko.SSHConfig()
ssh_conf.parse(open(user_ssh_config_filename))
host_info = ssh_conf.lookup(self.remote_host)
if host_info and host_info.get('proxycommand'):
host_proxy = paramiko.ProxyCommand(host_info.get('proxycommand'))

if not (self.password or self.key_file):
if host_info and host_info.get('identityfile'):
self.key_file = host_info.get('identityfile')[0]

try:
client = paramiko.SSHClient()
client.load_system_host_keys()
if self.no_host_key_check:
# Default is RejectPolicy
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

if self.password and self.password.strip():
client.connect(hostname=self.remote_host,
username=self.username,
password=self.password,
timeout=self.timeout,
compress=self.compress,
port=self.port,
sock=host_proxy)
else:
client.connect(hostname=self.remote_host,
username=self.username,
key_filename=self.key_file,
timeout=self.timeout,
compress=self.compress,
port=self.port,
sock=host_proxy)

if self.keepalive_interval:
client.get_transport().set_keepalive(self.keepalive_interval)

self.client = client
except paramiko.AuthenticationException as auth_error:
self.log.error(
"Auth failed while connecting to host: %s, error: %s",
self.remote_host, auth_error
)
except paramiko.SSHException as ssh_error:
self.log.error(
"Failed connecting to host: %s, error: %s",
self.remote_host, ssh_error
)
except Exception as error:
self.log.error(
"Error connecting to host: %s, error: %s",
self.remote_host, error
)
return self.client

@contextmanager
def create_tunnel(self, local_port, remote_port=None, remote_host="localhost"):
"""
Creates a tunnel between two hosts. Like ssh -L <LOCAL_PORT>:host:<REMOTE_PORT>.
Remember to close() the returned "tunnel" object in order to clean up
after yourself when you are done with the tunnel.
Opens a ssh connection to the remote host.
:param local_port:
:type local_port: int
:param remote_port:
:type remote_port: int
:param remote_host:
:type remote_host: str
:return:
:return paramiko.SSHClient object
"""

import subprocess
# this will ensure the connection to the ssh.remote_host from where the tunnel
# is getting created
self.get_conn()

tunnel_host = "{0}:{1}:{2}".format(local_port, remote_host, remote_port)

ssh_cmd = ["ssh", "{0}@{1}".format(self.username, self.remote_host),
"-o", "ControlMaster=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "StrictHostKeyChecking=no"]

ssh_tunnel_cmd = ["-L", tunnel_host,
"echo -n ready && cat"
]

ssh_cmd += ssh_tunnel_cmd
self.log.debug("Creating tunnel with cmd: %s", ssh_cmd)

proc = subprocess.Popen(ssh_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
close_fds=True)
ready = proc.stdout.read(5)
assert ready == b"ready", \
"Did not get 'ready' from remote, got '{0}' instead".format(ready)
yield
proc.communicate()
assert proc.returncode == 0, \
"Tunnel process did unclean exit (returncode {}".format(proc.returncode)
self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id)
client = paramiko.SSHClient()
client.load_system_host_keys()
if self.no_host_key_check:
# Default is RejectPolicy
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

if self.password and self.password.strip():
client.connect(hostname=self.remote_host,
username=self.username,
password=self.password,
key_filename=self.key_file,
timeout=self.timeout,
compress=self.compress,
port=self.port,
sock=self.host_proxy)
else:
client.connect(hostname=self.remote_host,
username=self.username,
key_filename=self.key_file,
timeout=self.timeout,
compress=self.compress,
port=self.port,
sock=self.host_proxy)

if self.keepalive_interval:
client.get_transport().set_keepalive(self.keepalive_interval)

self.client = client
return client

def __enter__(self):
warnings.warn('The contextmanager of SSHHook is deprecated.'
'Please use get_conn() as a contextmanager instead.'
'This method will be removed in Airflow 2.0',
category=DeprecationWarning)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.client is not None:
self.client.close()
self.client = None

def get_tunnel(self, remote_port, remote_host="localhost", local_port=None):
"""
Creates a tunnel between two hosts. Like ssh -L <LOCAL_PORT>:host:<REMOTE_PORT>.
:param remote_port: The remote port to create a tunnel to
:type remote_port: int
:param remote_host: The remote host to create a tunnel to (default localhost)
:type remote_host: str
:param local_port: The local port to attach the tunnel to
:type local_port: int
:return: sshtunnel.SSHTunnelForwarder object
"""

if local_port:
local_bind_address = ('localhost', local_port)
else:
local_bind_address = ('localhost',)

if self.password and self.password.strip():
client = SSHTunnelForwarder(self.remote_host,
ssh_port=self.port,
ssh_username=self.username,
ssh_password=self.password,
ssh_pkey=self.key_file,
ssh_proxy=self.host_proxy,
local_bind_address=local_bind_address,
remote_bind_address=(remote_host, remote_port),
logger=self.log)
else:
client = SSHTunnelForwarder(self.remote_host,
ssh_port=self.port,
ssh_username=self.username,
ssh_pkey=self.key_file,
ssh_proxy=self.host_proxy,
local_bind_address=local_bind_address,
remote_bind_address=(remote_host, remote_port),
host_pkey_directories=[],
logger=self.log)

return client

def create_tunnel(self, local_port, remote_port=None, remote_host="localhost"):
warnings.warn('SSHHook.create_tunnel is deprecated, Please'
'use get_tunnel() instead. But please note that the'
'order of the parameters have changed'
'This method will be removed in Airflow 2.0',
category=DeprecationWarning)

return self.get_tunnel(remote_port, remote_host, local_port)
28 changes: 14 additions & 14 deletions airflow/contrib/operators/sftp_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,20 @@ def execute(self, context):
if self.remote_host is not None:
self.ssh_hook.remote_host = self.remote_host

ssh_client = self.ssh_hook.get_conn()
sftp_client = ssh_client.open_sftp()
if self.operation.lower() == SFTPOperation.GET:
file_msg = "from {0} to {1}".format(self.remote_filepath,
self.local_filepath)
self.log.debug("Starting to transfer %s", file_msg)
sftp_client.get(self.remote_filepath, self.local_filepath)
else:
file_msg = "from {0} to {1}".format(self.local_filepath,
self.remote_filepath)
self.log.debug("Starting to transfer file %s", file_msg)
sftp_client.put(self.local_filepath,
self.remote_filepath,
confirm=self.confirm)
with self.ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp()
if self.operation.lower() == SFTPOperation.GET:
file_msg = "from {0} to {1}".format(self.remote_filepath,
self.local_filepath)
self.log.debug("Starting to transfer %s", file_msg)
sftp_client.get(self.remote_filepath, self.local_filepath)
else:
file_msg = "from {0} to {1}".format(self.local_filepath,
self.remote_filepath)
self.log.debug("Starting to transfer file %s", file_msg)
sftp_client.put(self.local_filepath,
self.remote_filepath,
confirm=self.confirm)

except Exception as e:
raise AirflowException("Error while transferring {0}, error: {1}"
Expand Down
Loading

0 comments on commit 53933c0

Please sign in to comment.