Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions dvc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ class RelPath(str):
"shared": All(Lower, Choices("group")),
Optional("slow_link_warning", default=True): Bool,
}
HTTP_COMMON = {
"auth": All(Lower, Choices("basic", "digest", "custom")),
"custom_auth_header": str,
"user": str,
"password": str,
"ask_password": Bool,
}
SCHEMA = {
"core": {
"remote": Lower,
Expand Down Expand Up @@ -169,8 +176,8 @@ class RelPath(str):
"gdrive_user_credentials_file": str,
**REMOTE_COMMON,
},
"http": REMOTE_COMMON,
"https": REMOTE_COMMON,
"http": {**HTTP_COMMON, **REMOTE_COMMON},
"https": {**HTTP_COMMON, **REMOTE_COMMON},
"remote": {str: object}, # Any of the above options are valid
}
)
Expand Down
75 changes: 72 additions & 3 deletions dvc/remote/http.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import os.path
import threading

from funcy import cached_property, wrap_prop
from funcy import cached_property, memoize, wrap_prop, wrap_with

import dvc.prompt as prompt
from dvc.config import ConfigError
from dvc.exceptions import DvcException, HTTPError
from dvc.progress import Tqdm
Expand All @@ -12,6 +14,15 @@
logger = logging.getLogger(__name__)


@wrap_with(threading.Lock())
@memoize
def ask_password(host, user):
return prompt.password(
"Enter a password for "
"host '{host}' user '{user}'".format(host=host, user=user)
)


class RemoteHTTP(RemoteBASE):
scheme = Schemes.HTTP
SESSION_RETRIES = 5
Expand All @@ -24,14 +35,26 @@ def __init__(self, repo, config):
super().__init__(repo, config)

url = config.get("url")
self.path_info = self.path_cls(url) if url else None
if url:
self.path_info = self.path_cls(url)
user = config.get("user", None)
if user:
self.path_info.user = user
else:
self.path_info = None

if not self.no_traverse:
raise ConfigError(
"HTTP doesn't support traversing the remote to list existing "
"files. Use: `dvc remote modify <name> no_traverse true`"
)

self.auth = config.get("auth", None)
self.custom_auth_header = config.get("custom_auth_header", None)
self.password = config.get("password", None)
self.ask_password = config.get("ask_password", False)
self.headers = {}

def _download(self, from_info, to_file, name=None, no_progress_bar=False):
response = self._request("GET", from_info.url, stream=True)
if response.status_code != 200:
Expand All @@ -48,6 +71,28 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False):
fd.write(chunk)
pbar.update(len(chunk))

def _upload(self, from_file, to_info, name=None, no_progress_bar=False):
with Tqdm(
total=None if no_progress_bar else os.path.getsize(from_file),
leave=False,
bytes=True,
desc=to_info.url if name is None else name,
disable=no_progress_bar,
) as pbar:

def chunks():
with open(from_file, "rb") as fd:
while True:
chunk = fd.read(self.CHUNK_SIZE)
if not chunk:
break
pbar.update(len(chunk))
yield chunk

response = self._request("POST", to_info.url, data=chunks())
if response.status_code not in (200, 201):
raise HTTPError(response.status_code, response.reason)

def exists(self, path_info):
return bool(self._request("HEAD", path_info.url))

Expand All @@ -74,6 +119,24 @@ def get_file_checksum(self, path_info):

return etag

def auth_method(self, path_info=None):
from requests.auth import HTTPBasicAuth, HTTPDigestAuth

if path_info is None:
path_info = self.path_info

if self.auth:
if self.ask_password and self.password is None:
host, user = path_info.host, path_info.user
self.password = ask_password(host, user)
if self.auth == "basic":
return HTTPBasicAuth(path_info.user, self.password)
if self.auth == "digest":
return HTTPDigestAuth(path_info.user, self.password)
if self.auth == "custom" and self.custom_auth_header:
self.headers.update({self.custom_auth_header: self.password})
return None

@wrap_prop(threading.Lock())
@cached_property
def _session(self):
Expand All @@ -100,7 +163,13 @@ def _request(self, method, url, **kwargs):
kwargs.setdefault("timeout", self.REQUEST_TIMEOUT)

try:
res = self._session.request(method, url, **kwargs)
res = self._session.request(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmrowla, I'm unable to make it work with Digest auth (I'm using the script that you provided on the gist).

Copy link
Collaborator

@skshetry skshetry Feb 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, when I change this to:

