diff --git a/dvc/exceptions.py b/dvc/exceptions.py index 1a167117ad..62c0a34669 100644 --- a/dvc/exceptions.py +++ b/dvc/exceptions.py @@ -348,3 +348,8 @@ def __init__(self, path, external_repo_path, external_repo_url): relpath(path, external_repo_path), external_repo_url ) ) + + +class HTTPError(DvcException): + def __init__(self, code, reason): + super(HTTPError, self).__init__("'{} {}'".format(code, reason)) diff --git a/dvc/remote/http.py b/dvc/remote/http.py index c224166254..98a1e50f55 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -7,7 +7,7 @@ from dvc.config import Config from dvc.config import ConfigError -from dvc.exceptions import DvcException +from dvc.exceptions import DvcException, HTTPError from dvc.progress import Tqdm from dvc.remote.base import RemoteBASE from dvc.scheme import Schemes @@ -37,30 +37,26 @@ def __init__(self, repo, config): ) def _download(self, from_info, to_file, name=None, no_progress_bar=False): - request = self._request("GET", from_info.url, stream=True) + response = self._request("GET", from_info.url, stream=True) + if response.status_code != 200: + raise HTTPError(response.status_code, response.reason) with Tqdm( - total=None if no_progress_bar else self._content_length(from_info), + total=None if no_progress_bar else self._content_length(response), leave=False, bytes=True, desc=from_info.url if name is None else name, disable=no_progress_bar, ) as pbar: with open(to_file, "wb") as fd: - for chunk in request.iter_content(chunk_size=self.CHUNK_SIZE): + for chunk in response.iter_content(chunk_size=self.CHUNK_SIZE): fd.write(chunk) - fd.flush() pbar.update(len(chunk)) def exists(self, path_info): return bool(self._request("HEAD", path_info.url)) - def _content_length(self, url_or_request): - headers = getattr( - url_or_request, - "headers", - self._request("HEAD", url_or_request).headers, - ) - res = headers.get("Content-Length") + def _content_length(self, response): + res = response.headers.get("Content-Length") return int(res) if res else None def get_file_checksum(self, path_info): diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index ae0727c144..4e46fdde90 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -43,7 +43,7 @@ from tests.func.test_data_cloud import get_ssh_url from tests.func.test_data_cloud import TEST_AWS_REPO_BUCKET from tests.func.test_data_cloud import TEST_GCP_REPO_BUCKET -from tests.utils.httpd import StaticFileServer +from tests.utils.httpd import StaticFileServer, ContentMD5Handler class TestRepro(TestDvc): @@ -1176,7 +1176,7 @@ def test(self): self.dvc.remove("imported_file.dvc") - with StaticFileServer(handler="Content-MD5") as httpd: + with StaticFileServer(handler_class=ContentMD5Handler) as httpd: import_url = urljoin(self.get_remote(httpd.server_port), self.FOO) import_output = "imported_file" import_stage = self.dvc.imp_url(import_url, import_output) diff --git a/tests/unit/remote/test_http.py b/tests/unit/remote/test_http.py index 2b1ecdb07a..dbd54bcd15 100644 --- a/tests/unit/remote/test_http.py +++ b/tests/unit/remote/test_http.py @@ -1,7 +1,10 @@ import pytest from dvc.config import ConfigError +from dvc.exceptions import HTTPError +from dvc.path_info import URLInfo from dvc.remote.http import RemoteHTTP +from tests.utils.httpd import StaticFileServer def test_no_traverse_compatibility(dvc_repo): @@ -13,3 +16,14 @@ def test_no_traverse_compatibility(dvc_repo): with pytest.raises(ConfigError): RemoteHTTP(dvc_repo, config) + + +def test_download_fails_on_error_code(dvc_repo): + with StaticFileServer() as httpd: + url = "http://localhost:{}/".format(httpd.server_port) + config = {"url": url} + + remote = RemoteHTTP(dvc_repo, config) + + with pytest.raises(HTTPError): + remote._download(URLInfo(url) / "missing.txt", "missing.txt") diff --git a/tests/utils/httpd.py b/tests/utils/httpd.py index ff902e71a1..92027ae751 100644 --- a/tests/utils/httpd.py +++ b/tests/utils/httpd.py @@ -45,9 +45,8 @@ class ContentMD5Handler(TestRequestHandler): class StaticFileServer: _lock = threading.Lock() - def __init__(self, handler="etag"): + def __init__(self, handler_class=ETagHandler): self._lock.acquire() - handler_class = ETagHandler if handler == "etag" else ContentMD5Handler self._httpd = HTTPServer(("localhost", 0), handler_class) self._thread = None