Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

automatically retry supervisor requests #704

Merged
9 commits merged into from
Mar 23, 2021
55 changes: 29 additions & 26 deletions src/agent/onefuzz-supervisor/src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
use anyhow::Result;
use downcast_rs::Downcast;
use onefuzz::{auth::AccessToken, http::ResponseExt, process::Output};
use reqwest::{Client, Request, Response, StatusCode};
use reqwest::{Client, RequestBuilder, Response, StatusCode};
use reqwest_retry::{SendRetry, DEFAULT_RETRY_PERIOD, MAX_RETRY_ATTEMPTS};
use serde::Serialize;
use uuid::Uuid;

Expand Down Expand Up @@ -232,8 +233,14 @@ impl Coordinator {
&mut self,
request_type: RequestType<'a>,
) -> Result<Response> {
let request = self.build_request(request_type.clone())?;
let mut response = self.client.execute(request).await?;
bmc-msft marked this conversation as resolved.
Show resolved Hide resolved
let request = self.get_request_builder(request_type.clone());
let mut response = request
.send_retry(
vec![StatusCode::UNAUTHORIZED],
DEFAULT_RETRY_PERIOD,
MAX_RETRY_ATTEMPTS,
)
.await?;

if response.status() == StatusCode::UNAUTHORIZED {
debug!("access token expired, renewing");
Expand All @@ -244,8 +251,8 @@ impl Coordinator {
debug!("retrying request after refreshing access token");

// And try one more time.
let request = self.build_request(request_type)?;
response = self.client.execute(request).await?;
let request = self.get_request_builder(request_type);
response = request.send_retry_default().await?;
};

// We've retried if we got a `401 Unauthorized`. If it happens again, we
Expand All @@ -255,7 +262,7 @@ impl Coordinator {
Ok(response)
}

fn build_request(&self, request_type: RequestType<'_>) -> Result<Request> {
fn get_request_builder(&self, request_type: RequestType<'_>) -> RequestBuilder {
match request_type {
RequestType::PollCommands => self.poll_commands_request(),
RequestType::ClaimCommand(message_id) => self.claim_command_request(message_id),
Expand All @@ -264,52 +271,49 @@ impl Coordinator {
}
}

fn poll_commands_request(&self) -> Result<Request> {
fn poll_commands_request(&self) -> RequestBuilder {
let request = PollCommandsRequest {
machine_id: self.registration.machine_id,
};

let url = self.registration.dynamic_config.commands_url.clone();
let request = self
let request_builder = self
.client
.get(url)
.bearer_auth(self.token.secret().expose_ref())
.json(&request)
.build()?;
.json(&request);

Ok(request)
request_builder
}

fn claim_command_request(&self, message_id: String) -> Result<Request> {
fn claim_command_request(&self, message_id: String) -> RequestBuilder {
let request = ClaimNodeCommandRequest {
machine_id: self.registration.machine_id,
message_id,
};

let url = self.registration.dynamic_config.commands_url.clone();
let request = self
let request_builder = self
.client
.delete(url)
.bearer_auth(self.token.secret().expose_ref())
.json(&request)
.build()?;
.json(&request);

Ok(request)
request_builder
}

fn emit_event_request(&self, event: &NodeEventEnvelope) -> Result<Request> {
fn emit_event_request(&self, event: &NodeEventEnvelope) -> RequestBuilder {
let url = self.registration.dynamic_config.events_url.clone();
let request = self
let request_builder = self
.client
.post(url)
.bearer_auth(self.token.secret().expose_ref())
.json(event)
.build()?;
.json(event);

Ok(request)
request_builder
}

fn can_schedule_request(&self, work_set: &WorkSet) -> Result<Request> {
fn can_schedule_request(&self, work_set: &WorkSet) -> RequestBuilder {
// Temporary: assume one work unit per work set.
//
// In the future, we will probably want the same behavior, but we will
Expand All @@ -325,14 +329,13 @@ impl Coordinator {

let mut url = self.registration.config.onefuzz_url.clone();
url.set_path("/api/agents/can_schedule");
let request = self
let request_builder = self
.client
.post(url)
.bearer_auth(self.token.secret().expose_ref())
.json(&can_schedule)
.build()?;
.json(&can_schedule);

Ok(request)
request_builder
}
}

Expand Down
8 changes: 6 additions & 2 deletions src/agent/onefuzz/src/syncdir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use anyhow::{Context, Result};
use futures::stream::StreamExt;
use onefuzz_telemetry::{Event, EventData};
use reqwest::StatusCode;
use reqwest_retry::SendRetry;
use reqwest_retry::{SendRetry, DEFAULT_RETRY_PERIOD, MAX_RETRY_ATTEMPTS};
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, str, time::Duration};
use tokio::fs;
Expand Down Expand Up @@ -142,7 +142,11 @@ impl SyncedDir {
// Conditional PUT, only if-not-exists.
// https://docs.microsoft.com/en-us/rest/api/storageservices/specifying-conditional-headers-for-blob-service-operations
.header("If-None-Match", "*")
.send_retry_default()
.send_retry(
vec![StatusCode::CONFLICT],
DEFAULT_RETRY_PERIOD,
MAX_RETRY_ATTEMPTS,
)
.await
.context("Uploading blob")?;

Expand Down
20 changes: 16 additions & 4 deletions src/agent/onefuzz/src/uploader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use std::path::Path;
use anyhow::Result;
use futures::stream::TryStreamExt;
use reqwest as r;
use reqwest_retry::{send_retry_reqwest_default, SendRetry};
use reqwest_retry::{
send_retry_reqwest_default, SendRetry, DEFAULT_RETRY_PERIOD, MAX_RETRY_ATTEMPTS,
};
use serde::Serialize;
use tokio::{fs, io};
use tokio_util::codec;
Expand Down Expand Up @@ -41,9 +43,19 @@ impl BlobUploader {
};

// Check if the file already exists before uploading
let head = self.client.head(url.clone()).send_retry_default().await?;
if head.status() == reqwest::StatusCode::OK {
return Ok(head);
if let Ok(head) = self
.client
.head(url.clone())
.send_retry(
vec![reqwest::StatusCode::NOT_FOUND],
DEFAULT_RETRY_PERIOD,
MAX_RETRY_ATTEMPTS,
)
.await
{
if head.status() == reqwest::StatusCode::OK {
return Ok(head);
}
}

let content_length = format!("{}", file_len);
Expand Down
64 changes: 49 additions & 15 deletions src/agent/reqwest-retry/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,56 @@ use anyhow::{Context, Result};
use async_trait::async_trait;
use backoff::{self, future::retry_notify, ExponentialBackoff};
use onefuzz_telemetry::warn;
use reqwest::Response;
use reqwest::{Response, StatusCode};
use std::{
sync::atomic::{AtomicUsize, Ordering},
time::Duration,
};

const DEFAULT_RETRY_PERIOD: Duration = Duration::from_secs(5);
const MAX_RETRY_ATTEMPTS: usize = 5;
pub const DEFAULT_RETRY_PERIOD: Duration = Duration::from_secs(5);
pub const MAX_RETRY_ATTEMPTS: usize = 5;

pub async fn send_retry_reqwest_default<
F: Fn() -> Result<reqwest::RequestBuilder> + Send + Sync,
>(
build_request: F,
) -> Result<Response> {
send_retry_reqwest(build_request, DEFAULT_RETRY_PERIOD, MAX_RETRY_ATTEMPTS).await
send_retry_reqwest(build_request, [], DEFAULT_RETRY_PERIOD, MAX_RETRY_ATTEMPTS).await
}

pub async fn send_retry_reqwest<F: Fn() -> Result<reqwest::RequestBuilder> + Send + Sync>(
build_request: F,
fail_fast_status: impl AsRef<[StatusCode]>,
retry_period: Duration,
max_retry: usize,
) -> Result<Response> {
let counter = AtomicUsize::new(0);
let op = || async {
let attempt_count = counter.fetch_add(1, Ordering::SeqCst);
let request = build_request().map_err(backoff::Error::Permanent)?;
let request = build_request().map_err(|err| backoff::Error::Permanent(Err(err)))?;
let result = request
.send()
.await
.with_context(|| format!("request attempt {} failed", attempt_count + 1));

match result {
Ok(x) => Ok(x),
Err(x) => {
if attempt_count >= max_retry {
Err(backoff::Error::Permanent(x))
Err(backoff::Error::Permanent(Err(x)))
} else {
Err(backoff::Error::Transient(x))
Err(backoff::Error::Transient(Err(x)))
}
}
Ok(x) => {
if x.status().is_success() {
Ok(x)
} else {
let fail_fast = fail_fast_status.as_ref().contains(&x.status());
if attempt_count >= max_retry || fail_fast {
Err(backoff::Error::Permanent(Ok(x)))
} else {
Err(backoff::Error::Transient(Ok(x)))
}
}
}
}
Expand All @@ -54,32 +66,54 @@ pub async fn send_retry_reqwest<F: Fn() -> Result<reqwest::RequestBuilder> + Sen
..ExponentialBackoff::default()
},
op,
|err, dur| warn!("request attempt failed after {:?}: {:?}", dur, err),
|err: Result<Response, anyhow::Error>, dur| match err {
Ok(response) => {
if let Err(err) = response.error_for_status() {
warn!("request attempt failed after {:?}: {:?}", dur, err)
}
}
err => warn!("request attempt failed after {:?}: {:?}", dur, err),
},
)
.await?;
Ok(result)
.await;

match result {
Ok(response) | Err(Ok(response)) => Ok(response),
Err(Err(err)) => Err(err),
}
}

#[async_trait]
pub trait SendRetry {
async fn send_retry(self, retry_period: Duration, max_retry: usize) -> Result<Response>;
async fn send_retry(
self,
fail_fast_status: Vec<StatusCode>,
retry_period: Duration,
max_retry: usize,
) -> Result<Response>;
async fn send_retry_default(self) -> Result<Response>;
}

#[async_trait]
impl SendRetry for reqwest::RequestBuilder {
async fn send_retry_default(self) -> Result<Response> {
self.send_retry(DEFAULT_RETRY_PERIOD, MAX_RETRY_ATTEMPTS)
self.send_retry(vec![], DEFAULT_RETRY_PERIOD, MAX_RETRY_ATTEMPTS)
.await
}

async fn send_retry(self, retry_period: Duration, max_retry: usize) -> Result<Response> {
async fn send_retry(
self,
fail_fast_status: Vec<StatusCode>,
retry_period: Duration,
max_retry: usize,
) -> Result<Response> {
let result = send_retry_reqwest(
|| {
self.try_clone().ok_or_else(|| {
anyhow::Error::msg("This request cannot be retried because it cannot be cloned")
})
},
fail_fast_status,
retry_period,
max_retry,
)
Expand Down Expand Up @@ -109,7 +143,7 @@ mod test {
let invalid_url = "http://127.0.0.1:81/test.txt";
let resp = reqwest::Client::new()
.get(invalid_url)
.send_retry(Duration::from_millis(1), 3)
.send_retry(vec![], Duration::from_millis(1), 3)
.await;

if let Err(err) = &resp {
Expand Down