From aa003ebf81e14e27ed70130a7a94ad8f2f219462 Mon Sep 17 00:00:00 2001 From: Cheick Keita Date: Fri, 5 Feb 2021 16:57:36 -0800 Subject: [PATCH] WIP local queue implementation --- .../tasks/generic/input_poller/callback.rs | 2 +- .../onefuzz-agent/src/tasks/merge/generic.rs | 4 +- .../src/tasks/merge/libfuzzer_merge.rs | 2 +- src/agent/onefuzz-supervisor/src/work.rs | 20 ++++---- src/agent/storage-queue/src/azure_queue.rs | 2 +- src/agent/storage-queue/src/lib.rs | 42 ++++++++--------- src/agent/storage-queue/src/local_queue.rs | 46 +++++++++++++------ 7 files changed, 66 insertions(+), 52 deletions(-) diff --git a/src/agent/onefuzz-agent/src/tasks/generic/input_poller/callback.rs b/src/agent/onefuzz-agent/src/tasks/generic/input_poller/callback.rs index eb461bb8bf9..8518e1a23fa 100644 --- a/src/agent/onefuzz-agent/src/tasks/generic/input_poller/callback.rs +++ b/src/agent/onefuzz-agent/src/tasks/generic/input_poller/callback.rs @@ -98,7 +98,7 @@ where P: Processor + Send, { fn parse(&mut self, msg: &Message) -> Result { - let url= msg.parse(|data| { + let url = msg.parse(|data| { let data = std::str::from_utf8(data)?; Ok(Url::parse(data)?) })?; diff --git a/src/agent/onefuzz-agent/src/tasks/merge/generic.rs b/src/agent/onefuzz-agent/src/tasks/merge/generic.rs index 5a334d6aaed..418ddd0b241 100644 --- a/src/agent/onefuzz-agent/src/tasks/merge/generic.rs +++ b/src/agent/onefuzz-agent/src/tasks/merge/generic.rs @@ -56,7 +56,7 @@ pub async fn spawn(config: Arc) -> Result<()> { verbose!("tmp dir reset"); utils::reset_tmp_dir(&tmp_dir).await?; config.unique_inputs.sync_pull().await?; - let mut queue = QueueClient::new(config.input_queue.clone()); + let queue = QueueClient::new(config.input_queue.clone()); if let Some(msg) = queue.pop().await? { let input_url = msg.parse(utils::parse_url_data); let input_url = match input_url { @@ -89,7 +89,7 @@ pub async fn spawn(config: Arc) -> Result<()> { } else { warn!("no new candidate inputs found, sleeping"); delay_with_jitter(EMPTY_QUEUE_DELAY).await; - } + }; } } diff --git a/src/agent/onefuzz-agent/src/tasks/merge/libfuzzer_merge.rs b/src/agent/onefuzz-agent/src/tasks/merge/libfuzzer_merge.rs index 75b23d7d622..50f0e49918e 100644 --- a/src/agent/onefuzz-agent/src/tasks/merge/libfuzzer_merge.rs +++ b/src/agent/onefuzz-agent/src/tasks/merge/libfuzzer_merge.rs @@ -93,7 +93,7 @@ async fn process_message(config: Arc, mut input_queue: QueueClient) -> R utils::reset_tmp_dir(tmp_dir).await?; if let Some(msg) = input_queue.pop().await? { - let input_url= msg.parse(|data| { + let input_url = msg.parse(|data| { let data = std::str::from_utf8(data)?; Ok(Url::parse(data)?) }); diff --git a/src/agent/onefuzz-supervisor/src/work.rs b/src/agent/onefuzz-supervisor/src/work.rs index 7c45af049d7..e95f259ba21 100644 --- a/src/agent/onefuzz-supervisor/src/work.rs +++ b/src/agent/onefuzz-supervisor/src/work.rs @@ -170,16 +170,16 @@ impl WorkQueue { } pub async fn poll(&mut self) -> Result> { - let mut msg = self.queue.pop().await; - - // If we had an auth err, renew our registration and retry once, in case - // it was just due to a stale SAS URL. - if let Err(err) = &msg { - if is_auth_error(err) { - self.renew().await?; - msg = self.queue.pop().await; - } - } + let msg = self.queue.pop().await; + + // // If we had an auth err, renew our registration and retry once, in case + // // it was just due to a stale SAS URL. + // if let Err(err) = &msg { + // if is_auth_error(err) { + // self.renew().await?; + // msg = self.queue.pop().await; + // } + // } // Now we've had a chance to ensure our SAS URL is fresh. For any other // error, including another auth error, bail. diff --git a/src/agent/storage-queue/src/azure_queue.rs b/src/agent/storage-queue/src/azure_queue.rs index b77ae694708..8413eb0207f 100644 --- a/src/agent/storage-queue/src/azure_queue.rs +++ b/src/agent/storage-queue/src/azure_queue.rs @@ -131,7 +131,7 @@ impl AzureQueueClient { Ok(()) } - pub async fn pop(&mut self) -> Result> { + pub async fn pop(&self) -> Result> { let response = self .http .get(self.messages_url.clone()) diff --git a/src/agent/storage-queue/src/lib.rs b/src/agent/storage-queue/src/lib.rs index 348d6ccab63..01afccaa489 100644 --- a/src/agent/storage-queue/src/lib.rs +++ b/src/agent/storage-queue/src/lib.rs @@ -2,8 +2,7 @@ // Licensed under the MIT License. use anyhow::{bail, Result}; -use reqwest::{Client, Url}; -use reqwest_retry::SendRetry; +use reqwest::Url; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::time::Duration; use uuid::Uuid; @@ -13,9 +12,11 @@ pub mod azure_queue; pub mod local_queue; use azure_queue::{AzureQueueClient, AzureQueueMessage}; +use local_queue::{LocalQueueClient, LocalQueueMessage}; pub enum QueueClient { AzureQueue(AzureQueueClient), + LocalQueue(LocalQueueClient), } impl QueueClient { @@ -26,15 +27,20 @@ impl QueueClient { pub async fn enqueue(&self, data: impl Serialize) -> Result<()> { match self { QueueClient::AzureQueue(queue_client) => queue_client.enqueue(data).await, + QueueClient::LocalQueue(queue_client) => queue_client.enqueue(data).await, } } - pub async fn pop(&mut self) -> Result> { + pub async fn pop(&self) -> Result> { match self { QueueClient::AzureQueue(queue_client) => { let message = queue_client.pop().await?; Ok(message.map(Message::QueueMessage)) } + QueueClient::LocalQueue(queue_client) => { + let message = queue_client.pop().await?; + Ok(message.map(Message::LocalQueueMessage)) + } } } } @@ -42,7 +48,7 @@ impl QueueClient { // #[derive(Clone)] pub enum Message { QueueMessage(AzureQueueMessage), - // LocalQueueMessage(LocalQueueMessage<'a, T>) + LocalQueueMessage(LocalQueueMessage), } #[derive(Clone, Debug, Eq, PartialEq)] @@ -61,39 +67,31 @@ impl Message { let data = message.get()?; Ok(data) } - // Message::LocalQueueMessage(message) => { - // Ok(serde_json::from_slice(&*message.data)?) - // } + Message::LocalQueueMessage(message) => Ok(serde_json::from_slice(&*message.data)?), } } pub async fn claim(self) -> Result { match self { - Message::QueueMessage(message) => Ok(message.claim().await?), // Message::LocalQueueMessage(message) => { - // let value = message.data.into_inner(); - // Ok(serde_json::from_slice(value)) - - // } + Message::QueueMessage(message) => Ok(message.claim().await?), + Message::LocalQueueMessage(message) => Ok(serde_json::from_slice(&message.data)?), } } pub async fn delete(self) -> Result<()> { match self { - Message::QueueMessage(message) => Ok(message.delete().await?), // Message::LocalQueueMessage(message) => { - // let value = message.data.into_inner(); - // Ok(serde_json::from_slice(value)) - - // } + Message::QueueMessage(message) => Ok(message.delete().await?), + Message::LocalQueueMessage(_) => { + // message.data.commit(); + Ok(()) + } } } pub fn parse(&self, parser: impl FnOnce(&[u8]) -> Result) -> Result { match self { - Message::QueueMessage(message) => message.parse(parser), // Message::LocalQueueMessage(message) => { - // let value = message.data.into_inner(); - // Ok(serde_json::from_slice(value)) - - // } + Message::QueueMessage(message) => message.parse(parser), + Message::LocalQueueMessage(message) => parser(&*message.data), } } diff --git a/src/agent/storage-queue/src/local_queue.rs b/src/agent/storage-queue/src/local_queue.rs index 13f2887ef78..20f1df824d8 100644 --- a/src/agent/storage-queue/src/local_queue.rs +++ b/src/agent/storage-queue/src/local_queue.rs @@ -4,8 +4,9 @@ use anyhow::{bail, Result}; use reqwest::Url; use serde::{Deserialize, Serialize}; -use std::path::Path; +use std::{borrow::Borrow, path::Path}; use std::{io::Read, time::Duration}; +use tokio::sync::Mutex; use tokio::time::delay_for; use uuid::Uuid; @@ -13,33 +14,48 @@ use yaque::{self, channel, queue::RecvGuard, Sender}; pub const EMPTY_QUEUE_DELAY: Duration = Duration::from_secs(10); -pub struct LocalQueueMessage<'a> { - pub data: RecvGuard<'a, Vec>, +pub struct LocalQueueMessage { + pub data: Vec, } + pub struct LocalQueueClient { - sender: yaque::Sender, - receiver: yaque::Receiver, + sender: Mutex, + receiver: Mutex, } impl LocalQueueClient { pub fn new(queue_url: impl AsRef) -> Result { let (sender, receiver) = yaque::channel(queue_url)?; - Ok(LocalQueueClient { sender, receiver }) + Ok(LocalQueueClient { + sender: Mutex::new(sender), + receiver: Mutex::new(receiver), + }) } - pub async fn enqueue(&mut self, data: impl Serialize) -> Result<()> { + pub async fn enqueue(&self, data: impl Serialize) -> Result<()> { let body = serde_xml_rs::to_string(&data).unwrap(); - self.sender.send(body.as_bytes())?; - Ok(()) + match self.sender.try_lock() { + Ok(ref mut sender) => { + sender.send(body.as_bytes())?; + Ok(()) + } + Err(_) => bail!("cant enqueue"), + } } - pub async fn pop(&mut self) -> Result>>> { - let data = self - .receiver - .recv_timeout(tokio::time::delay_for(Duration::from_secs(1))) - .await?; + pub async fn pop(&self) -> Result> { + match self.receiver.try_lock() { + Ok(ref mut receiver) => { + let data = receiver + .recv_timeout(tokio::time::delay_for(Duration::from_secs(1))) + .await?; - Ok(data) + Ok(data.map(|data| LocalQueueMessage { + data: data.into_inner(), + })) + } + Err(_) => bail!("cant enqueue"), + } } // pub async fn pop(&mut self) -> Result> {