From 13cdf845d3d81c04e376b5f39fc9fb5074d08e76 Mon Sep 17 00:00:00 2001 From: ravenclaw900 <50060110+ravenclaw900@users.noreply.github.com> Date: Wed, 22 Jun 2022 14:43:36 -0500 Subject: [PATCH] refactor(backend): correctly handle errors on socket sending Also allow canceling of creating a large zip file --- src/backend/src/page_handlers.rs | 150 +++++++++++++++-------------- src/backend/src/shared.rs | 4 +- src/backend/src/socket_handlers.rs | 139 +++++++++++++++++--------- 3 files changed, 173 insertions(+), 120 deletions(-) diff --git a/src/backend/src/page_handlers.rs b/src/backend/src/page_handlers.rs index d2cbcdd2..2d2106f9 100644 --- a/src/backend/src/page_handlers.rs +++ b/src/backend/src/page_handlers.rs @@ -49,14 +49,13 @@ pub async fn main_handler(socket_send: &mut SocketSend, data_recv: &mut RecvChan loop { tokio::select! { biased; - Some(data) = data_recv.recv() => if data.is_none() { - break; + data = data_recv.recv() => match data { + Some(Some(_)) => {}, + _ => break, }, - _ = async { - let _send = socket_send - .send(Message::text(SerJson::serialize_json(&handle_error!(main_handler_getter(&mut cpu_collector, &mut net_collector, &mut prev_data).await, shared::SysData::default())))) - .await; - } => {} + Err(_) = socket_send + .send(Message::text(SerJson::serialize_json(&handle_error!(main_handler_getter(&mut cpu_collector, &mut net_collector, &mut prev_data).await, shared::SysData::default())))) + => break, } } } @@ -88,28 +87,27 @@ pub async fn process_handler(socket_send: &mut SocketSend, data_recv: &mut RecvC loop { tokio::select! { biased; - Some(data) = data_recv.recv() => match data { - Some(data) => handle_error!(process_handler_helper(&data)), - None => break, + data = data_recv.recv() => match data { + Some(Some(data)) => handle_error!(process_handler_helper(&data)), + _ => break, }, - _ = async { - let _send = socket_send + Err(_) = async { + let send = socket_send .send(Message::text(SerJson::serialize_json( &shared::ProcessList { processes: handle_error!(systemdata::processes().await, Vec::new()), }, - ))) - .await; + ))).await; sleep(Duration::from_secs(1)).await; - } => {}, + send + } => break, } } } pub async fn software_handler_helper( data: &shared::Request, - socket_send: &mut SocketSend, -) -> anyhow::Result<()> { +) -> anyhow::Result { // We don't just want to run dietpi-software without args anyhow::ensure!(!data.args.is_empty(), "Empty dietpi-software args"); @@ -130,22 +128,16 @@ pub async fn software_handler_helper( .replace('', ""); let software = systemdata::dpsoftware().await?; - let _send = socket_send - .send(Message::text(SerJson::serialize_json( - &shared::DPSoftwareList { - uninstalled: software.0, - installed: software.1, - response: out, - }, - ))) - .await; - - Ok(()) + Ok(shared::DPSoftwareList { + uninstalled: software.0, + installed: software.1, + response: out, + }) } pub async fn software_handler(socket_send: &mut SocketSend, data_recv: &mut RecvChannel) { let software = handle_error!(systemdata::dpsoftware().await, (Vec::new(), Vec::new())); - let _send = socket_send + if socket_send .send(Message::text(SerJson::serialize_json( &shared::DPSoftwareList { uninstalled: software.0, @@ -153,19 +145,37 @@ pub async fn software_handler(socket_send: &mut SocketSend, data_recv: &mut Recv response: String::new(), }, ))) - .await; + .await + .is_err() + { + return; + } while let Some(Some(data)) = data_recv.recv().await { - handle_error!(software_handler_helper(&data, socket_send).await); + let out = handle_error!( + software_handler_helper(&data).await, + shared::DPSoftwareList::default() + ); + if socket_send + .send(Message::text(SerJson::serialize_json(&out))) + .await + .is_err() + { + break; + } } } pub async fn management_handler(socket_send: &mut SocketSend, data_recv: &mut RecvChannel) { - let _send = socket_send + if socket_send .send(Message::text(SerJson::serialize_json(&handle_error!( systemdata::host().await, shared::HostData::default() )))) - .await; + .await + .is_err() + { + return; + } while let Some(Some(data)) = data_recv.recv().await { // Don't care about the Ok value, so remove it to make the type checker happy handle_error!(Command::new(&data.cmd) @@ -176,64 +186,55 @@ pub async fn management_handler(socket_send: &mut SocketSend, data_recv: &mut Re } pub async fn service_handler(socket_send: &mut SocketSend, data_recv: &mut RecvChannel) { - let _send = socket_send + if socket_send .send(Message::text(SerJson::serialize_json( &shared::ServiceList { services: handle_error!(systemdata::services().await, Vec::new()), }, ))) - .await; + .await + .is_err() + { + return; + } while let Some(Some(data)) = data_recv.recv().await { handle_error!(Command::new("systemctl") .args([&data.cmd, data.args[0].as_str()]) .spawn() .map(|_| ()) // Don't care about the Ok value, so remove it to make the type checker happy .with_context(|| format!("Couldn't {} service {}", &data.cmd, &data.args[0]))); - let _send = socket_send + if socket_send .send(Message::text(SerJson::serialize_json( &shared::ServiceList { services: handle_error!(systemdata::services().await, Vec::new()), }, ))) - .await; + .await + .is_err() + { + break; + } } } -async fn browser_refresh( - socket_send: &mut SplitSink, - path: &std::path::Path, -) -> anyhow::Result<()> { +async fn browser_refresh(path: &std::path::Path) -> anyhow::Result { let dir_path = path .parent() .with_context(|| format!("Couldn't get parent of path {}", path.display()))?; - let _send = socket_send - .send(Message::text(SerJson::serialize_json( - &shared::BrowserList { - contents: systemdata::browser_dir(std::path::Path::new(dir_path)).await?, - }, - ))) - .await; - Ok(()) + Ok(shared::BrowserList { + contents: systemdata::browser_dir(std::path::Path::new(dir_path)).await?, + }) } -async fn browser_handler_helper( - data: &shared::Request, - socket_send: &mut SplitSink, -) -> anyhow::Result<()> { +async fn browser_handler_helper(data: shared::Request) -> anyhow::Result { use tokio::fs; match data.cmd.as_str() { "cd" => { - let _send = socket_send - .send(Message::text(SerJson::serialize_json( - &shared::BrowserList { - contents: systemdata::browser_dir(std::path::Path::new(&data.args[0])) - .await?, - }, - ))) - .await; - return Ok(()); + return Ok(shared::BrowserList { + contents: systemdata::browser_dir(std::path::Path::new(&data.args[0])).await?, + }); } "copy" => { let mut num = 2; @@ -279,14 +280,12 @@ async fn browser_handler_helper( _ => {} } - browser_refresh(socket_send, std::path::Path::new(&data.args[0])).await?; - - Ok(()) + browser_refresh(std::path::Path::new(&data.args[0])).await } pub async fn browser_handler(socket_send: &mut SocketSend, data_recv: &mut RecvChannel) { // Get initial listing of $HOME - let _send = socket_send + if socket_send .send(Message::text(SerJson::serialize_json( &shared::BrowserList { contents: handle_error!( @@ -298,18 +297,25 @@ pub async fn browser_handler(socket_send: &mut SocketSend, data_recv: &mut RecvC ), }, ))) - .await; + .await + .is_err() + { + return; + } 'outer: while let Some(Some(mut data)) = data_recv.recv().await { loop { tokio::select! { - res = browser_handler_helper(&data, socket_send) => { - handle_error!(res); + res = browser_handler_helper(data) => { + let list = handle_error!(res, shared::BrowserList::default()); + if socket_send.send(Message::text(SerJson::serialize_json(&list))).await.is_err() { + break 'outer; + } break; }, - Some(recv) = data_recv.recv() => match recv { - Some(data_tmp) => data = data_tmp, - None => break 'outer, + recv = data_recv.recv() => match recv { + Some(Some(data_tmp)) => data = data_tmp, + _ => break 'outer, }, } } diff --git a/src/backend/src/shared.rs b/src/backend/src/shared.rs index 7a798a19..4c675a9a 100644 --- a/src/backend/src/shared.rs +++ b/src/backend/src/shared.rs @@ -76,7 +76,7 @@ pub struct DPSoftwareData { pub docs: String, } -#[derive(SerJson)] +#[derive(SerJson, Default)] pub struct DPSoftwareList { pub installed: Vec, pub uninstalled: Vec, @@ -130,7 +130,7 @@ pub struct BrowserData { pub size: u64, } -#[derive(SerJson)] +#[derive(SerJson, Default)] pub struct BrowserList { pub contents: Vec, } diff --git a/src/backend/src/socket_handlers.rs b/src/backend/src/socket_handlers.rs index 0a58caca..a7c663ec 100644 --- a/src/backend/src/socket_handlers.rs +++ b/src/backend/src/socket_handlers.rs @@ -84,11 +84,15 @@ pub async fn socket_handler(socket: warp::ws::WebSocket) { } }); // Send global message (shown on all pages) - let _send = socket_send + if socket_send .send(Message::text(SerJson::serialize_json( &systemdata::global().await, ))) - .await; + .await + .is_err() + { + return; + } while let Some(Some(message)) = data_recv.recv().await { match message.page.as_str() { "/" => page_handlers::main_handler(&mut socket_send, &mut data_recv).await, @@ -109,11 +113,15 @@ pub async fn socket_handler(socket: warp::ws::WebSocket) { } "/login" => { // Internal poll, see other thread - let _send = socket_send + if socket_send .send(Message::text(SerJson::serialize_json( &shared::TokenError { error: true }, ))) - .await; + .await + .is_err() + { + break; + } } _ => { log::debug!("Got page {}, not handling", message.page); @@ -298,19 +306,23 @@ async fn create_zip_file(req: &shared::FileRequest) -> anyhow::Result> { .with_context(|| format!("Couldn't read file {}", src_path.display())) } +// Not the most elegant solution, but it works +enum FileHandlerHelperReturns { + String(String), + SizeBuf(shared::FileSize, Vec), + StreamUpload(usize, tokio::fs::File), +} + async fn file_handler_helper( req: &shared::FileRequest, - socket: &mut warp::ws::WebSocket, -) -> anyhow::Result<()> { +) -> anyhow::Result> { match req.cmd.as_str() { "open" => { - let _send = socket - .send(Message::text( - tokio::fs::read_to_string(&req.path) - .await - .with_context(|| format!("Couldn't read file {}", &req.path))?, - )) - .await; + return Ok(Some(FileHandlerHelperReturns::String( + tokio::fs::read_to_string(&req.path) + .await + .with_context(|| format!("Couldn't read file {}", &req.path))?, + ))) } // Technically works for both files and directories "dl" => { @@ -322,59 +334,94 @@ async fn file_handler_helper( clippy::cast_possible_truncation )] let size = (buf.len() as f64 / f64::from(1000 * 1000)).ceil() as usize; - let _send = socket - .send(Message::text(SerJson::serialize_json(&shared::FileSize { - size, - }))) - .await; - for i in buf.chunks(1000 * 1000) { - let _feed = socket.feed(Message::binary(i)).await; - } - let _send = socket.flush().await; + return Ok(Some(FileHandlerHelperReturns::SizeBuf( + shared::FileSize { size }, + buf, + ))); } "up" => { - let mut file = tokio::fs::File::create(&req.path) + let file = tokio::fs::File::create(&req.path) .await .with_context(|| format!("Couldn't create file at {}", &req.path))?; - while let Some(Ok(msg)) = socket - .take(req.arg.parse::().context("Invalid max size")?) - .next() - .await - { - file.write_all(msg.as_bytes()).await.with_context(|| { - format!("Couldn't write to file at {}, stopping upload", &req.path) - })?; - } + return Ok(Some(FileHandlerHelperReturns::StreamUpload( + req.arg.parse::().context("Invalid max size")?, + file, + ))); } "save" => tokio::fs::write(&req.path, &req.arg) .await .with_context(|| format!("Couldn't save file {}", &req.path))?, _ => {} } - Ok(()) + Ok(None) } -pub async fn file_handler(mut socket: warp::ws::WebSocket) { +fn get_file_req(data: &warp::ws::Message) -> anyhow::Result { + let data_str = data + .to_str() + .map_err(|_| anyhow::anyhow!("Couldn't convert received data {:?} to text", data))?; + let req = DeJson::deserialize_json(data_str) + .with_context(|| format!("Couldn't parse JSON from {}", data_str))?; + Ok(req) +} + +pub async fn file_handler(socket: warp::ws::WebSocket) { + let (mut socket_send, mut socket_recv) = socket.split(); let mut req: shared::FileRequest; - while let Some(Ok(data)) = socket.next().await { + 'outer: while let Some(Ok(data)) = socket_recv.next().await { if data.is_close() { break; } - let data_str = handle_error!( - data.to_str() - .map_err(|_| anyhow::anyhow!("Couldn't convert received data {:?} to text", data)), - continue - ); - req = handle_error!( - DeJson::deserialize_json(data_str) - .with_context(|| format!("Couldn't parse JSON from {}", data_str)), - continue - ); + req = handle_error!(get_file_req(&data), continue); + if CONFIG.pass && !validate_token(&req.token) { continue; } - handle_error!(file_handler_helper(&req, &mut socket).await, continue); + + loop { + tokio::select! { + result = file_handler_helper(&req) => { + match handle_error!(result, continue) { + Some(FileHandlerHelperReturns::String(file)) => { + if socket_send.send(Message::text(file)).await.is_err() { + break 'outer; + } + } + Some(FileHandlerHelperReturns::SizeBuf(size, buf)) => { + if socket_send + .send(Message::text(SerJson::serialize_json(&size))) + .await + .is_err() + { + break 'outer; + } + for i in buf.chunks(1000 * 1000) { + if socket_send.feed(Message::binary(i)).await.is_err() { + break 'outer; + } + } + if socket_send.flush().await.is_err() { + break 'outer; + } + } + Some(FileHandlerHelperReturns::StreamUpload(size, mut file)) => { + while let Some(Ok(msg)) = (&mut socket_recv).take(size).next().await { + handle_error!(file.write_all(msg.as_bytes()).await.with_context(|| { + format!("Couldn't write to file {}, stopping upload", &req.path) + }), continue 'outer); + } + } + None => {} + } + break; + }, + recv = socket_recv.next() => match recv { + Some(Ok(req_tmp)) => req = handle_error!(get_file_req(&req_tmp), continue 'outer), + _ => break 'outer, + }, + } + } } }