diff --git a/smart_open/smart_open_lib.py b/smart_open/smart_open_lib.py index 35c71cd5..64538e5e 100644 --- a/smart_open/smart_open_lib.py +++ b/smart_open/smart_open_lib.py @@ -51,6 +51,7 @@ S3_MIN_PART_SIZE = 50 * 1024**2 # minimum part size for S3 multipart uploads WEBHDFS_MIN_PART_SIZE = 50 * 1024**2 # minimum part size for HDFS multipart uploads +_SSH={} # place to put ssh connections, if necessary. def smart_open(uri, mode="rb", **kw): """ @@ -117,6 +118,24 @@ def smart_open(uri, mode="rb", **kw): # local files -- both read & write supported # compression, if any, is determined by the filename extension (.gz, .bz2) return file_smart_open(parsed_uri.uri_path, mode) + + elif parsed_uri.scheme in ("ssh", "scp", "sftp"): + def SSH(hostname, username): + import paramiko + ssh = _SSH.get( (hostname,username) ) + if ssh is None: + ssh = _SSH[ (hostname,username) ] = paramiko.client.SSHClient() + ssh.load_system_host_keys() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(hostname, 22, username) + pass + return ssh + def SFTP(): return SSH(parsed_uri.host, + parsed_uri.user).get_transport().open_sftp_client() + def sopen(filename, mode): return SFTP().open(filename, mode) + + return sopen(parsed_uri.uri_path[1:], mode) + elif parsed_uri.scheme in ("s3", "s3n"): s3_connection = boto.connect_s3(aws_access_key_id=parsed_uri.access_id, aws_secret_access_key=parsed_uri.access_secret) bucket = s3_connection.get_bucket(parsed_uri.bucket_id) @@ -177,6 +196,8 @@ class ParseUri(object): * ./local/path/file.gz * file:///home/user/file * file:///home/user/file.bz2 + * [ssh|scp|sftp]://username@host//path/file + * [ssh|scp|sftp]://username@host/path/file """ def __init__(self, uri, default_scheme="file"): @@ -198,6 +219,9 @@ def __init__(self, uri, default_scheme="file"): if not self.uri_path: raise RuntimeError("invalid HDFS URI: %s" % uri) + elif self.scheme in ("ssh", "scp", "sftp"): + self.user, self.host = parsed_uri.netloc.split('@') + self.uri_path = parsed_uri.path elif self.scheme == "webhdfs": self.uri_path = parsed_uri.netloc + "/webhdfs/v1" + parsed_uri.path