diff --git a/dvc/config.py b/dvc/config.py index d893855f71..6a59fff1c4 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -101,6 +101,13 @@ class RelPath(str): "shared": All(Lower, Choices("group")), Optional("slow_link_warning", default=True): Bool, } +HTTP_COMMON = { + "auth": All(Lower, Choices("basic", "digest", "custom")), + "custom_auth_header": str, + "user": str, + "password": str, + "ask_password": Bool, +} SCHEMA = { "core": { "remote": Lower, @@ -169,8 +176,8 @@ class RelPath(str): "gdrive_user_credentials_file": str, **REMOTE_COMMON, }, - "http": REMOTE_COMMON, - "https": REMOTE_COMMON, + "http": {**HTTP_COMMON, **REMOTE_COMMON}, + "https": {**HTTP_COMMON, **REMOTE_COMMON}, "remote": {str: object}, # Any of the above options are valid } ) diff --git a/dvc/remote/http.py b/dvc/remote/http.py index f3b3fb5a55..d0f35029bb 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -1,8 +1,10 @@ import logging +import os.path import threading -from funcy import cached_property, wrap_prop +from funcy import cached_property, memoize, wrap_prop, wrap_with +import dvc.prompt as prompt from dvc.config import ConfigError from dvc.exceptions import DvcException, HTTPError from dvc.progress import Tqdm @@ -12,6 +14,15 @@ logger = logging.getLogger(__name__) +@wrap_with(threading.Lock()) +@memoize +def ask_password(host, user): + return prompt.password( + "Enter a password for " + "host '{host}' user '{user}'".format(host=host, user=user) + ) + + class RemoteHTTP(RemoteBASE): scheme = Schemes.HTTP SESSION_RETRIES = 5 @@ -24,7 +35,13 @@ def __init__(self, repo, config): super().__init__(repo, config) url = config.get("url") - self.path_info = self.path_cls(url) if url else None + if url: + self.path_info = self.path_cls(url) + user = config.get("user", None) + if user: + self.path_info.user = user + else: + self.path_info = None if not self.no_traverse: raise ConfigError( @@ -32,6 +49,12 @@ def __init__(self, repo, config): "files. Use: `dvc remote modify no_traverse true`" ) + self.auth = config.get("auth", None) + self.custom_auth_header = config.get("custom_auth_header", None) + self.password = config.get("password", None) + self.ask_password = config.get("ask_password", False) + self.headers = {} + def _download(self, from_info, to_file, name=None, no_progress_bar=False): response = self._request("GET", from_info.url, stream=True) if response.status_code != 200: @@ -48,6 +71,28 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): fd.write(chunk) pbar.update(len(chunk)) + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): + with Tqdm( + total=None if no_progress_bar else os.path.getsize(from_file), + leave=False, + bytes=True, + desc=to_info.url if name is None else name, + disable=no_progress_bar, + ) as pbar: + + def chunks(): + with open(from_file, "rb") as fd: + while True: + chunk = fd.read(self.CHUNK_SIZE) + if not chunk: + break + pbar.update(len(chunk)) + yield chunk + + response = self._request("POST", to_info.url, data=chunks()) + if response.status_code not in (200, 201): + raise HTTPError(response.status_code, response.reason) + def exists(self, path_info): return bool(self._request("HEAD", path_info.url)) @@ -74,6 +119,24 @@ def get_file_checksum(self, path_info): return etag + def auth_method(self, path_info=None): + from requests.auth import HTTPBasicAuth, HTTPDigestAuth + + if path_info is None: + path_info = self.path_info + + if self.auth: + if self.ask_password and self.password is None: + host, user = path_info.host, path_info.user + self.password = ask_password(host, user) + if self.auth == "basic": + return HTTPBasicAuth(path_info.user, self.password) + if self.auth == "digest": + return HTTPDigestAuth(path_info.user, self.password) + if self.auth == "custom" and self.custom_auth_header: + self.headers.update({self.custom_auth_header: self.password}) + return None + @wrap_prop(threading.Lock()) @cached_property def _session(self): @@ -100,7 +163,13 @@ def _request(self, method, url, **kwargs): kwargs.setdefault("timeout", self.REQUEST_TIMEOUT) try: - res = self._session.request(method, url, **kwargs) + res = self._session.request( + method, + url, + auth=self.auth_method(), + headers=self.headers, + **kwargs, + ) redirect_no_location = ( kwargs["allow_redirects"] diff --git a/tests/conftest.py b/tests/conftest.py index 6e01739aac..a6a4c9ba8e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import pytest from dvc.remote.ssh.connection import SSHConnection +from tests.utils.httpd import PushRequestHandler, StaticFileServer from .dir_helpers import * # noqa @@ -57,3 +58,9 @@ def _close_pools(): yield close_pools() + + +@pytest.fixture +def http_server(tmp_dir): + with StaticFileServer(handler_class=PushRequestHandler) as httpd: + yield httpd diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 39df54bbbf..b28c2e438b 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -30,6 +30,7 @@ GCP, GDrive, HDFS, + HTTP, Local, S3, SSHMocked, @@ -290,6 +291,20 @@ def _get_cloud_class(self): return RemoteHDFS +@pytest.mark.usefixtures("http_server") +class TestRemoteHTTP(HTTP, TestDataCloudBase): + @pytest.fixture(autouse=True) + def setup_method_fixture(self, request, http_server): + self.http_server = http_server + self.method_name = request.function.__name__ + + def get_url(self): + return super().get_url(self.http_server.server_port) + + def _get_cloud_class(self): + return RemoteHTTP + + class TestDataCloudCLIBase(TestDvc): def main(self, args): ret = main(args) diff --git a/tests/remotes.py b/tests/remotes.py index 2102ad9632..6964fc8a77 100644 --- a/tests/remotes.py +++ b/tests/remotes.py @@ -275,3 +275,11 @@ def get_url(): return "hdfs://{}@127.0.0.1{}".format( getpass.getuser(), Local.get_storagepath() ) + + +class HTTP: + should_test = always_test + + @staticmethod + def get_url(port): + return "http://127.0.0.1:{}".format(port) diff --git a/tests/unit/remote/test_http.py b/tests/unit/remote/test_http.py index 20c8854fce..65ad2d2e85 100644 --- a/tests/unit/remote/test_http.py +++ b/tests/unit/remote/test_http.py @@ -27,3 +27,74 @@ def test_download_fails_on_error_code(dvc): with pytest.raises(HTTPError): remote._download(URLInfo(url) / "missing.txt", "missing.txt") + + +def test_public_auth_method(dvc): + config = { + "url": "http://example.com/", + "path_info": "file.html", + "user": "", + "password": "", + } + + remote = RemoteHTTP(dvc, config) + + assert remote.auth_method() is None + + +def test_basic_auth_method(dvc): + from requests.auth import HTTPBasicAuth + + user = "username" + password = "password" + auth = HTTPBasicAuth(user, password) + config = { + "url": "http://example.com/", + "path_info": "file.html", + "auth": "basic", + "user": user, + "password": password, + } + + remote = RemoteHTTP(dvc, config) + + assert remote.auth_method() == auth + assert isinstance(remote.auth_method(), HTTPBasicAuth) + + +def test_digest_auth_method(dvc): + from requests.auth import HTTPDigestAuth + + user = "username" + password = "password" + auth = HTTPDigestAuth(user, password) + config = { + "url": "http://example.com/", + "path_info": "file.html", + "auth": "digest", + "user": user, + "password": password, + } + + remote = RemoteHTTP(dvc, config) + + assert remote.auth_method() == auth + assert isinstance(remote.auth_method(), HTTPDigestAuth) + + +def test_custom_auth_method(dvc): + header = "Custom-Header" + password = "password" + config = { + "url": "http://example.com/", + "path_info": "file.html", + "auth": "custom", + "custom_auth_header": header, + "password": password, + } + + remote = RemoteHTTP(dvc, config) + + assert remote.auth_method() is None + assert header in remote.headers + assert remote.headers[header] == password diff --git a/tests/utils/httpd.py b/tests/utils/httpd.py index 2a3091eb37..378bb75b3f 100644 --- a/tests/utils/httpd.py +++ b/tests/utils/httpd.py @@ -1,7 +1,8 @@ import hashlib import os import threading -from http.server import HTTPServer +from http import HTTPStatus +from http.server import HTTPServer, SimpleHTTPRequestHandler from RangeHTTPServer import RangeRequestHandler @@ -35,6 +36,35 @@ class ContentMD5Handler(TestRequestHandler): checksum_header = "Content-MD5" +class PushRequestHandler(SimpleHTTPRequestHandler): + def _chunks(self): + while True: + data = self.rfile.readline(65537) + chunk_size = int(data[:-2], 16) + if chunk_size == 0: + return + data = self.rfile.read(chunk_size) + yield data + self.rfile.read(2) + + def do_POST(self): + chunked = self.headers.get("Transfer-Encoding", "") == "chunked" + path = self.translate_path(self.path) + try: + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "wb") as fd: + if chunked: + for chunk in self._chunks(): + fd.write(chunk) + else: + size = int(self.headers.get("Content-Length", 0)) + fd.write(self.rfile.read(size)) + except OSError as e: + self.send_error(HTTPStatus.INTERNAL_SERVER_ERROR, str(e)) + self.send_response(HTTPStatus.OK) + self.end_headers() + + class StaticFileServer: _lock = threading.Lock()