Suggested change
res = self._session.request(
res = requests.Session().request(

it works. Something's wrong in our cached session perhaps?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I figured it out. When I try to push, we send a HEAD request, after which the script sets a cookie. And, then all POST requests fails. So, I first tried clearing cookies with self._session.cookies.clear() and it worked.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably, we need to set auth when on creating sessions?

Copy link
Contributor Author

@pmrowla pmrowla Feb 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this is a side effect of how my test app in the gist was configured?

By default, flask uses client-side cookies for session data, so flask-httpauth does as well. But for a web server running digest auth to be properly secured, it needs to be handled server-side: https://flask-httpauth.readthedocs.io/en/latest/#security-concerns-with-digest-authentication

If I modify the test app to actually use server-side sessions (https://gist.github.com/pmrowla/0615f162d1308cab4f429b6efafe276a) the existing remote code works without needing to clear any cached cookies

If I set session.auth in the http remote at the time we first create the session instead of setting it per request, I still see the same issue making requests against the original test app. So I'm not sure if it's just that requests.auth.HTTPDigestAuth won't work properly when talking to improperly configured flask apps?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I set session.auth in the http remote at the time we first create the session instead of setting it per request, I still see the same issue making requests against the original test app.

@pmrowla, I tried that too but, does not work. It's unclear what's the best thing to do here.

If it's probably only flask-httpauth, let's ignore for now then.

method,
url,
auth=self.auth_method(),
headers=self.headers,
**kwargs,
)

redirect_no_location = (
kwargs["allow_redirects"]
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from dvc.remote.ssh.connection import SSHConnection
from tests.utils.httpd import PushRequestHandler, StaticFileServer
from .dir_helpers import * # noqa


Expand Down Expand Up @@ -57,3 +58,9 @@ def _close_pools():

yield
close_pools()


@pytest.fixture
def http_server(tmp_dir):
with StaticFileServer(handler_class=PushRequestHandler) as httpd:
yield httpd
15 changes: 15 additions & 0 deletions tests/func/test_data_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
GCP,
GDrive,
HDFS,
HTTP,
Local,
S3,
SSHMocked,
Expand Down Expand Up @@ -290,6 +291,20 @@ def _get_cloud_class(self):
return RemoteHDFS


@pytest.mark.usefixtures("http_server")
class TestRemoteHTTP(HTTP, TestDataCloudBase):
@pytest.fixture(autouse=True)
def setup_method_fixture(self, request, http_server):
self.http_server = http_server
self.method_name = request.function.__name__

def get_url(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think DS has a valid point here, the original method is static, shouldn't we make all of them non-static now, that it is required by this particular change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what it's worth, I wrote it like this because it was done the same way in TestRemoteSSHMocked (which overrides static SSHMocked.get_url() from tests/remotes.py).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it was only a question, its not severe, because any potential problems should be detected on the build stage. For me, it can stay as is.

return super().get_url(self.http_server.server_port)

def _get_cloud_class(self):
return RemoteHTTP


class TestDataCloudCLIBase(TestDvc):
def main(self, args):
ret = main(args)
Expand Down
8 changes: 8 additions & 0 deletions tests/remotes.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,11 @@ def get_url():
return "hdfs://{}@127.0.0.1{}".format(
getpass.getuser(), Local.get_storagepath()
)


class HTTP:
should_test = always_test

@staticmethod
def get_url(port):
return "http://127.0.0.1:{}".format(port)
71 changes: 71 additions & 0 deletions tests/unit/remote/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,74 @@ def test_download_fails_on_error_code(dvc):

with pytest.raises(HTTPError):
remote._download(URLInfo(url) / "missing.txt", "missing.txt")


def test_public_auth_method(dvc):
config = {
"url": "http://example.com/",
"path_info": "file.html",
"user": "",
"password": "",
}

remote = RemoteHTTP(dvc, config)

assert remote.auth_method() is None


def test_basic_auth_method(dvc):
from requests.auth import HTTPBasicAuth

user = "username"
password = "password"
auth = HTTPBasicAuth(user, password)
config = {
"url": "http://example.com/",
"path_info": "file.html",
"auth": "basic",
"user": user,
"password": password,
}

remote = RemoteHTTP(dvc, config)

assert remote.auth_method() == auth
assert isinstance(remote.auth_method(), HTTPBasicAuth)


def test_digest_auth_method(dvc):
from requests.auth import HTTPDigestAuth

user = "username"
password = "password"
auth = HTTPDigestAuth(user, password)
config = {
"url": "http://example.com/",
"path_info": "file.html",
"auth": "digest",
"user": user,
"password": password,
}

remote = RemoteHTTP(dvc, config)

assert remote.auth_method() == auth
assert isinstance(remote.auth_method(), HTTPDigestAuth)


def test_custom_auth_method(dvc):
header = "Custom-Header"
password = "password"
config = {
"url": "http://example.com/",
"path_info": "file.html",
"auth": "custom",
"custom_auth_header": header,
"password": password,
}

remote = RemoteHTTP(dvc, config)

assert remote.auth_method() is None
assert header in remote.headers
assert remote.headers[header] == password
32 changes: 31 additions & 1 deletion tests/utils/httpd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import hashlib
import os
import threading
from http.server import HTTPServer
from http import HTTPStatus
from http.server import HTTPServer, SimpleHTTPRequestHandler
from RangeHTTPServer import RangeRequestHandler


Expand Down Expand Up @@ -35,6 +36,35 @@ class ContentMD5Handler(TestRequestHandler):
checksum_header = "Content-MD5"


class PushRequestHandler(SimpleHTTPRequestHandler):
def _chunks(self):
while True:
data = self.rfile.readline(65537)
chunk_size = int(data[:-2], 16)
if chunk_size == 0:
return
data = self.rfile.read(chunk_size)
yield data
self.rfile.read(2)

def do_POST(self):
chunked = self.headers.get("Transfer-Encoding", "") == "chunked"
path = self.translate_path(self.path)
try:
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as fd:
if chunked:
for chunk in self._chunks():
fd.write(chunk)
else:
size = int(self.headers.get("Content-Length", 0))
fd.write(self.rfile.read(size))
except OSError as e:
self.send_error(HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
self.send_response(HTTPStatus.OK)
self.end_headers()


class StaticFileServer:
_lock = threading.Lock()

Expand Down