diff --git a/embedded-service/src/hid/command.rs b/embedded-service/src/hid/command.rs index 2951cf7f..30bb9599 100644 --- a/embedded-service/src/hid/command.rs +++ b/embedded-service/src/hid/command.rs @@ -2,7 +2,7 @@ use core::borrow::Borrow; use super::{Error, ReportId}; -use crate::buffer::SharedRef; +use crate::{buffer::SharedRef, hid::InvalidSizeError}; /// HID report types #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -350,7 +350,10 @@ impl<'a> Command<'a> { if report_id.0 >= EXTENDED_REPORT_ID { if buf.len() < EXTENDED_REPORT_CMD_LEN { - return Err(Error::InvalidSize(EXTENDED_REPORT_CMD_LEN, buf.len())); + return Err(Error::InvalidSize(InvalidSizeError { + expected: EXTENDED_REPORT_CMD_LEN, + actual: buf.len(), + })); } val |= EXTENDED_REPORT_ID as u16; @@ -363,7 +366,10 @@ impl<'a> Command<'a> { Ok((EXTENDED_REPORT_CMD_LEN, &mut buf[EXTENDED_REPORT_CMD_LEN..])) } else { if buf.len() < STANDARD_REPORT_CMD_LEN { - return Err(Error::InvalidSize(STANDARD_REPORT_CMD_LEN, buf.len())); + return Err(Error::InvalidSize(InvalidSizeError { + expected: STANDARD_REPORT_CMD_LEN, + actual: buf.len(), + })); } val |= report_id.0 as u16; @@ -377,7 +383,10 @@ impl<'a> Command<'a> { /// Returns the number of bytes written and the remaining buffer fn encode_basic_op(buf: &mut [u8], opcode: Opcode) -> Result<(usize, &mut [u8]), Error> { if buf.len() < BASIC_CMD_LEN { - return Err(Error::InvalidSize(BASIC_CMD_LEN, buf.len())); + return Err(Error::InvalidSize(InvalidSizeError { + expected: BASIC_CMD_LEN, + actual: buf.len(), + })); } buf[0..BASIC_CMD_LEN].copy_from_slice(&>::into(opcode).to_le_bytes()); @@ -389,7 +398,10 @@ impl<'a> Command<'a> { fn encode_register(buf: &mut [u8], reg: Option) -> Result<(usize, &mut [u8]), Error> { if let Some(reg) = reg { if buf.len() < REGISTER_LEN { - return Err(Error::InvalidSize(REGISTER_LEN, buf.len())); + return Err(Error::InvalidSize(InvalidSizeError { + expected: REGISTER_LEN, + actual: buf.len(), + })); } buf[0..REGISTER_LEN].copy_from_slice(®.to_le_bytes()); Ok((REGISTER_LEN, &mut buf[REGISTER_LEN..])) @@ -402,7 +414,10 @@ impl<'a> Command<'a> { /// Returns the number of bytes written and the remaining buffer fn encode_value>(buf: &mut [u8], value: T) -> Result<(usize, &mut [u8]), Error> { if buf.len() < LENGTH_VALUE_LEN { - return Err(Error::InvalidSize(LENGTH_VALUE_LEN, buf.len())); + return Err(Error::InvalidSize(InvalidSizeError { + expected: LENGTH_VALUE_LEN, + actual: buf.len(), + })); } // Length value includes the size of the length as well buf[0..VALUE_LEN].copy_from_slice(&4u16.to_le_bytes()); @@ -416,7 +431,10 @@ impl<'a> Command<'a> { // +2 to encode the length of the data let total_len = data.len() + 2; if buf.len() < total_len { - return Err(Error::InvalidSize(total_len, buf.len())); + return Err(Error::InvalidSize(InvalidSizeError { + expected: total_len, + actual: buf.len(), + })); } buf[0..VALUE_LEN].copy_from_slice(&(total_len as u16).to_le_bytes()); diff --git a/embedded-service/src/hid/mod.rs b/embedded-service/src/hid/mod.rs index b637262a..941ebbe8 100644 --- a/embedded-service/src/hid/mod.rs +++ b/embedded-service/src/hid/mod.rs @@ -14,6 +14,16 @@ pub use command::*; /// HID descriptor length pub const DESCRIPTOR_LEN: usize = 30; +/// Data for [`Error::InvalidSize`] +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct InvalidSizeError { + /// Expected size + pub expected: usize, + /// Actual size + pub actual: usize, +} + /// HID errors #[derive(Clone, Copy, Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -21,7 +31,7 @@ pub enum Error { /// Invalid data InvalidData, /// Invalid size: expected and actual sizes - InvalidSize(usize, usize), + InvalidSize(InvalidSizeError), /// Invalid register address InvalidRegisterAddress, /// Invalid device @@ -68,7 +78,10 @@ impl Descriptor { /// Serializes a descriptor into the slice pub fn encode_into_slice(&self, buf: &mut [u8]) -> Result { if buf.len() < DESCRIPTOR_LEN { - return Err(Error::InvalidSize(DESCRIPTOR_LEN, buf.len())); + return Err(Error::InvalidSize(InvalidSizeError { + expected: DESCRIPTOR_LEN, + actual: buf.len(), + })); } buf[0..2].copy_from_slice(&self.w_hid_desc_length.to_le_bytes()); @@ -93,7 +106,10 @@ impl Descriptor { /// Deserializes a descriptor from the slice pub fn decode_from_slice(buf: &[u8]) -> Result { if buf.len() < DESCRIPTOR_LEN { - return Err(Error::InvalidSize(DESCRIPTOR_LEN, buf.len())); + return Err(Error::InvalidSize(InvalidSizeError { + expected: DESCRIPTOR_LEN, + actual: buf.len(), + })); } // Reserved bytes must be zero diff --git a/hid-service/src/i2c/device.rs b/hid-service/src/i2c/device.rs index d7efd06b..0182d8e9 100644 --- a/hid-service/src/i2c/device.rs +++ b/hid-service/src/i2c/device.rs @@ -2,7 +2,7 @@ use core::borrow::BorrowMut; use embassy_sync::mutex::Mutex; use embedded_hal_async::i2c::{AddressMode, I2c}; -use embedded_services::hid::{DeviceContainer, Opcode, Response}; +use embedded_services::hid::{DeviceContainer, InvalidSizeError, Opcode, Response}; use embedded_services::{GlobalRawMutex, buffer::*}; use embedded_services::{error, hid, info, trace}; @@ -38,7 +38,13 @@ impl> Device { let mut borrow = self.buffer.borrow_mut(); let mut reg = [0u8; 2]; let buf: &mut [u8] = borrow.borrow_mut(); - let buf = &mut buf[0..hid::DESCRIPTOR_LEN]; + let buf_len = buf.len(); + let buf = buf + .get_mut(0..hid::DESCRIPTOR_LEN) + .ok_or(Error::Hid(hid::Error::InvalidSize(InvalidSizeError { + expected: hid::DESCRIPTOR_LEN, + actual: buf_len, + })))?; reg.copy_from_slice(&self.device.regs.hid_desc_reg.to_le_bytes()); if let Err(e) = bus.write_read(self.address, ®, buf).await { @@ -47,18 +53,18 @@ impl> Device { } let res = hid::Descriptor::decode_from_slice(buf); - if res.is_err() { - error!("Failed to deseralize HID descriptor"); - return Err(Error::Hid(hid::Error::Serialize)); - } - let desc = res.unwrap(); - info!("HID descriptor: {:#?}", desc); - { - let mut descriptor = self.descriptor.lock().await; - *descriptor = Some(desc); + match res { + Ok(desc) => { + info!("HID descriptor: {:#?}", desc); + let mut descriptor = self.descriptor.lock().await; + *descriptor = Some(desc); + Ok(desc) + } + Err(e) => { + error!("Failed to deserialize HID descriptor: {:?}", e); + Err(Error::Hid(hid::Error::Serialize)) + } } - - Ok(desc) } pub async fn read_hid_descriptor(&self) -> Result, Error> { @@ -78,11 +84,23 @@ impl> Device { let mut borrow = self.buffer.borrow_mut(); let buf: &mut [u8] = borrow.borrow_mut(); + let buffer_len = buf.len(); let reg = desc.w_report_desc_register.to_le_bytes(); let len = desc.w_report_desc_length as usize; let mut bus = self.bus.lock().await; - if let Err(e) = bus.write_read(self.address, ®, &mut buf[0..len]).await { + if let Err(e) = bus + .write_read( + self.address, + ®, + buf.get_mut(0..len) + .ok_or(Error::Hid(hid::Error::InvalidSize(InvalidSizeError { + expected: len, + actual: buffer_len, + })))?, + ) + .await + { error!("Failed to read report descriptor"); return Err(Error::Bus(e)); } @@ -96,7 +114,13 @@ impl> Device { let mut borrow = self.buffer.borrow_mut(); let buf: &mut [u8] = borrow.borrow_mut(); - let buf = &mut buf[0..desc.w_max_input_length as usize]; + let buffer_len = buf.len(); + let buf = buf + .get_mut(0..desc.w_max_input_length as usize) + .ok_or(Error::Hid(hid::Error::InvalidSize(InvalidSizeError { + expected: desc.w_max_input_length as usize, + actual: buffer_len, + })))?; let mut bus = self.bus.lock().await; if let Err(e) = bus.read(self.address, buf).await { @@ -115,25 +139,36 @@ impl> Device { let mut borrow = self.buffer.borrow_mut(); let buf: &mut [u8] = borrow.borrow_mut(); + let buffer_len = buf.len(); let opcode: Opcode = cmd.into(); - let res = cmd.encode_into_slice( - buf, - Some(self.device.regs.command_reg), - if opcode.has_response() || opcode.requires_host_data() { - Some(self.device.regs.data_reg) - } else { - None - }, - ); - if res.is_err() { - error!("Failed to serialize command"); - return Err(Error::Hid(hid::Error::Serialize)); - } + let len = cmd + .encode_into_slice( + buf, + Some(self.device.regs.command_reg), + if opcode.has_response() || opcode.requires_host_data() { + Some(self.device.regs.data_reg) + } else { + None + }, + ) + .map_err(|_| { + error!("Failed to serialize command"); + Error::Hid(hid::Error::Serialize) + })?; - let len = res.unwrap(); let mut bus = self.bus.lock().await; - if let Err(e) = bus.write(self.address, &buf[..len]).await { + if let Err(e) = bus + .write( + self.address, + buf.get(..len) + .ok_or(Error::Hid(hid::Error::InvalidSize(InvalidSizeError { + expected: len, + actual: buffer_len, + })))?, + ) + .await + { error!("Failed to write command"); return Err(Error::Bus(e)); } @@ -168,7 +203,10 @@ impl> Device { Some(hid::Response::InputReport(report)) } hid::Request::Command(cmd) => self.handle_command(&cmd).await?, - _ => unimplemented!(), + _ => { + error!("Unimplemented HID request"); + None + } }; self.device diff --git a/hid-service/src/i2c/host.rs b/hid-service/src/i2c/host.rs index fbedabad..5f008e15 100644 --- a/hid-service/src/i2c/host.rs +++ b/hid-service/src/i2c/host.rs @@ -7,7 +7,7 @@ use embassy_time::{Duration, with_timeout}; use embedded_services::GlobalRawMutex; use embedded_services::buffer::OwnedRef; use embedded_services::comms::{self, Endpoint, EndpointID, External, MailboxDelegate}; -use embedded_services::hid::{self, DeviceId, Opcode}; +use embedded_services::hid::{self, DeviceId, InvalidSizeError, Opcode}; use embedded_services::{error, trace}; use super::{Command as I2cCommand, I2cSlaveAsync}; @@ -43,56 +43,74 @@ impl Host { async fn read_bus(&self, timeout_ms: u64, buffer: &mut [u8]) -> Result<(), Error> { let mut bus = self.bus.lock().await; - let result = with_timeout(Duration::from_millis(timeout_ms), bus.respond_to_write(buffer)).await; - if result.is_err() { - error!("Response timeout"); - return Err(Error::Hid(hid::Error::Timeout)); - } - - if let Err(e) = result.unwrap() { - error!("Failed to read from bus"); - return Err(Error::Bus(e)); - } - - Ok(()) + with_timeout(Duration::from_millis(timeout_ms), bus.respond_to_write(buffer)) + .await + .map_err(|_| { + error!("Response timeout"); + Error::Hid(hid::Error::Timeout) + })? + .map_err(|e| { + error!("Failed to read from bus"); + Error::Bus(e) + }) } async fn write_bus(&self, timeout_ms: u64, buffer: &[u8]) -> Result<(), Error> { let mut bus = self.bus.lock().await; // Send response, timeout if the host doesn't read so we don't get stuck here trace!("Sending {} bytes", buffer.len()); - let result = with_timeout(Duration::from_millis(timeout_ms), bus.respond_to_read(buffer)).await; - if result.is_err() { - error!("Response timeout"); - return Err(Error::Hid(hid::Error::Timeout)); - } - - if let Err(e) = result.unwrap() { - error!("Failed to rwrite to bus"); - return Err(Error::Bus(e)); - } - - trace!("Response sent"); - Ok(()) + with_timeout(Duration::from_millis(timeout_ms), bus.respond_to_read(buffer)) + .await + .map_err(|_| { + error!("Response timeout"); + Error::Hid(hid::Error::Timeout) + })? + .map_err(|e| { + error!("Failed to write to bus"); + Error::Bus(e) + }) } async fn process_output_report(&self) -> Result, Error> { let mut borrow = self.buffer.borrow_mut(); let buffer: &mut [u8] = borrow.borrow_mut(); - - self.read_bus(DATA_READ_TIMEOUT_MS, &mut buffer[..2]).await?; - - let length = u16::from_le_bytes([buffer[0], buffer[1]]); - if buffer.len() < length as usize { - error!("Output report buffer overrun: {}", length); - return Err(Error::Hid(hid::Error::InvalidSize(length as usize, buffer.len()))); - } - + let buffer_len = buffer.len(); + + self.read_bus( + DATA_READ_TIMEOUT_MS, + buffer + .get_mut(..2) + .ok_or(Error::Hid(hid::Error::InvalidSize(InvalidSizeError { + expected: 2, + actual: buffer_len, + })))?, + ) + .await?; + + let length = u16::from_le_bytes(buffer.get(..2).and_then(|b| <[u8; 2]>::try_from(b).ok()).ok_or( + Error::Hid(hid::Error::InvalidSize(InvalidSizeError { + expected: 2, + actual: buffer_len, + })), + )?); trace!("Reading {} bytes", length); - self.read_bus(DATA_READ_TIMEOUT_MS, &mut buffer[2..length as usize]) - .await?; + self.read_bus( + DATA_READ_TIMEOUT_MS, + buffer + .get_mut(2..length as usize) + .ok_or(Error::Hid(hid::Error::InvalidSize(InvalidSizeError { + expected: length as usize, + actual: buffer_len, + })))?, + ) + .await?; Ok(hid::Request::OutputReport( - Some(hid::ReportId(buffer[2])), + Some(hid::ReportId(buffer.get(2).copied().ok_or(Error::Hid( + hid::Error::InvalidSize(InvalidSizeError { + expected: 3, + actual: buffer_len, + }), + ))?)), self.buffer.reference().slice(3..length as usize), )) } @@ -103,15 +121,13 @@ impl Host { self.read_bus(DATA_READ_TIMEOUT_MS, &mut cmd).await?; let cmd = u16::from_le_bytes(cmd); - let opcode = Opcode::try_from(cmd); - if let Err(e) = opcode { + let opcode = Opcode::try_from(cmd).map_err(|e| { error!("Invalid command {:#x}", cmd); - return Err(Error::Hid(e)); - } + Error::Hid(e) + })?; trace!("Command {:#x}", cmd); // Get report ID - let opcode = opcode.unwrap(); trace!("Opcode {:?}", opcode); let report_id = if opcode.requires_report_id() { // See if we need to read another byte for the full report ID @@ -145,18 +161,37 @@ impl Host { trace!("Waiting for data"); let mut borrow = self.buffer.borrow_mut(); let buffer: &mut [u8] = borrow.borrow_mut(); - - self.read_bus(DATA_READ_TIMEOUT_MS, &mut buffer[0..2]).await?; - - let length = u16::from_le_bytes([buffer[0], buffer[1]]); - if buffer.len() < length as usize { - error!("Buffer overrun: {}", length); - return Err(Error::Hid(hid::Error::InvalidSize(length as usize, buffer.len()))); - } + let buffer_len = buffer.len(); + + self.read_bus( + DATA_READ_TIMEOUT_MS, + buffer + .get_mut(0..2) + .ok_or(Error::Hid(hid::Error::InvalidSize(InvalidSizeError { + expected: 2, + actual: buffer_len, + })))?, + ) + .await?; + + let length = u16::from_le_bytes(buffer.get(..2).and_then(|b| <[u8; 2]>::try_from(b).ok()).ok_or( + Error::Hid(hid::Error::InvalidSize(InvalidSizeError { + expected: 2, + actual: buffer_len, + })), + )?); trace!("Reading {} bytes", length); - self.read_bus(DATA_READ_TIMEOUT_MS, &mut buffer[2..length as usize]) - .await?; + self.read_bus( + DATA_READ_TIMEOUT_MS, + buffer + .get_mut(2..length as usize) + .ok_or(Error::Hid(hid::Error::InvalidSize(InvalidSizeError { + expected: length as usize, + actual: buffer_len, + })))?, + ) + .await?; Some(self.buffer.reference().slice(2..length as usize)) } else { None @@ -168,12 +203,13 @@ impl Host { // Create command let report_type = hid::ReportType::try_from(cmd).ok(); let command = hid::Command::new(cmd, opcode, report_type, report_id, buffer); - if let Err(e) = command { - error!("Invalid command {:?}", e); - return Err(Error::Hid(hid::Error::InvalidCommand)); + match command { + Ok(command) => Ok(command), + Err(e) => { + error!("Invalid command {:?}", e); + Err(Error::Hid(hid::Error::InvalidCommand)) + } } - - Ok(command.unwrap()) } /// Handle an access to a specific register diff --git a/hid-service/src/i2c/passthrough/interrupt.rs b/hid-service/src/i2c/passthrough/interrupt.rs index d574e363..fac8d43c 100644 --- a/hid-service/src/i2c/passthrough/interrupt.rs +++ b/hid-service/src/i2c/passthrough/interrupt.rs @@ -14,7 +14,17 @@ pub struct InterruptSignal { signal: Signal, } -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Error { + /// Failed to read incoming interrupt line + IoRead, + /// Failed to assert/deassert outgoing interrupt line + IoSet, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] enum InterruptState { Idle, Asserted, @@ -57,15 +67,15 @@ impl InterruptSignal { self.signal.signal(()); } - pub async fn process(&self) { + pub async fn process(&self) -> Result<(), Error> { let mut int_in = self.int_in.lock().await; let mut int_out = self.int_out.lock().await; trace!("Waiting for interrupt"); - int_in.wait_for_low().await.unwrap(); + int_in.wait_for_low().await.map_err(|_| Error::IoRead)?; - int_out.set_low().unwrap(); + int_out.set_low().map_err(|_| Error::IoSet)?; { let mut state = self.state.lock().await; *state = InterruptState::Asserted; @@ -73,14 +83,14 @@ impl InterruptSignal { trace!("Interrupt received"); self.signal.wait().await; - int_out.set_high().unwrap(); + int_out.set_high().map_err(|_| Error::IoSet)?; trace!("Interrupt deasserted"); { let mut state = self.state.lock().await; if *state == InterruptState::Reset { *state = InterruptState::Idle; - return; + return Ok(()); } } @@ -91,5 +101,7 @@ impl InterruptSignal { *state = InterruptState::Idle; } trace!("Interrupt cleared"); + + Ok(()) } }