Skip to content

Commit

Permalink
Retry the streaming portion of downloads, as well as the setup. (#17298)
Browse files Browse the repository at this point in the history
As reported in #17294, if an HTTP stream is interrupted after it has opened, the retry that was added in #16798 won't kick in.

This change moves the retry up a level to wrap the entire download attempt, and adds a test of recovering from "post-header" errors.

Fixes #17294.
  • Loading branch information
stuhood authored Oct 20, 2022
1 parent 91e021d commit a0aa15b
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 88 deletions.
76 changes: 67 additions & 9 deletions src/python/pants/engine/fs_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright 2015 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from __future__ import annotations

import hashlib
import os
import pkgutil
Expand Down Expand Up @@ -857,6 +860,40 @@ def send_headers(self):
self.end_headers()


def stub_erroring_handler(error_count_value: int) -> type[BaseHTTPRequestHandler]:
"""Return a handler that errors once mid-download before succeeding for the next GET.
This function returns an anonymous class so that each call can create a new instance with its
own error counter.
"""
error_num = 1

class StubErroringHandler(BaseHTTPRequestHandler):
error_count = error_count_value
response_text = b"Hello, client!"

def do_HEAD(self):
self.send_headers()

def do_GET(self):
self.send_headers()
nonlocal error_num
if error_num <= self.error_count:
msg = f"Returning error {error_num}"
error_num += 1
raise Exception(msg)
self.wfile.write(self.response_text)

def send_headers(self):
code = 200 if self.path == "/file.txt" else 404
self.send_response(code)
self.send_header("Content-Type", "text/utf-8")
self.send_header("Content-Length", f"{len(self.response_text)}")
self.end_headers()

return StubErroringHandler


DOWNLOADS_FILE_DIGEST = FileDigest(
"8fcbc50cda241aee7238c71e87c27804e7abc60675974eaf6567aa16366bc105", 14
)
Expand Down Expand Up @@ -886,6 +923,24 @@ def test_download_missing_file(downloads_rule_runner: RuleRunner) -> None:
assert "404" in str(exc.value)


def test_download_body_error_retry(downloads_rule_runner: RuleRunner) -> None:
with http_server(stub_erroring_handler(1)) as port:
snapshot = downloads_rule_runner.request(
Snapshot, [DownloadFile(f"http://localhost:{port}/file.txt", DOWNLOADS_FILE_DIGEST)]
)
assert snapshot.files == ("file.txt",)
assert snapshot.digest == DOWNLOADS_EXPECTED_DIRECTORY_DIGEST


def test_download_body_error_retry_eventually_fails(downloads_rule_runner: RuleRunner) -> None:
# Returns one more error than the retry will allow.
with http_server(stub_erroring_handler(5)) as port:
with pytest.raises(Exception):
_ = downloads_rule_runner.request(
Snapshot, [DownloadFile(f"http://localhost:{port}/file.txt", DOWNLOADS_FILE_DIGEST)]
)


def test_download_wrong_digest(downloads_rule_runner: RuleRunner) -> None:
file_digest = FileDigest(
DOWNLOADS_FILE_DIGEST.fingerprint, DOWNLOADS_FILE_DIGEST.serialized_bytes_length + 1
Expand Down Expand Up @@ -1010,18 +1065,21 @@ def test_write_digest_workspace(rule_runner: RuleRunner) -> None:
assert path2.read_text() == "goodbye"


def test_workspace_in_goal_rule() -> None:
class WorkspaceGoalSubsystem(GoalSubsystem):
name = "workspace-goal"
@dataclass(frozen=True)
class DigestRequest:
create_digest: CreateDigest


class WorkspaceGoal(Goal):
subsystem_cls = WorkspaceGoalSubsystem
environment_behavior = Goal.EnvironmentBehavior.LOCAL_ONLY
class WorkspaceGoalSubsystem(GoalSubsystem):
name = "workspace-goal"

@dataclass(frozen=True)
class DigestRequest:
create_digest: CreateDigest

class WorkspaceGoal(Goal):
subsystem_cls = WorkspaceGoalSubsystem
environment_behavior = Goal.EnvironmentBehavior.LOCAL_ONLY


def test_workspace_in_goal_rule() -> None:
@rule
def digest_request_singleton() -> DigestRequest:
fc = FileContent(path="a.txt", content=b"hello")
Expand Down
186 changes: 107 additions & 79 deletions src/rust/engine/src/downloads.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
// Copyright 2021 Pants project contributors (see CONTRIBUTORS.md).
// Licensed under the Apache License, Version 2.0 (see LICENSE).

use std::io::Write;
use std::io::{self, Write};
use std::pin::Pin;
use std::sync::Arc;

use async_trait::async_trait;
use bytes::{BufMut, Bytes};
use futures::stream::StreamExt;
use hashing::Digest;
use humansize::{file_size_opts, FileSize};
use reqwest::Error;
use tokio_retry::strategy::{jitter, ExponentialBackoff};
Expand All @@ -17,6 +18,19 @@ use url::Url;
use crate::context::Core;
use workunit_store::{in_workunit, Level};

enum StreamingError {
Retryable(String),
Permanent(String),
}

impl From<StreamingError> for String {
fn from(err: StreamingError) -> Self {
match err {
StreamingError::Retryable(s) | StreamingError::Permanent(s) => s,
}
}
}

#[async_trait]
trait StreamingDownload: Send {
async fn next(&mut self) -> Option<Result<Bytes, String>>;
Expand All @@ -27,41 +41,36 @@ struct NetDownload {
}

impl NetDownload {
async fn start(core: &Arc<Core>, url: Url, file_name: String) -> Result<NetDownload, String> {
let try_download = || async {
core
async fn start(
core: &Arc<Core>,
url: Url,
file_name: String,
) -> Result<NetDownload, StreamingError> {
let response = core
.http_client
.get(url.clone())
.send()
.await
.map_err(|err| (format!("Error downloading file: {}", err), true))
.map_err(|err| StreamingError::Retryable(format!("Error downloading file: {err}")))
.and_then(|res|
// Handle common HTTP errors.
if res.status().is_server_error() {
Err((format!(
Err(StreamingError::Retryable(format!(
"Server error ({}) downloading file {} from {}",
res.status().as_str(),
file_name,
url,
), true))
)))
} else if res.status().is_client_error() {
Err((format!(
Err(StreamingError::Permanent(format!(
"Client error ({}) downloading file {} from {}",
res.status().as_str(),
file_name,
url,
), false))
)))
} else {
Ok(res)
})
};

// TODO: Allow the retry strategy to be configurable?
// For now we retry after 10ms, 100ms, 1s, and 10s.
let retry_strategy = ExponentialBackoff::from_millis(10).map(jitter).take(4);
let response = RetryIf::spawn(retry_strategy, try_download, |err: &(String, bool)| err.1)
.await
.map_err(|(err, _)| err)?;
})?;

let byte_stream = Pin::new(Box::new(response.bytes_stream()));
Ok(NetDownload {
Expand All @@ -86,12 +95,18 @@ struct FileDownload {
}

impl FileDownload {
async fn start(path: &str, file_name: String) -> Result<FileDownload, String> {
async fn start(path: &str, file_name: String) -> Result<FileDownload, StreamingError> {
let file = tokio::fs::File::open(path).await.map_err(|e| {
format!(
let msg = format!(
"Error ({}) opening file at {} for download to {}",
e, path, file_name
)
);
// Fail quickly for non-existent files.
if e.kind() == io::ErrorKind::NotFound {
StreamingError::Permanent(msg)
} else {
StreamingError::Retryable(msg)
}
})?;
let stream = tokio_util::io::ReaderStream::new(file);
Ok(FileDownload { stream })
Expand All @@ -109,25 +124,72 @@ impl StreamingDownload for FileDownload {
}
}

async fn start_download(
async fn attempt_download(
core: &Arc<Core>,
url: Url,
url: &Url,
file_name: String,
) -> Result<Box<dyn StreamingDownload>, String> {
if url.scheme() == "file" {
if let Some(host) = url.host_str() {
return Err(format!(
"The file Url `{}` has a host component. Instead, use `file:$path`, \
which in this case might be either `file:{}{}` or `file:{}`.",
url,
host,
url.path(),
url.path(),
));
expected_digest: Digest,
) -> Result<(Digest, Bytes), StreamingError> {
let mut response_stream: Box<dyn StreamingDownload> = {
if url.scheme() == "file" {
if let Some(host) = url.host_str() {
return Err(StreamingError::Permanent(format!(
"The file Url `{}` has a host component. Instead, use `file:$path`, \
which in this case might be either `file:{}{}` or `file:{}`.",
url,
host,
url.path(),
url.path(),
)));
}
Box::new(FileDownload::start(url.path(), file_name).await?)
} else {
Box::new(NetDownload::start(core, url.clone(), file_name).await?)
}
return Ok(Box::new(FileDownload::start(url.path(), file_name).await?));
};

struct SizeLimiter<W: std::io::Write> {
writer: W,
written: usize,
size_limit: usize,
}
Ok(Box::new(NetDownload::start(core, url, file_name).await?))

impl<W: std::io::Write> Write for SizeLimiter<W> {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
let new_size = self.written + buf.len();
if new_size > self.size_limit {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Downloaded file was larger than expected digest",
))
} else {
self.written = new_size;
self.writer.write_all(buf)?;
Ok(buf.len())
}
}

fn flush(&mut self) -> Result<(), std::io::Error> {
self.writer.flush()
}
}

let mut hasher = hashing::WriterHasher::new(SizeLimiter {
writer: bytes::BytesMut::with_capacity(expected_digest.size_bytes).writer(),
written: 0,
size_limit: expected_digest.size_bytes,
});

while let Some(next_chunk) = response_stream.next().await {
let chunk = next_chunk.map_err(|err| {
StreamingError::Retryable(format!("Error reading URL fetch response: {err}"))
})?;
hasher.write_all(&chunk).map_err(|err| {
StreamingError::Retryable(format!("Error hashing/capturing URL fetch response: {err}"))
})?;
}
let (digest, bytewriter) = hasher.finish();
Ok((digest, bytewriter.writer.into_inner().freeze()))
}

pub async fn download(
Expand All @@ -148,49 +210,15 @@ pub async fn download(
.unwrap()
)),
|_workunit| async move {
let mut response_stream = start_download(&core2, url, file_name).await?;
struct SizeLimiter<W: std::io::Write> {
writer: W,
written: usize,
size_limit: usize,
}

impl<W: std::io::Write> Write for SizeLimiter<W> {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
let new_size = self.written + buf.len();
if new_size > self.size_limit {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Downloaded file was larger than expected digest",
))
} else {
self.written = new_size;
self.writer.write_all(buf)?;
Ok(buf.len())
}
}

fn flush(&mut self) -> Result<(), std::io::Error> {
self.writer.flush()
}
}

let mut hasher = hashing::WriterHasher::new(SizeLimiter {
writer: bytes::BytesMut::with_capacity(expected_digest.size_bytes).writer(),
written: 0,
size_limit: expected_digest.size_bytes,
});

while let Some(next_chunk) = response_stream.next().await {
let chunk =
next_chunk.map_err(|err| format!("Error reading URL fetch response: {}", err))?;
hasher
.write_all(&chunk)
.map_err(|err| format!("Error hashing/capturing URL fetch response: {}", err))?;
}
let (digest, bytewriter) = hasher.finish();
let res: Result<_, String> = Ok((digest, bytewriter.writer.into_inner().freeze()));
res
// TODO: Allow the retry strategy to be configurable?
// For now we retry after 10ms, 100ms, 1s, and 10s.
let retry_strategy = ExponentialBackoff::from_millis(10).map(jitter).take(4);
RetryIf::spawn(
retry_strategy,
|| attempt_download(&core2, &url, file_name.clone(), expected_digest),
|err: &StreamingError| matches!(err, StreamingError::Retryable(_)),
)
.await
}
)
.await?;
Expand Down

0 comments on commit a0aa15b

Please sign in to comment.