diff --git a/smart_open/smart_open_lib.py b/smart_open/smart_open_lib.py index 9454450e..01c5f12e 100644 --- a/smart_open/smart_open_lib.py +++ b/smart_open/smart_open_lib.py @@ -370,7 +370,7 @@ def _open_binary_stream(uri, mode, **kw): # host = kw.pop('host', None) if host is not None: - kw['endpoint_url'] = 'http://' + host + kw['endpoint_url'] = _add_scheme_to_host(host) return smart_open_s3.open(uri.bucket.name, uri.name, mode, **kw), uri.name elif hasattr(uri, 'read'): # simply pass-through if already a file-like @@ -395,7 +395,7 @@ def _s3_open_uri(parsed_uri, mode, **kwargs): # Get an S3 host. It is required for sigv4 operations. host = kwargs.pop('host', None) if host is not None: - kwargs['endpoint_url'] = 'http://' + host + kwargs['endpoint_url'] = _add_scheme_to_host(host) return smart_open_s3.open(parsed_uri.bucket_id, parsed_uri.key_id, mode, **kwargs) @@ -615,3 +615,8 @@ def _encoding_wrapper(fileobj, mode, encoding=None, errors=DEFAULT_ERRORS): else: decoder = codecs.getwriter(encoding) return decoder(fileobj, errors=errors) + +def _add_scheme_to_host(host): + if host.startswith('http://') or host.startswith('https://'): + return host + return 'http://' + host diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index 461319a7..cc567d2e 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -1163,6 +1163,31 @@ def test_write_text_gzip(self): actual = fin.read() self.assertEqual(text, actual) +class HostNameTest(unittest.TestCase): + + def test_host_name_with_http(self): + host = 'http://a.com/b' + expected = 'http://a.com/b' + res = smart_open_lib._add_scheme_to_host(host) + self.assertEqual(expected, res) + + def test_host_name_without_http(self): + host = 'a.com/b' + expected = 'http://a.com/b' + res = smart_open_lib._add_scheme_to_host(host) + self.assertEqual(expected, res) + + def test_host_name_with_https(self): + host = 'https://a.com/b' + expected = 'https://a.com/b' + res = smart_open_lib._add_scheme_to_host(host) + self.assertEqual(expected, res) + + def test_host_name_without_http_prefix(self): + host = 'httpa.com/b' + expected = 'http://httpa.com/b' + res = smart_open_lib._add_scheme_to_host(host) + self.assertEqual(expected, res) if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)