Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry the streaming portion of downloads, as well as the setup. (Cherry-pick of #17298) #17302

Merged
merged 1 commit into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 66 additions & 8 deletions src/python/pants/engine/fs_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright 2015 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from __future__ import annotations

import hashlib
import os
import pkgutil
Expand Down Expand Up @@ -857,6 +860,40 @@ def send_headers(self):
self.end_headers()


def stub_erroring_handler(error_count_value: int) -> type[BaseHTTPRequestHandler]:
"""Return a handler that errors once mid-download before succeeding for the next GET.

This function returns an anonymous class so that each call can create a new instance with its
own error counter.
"""
error_num = 1

class StubErroringHandler(BaseHTTPRequestHandler):
error_count = error_count_value
response_text = b"Hello, client!"

def do_HEAD(self):
self.send_headers()

def do_GET(self):
self.send_headers()
nonlocal error_num
if error_num <= self.error_count:
msg = f"Returning error {error_num}"
error_num += 1
raise Exception(msg)
self.wfile.write(self.response_text)

def send_headers(self):
code = 200 if self.path == "/file.txt" else 404
self.send_response(code)
self.send_header("Content-Type", "text/utf-8")
self.send_header("Content-Length", f"{len(self.response_text)}")
self.end_headers()

return StubErroringHandler


DOWNLOADS_FILE_DIGEST = FileDigest(
"8fcbc50cda241aee7238c71e87c27804e7abc60675974eaf6567aa16366bc105", 14
)
Expand Down Expand Up @@ -886,6 +923,24 @@ def test_download_missing_file(downloads_rule_runner: RuleRunner) -> None:
assert "404" in str(exc.value)


def test_download_body_error_retry(downloads_rule_runner: RuleRunner) -> None:
with http_server(stub_erroring_handler(1)) as port:
snapshot = downloads_rule_runner.request(
Snapshot, [DownloadFile(f"http://localhost:{port}/file.txt", DOWNLOADS_FILE_DIGEST)]
)
assert snapshot.files == ("file.txt",)
assert snapshot.digest == DOWNLOADS_EXPECTED_DIRECTORY_DIGEST


def test_download_body_error_retry_eventually_fails(downloads_rule_runner: RuleRunner) -> None:
# Returns one more error than the retry will allow.
with http_server(stub_erroring_handler(5)) as port:
with pytest.raises(Exception):
_ = downloads_rule_runner.request(
Snapshot, [DownloadFile(f"http://localhost:{port}/file.txt", DOWNLOADS_FILE_DIGEST)]
)


def test_download_wrong_digest(downloads_rule_runner: RuleRunner) -> None:
file_digest = FileDigest(
DOWNLOADS_FILE_DIGEST.fingerprint, DOWNLOADS_FILE_DIGEST.serialized_bytes_length + 1
Expand Down Expand Up @@ -1010,17 +1065,20 @@ def test_write_digest_workspace(rule_runner: RuleRunner) -> None:
assert path2.read_text() == "goodbye"


def test_workspace_in_goal_rule() -> None:
class WorkspaceGoalSubsystem(GoalSubsystem):
name = "workspace-goal"
@dataclass(frozen=True)
class DigestRequest:
create_digest: CreateDigest


class WorkspaceGoal(Goal):
subsystem_cls = WorkspaceGoalSubsystem
class WorkspaceGoalSubsystem(GoalSubsystem):
name = "workspace-goal"

@dataclass(frozen=True)
class DigestRequest:
create_digest: CreateDigest

class WorkspaceGoal(Goal):
subsystem_cls = WorkspaceGoalSubsystem


def test_workspace_in_goal_rule() -> None:
@rule
def digest_request_singleton() -> DigestRequest:
fc = FileContent(path="a.txt", content=b"hello")
Expand Down
186 changes: 107 additions & 79 deletions src/rust/engine/src/downloads.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
// Copyright 2021 Pants project contributors (see CONTRIBUTORS.md).
// Licensed under the Apache License, Version 2.0 (see LICENSE).

use std::io::Write;
use std::io::{self, Write};
use std::pin::Pin;
use std::sync::Arc;

use async_trait::async_trait;
use bytes::{BufMut, Bytes};
use futures::stream::StreamExt;
use hashing::Digest;
use humansize::{file_size_opts, FileSize};
use reqwest::Error;
use tokio_retry::strategy::{jitter, ExponentialBackoff};
Expand All @@ -17,6 +18,19 @@ use url::Url;
use crate::context::Core;
use workunit_store::{in_workunit, Level};

enum StreamingError {
Retryable(String),
Permanent(String),
}

impl From<StreamingError> for String {
fn from(err: StreamingError) -> Self {
match err {
StreamingError::Retryable(s) | StreamingError::Permanent(s) => s,
}
}
}

#[async_trait]
trait StreamingDownload: Send {
async fn next(&mut self) -> Option<Result<Bytes, String>>;
Expand All @@ -27,41 +41,36 @@ struct NetDownload {
}

impl NetDownload {
async fn start(core: &Arc<Core>, url: Url, file_name: String) -> Result<NetDownload, String> {
let try_download = || async {
core
async fn start(
core: &Arc<Core>,
url: Url,
file_name: String,
) -> Result<NetDownload, StreamingError> {
let response = core
.http_client
.get(url.clone())
.send()
.await
.map_err(|err| (format!("Error downloading file: {}", err), true))
.map_err(|err| StreamingError::Retryable(format!("Error downloading file: {err}")))
.and_then(|res|
// Handle common HTTP errors.
if res.status().is_server_error() {
Err((format!(
Err(StreamingError::Retryable(format!(
"Server error ({}) downloading file {} from {}",
res.status().as_str(),
file_name,
url,
), true))
)))
} else if res.status().is_client_error() {
Err((format!(
Err(StreamingError::Permanent(format!(
"Client error ({}) downloading file {} from {}",
res.status().as_str(),
file_name,
url,
), false))
)))
} else {
Ok(res)
})
};

// TODO: Allow the retry strategy to be configurable?
// For now we retry after 10ms, 100ms, 1s, and 10s.
let retry_strategy = ExponentialBackoff::from_millis(10).map(jitter).take(4);
let response = RetryIf::spawn(retry_strategy, try_download, |err: &(String, bool)| err.1)
.await
.map_err(|(err, _)| err)?;
})?;

let byte_stream = Pin::new(Box::new(response.bytes_stream()));
Ok(NetDownload {
Expand All @@ -86,12 +95,18 @@ struct FileDownload {
}

impl FileDownload {
async fn start(path: &str, file_name: String) -> Result<FileDownload, String> {
async fn start(path: &str, file_name: String) -> Result<FileDownload, StreamingError> {
let file = tokio::fs::File::open(path).await.map_err(|e| {
format!(
let msg = format!(
"Error ({}) opening file at {} for download to {}",
e, path, file_name
)
);
// Fail quickly for non-existent files.
if e.kind() == io::ErrorKind::NotFound {
StreamingError::Permanent(msg)
} else {
StreamingError::Retryable(msg)
}
})?;
let stream = tokio_util::io::ReaderStream::new(file);
Ok(FileDownload { stream })
Expand All @@ -109,25 +124,72 @@ impl StreamingDownload for FileDownload {
}
}

async fn start_download(
async fn attempt_download(
core: &Arc<Core>,
url: Url,
url: &Url,
file_name: String,
) -> Result<Box<dyn StreamingDownload>, String> {
if url.scheme() == "file" {
if let Some(host) = url.host_str() {
return Err(format!(
"The file Url `{}` has a host component. Instead, use `file:$path`, \
which in this case might be either `file:{}{}` or `file:{}`.",
url,
host,
url.path(),
url.path(),
));
expected_digest: Digest,
) -> Result<(Digest, Bytes), StreamingError> {
let mut response_stream: Box<dyn StreamingDownload> = {
if url.scheme() == "file" {
if let Some(host) = url.host_str() {
return Err(StreamingError::Permanent(format!(
"The file Url `{}` has a host component. Instead, use `file:$path`, \
which in this case might be either `file:{}{}` or `file:{}`.",
url,
host,
url.path(),
url.path(),
)));
}
Box::new(FileDownload::start(url.path(), file_name).await?)
} else {
Box::new(NetDownload::start(core, url.clone(), file_name).await?)
}
return Ok(Box::new(FileDownload::start(url.path(), file_name).await?));
};

struct SizeLimiter<W: std::io::Write> {
writer: W,
written: usize,
size_limit: usize,
}
Ok(Box::new(NetDownload::start(core, url, file_name).await?))

impl<W: std::io::Write> Write for SizeLimiter<W> {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
let new_size = self.written + buf.len();
if new_size > self.size_limit {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Downloaded file was larger than expected digest",
))
} else {
self.written = new_size;
self.writer.write_all(buf)?;
Ok(buf.len())
}
}

fn flush(&mut self) -> Result<(), std::io::Error> {
self.writer.flush()
}
}

let mut hasher = hashing::WriterHasher::new(SizeLimiter {
writer: bytes::BytesMut::with_capacity(expected_digest.size_bytes).writer(),
written: 0,
size_limit: expected_digest.size_bytes,
});

while let Some(next_chunk) = response_stream.next().await {
let chunk = next_chunk.map_err(|err| {
StreamingError::Retryable(format!("Error reading URL fetch response: {err}"))
})?;
hasher.write_all(&chunk).map_err(|err| {
StreamingError::Retryable(format!("Error hashing/capturing URL fetch response: {err}"))
})?;
}
let (digest, bytewriter) = hasher.finish();
Ok((digest, bytewriter.writer.into_inner().freeze()))
}

pub async fn download(
Expand All @@ -148,49 +210,15 @@ pub async fn download(
.unwrap()
)),
|_workunit| async move {
let mut response_stream = start_download(&core2, url, file_name).await?;
struct SizeLimiter<W: std::io::Write> {
writer: W,
written: usize,
size_limit: usize,
}

impl<W: std::io::Write> Write for SizeLimiter<W> {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
let new_size = self.written + buf.len();
if new_size > self.size_limit {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Downloaded file was larger than expected digest",
))
} else {
self.written = new_size;
self.writer.write_all(buf)?;
Ok(buf.len())
}
}

fn flush(&mut self) -> Result<(), std::io::Error> {
self.writer.flush()
}
}

let mut hasher = hashing::WriterHasher::new(SizeLimiter {
writer: bytes::BytesMut::with_capacity(expected_digest.size_bytes).writer(),
written: 0,
size_limit: expected_digest.size_bytes,
});

while let Some(next_chunk) = response_stream.next().await {
let chunk =
next_chunk.map_err(|err| format!("Error reading URL fetch response: {}", err))?;
hasher
.write_all(&chunk)
.map_err(|err| format!("Error hashing/capturing URL fetch response: {}", err))?;
}
let (digest, bytewriter) = hasher.finish();
let res: Result<_, String> = Ok((digest, bytewriter.writer.into_inner().freeze()));
res
// TODO: Allow the retry strategy to be configurable?
// For now we retry after 10ms, 100ms, 1s, and 10s.
let retry_strategy = ExponentialBackoff::from_millis(10).map(jitter).take(4);
RetryIf::spawn(
retry_strategy,
|| attempt_download(&core2, &url, file_name.clone(), expected_digest),
|err: &StreamingError| matches!(err, StreamingError::Retryable(_)),
)
.await
}
)
.await?;
Expand Down