diff --git a/debug-service/src/task.rs b/debug-service/src/task.rs index d88f8592..757c31bb 100644 --- a/debug-service/src/task.rs +++ b/debug-service/src/task.rs @@ -7,11 +7,16 @@ use embedded_services::{ use crate::{debug_service_entry, defmt_ring_logger::DEFMT_BUFFER, frame_available, shared_buffer}; +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum Error { + Buffer(embedded_services::buffer::Error), +} + pub async fn debug_service(endpoint: comms::Endpoint) { debug_service_entry(endpoint).await; } -pub async fn defmt_to_host_task() { +pub async fn defmt_to_host_task() -> Result { embedded_services::info!("defmt to host task start"); use crate::debug_service::{host_endpoint_id, response_notify_signal}; use embedded_services::comms::{self, EndpointID, Internal}; @@ -34,7 +39,7 @@ pub async fn defmt_to_host_task() { // destination length to be robust if the staging buffer size changes. let copy_len = core::cmp::min(frame.len(), acpi_owned.len()); { - let mut access = acpi_owned.borrow_mut(); + let mut access = acpi_owned.borrow_mut().map_err(Error::Buffer)?; let buf: &mut [u8] = BorrowMut::borrow_mut(&mut access); buf[..copy_len].copy_from_slice(&frame[..copy_len]); @@ -69,7 +74,7 @@ pub async fn defmt_to_host_task() { status: 0, payload: StdHostPayload::DebugGetMsgsResponse { debug_buf: { - let access = shared_buffer().borrow(); + let access = shared_buffer().borrow().map_err(Error::Buffer)?; let slice: &[u8] = access.borrow(); slice.try_into().unwrap() }, @@ -81,14 +86,14 @@ pub async fn defmt_to_host_task() { // Clear the staged portion of the buffer { - let mut access = acpi_owned.borrow_mut(); + let mut access = acpi_owned.borrow_mut().map_err(Error::Buffer)?; let buf: &mut [u8] = BorrowMut::borrow_mut(&mut access); buf[..copy_len].fill(0); } } } -pub async fn no_avail_to_host_task() { +pub async fn no_avail_to_host_task() -> Result { embedded_services::define_static_buffer!(no_avail_acpi_buf, u8, [0u8; 12]); embedded_services::info!("no avail to host task start"); @@ -100,7 +105,7 @@ pub async fn no_avail_to_host_task() { let acpi_owned = no_avail_acpi_buf::get_mut().expect("defmt staging buffer already initialized elsewhere"); { - let mut access = acpi_owned.borrow_mut(); + let mut access = acpi_owned.borrow_mut().map_err(Error::Buffer)?; let buf: &mut [u8] = BorrowMut::borrow_mut(&mut access); // Use 0xDEADBEEF to signify no frame available buf[4..12].copy_from_slice(&0xDEADBEEFu64.to_be_bytes()); diff --git a/embedded-service/src/buffer.rs b/embedded-service/src/buffer.rs index c4a73481..19221ef4 100644 --- a/embedded-service/src/buffer.rs +++ b/embedded-service/src/buffer.rs @@ -18,11 +18,26 @@ use core::ops::Range; use crate::SyncCell; use core::sync::atomic::AtomicPtr; +/// Buffer error. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Error { + /// Buffer already borrowed immutably. + BorrowedImmutably, + /// Buffer already borrowed mutably. + BorrowedMutably, + /// Range is out-of-bounds. + InvalidRange, + /// Buffer is poisoned and should be considered no longer valid. + Poisoned, +} + #[derive(Copy, Clone, PartialEq, Eq)] enum Status { None, Mutable, Immutable(u32), + Poisoned, } /// Underlying buffer storage struct @@ -65,24 +80,33 @@ impl<'a, T> Buffer<'a, T> { self.len == 0 } - fn borrow(&self, mutable: bool) { + fn borrow(&self, mutable: bool) -> Result<(), Error> { let status = match (self.status.get(), mutable) { (Status::None, false) => Status::Immutable(1), (Status::None, true) => Status::Mutable, - (Status::Mutable, _) => panic!("Buffer already borrowed mutably"), + (Status::Mutable, _) => return Err(Error::BorrowedMutably), (Status::Immutable(count), false) => Status::Immutable(count + 1), - (Status::Immutable(_), true) => panic!("Buffer already borrowed immutably"), + (Status::Immutable(_), true) => return Err(Error::BorrowedImmutably), + (Status::Poisoned, _) => return Err(Error::Poisoned), }; self.status.set(status); + Ok(()) } + // In the case of invalid status, we can't return an error since this is within the drop handler, + // but don't want to panic either. + // Instead, mark this buffer `Poisoned` to signify it's now in a bad/unexpected state. fn drop_borrow(&self) { let status = match self.status.get() { - Status::None => panic!("Unborrowed buffer dropped"), + // Unborrowed buffer dropped + Status::None => Status::Poisoned, Status::Mutable => Status::None, - Status::Immutable(0) => panic!("Buffer borrow count underflow"), + // Buffer borrow count underflow + Status::Immutable(0) => Status::Poisoned, Status::Immutable(1) => Status::None, Status::Immutable(count) => Status::Immutable(count - 1), + // Buffer already poisoned + Status::Poisoned => Status::Poisoned, }; self.status.set(status); } @@ -94,18 +118,20 @@ pub struct OwnedRef<'a, T>(&'a Buffer<'a, T>); impl<'a, T> OwnedRef<'a, T> { /// Creates an immutable reference to the buffer pub fn reference(&self) -> SharedRef<'a, T> { - SharedRef::new(self.0, 0..self.0.len()) + SharedRef::new_full_len(self.0) } /// Borrows the buffer immutably - /// Panics if the buffer is already borrowed mutably - pub fn borrow(&self) -> Access<'a, T> { + /// + /// Returns an error if the buffer is already borrowed mutably + pub fn borrow(&self) -> Result, Error> { Access::new(self.0, 0..self.0.len()) } /// Borrows the buffer mutably - /// Panics if the buffer is already borrowed - pub fn borrow_mut(&self) -> AccessMut<'a, T> { + /// + /// Returns an error if the buffer is already borrowed + pub fn borrow_mut(&self) -> Result, Error> { AccessMut::new(self.0) } @@ -124,9 +150,9 @@ impl<'a, T> OwnedRef<'a, T> { pub struct AccessMut<'a, T>(&'a Buffer<'a, T>); impl<'a, T> AccessMut<'a, T> { - fn new(buffer: &'a Buffer<'a, T>) -> Self { - buffer.borrow(true); - Self(buffer) + fn new(buffer: &'a Buffer<'a, T>) -> Result { + buffer.borrow(true)?; + Ok(Self(buffer)) } } @@ -160,26 +186,44 @@ pub struct SharedRef<'a, T> { } impl<'a, T> SharedRef<'a, T> { - /// Creates a new immutable buffer refference - pub fn new(buffer: &'a Buffer<'a, T>, slice: Range) -> Self { - Self { buffer, slice } + // Creates a new immutable buffer reference with the same length as the original buffer + // Allows us to make an infallible version of `Self::new()` for `OwnedRef::reference()` + fn new_full_len(buffer: &'a Buffer<'a, T>) -> Self { + Self { + buffer, + slice: 0..buffer.len(), + } + } + + /// Creates a new immutable buffer reference + /// + /// Returns an error if the given slice is out-of-bounds + pub fn new(buffer: &'a Buffer<'a, T>, slice: Range) -> Result { + if slice.start >= buffer.len() || slice.end > buffer.len() { + Err(Error::InvalidRange) + } else { + Ok(Self { buffer, slice }) + } } /// Borrows the buffer immutably - /// Panics if the buffer is already borrowed mutably - pub fn borrow<'s>(&'s self) -> Access<'a, T> { + /// + /// Returns an error if the buffer is already borrowed mutably + pub fn borrow<'s>(&'s self) -> Result, Error> { Access::new(self.buffer, self.slice.clone()) } /// Produces a new slice into the buffer - pub fn slice(&self, range: Range) -> SharedRef<'a, T> { + /// + /// Returns an error if the given range is out-of-bounds + pub fn slice(&self, range: Range) -> Result, Error> { if range.start >= self.slice.len() || range.end > self.slice.len() { - panic!("Slice out of bounds"); + Err(Error::InvalidRange) + } else { + let start = self.slice.start + range.start; + let end = start + range.len(); + SharedRef::new(self.buffer, start..end) } - - let start = self.slice.start + range.start; - let end = start + range.len(); - SharedRef::new(self.buffer, start..end) } /// Returns the length of the buffer @@ -200,9 +244,13 @@ pub struct Access<'a, T> { } impl<'a, T> Access<'a, T> { - fn new(buffer: &'a Buffer<'a, T>, slice: Range) -> Self { - buffer.borrow(false); - Self { buffer, slice } + fn new(buffer: &'a Buffer<'a, T>, slice: Range) -> Result { + if slice.start >= buffer.len() || slice.end > buffer.len() { + Err(Error::InvalidRange) + } else { + buffer.borrow(false)?; + Ok(Self { buffer, slice }) + } } } @@ -215,6 +263,10 @@ impl Borrow<[T]> for Access<'_, T> { self.buffer.len, ) }; + + // Panic safety: The public API prevents a slice from being stored that would + // be outside the bounds of the buffer + #[allow(clippy::indexing_slicing)] &buffer[self.slice.clone()] } } @@ -274,32 +326,32 @@ mod test { // Verify that only one mutable borrow is allowed #[test] - #[should_panic(expected = "Buffer already borrowed mutably")] + #[should_panic] fn test_mut_mut_fail() { define_static_buffer!(buffer, u8, [0; 16]); let buffer = buffer::get_mut().unwrap(); - let _mut_a = buffer.borrow_mut(); - let _mut_b = buffer.borrow_mut(); + let _mut_a = buffer.borrow_mut().unwrap(); + let _mut_b = buffer.borrow_mut().unwrap(); } // Verify that mutable and immutable borrows are not allowed #[test] - #[should_panic(expected = "Buffer already borrowed mutably")] + #[should_panic] fn test_mut_imm_fail() { define_static_buffer!(buffer, u8, [0; 16]); let buffer = buffer::get_mut().unwrap(); - let _mut_a = buffer.borrow_mut(); - let _b = buffer.borrow(); + let _mut_a = buffer.borrow_mut().unwrap(); + let _b = buffer.borrow().unwrap(); } // Verify that mutable and immutable borrows are not allowed #[test] - #[should_panic(expected = "Buffer already borrowed immutably")] + #[should_panic] fn test_imm_mut_fail() { define_static_buffer!(buffer, u8, [0u8; 16]); let buffer = buffer::get_mut().unwrap(); - let _a = buffer.borrow(); - let _mut_b = buffer.borrow_mut(); + let _a = buffer.borrow().unwrap(); + let _mut_b = buffer.borrow_mut().unwrap(); } // Verify that multiple immutable borrows are allowed @@ -307,8 +359,8 @@ mod test { fn test_immutable() { define_static_buffer!(buffer, u8, [0; 16]); let buffer = buffer::get_mut().unwrap(); - let _a = buffer.borrow(); - let _b = buffer.borrow(); + let _a = buffer.borrow().unwrap(); + let _b = buffer.borrow().unwrap(); } // Verify dropping a mutable borrow releases the buffer @@ -316,11 +368,11 @@ mod test { fn test_drop() { define_static_buffer!(buffer, u8, [0; 16]); let buffer = buffer::get_mut().unwrap(); - let mut_a = buffer.borrow_mut(); + let mut_a = buffer.borrow_mut().unwrap(); drop(mut_a); - let mut_b = buffer.borrow_mut(); + let mut_b = buffer.borrow_mut().unwrap(); drop(mut_b); - let _c = buffer.borrow(); + let _c = buffer.borrow().unwrap(); } // Test slicing @@ -329,44 +381,44 @@ mod test { define_static_buffer!(buffer, u8, [0, 1, 2, 3, 4, 5, 6, 7]); let buffer = buffer::get_mut().unwrap(); - let slice = buffer.reference().slice(0..8); - let sliced = slice.borrow(); + let slice = buffer.reference().slice(0..8).unwrap(); + let sliced = slice.borrow().unwrap(); assert_eq!(sliced.borrow(), [0, 1, 2, 3, 4, 5, 6, 7]); - let slice = buffer.reference().slice(0..4); - let sliced = slice.borrow(); + let slice = buffer.reference().slice(0..4).unwrap(); + let sliced = slice.borrow().unwrap(); assert_eq!(sliced.borrow(), [0, 1, 2, 3]); - let slice = buffer.reference().slice(4..8); - let sliced = slice.borrow(); + let slice = buffer.reference().slice(4..8).unwrap(); + let sliced = slice.borrow().unwrap(); assert_eq!(sliced.borrow(), [4, 5, 6, 7]); - let slice = buffer.reference().slice(4..8).slice(1..4); - let sliced = slice.borrow(); + let slice = buffer.reference().slice(4..8).unwrap().slice(1..4).unwrap(); + let sliced = slice.borrow().unwrap(); assert_eq!(sliced.borrow(), [5, 6, 7]); - let slice = buffer.reference().slice(3..7); - let sliced = slice.borrow(); + let slice = buffer.reference().slice(3..7).unwrap(); + let sliced = slice.borrow().unwrap(); assert_eq!(sliced.borrow(), [3, 4, 5, 6]); } // Test slice starting index out of bounds #[test] - #[should_panic(expected = "Slice out of bounds")] + #[should_panic] fn test_slice_bounds_start_fail() { define_static_buffer!(buffer, u8, [0, 1, 2, 3, 4, 5, 6, 7]); let buffer = buffer::get_mut().unwrap(); - let _slice = buffer.reference().slice(8..8); + let _slice = buffer.reference().slice(8..8).unwrap(); } // Test slice ending index out of bounds #[test] - #[should_panic(expected = "Slice out of bounds")] + #[should_panic] fn test_slice_bounds_end_fail() { define_static_buffer!(buffer, u8, [0, 1, 2, 3, 4, 5, 6, 7]); let buffer = buffer::get_mut().unwrap(); - let _slice = buffer.reference().slice(0..9); + let _slice = buffer.reference().slice(0..9).unwrap(); } } diff --git a/embedded-service/src/hid/command.rs b/embedded-service/src/hid/command.rs index c0c5f5d1..89a91ee0 100644 --- a/embedded-service/src/hid/command.rs +++ b/embedded-service/src/hid/command.rs @@ -485,7 +485,7 @@ impl<'a> Command<'a> { len += register_len; } Command::SetReport(report_type, report_id, data) => { - let borrow = data.borrow(); + let borrow = data.borrow().map_err(|_| Error::InvalidData)?; let data: &[u8] = borrow.borrow(); let (command_len, buf) = Self::encode_common(buf, Opcode::SetReport, Some(*report_type), *report_id)?; diff --git a/embedded-service/src/lib.rs b/embedded-service/src/lib.rs index d4bee102..4ce50d61 100644 --- a/embedded-service/src/lib.rs +++ b/embedded-service/src/lib.rs @@ -62,6 +62,14 @@ pub type SyncCell = critical_section_cell::CriticalSectionCell; #[cfg(all(not(test), target_os = "none", target_arch = "arm"))] pub type SyncCell = thread_mode_cell::ThreadModeCell; +/// Until the Never type (`!`) is stable, the best we have is `Infallible`. +/// +/// Although they mean the same thing for the most part from the type system pov, +/// `Never` typically reads better than `Infallible` in some cases. +/// +/// For example, a result that should never return unless there is an error: `Result`. +pub type Never = core::convert::Infallible; + /// initialize all service static interfaces as required. Ideally, this is done before subsystem initialization #[allow(clippy::unused_async)] pub async fn init() { diff --git a/espi-service/src/espi_service.rs b/espi-service/src/espi_service.rs index 53721341..74d9c0e8 100644 --- a/espi-service/src/espi_service.rs +++ b/espi-service/src/espi_service.rs @@ -29,8 +29,9 @@ type HostMsgInternal = (EndpointID, StdHostMsg); #[derive(Debug, Clone, Copy)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] -enum Error { +pub enum Error { Serialize, + Buffer(embedded_services::buffer::Error), } pub struct Service<'a> { @@ -169,7 +170,7 @@ impl Service<'_> { response: &StdHostRequest, endpoint: EndpointID, ) -> Result<(), Error> { - let mut assembly_buf_access = self.assembly_buf_owned_ref.borrow_mut(); + let mut assembly_buf_access = self.assembly_buf_owned_ref.borrow_mut().map_err(Error::Buffer)?; let pkt_ctx_buf = assembly_buf_access.borrow_mut(); let mut mctp_ctx = mctp_rs::MctpPacketContext::new(mctp_rs::smbus_espi::SmbusEspiMedium, pkt_ctx_buf); @@ -229,7 +230,7 @@ impl Service<'_> { // Immediately service the packet with the ESPI HAL let event = espi.wait_for_event().await; - process_controller_event(espi, self, event).await; + process_controller_event(espi, self, event).await?; } Ok(()) } @@ -314,7 +315,7 @@ pub(crate) async fn process_controller_event( espi: &mut espi::Espi<'static>, espi_service: &Service<'_>, event: Result, -) { +) -> Result<(), Error> { match event { Ok(espi::Event::PeripheralEvent(port_event)) => { info!( @@ -361,7 +362,10 @@ pub(crate) async fn process_controller_event( let endpoint: EndpointID; { - let mut assembly_access = espi_service.assembly_buf_owned_ref.borrow_mut(); + let mut assembly_access = espi_service + .assembly_buf_owned_ref + .borrow_mut() + .map_err(Error::Buffer)?; // let mut comms_access = espi_service.comms_buf_owned_ref.borrow_mut(); let mut mctp_ctx = mctp_rs::MctpPacketContext::::new( SmbusEspiMedium, @@ -409,7 +413,7 @@ pub(crate) async fn process_controller_event( EndpointID::Internal(embedded_services::comms::Internal::Debug), espi, ); - return; + return Err(Error::Serialize); } } } @@ -425,7 +429,7 @@ pub(crate) async fn process_controller_event( EndpointID::Internal(embedded_services::comms::Internal::Debug), espi, ); - return; + return Err(Error::Serialize); } Err(_e) => { // Handle protocol or medium error @@ -439,7 +443,7 @@ pub(crate) async fn process_controller_event( EndpointID::Internal(embedded_services::comms::Internal::Debug), espi, ); - return; + return Err(Error::Serialize); } } } @@ -461,4 +465,5 @@ pub(crate) async fn process_controller_event( error!("eSPI Failed with error: {:?}", e); } } + Ok(()) } diff --git a/espi-service/src/task.rs b/espi-service/src/task.rs index 2719184c..ad1ee7fa 100644 --- a/espi-service/src/task.rs +++ b/espi-service/src/task.rs @@ -4,7 +4,10 @@ use embedded_services::{comms, ec_type, info}; use crate::{ESPI_SERVICE, Service, process_controller_event}; -pub async fn espi_service(mut espi: espi::Espi<'static>, memory_map_buffer: &'static mut [u8]) { +pub async fn espi_service( + mut espi: espi::Espi<'static>, + memory_map_buffer: &'static mut [u8], +) -> Result { info!("Reserved eSPI memory map buffer size: {}", memory_map_buffer.len()); info!("eSPI MemoryMap size: {}", size_of::()); @@ -35,7 +38,7 @@ pub async fn espi_service(mut espi: espi::Espi<'static>, memory_map_buffer: &'st match event { embassy_futures::select::Either::First(controller_event) => { - process_controller_event(&mut espi, espi_service, controller_event).await + process_controller_event(&mut espi, espi_service, controller_event).await? } embassy_futures::select::Either::Second(host_msg) => { espi_service.process_subsystem_msg(&mut espi, host_msg).await diff --git a/examples/rt633/src/bin/espi.rs b/examples/rt633/src/bin/espi.rs index 04cab046..60850559 100644 --- a/examples/rt633/src/bin/espi.rs +++ b/examples/rt633/src/bin/espi.rs @@ -98,8 +98,8 @@ unsafe extern "C" { #[embassy_executor::task] async fn espi_service_task(espi: embassy_imxrt::espi::Espi<'static>, memory_map_buffer: &'static mut [u8]) -> ! { - espi_service::task::espi_service(espi, memory_map_buffer).await; - unreachable!() + let Err(e) = espi_service::task::espi_service(espi, memory_map_buffer).await; + panic!("espi_service_task error: {e:?}"); } #[embassy_executor::main] diff --git a/examples/rt633/src/bin/espi_battery.rs b/examples/rt633/src/bin/espi_battery.rs index 34fe2110..165b5244 100644 --- a/examples/rt633/src/bin/espi_battery.rs +++ b/examples/rt633/src/bin/espi_battery.rs @@ -239,8 +239,8 @@ async fn wrapper_task(wrapper: Wrapper<'static, Bq40z50Controller>) { #[embassy_executor::task] async fn espi_service_task(espi: embassy_imxrt::espi::Espi<'static>, memory_map_buffer: &'static mut [u8]) -> ! { - espi_service::task::espi_service(espi, memory_map_buffer).await; - unreachable!() + let Err(e) = espi_service::task::espi_service(espi, memory_map_buffer).await; + panic!("espi_service_task error: {e:?}"); } #[embassy_executor::task] diff --git a/examples/std/src/bin/buffer.rs b/examples/std/src/bin/buffer.rs index 95b5041e..1d1b14a9 100644 --- a/examples/std/src/bin/buffer.rs +++ b/examples/std/src/bin/buffer.rs @@ -27,7 +27,7 @@ mod sender { pub async fn send(&self, even: bool) { { - let mut borrow = self.buffer.borrow_mut(); + let mut borrow = self.buffer.borrow_mut().unwrap(); let data: &mut [u8] = borrow.borrow_mut(); let data = &mut data[0..4]; if even { @@ -69,7 +69,7 @@ mod receiver { .get::>() .ok_or(comms::MailboxDelegateError::MessageNotFound)?; - let borrow = data.borrow(); + let borrow = data.borrow().unwrap(); let data: &[u8] = borrow.borrow(); info!("Received data: {data:?}"); diff --git a/examples/std/src/bin/debug.rs b/examples/std/src/bin/debug.rs index 2d234bc6..9944d98d 100644 --- a/examples/std/src/bin/debug.rs +++ b/examples/std/src/bin/debug.rs @@ -61,7 +61,7 @@ mod espi_service { } HostMsg::Response(acpi) => { // Stage the response bytes into the mock OOB buffer for the host - let mut access = self.resp_owned.borrow_mut(); + let mut access = self.resp_owned.borrow_mut().unwrap(); let buf: &mut [u8] = core::borrow::BorrowMut::borrow_mut(&mut access); if let StdHostPayload::DebugGetMsgsResponse { debug_buf } = acpi.payload { let copy_len = core::cmp::min(debug_buf.len(), buf.len()); @@ -121,7 +121,7 @@ mod espi_service { let request = b"GetDebugBuffer"; let req_len = request.len(); { - let mut access = req_owned.borrow_mut(); + let mut access = req_owned.borrow_mut().unwrap(); let buf: &mut [u8] = BorrowMut::borrow_mut(&mut access); buf[..req_len].copy_from_slice(request); } @@ -143,7 +143,7 @@ mod espi_service { // Wait for the response payload staged by the Debug service, then "forward" it to host let len = wait_response_len().await; let buf = response_buf(); - let access = buf.borrow(); + let access = buf.borrow().unwrap(); let slice: &[u8] = core::borrow::Borrow::borrow(&access); let bytes = &slice[..len.min(slice.len())]; let preview = bytes @@ -194,8 +194,8 @@ async fn debug_service() -> ! { #[embassy_executor::task] async fn defmt_to_host_task() -> ! { - debug_service::task::defmt_to_host_task().await; - unreachable!() + let Err(e) = debug_service::task::defmt_to_host_task().await; + panic!("defmt_to_host_task error: {e:?}"); } fn main() { diff --git a/examples/std/src/bin/keyboard.rs b/examples/std/src/bin/keyboard.rs index 1369f06b..5f107096 100644 --- a/examples/std/src/bin/keyboard.rs +++ b/examples/std/src/bin/keyboard.rs @@ -25,7 +25,7 @@ mod device { pub async fn key_down(&self, key: Key) { { - let mut borrow = self.event_buffer.borrow_mut(); + let mut borrow = self.event_buffer.borrow_mut().unwrap(); let buf: &mut [KeyEvent] = borrow.borrow_mut(); buf[0] = KeyEvent::Make(key); @@ -33,14 +33,17 @@ mod device { keyboard::broadcast_message( self.id, - MessageData::Event(Event::KeyEvent(self.id, self.event_buffer.reference().slice(0..1))), + MessageData::Event(Event::KeyEvent( + self.id, + self.event_buffer.reference().slice(0..1).unwrap(), + )), ) .await; } pub async fn key_up(&self, key: Key) { { - let mut borrow = self.event_buffer.borrow_mut(); + let mut borrow = self.event_buffer.borrow_mut().unwrap(); let buf: &mut [KeyEvent] = borrow.borrow_mut(); buf[0] = KeyEvent::Break(key); @@ -48,7 +51,10 @@ mod device { keyboard::broadcast_message( self.id, - MessageData::Event(Event::KeyEvent(self.id, self.event_buffer.reference().slice(0..1))), + MessageData::Event(Event::KeyEvent( + self.id, + self.event_buffer.reference().slice(0..1).unwrap(), + )), ) .await; } @@ -84,7 +90,7 @@ mod host { match &message.data { MessageData::Event(Event::KeyEvent(id, events)) => { - let borrow = events.borrow(); + let borrow = events.borrow().unwrap(); let buf: &[KeyEvent] = borrow.borrow(); for event in buf { diff --git a/hid-service/src/i2c/device.rs b/hid-service/src/i2c/device.rs index 0182d8e9..5af3870a 100644 --- a/hid-service/src/i2c/device.rs +++ b/hid-service/src/i2c/device.rs @@ -35,7 +35,7 @@ impl> Device { } } let mut bus = self.bus.lock().await; - let mut borrow = self.buffer.borrow_mut(); + let mut borrow = self.buffer.borrow_mut().map_err(Error::Buffer)?; let mut reg = [0u8; 2]; let buf: &mut [u8] = borrow.borrow_mut(); let buf_len = buf.len(); @@ -70,19 +70,19 @@ impl> Device { pub async fn read_hid_descriptor(&self) -> Result, Error> { let desc = self.get_hid_descriptor().await?; - let mut borrow = self.buffer.borrow_mut(); + let mut borrow = self.buffer.borrow_mut().map_err(Error::Buffer)?; let buf: &mut [u8] = borrow.borrow_mut(); let len = desc.encode_into_slice(buf).map_err(Error::Hid)?; trace!("HID descriptor length: {}", len); - Ok(self.buffer.reference().slice(0..len)) + self.buffer.reference().slice(0..len).map_err(Error::Buffer) } pub async fn read_report_descriptor(&self) -> Result, Error> { info!("Sending report descriptor"); let desc = self.get_hid_descriptor().await?; - let mut borrow = self.buffer.borrow_mut(); + let mut borrow = self.buffer.borrow_mut().map_err(Error::Buffer)?; let buf: &mut [u8] = borrow.borrow_mut(); let buffer_len = buf.len(); let reg = desc.w_report_desc_register.to_le_bytes(); @@ -105,14 +105,14 @@ impl> Device { return Err(Error::Bus(e)); } - Ok(self.buffer.reference().slice(0..len)) + self.buffer.reference().slice(0..len).map_err(Error::Buffer) } pub async fn handle_input_report(&self) -> Result, Error> { info!("Handling input report"); let desc = self.get_hid_descriptor().await?; - let mut borrow = self.buffer.borrow_mut(); + let mut borrow = self.buffer.borrow_mut().map_err(Error::Buffer)?; let buf: &mut [u8] = borrow.borrow_mut(); let buffer_len = buf.len(); let buf = buf @@ -128,7 +128,10 @@ impl> Device { return Err(Error::Bus(e)); } - Ok(self.buffer.reference().slice(0..desc.w_max_input_length as usize)) + self.buffer + .reference() + .slice(0..desc.w_max_input_length as usize) + .map_err(Error::Buffer) } pub async fn handle_command( @@ -137,7 +140,7 @@ impl> Device { ) -> Result>, Error> { info!("Handling command"); - let mut borrow = self.buffer.borrow_mut(); + let mut borrow = self.buffer.borrow_mut().map_err(Error::Buffer)?; let buf: &mut [u8] = borrow.borrow_mut(); let buffer_len = buf.len(); diff --git a/hid-service/src/i2c/host.rs b/hid-service/src/i2c/host.rs index 5f008e15..9de78e26 100644 --- a/hid-service/src/i2c/host.rs +++ b/hid-service/src/i2c/host.rs @@ -72,7 +72,7 @@ impl Host { } async fn process_output_report(&self) -> Result, Error> { - let mut borrow = self.buffer.borrow_mut(); + let mut borrow = self.buffer.borrow_mut().map_err(Error::Buffer)?; let buffer: &mut [u8] = borrow.borrow_mut(); let buffer_len = buffer.len(); @@ -111,7 +111,10 @@ impl Host { actual: buffer_len, }), ))?)), - self.buffer.reference().slice(3..length as usize), + self.buffer + .reference() + .slice(3..length as usize) + .map_err(Error::Buffer)?, )) } @@ -159,7 +162,7 @@ impl Host { if opcode.requires_host_data() { trace!("Waiting for data"); - let mut borrow = self.buffer.borrow_mut(); + let mut borrow = self.buffer.borrow_mut().map_err(Error::Buffer)?; let buffer: &mut [u8] = borrow.borrow_mut(); let buffer_len = buffer.len(); @@ -192,7 +195,12 @@ impl Host { })))?, ) .await?; - Some(self.buffer.reference().slice(2..length as usize)) + Some( + self.buffer + .reference() + .slice(2..length as usize) + .map_err(Error::Buffer)?, + ) } else { None } @@ -315,7 +323,7 @@ impl Host { | hid::Response::ReportDescriptor(data) | hid::Response::InputReport(data) | hid::Response::FeatureReport(data) => { - let bytes = data.borrow(); + let bytes = data.borrow().map_err(Error::Buffer)?; self.write_bus(DEVICE_RESPONSE_TIMEOUT_MS, bytes.borrow()).await } hid::Response::Command(cmd) => match cmd { diff --git a/hid-service/src/lib.rs b/hid-service/src/lib.rs index 5b841dc1..96bbe883 100644 --- a/hid-service/src/lib.rs +++ b/hid-service/src/lib.rs @@ -9,5 +9,8 @@ pub mod i2c; pub enum Error { /// Error from the underlying bus Bus(B), + /// HID error Hid(hid::Error), + /// Error from the underlying buffer + Buffer(embedded_services::buffer::Error), } diff --git a/keyboard-service/src/gpio_kb.rs b/keyboard-service/src/gpio_kb.rs index b58f22b3..3b720b42 100644 --- a/keyboard-service/src/gpio_kb.rs +++ b/keyboard-service/src/gpio_kb.rs @@ -526,7 +526,7 @@ impl< match report_type { // Received a set output report for LEDs hid::ReportType::Output if report_id.0 == REPORT_ID => { - let buf = buf.borrow(); + let buf = buf.borrow().map_err(super::KeyboardError::Buffer)?; let leds: &[u8] = buf.borrow(); let flags = LedFlags::from_bits_retain(leds[0]); diff --git a/keyboard-service/src/hid_kb.rs b/keyboard-service/src/hid_kb.rs index af3a5ca4..f72d9a47 100644 --- a/keyboard-service/src/hid_kb.rs +++ b/keyboard-service/src/hid_kb.rs @@ -82,7 +82,7 @@ impl HidI2cReport { let err = match error { super::KeyboardError::Ghosting | super::KeyboardError::Rollover => [ERROR_ROLL_OVER; REPORT_MAX_SZ], - super::KeyboardError::Scan | super::KeyboardError::Command => [ERROR_UNDEFINED; REPORT_MAX_SZ], + _ => [ERROR_UNDEFINED; REPORT_MAX_SZ], }; HidI2cReport::from_report_slice(super::HidReportSlice(&err), max_len) @@ -121,7 +121,7 @@ pub(crate) fn init(reg_file: hid::RegisterFile) -> &'static hid::Device { /// This task handles calling the keyboard `scan` in a loop, while also listening for commands /// from the HID request handler task. To minimize delay between scan loops, we quickly process commands /// and let the HID request handler task handle forwarding the response to the host. -pub async fn handle_keyboard(mut hid_kb: T) { +pub async fn handle_keyboard(mut hid_kb: T) -> Result { let context = CONTEXT.get().await; // Buffer holding immediate report requests @@ -172,7 +172,7 @@ pub async fn handle_keyboard(mut hid_kb: T) { { let report = hid_kb.get_report(report_type, report_id); let report = HidI2cReport::from_report_slice(report, max_input_len).to_bytes(); - let mut buf = owned_buf.borrow_mut(); + let mut buf = owned_buf.borrow_mut().map_err(super::KeyboardError::Buffer)?; let buf: &mut [u8] = buf.borrow_mut(); buf[..report.len()].copy_from_slice(&report); } @@ -250,7 +250,7 @@ pub async fn handle_keyboard(mut hid_kb: T) { /// This is a separate task because we want the main `scan` loop to quickly fire off an available report /// without it being blocked waiting for communication with the host. We also use a queue in case multiple reports /// are available before one is fully processed to prevent any lost key events. -pub async fn handle_reports(mut kb_int: impl OutputPin) { +pub async fn handle_reports(mut kb_int: impl OutputPin) -> Result { let context = CONTEXT.get().await; embedded_services::define_static_buffer!(input_buf, u8, [0u8; INPUT_MAX]); @@ -266,7 +266,7 @@ pub async fn handle_reports(mut kb_int: impl OutputPin) { // Once we have one, copy it to outgoing buffer { - let mut buf = owned_buf.borrow_mut(); + let mut buf = owned_buf.borrow_mut().map_err(super::KeyboardError::Buffer)?; let buf: &mut [u8] = buf.borrow_mut(); buf.copy_from_slice(&report); } @@ -302,6 +302,7 @@ pub async fn handle_host_requests(host: &'static mut hid_service::i2c::Host context.send_complete.signal(()), Err(hid_service::Error::Bus(_)) => error!("Host I2C bus error"), Err(hid_service::Error::Hid(e)) => error!("Host HID error: {:?}", e), + Err(hid_service::Error::Buffer(e)) => error!("Host buffer error: {:?}", e), } } } diff --git a/keyboard-service/src/lib.rs b/keyboard-service/src/lib.rs index 4e7e0cd8..23263d46 100644 --- a/keyboard-service/src/lib.rs +++ b/keyboard-service/src/lib.rs @@ -25,6 +25,8 @@ pub enum KeyboardError { Ghosting, /// Command error Command, + /// Buffer error + Buffer(embedded_services::buffer::Error), } /// A slice of a HID report. diff --git a/keyboard-service/src/task.rs b/keyboard-service/src/task.rs index b4d6d80c..658e78c3 100644 --- a/keyboard-service/src/task.rs +++ b/keyboard-service/src/task.rs @@ -4,11 +4,15 @@ use embedded_services::hid; use crate::hid_kb::{self, CONTEXT}; -pub async fn keyboard_task(keyboard: T) { +pub async fn keyboard_task( + keyboard: T, +) -> Result { crate::hid_kb::handle_keyboard(keyboard).await } -pub async fn reports_task(keyboard_interrupt: T) { +pub async fn reports_task( + keyboard_interrupt: T, +) -> Result { crate::hid_kb::handle_reports(keyboard_interrupt).await } @@ -48,7 +52,7 @@ pub async fn init_and_recv_device_requests_task( hid_descriptor: hid::Descriptor, report_descriptor: &'static [u8], reg_file: hid::RegisterFile, -) { +) -> Result { let device = crate::hid_kb::init(reg_file); hid::register_device(device) .await @@ -60,7 +64,8 @@ pub async fn init_and_recv_device_requests_task( { let mut buf = hid_desc_buf::get_mut() .expect("Must not already be borrowed mutably") - .borrow_mut(); + .borrow_mut() + .map_err(super::KeyboardError::Buffer)?; let buf: &mut [u8] = buf.borrow_mut(); hid_descriptor .encode_into_slice(buf) @@ -72,7 +77,8 @@ pub async fn init_and_recv_device_requests_task( { let mut buf = report_desc_buf::get_mut() .expect("Must not already be borrowed mutably") - .borrow_mut(); + .borrow_mut() + .map_err(super::KeyboardError::Buffer)?; let buf: &mut [u8] = buf.borrow_mut(); buf[..report_descriptor.len()].copy_from_slice(report_descriptor); } @@ -88,7 +94,9 @@ pub async fn init_and_recv_device_requests_task( device.send_response(response).await.expect("Infallible"); } hid::Request::ReportDescriptor => { - let response = report_desc_buf::get().slice(0..report_descriptor.len()); + let response = report_desc_buf::get() + .slice(0..report_descriptor.len()) + .map_err(super::KeyboardError::Buffer)?; let response = Some(hid::Response::ReportDescriptor(response)); device.send_response(response).await.expect("Infallible"); } @@ -99,7 +107,9 @@ pub async fn init_and_recv_device_requests_task( let ipc = context.report_ipc.receive().await; let report = ipc.command.clone(); let response = Some(hid::Response::InputReport( - report.slice(0..hid_descriptor.w_max_input_length as usize), + report + .slice(0..hid_descriptor.w_max_input_length as usize) + .map_err(super::KeyboardError::Buffer)?, )); // Then send it to the host diff --git a/power-policy-service/src/task.rs b/power-policy-service/src/task.rs index 39c79935..916b2909 100644 --- a/power-policy-service/src/task.rs +++ b/power-policy-service/src/task.rs @@ -1,5 +1,3 @@ -use core::convert::Infallible; - use embassy_sync::once_lock::OnceLock; use embedded_services::{comms, error, info}; @@ -14,7 +12,7 @@ pub enum InitError { RegistrationFailed, } -pub async fn task(config: config::Config) -> Result { +pub async fn task(config: config::Config) -> Result { info!("Starting power policy task"); static POLICY: OnceLock = OnceLock::new(); let policy = if let Some(policy) = PowerPolicy::create(config) {