From fef1f404108a8f43481c73181beffdbcab1aa4dc Mon Sep 17 00:00:00 2001 From: Drew Banin Date: Sat, 20 Aug 2016 02:35:32 -0400 Subject: [PATCH 1/4] working! needs more testing and error handling --- dbt/runner.py | 8 +++++++- dbt/ssh_forward.py | 24 +++++++++++++++++++++++ dbt/targets.py | 49 ++++++++++++++++++++++++++++++++++++++++++++-- requirements.txt | 2 ++ 4 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 dbt/ssh_forward.py diff --git a/dbt/runner.py b/dbt/runner.py index 4cb382868b7..9b2899899cd 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -203,6 +203,8 @@ def __init__(self, project, target_path, graph_type): self.graph_type = graph_type self.target = RedshiftTarget(self.project.run_environment()) + self.target.open_tunnel_if_needed() + self.schema = dbt.schema.Schema(self.project, self.target) def deserialize_graph(self): @@ -350,7 +352,11 @@ def run_from_graph(self, runner, limit_to): model_dependency_list = self.as_concurrent_dep_list(linker, relevant_compiled_models, existing, self.target, specified_models) on_failure = self.on_model_failure(linker, relevant_compiled_models) - return self.execute_models(runner, model_dependency_list, on_failure) + results = self.execute_models(runner, model_dependency_list, on_failure) + + self.target.cleanup() + + return results # ------------------------------------ diff --git a/dbt/ssh_forward.py b/dbt/ssh_forward.py new file mode 100644 index 00000000000..3f8dd269928 --- /dev/null +++ b/dbt/ssh_forward.py @@ -0,0 +1,24 @@ + +from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError +import logging + +# modules are only imported once -- make sure that we don't have > 1 +# because subsequent tunnels will block waiting to acquire the port + +server = None + +def get_or_create_tunnel(host, port, user, remote_host, remote_port): + global server + if server is None: + logger = logging.getLogger(__name__) + + bind_from = (host, port) + bind_to = (remote_host, remote_port) + + server = SSHTunnelForwarder(bind_from, ssh_username=user, remote_bind_address=bind_to, logger=logger) + try: + server.start() + except BaseSSHTunnelForwarderError as e: + raise RuntimeError("Problem connecting through {}:{}: {}".format(host, port, str(e))) + + return server diff --git a/dbt/targets.py b/dbt/targets.py index c8842eec21c..882d32c1721 100644 --- a/dbt/targets.py +++ b/dbt/targets.py @@ -1,8 +1,12 @@ import psycopg2 +import os + +from paramiko import SSHConfig +import ssh_forward THREAD_MIN = 1 -THREAD_MAX = 10 +THREAD_MAX = 8 BAD_THREADS_ERROR = """Invalid value given for "threads" in active run-target. Value given was {supplied} but it should be an int between {min_val} and {max_val}""" @@ -18,12 +22,53 @@ def __init__(self, cfg): self.schema = cfg['schema'] self.threads = self.__get_threads(cfg) + self.ssh_host = cfg.get('ssh-host', None) self.handle = None + def get_tunnel_config(self): + config = SSHConfig() + + config_filepath = os.path.join(os.path.expanduser('~'), '.ssh/config') + config.parse(open(config_filepath)) + options = config.lookup(self.ssh_host) + return options + + def __open_tunnel(self): + config = self.get_tunnel_config() + host = config.get('hostname') + port = int(config.get('port', '22')) + user = config.get('user') + + if host is None: + raise RuntimeError("Invalid ssh config for Hostname {} -- missing 'hostname' field".format(self.ssh_host)) + if user is None: + raise RuntimeError("Invalid ssh config for Hostname {} -- missing 'user' field".format(self.ssh_host)) + + # modules are only imported once -- this singleton makes sure we don't try to bind to the host twice (and lock) + server = ssh_forward.get_or_create_tunnel(host, port, user, self.host, self.port) + + # rebind the pg host and port + self.host = 'localhost' + self.port = server.local_bind_port + + return server + + # make the user explicitly call this function to enable the ssh tunnel + # we don't want it to be automatically opened any time someone makes a RedshiftTarget() + def open_tunnel_if_needed(self): + if self.ssh_host is None: + self.ssh_tunnel = None + else: + self.ssh_tunnel = self.__open_tunnel() + + def cleanup(self): + if self.ssh_tunnel is not None: + self.ssh_tunnel.stop() + def __get_threads(self, cfg): supplied = cfg.get('threads', 1) - bad_threads_error = RuntimeError(BAD_THREADS_ERROR.format(run_target="...", supplied=supplied, min_val=THREAD_MIN, max_val=THREAD_MAX)) + bad_threads_error = RuntimeError(BAD_THREADS_ERROR.format(supplied=supplied, min_val=THREAD_MIN, max_val=THREAD_MAX)) if type(supplied) != int: raise bad_threads_error diff --git a/requirements.txt b/requirements.txt index fdc077ea165..3d09fbf309c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ psycopg2==2.6.1 sqlparse==0.1.19 networkx==1.11 csvkit==0.9.1 +paramiko==2.0.1 +sshtunnel==0.0.8.2 From 90e9e50dd88c6ee3334f4331b7890075df796e0b Mon Sep 17 00:00:00 2001 From: Drew Banin Date: Sat, 20 Aug 2016 03:35:02 -0400 Subject: [PATCH 2/4] add requirements --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index f8b68f1e90e..7cadbfe019a 100644 --- a/setup.py +++ b/setup.py @@ -29,5 +29,7 @@ 'sqlparse==0.1.19', 'networkx==1.11', 'csvkit==0.9.1', + 'paramiko==2.0.1', + 'sshtunnel==0.0.8.2' ], ) From fe42c7d1ea8a787008958e14f40bdf54a2bc32f2 Mon Sep 17 00:00:00 2001 From: Drew Banin Date: Sat, 20 Aug 2016 03:45:09 -0400 Subject: [PATCH 3/4] add 10 second pg timeout --- dbt/targets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/targets.py b/dbt/targets.py index 882d32c1721..91760b823bb 100644 --- a/dbt/targets.py +++ b/dbt/targets.py @@ -79,7 +79,7 @@ def __get_threads(self, cfg): raise bad_threads_error def __get_spec(self): - return "dbname='{}' user='{}' host='{}' password='{}' port='{}'".format( + return "dbname='{}' user='{}' host='{}' password='{}' port='{}' connect_timeout=10".format( self.dbname, self.user, self.host, From 90eed518a2ec4fe39c80935cd9ba2024d1dd3199 Mon Sep 17 00:00:00 2001 From: Drew Banin Date: Sat, 20 Aug 2016 03:49:36 -0400 Subject: [PATCH 4/4] qualify ssh_forward with dbt prefix --- dbt/targets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dbt/targets.py b/dbt/targets.py index 91760b823bb..e7cf277f728 100644 --- a/dbt/targets.py +++ b/dbt/targets.py @@ -3,7 +3,7 @@ import os from paramiko import SSHConfig -import ssh_forward +import dbt.ssh_forward THREAD_MIN = 1 THREAD_MAX = 8 @@ -45,7 +45,7 @@ def __open_tunnel(self): raise RuntimeError("Invalid ssh config for Hostname {} -- missing 'user' field".format(self.ssh_host)) # modules are only imported once -- this singleton makes sure we don't try to bind to the host twice (and lock) - server = ssh_forward.get_or_create_tunnel(host, port, user, self.host, self.port) + server = dbt.ssh_forward.get_or_create_tunnel(host, port, user, self.host, self.port) # rebind the pg host and port self.host = 'localhost'