Skip to content

Commit

Permalink
Merge pull request #120 from analyst-collective/feature/ssh-forwarding
Browse files Browse the repository at this point in the history
ssh forwarding
  • Loading branch information
drewbanin authored Aug 20, 2016
2 parents 84013d5 + 90eed51 commit f237f7d
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 4 deletions.
8 changes: 7 additions & 1 deletion dbt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def __init__(self, project, target_path, graph_type):
self.graph_type = graph_type

self.target = RedshiftTarget(self.project.run_environment())
self.target.open_tunnel_if_needed()

self.schema = dbt.schema.Schema(self.project, self.target)

def deserialize_graph(self):
Expand Down Expand Up @@ -350,7 +352,11 @@ def run_from_graph(self, runner, limit_to):
model_dependency_list = self.as_concurrent_dep_list(linker, relevant_compiled_models, existing, self.target, specified_models)

on_failure = self.on_model_failure(linker, relevant_compiled_models)
return self.execute_models(runner, model_dependency_list, on_failure)
results = self.execute_models(runner, model_dependency_list, on_failure)

self.target.cleanup()

return results

# ------------------------------------

Expand Down
24 changes: 24 additions & 0 deletions dbt/ssh_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError
import logging

# modules are only imported once -- make sure that we don't have > 1
# because subsequent tunnels will block waiting to acquire the port

server = None

def get_or_create_tunnel(host, port, user, remote_host, remote_port):
global server
if server is None:
logger = logging.getLogger(__name__)

bind_from = (host, port)
bind_to = (remote_host, remote_port)

server = SSHTunnelForwarder(bind_from, ssh_username=user, remote_bind_address=bind_to, logger=logger)
try:
server.start()
except BaseSSHTunnelForwarderError as e:
raise RuntimeError("Problem connecting through {}:{}: {}".format(host, port, str(e)))

return server
51 changes: 48 additions & 3 deletions dbt/targets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@

import psycopg2
import os

from paramiko import SSHConfig
import dbt.ssh_forward

THREAD_MIN = 1
THREAD_MAX = 10
THREAD_MAX = 8

BAD_THREADS_ERROR = """Invalid value given for "threads" in active run-target.
Value given was {supplied} but it should be an int between {min_val} and {max_val}"""
Expand All @@ -18,12 +22,53 @@ def __init__(self, cfg):
self.schema = cfg['schema']
self.threads = self.__get_threads(cfg)

self.ssh_host = cfg.get('ssh-host', None)
self.handle = None

def get_tunnel_config(self):
config = SSHConfig()

config_filepath = os.path.join(os.path.expanduser('~'), '.ssh/config')
config.parse(open(config_filepath))
options = config.lookup(self.ssh_host)
return options

def __open_tunnel(self):
config = self.get_tunnel_config()
host = config.get('hostname')
port = int(config.get('port', '22'))
user = config.get('user')

if host is None:
raise RuntimeError("Invalid ssh config for Hostname {} -- missing 'hostname' field".format(self.ssh_host))
if user is None:
raise RuntimeError("Invalid ssh config for Hostname {} -- missing 'user' field".format(self.ssh_host))

# modules are only imported once -- this singleton makes sure we don't try to bind to the host twice (and lock)
server = dbt.ssh_forward.get_or_create_tunnel(host, port, user, self.host, self.port)

# rebind the pg host and port
self.host = 'localhost'
self.port = server.local_bind_port

return server

# make the user explicitly call this function to enable the ssh tunnel
# we don't want it to be automatically opened any time someone makes a RedshiftTarget()
def open_tunnel_if_needed(self):
if self.ssh_host is None:
self.ssh_tunnel = None
else:
self.ssh_tunnel = self.__open_tunnel()

def cleanup(self):
if self.ssh_tunnel is not None:
self.ssh_tunnel.stop()

def __get_threads(self, cfg):
supplied = cfg.get('threads', 1)

bad_threads_error = RuntimeError(BAD_THREADS_ERROR.format(run_target="...", supplied=supplied, min_val=THREAD_MIN, max_val=THREAD_MAX))
bad_threads_error = RuntimeError(BAD_THREADS_ERROR.format(supplied=supplied, min_val=THREAD_MIN, max_val=THREAD_MAX))

if type(supplied) != int:
raise bad_threads_error
Expand All @@ -34,7 +79,7 @@ def __get_threads(self, cfg):
raise bad_threads_error

def __get_spec(self):
return "dbname='{}' user='{}' host='{}' password='{}' port='{}'".format(
return "dbname='{}' user='{}' host='{}' password='{}' port='{}' connect_timeout=10".format(
self.dbname,
self.user,
self.host,
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ psycopg2==2.6.1
sqlparse==0.1.19
networkx==1.11
csvkit==0.9.1
paramiko==2.0.1
sshtunnel==0.0.8.2
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,7 @@
'sqlparse==0.1.19',
'networkx==1.11',
'csvkit==0.9.1',
'paramiko==2.0.1',
'sshtunnel==0.0.8.2'
],
)

0 comments on commit f237f7d

Please sign in to comment.