diff --git a/CHANGELOG.md b/CHANGELOG.md index 771128f2..8ffb0e83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Unreleased +- make HTTP/S seeking less strict (PR [#646](https://github.com/RaRe-Technologies/smart_open/pull/646), [@mpenkov](https://github.com/mpenkov)) + # 5.2.0, 18 August 2021 - Work around changes to `urllib.parse.urlsplit` (PR [#633](https://github.com/RaRe-Technologies/smart_open/pull/633), [@judahrand](https://github.com/judahrand)) diff --git a/smart_open/http.py b/smart_open/http.py index a9b5214b..e4439bd1 100644 --- a/smart_open/http.py +++ b/smart_open/http.py @@ -241,13 +241,14 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, logger.debug('self.response: %r, raw: %r', self.response, self.response.raw) - self._seekable = True - self.content_length = int(self.response.headers.get("Content-Length", -1)) - if self.content_length < 0: - self._seekable = False - if self.response.headers.get("Accept-Ranges", "none").lower() != "bytes": - self._seekable = False + # + # We assume the HTTP stream is seekable unless the server explicitly + # tells us it isn't. It's better to err on the side of "seekable" + # because we don't want to prevent users from seeking a stream that + # does not appear to be seekable but really is. + # + self._seekable = self.response.headers.get("Accept-Ranges", "").lower() != "none" self._read_iter = self.response.iter_content(self.buffer_size) self._read_buffer = bytebuffer.ByteBuffer(buffer_size) @@ -270,7 +271,7 @@ def seek(self, offset, whence=0): raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES) if not self.seekable(): - raise OSError + raise OSError('stream is not seekable') if whence == constants.WHENCE_START: new_pos = offset @@ -279,7 +280,10 @@ def seek(self, offset, whence=0): elif whence == constants.WHENCE_END: new_pos = self.content_length + offset - new_pos = smart_open.utils.clamp(new_pos, 0, self.content_length) + if self.content_length == -1: + new_pos = smart_open.utils.clamp(new_pos, maxval=None) + else: + new_pos = smart_open.utils.clamp(new_pos, maxval=self.content_length) if self._current_pos == new_pos: return self._current_pos diff --git a/smart_open/tests/test_http.py b/smart_open/tests/test_http.py index d29624b2..f4f8338a 100644 --- a/smart_open/tests/test_http.py +++ b/smart_open/tests/test_http.py @@ -5,9 +5,11 @@ # This code is distributed under the terms and conditions # from the MIT License (MIT). # +import functools import os import unittest +import pytest import responses import smart_open.http @@ -24,19 +26,19 @@ } -def request_callback(request): +def request_callback(request, headers=HEADERS, data=BYTES): try: range_string = request.headers['range'] except KeyError: - return (200, HEADERS, BYTES) + return (200, headers, data) start, end = range_string.replace('bytes=', '').split('-', 1) start = int(start) if end: end = int(end) else: - end = len(BYTES) - return (200, HEADERS, BYTES[start:end]) + end = len(data) + return (200, headers, data[start:end]) @unittest.skipIf(os.environ.get('TRAVIS'), 'This test does not work on TravisCI for some reason') @@ -158,3 +160,28 @@ def test_timeout_attribute(self): reader = smart_open.open(URL, "rb", transport_params={'timeout': timeout}) assert hasattr(reader, 'timeout') assert reader.timeout == timeout + + +@responses.activate +def test_seek_implicitly_enabled(numbytes=10): + """Can we seek even if the server hasn't explicitly allowed it?""" + callback = functools.partial(request_callback, headers={}) + responses.add_callback(responses.GET, HTTPS_URL, callback=callback) + with smart_open.open(HTTPS_URL, 'rb') as fin: + assert fin.seekable() + first = fin.read(size=numbytes) + fin.seek(-numbytes, whence=smart_open.constants.WHENCE_CURRENT) + second = fin.read(size=numbytes) + assert first == second + + +@responses.activate +def test_seek_implicitly_disabled(): + """Does seeking fail when the server has explicitly disabled it?""" + callback = functools.partial(request_callback, headers={'Accept-Ranges': 'none'}) + responses.add_callback(responses.GET, HTTPS_URL, callback=callback) + with smart_open.open(HTTPS_URL, 'rb') as fin: + assert not fin.seekable() + fin.read() + with pytest.raises(OSError): + fin.seek(0) diff --git a/smart_open/tests/test_utils.py b/smart_open/tests/test_utils.py index 17eaf6b6..c6be9a2d 100644 --- a/smart_open/tests/test_utils.py +++ b/smart_open/tests/test_utils.py @@ -5,7 +5,6 @@ # This code is distributed under the terms and conditions # from the MIT License (MIT). # -import unittest import urllib.parse import pytest @@ -13,15 +12,31 @@ import smart_open.utils -class ClampTest(unittest.TestCase): - def test_low(self): - self.assertEqual(smart_open.utils.clamp(5, 0, 10), 5) +@pytest.mark.parametrize( + 'value,minval,maxval,expected', + [ + (5, 0, 10, 5), + (11, 0, 10, 10), + (-1, 0, 10, 0), + (10, 0, None, 10), + (-10, 0, None, 0), + ] +) +def test_clamp(value, minval, maxval, expected): + assert smart_open.utils.clamp(value, minval=minval, maxval=maxval) == expected - def test_high(self): - self.assertEqual(smart_open.utils.clamp(11, 0, 10), 10) - def test_out_of_range(self): - self.assertEqual(smart_open.utils.clamp(-1, 0, 10), 0) +@pytest.mark.parametrize( + 'value,params,expected', + [ + (10, {}, 10), + (-10, {}, 0), + (-10, {'minval': -5}, -5), + (10, {'maxval': 5}, 5), + ] +) +def test_clamp_defaults(value, params, expected): + assert smart_open.utils.clamp(value, **params) == expected def test_check_kwargs(): diff --git a/smart_open/utils.py b/smart_open/utils.py index 00ebd9e0..4fc6aa84 100644 --- a/smart_open/utils.py +++ b/smart_open/utils.py @@ -74,7 +74,7 @@ def check_kwargs(kallable, kwargs): return supported_kwargs -def clamp(value, minval, maxval): +def clamp(value, minval=0, maxval=None): """Clamp a numeric value to a specific range. Parameters @@ -94,7 +94,10 @@ def clamp(value, minval, maxval): The clamped value. It will be in the range ``[minval, maxval]``. """ - return max(min(value, maxval), minval) + if maxval is not None: + value = min(value, maxval) + value = max(value, minval) + return value def make_range_string(start=None, stop=None): @@ -119,7 +122,9 @@ def make_range_string(start=None, stop=None): # if start is None and stop is None: raise ValueError("make_range_string requires either a stop or start value") - return 'bytes=%s-%s' % ('' if start is None else start, '' if stop is None else stop) + start_str = '' if start is None else str(start) + stop_str = '' if stop is None else str(stop) + return 'bytes=%s-%s' % (start_str, stop_str) def parse_content_range(content_range):