diff --git a/Cargo.toml b/Cargo.toml index de692d4..683d24b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,8 @@ log = "0.4.17" num-traits = "0.2.15" num-derive = "0.3.3" rusb = "0.9.1" +async-trait = "0.1.73" +async-scoped = { version = "0.7.1", features = ["use-tokio"] } [dev-dependencies] tokio = { version = "1.22.0", features = ["full"] } diff --git a/examples/cdc_acm_serial.rs b/examples/cdc_acm_serial.rs index 598c512..e6ad2fb 100644 --- a/examples/cdc_acm_serial.rs +++ b/examples/cdc_acm_serial.rs @@ -3,13 +3,15 @@ use std::net::*; use std::sync::{Arc, Mutex}; use std::time::Duration; +use usbip::server::*; + #[tokio::main] async fn main() { env_logger::init(); let handler = Arc::new(Mutex::new(Box::new(usbip::cdc::UsbCdcAcmHandler::new()) as Box)); - let server = Arc::new(usbip::UsbIpServer::new_simulated(vec![ + let server = SyncUsbIpServer::new_simulated(vec![ usbip::UsbDevice::new(0).with_interface( usbip::ClassCode::CDC as u8, usbip::cdc::CDC_ACM_SUBCLASS, @@ -18,9 +20,9 @@ async fn main() { usbip::cdc::UsbCdcAcmHandler::endpoints(), handler.clone(), ), - ])); + ]); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 3240); - tokio::spawn(usbip::server(addr, server)); + tokio::spawn(server.serve(addr)); loop { // sleep 1s diff --git a/examples/hid_keyboard.rs b/examples/hid_keyboard.rs index e73034a..5e192b2 100644 --- a/examples/hid_keyboard.rs +++ b/examples/hid_keyboard.rs @@ -3,6 +3,8 @@ use std::net::*; use std::sync::{Arc, Mutex}; use std::time::Duration; +use usbip::server::*; + #[tokio::main] async fn main() { env_logger::init(); @@ -10,7 +12,7 @@ async fn main() { Box::new(usbip::hid::UsbHidKeyboardHandler::new_keyboard()) as Box, )); - let server = Arc::new(usbip::UsbIpServer::new_simulated(vec![ + let server = SyncUsbIpServer::new_simulated(vec![ usbip::UsbDevice::new(0).with_interface( usbip::ClassCode::HID as u8, 0x00, @@ -24,9 +26,9 @@ async fn main() { }], handler.clone(), ), - ])); + ]); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 3240); - tokio::spawn(usbip::server(addr, server)); + tokio::spawn(server.serve(addr)); loop { // sleep 1s diff --git a/examples/host.rs b/examples/host.rs index d1d6f40..98578d1 100644 --- a/examples/host.rs +++ b/examples/host.rs @@ -1,13 +1,15 @@ use std::net::*; -use std::sync::Arc; + use std::time::Duration; +use usbip::server::*; + #[tokio::main] async fn main() { env_logger::init(); - let server = Arc::new(usbip::UsbIpServer::new_from_host()); + let server = usbip::server::SyncUsbIpServer::new_from_host(); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 3240); - tokio::spawn(usbip::server(addr, server)); + tokio::spawn(server.serve(addr)); loop { // sleep 1s diff --git a/src/lib.rs b/src/lib.rs index 330689a..b625aca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,16 +4,12 @@ use log::*; use num_derive::FromPrimitive; use num_traits::FromPrimitive; use rusb::*; -use std::any::Any; -use std::collections::{HashMap, VecDeque}; -use std::io::{ErrorKind, Result}; -use std::net::SocketAddr; -use std::sync::{Arc, Mutex}; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; -use tokio::net::TcpListener; -use tokio::sync::RwLock; -use usbip_protocol::UsbIpCommand; +use std::{ + any::Any, + collections::{HashMap, VecDeque}, + io::Result, + sync::{Arc, Mutex}, +}; pub mod cdc; mod consts; @@ -22,9 +18,11 @@ mod endpoint; pub mod hid; mod host; mod interface; +pub mod server; mod setup; pub mod usbip_protocol; mod util; + pub use consts::*; pub use device::*; pub use endpoint::*; @@ -32,695 +30,3 @@ pub use host::*; pub use interface::*; pub use setup::*; pub use util::*; - -use crate::usbip_protocol::{UsbIpResponse, USBIP_RET_SUBMIT, USBIP_RET_UNLINK}; - -/// Main struct of a USB/IP server -#[derive(Default)] -pub struct UsbIpServer { - available_devices: RwLock>, - used_devices: RwLock>, -} - -impl UsbIpServer { - /// Create a [UsbIpServer] with simulated devices - pub fn new_simulated(devices: Vec) -> Self { - Self { - available_devices: RwLock::new(devices), - used_devices: RwLock::new(HashMap::new()), - } - } - - fn with_devices(device_list: Vec>) -> Vec { - let mut devices = vec![]; - - for dev in device_list { - let open_device = match dev.open() { - Ok(dev) => dev, - Err(err) => { - warn!("Impossible to share {:?}: {}", dev, err); - continue; - } - }; - let handle = Arc::new(Mutex::new(open_device)); - let desc = dev.device_descriptor().unwrap(); - let cfg = dev.active_config_descriptor().unwrap(); - let mut interfaces = vec![]; - handle - .lock() - .unwrap() - .set_auto_detach_kernel_driver(true) - .ok(); - for intf in cfg.interfaces() { - // ignore alternate settings - let intf_desc = intf.descriptors().next().unwrap(); - handle - .lock() - .unwrap() - .set_auto_detach_kernel_driver(true) - .ok(); - let mut endpoints = vec![]; - - for ep_desc in intf_desc.endpoint_descriptors() { - endpoints.push(UsbEndpoint { - address: ep_desc.address(), - attributes: ep_desc.transfer_type() as u8, - max_packet_size: ep_desc.max_packet_size(), - interval: ep_desc.interval(), - }); - } - - let handler = Arc::new(Mutex::new(Box::new(UsbHostInterfaceHandler::new( - handle.clone(), - )) - as Box)); - interfaces.push(UsbInterface { - interface_class: intf_desc.class_code(), - interface_subclass: intf_desc.sub_class_code(), - interface_protocol: intf_desc.protocol_code(), - endpoints, - string_interface: intf_desc.description_string_index().unwrap_or(0), - class_specific_descriptor: Vec::from(intf_desc.extra()), - handler, - }); - } - let mut device = UsbDevice { - path: format!( - "/sys/bus/{}/{}/{}", - dev.bus_number(), - dev.address(), - dev.port_number() - ), - bus_id: format!( - "{}-{}-{}", - dev.bus_number(), - dev.address(), - dev.port_number() - ), - bus_num: dev.bus_number() as u32, - dev_num: dev.port_number() as u32, - speed: dev.speed() as u32, - vendor_id: desc.vendor_id(), - product_id: desc.product_id(), - device_class: desc.class_code(), - device_subclass: desc.sub_class_code(), - device_protocol: desc.protocol_code(), - device_bcd: desc.device_version().into(), - configuration_value: cfg.number(), - num_configurations: desc.num_configurations(), - ep0_in: UsbEndpoint { - address: 0x80, - attributes: EndpointAttributes::Control as u8, - max_packet_size: desc.max_packet_size() as u16, - interval: 0, - }, - ep0_out: UsbEndpoint { - address: 0x00, - attributes: EndpointAttributes::Control as u8, - max_packet_size: desc.max_packet_size() as u16, - interval: 0, - }, - interfaces, - device_handler: Some(Arc::new(Mutex::new(Box::new(UsbHostDeviceHandler::new( - handle.clone(), - ))))), - usb_version: desc.usb_version().into(), - ..UsbDevice::default() - }; - - // set strings - if let Some(index) = desc.manufacturer_string_index() { - device.string_manufacturer = device.new_string( - &handle - .lock() - .unwrap() - .read_string_descriptor_ascii(index) - .unwrap(), - ) - } - if let Some(index) = desc.product_string_index() { - device.string_product = device.new_string( - &handle - .lock() - .unwrap() - .read_string_descriptor_ascii(index) - .unwrap(), - ) - } - if let Some(index) = desc.serial_number_string_index() { - device.string_serial = device.new_string( - &handle - .lock() - .unwrap() - .read_string_descriptor_ascii(index) - .unwrap(), - ) - } - devices.push(device); - } - devices - } - - /// Create a [UsbIpServer] exposing devices in the host, and redirect all USB transfers to them using libusb - pub fn new_from_host() -> Self { - match rusb::devices() { - Ok(list) => { - let mut devs = vec![]; - for d in list.iter() { - devs.push(d) - } - Self { - available_devices: RwLock::new(Self::with_devices(devs)), - ..Default::default() - } - } - Err(_) => Default::default(), - } - } - - pub fn new_from_host_with_filter(filter: F) -> Self - where - F: FnMut(&Device) -> bool, - { - match rusb::devices() { - Ok(list) => { - let mut devs = vec![]; - for d in list.iter().filter(filter) { - devs.push(d) - } - Self { - available_devices: RwLock::new(Self::with_devices(devs)), - ..Default::default() - } - } - Err(_) => Default::default(), - } - } - - pub async fn add_device(&self, device: UsbDevice) { - self.available_devices.write().await.push(device); - } - - pub async fn remove_device(&self, bus_id: &str) -> Result<()> { - let mut available_devices = self.available_devices.write().await; - - if let Some(device) = available_devices.iter().position(|d| d.bus_id == bus_id) { - available_devices.remove(device); - Ok(()) - } else if let Some(device) = self - .used_devices - .read() - .await - .values() - .find(|d| d.bus_id == bus_id) - { - Err(std::io::Error::new( - ErrorKind::Other, - format!("Device {} is in use", device.bus_id), - )) - } else { - Err(std::io::Error::new( - ErrorKind::NotFound, - format!("Device {} not found", bus_id), - )) - } - } -} - -pub async fn handler( - mut socket: &mut T, - server: Arc, -) -> Result<()> { - let mut current_import_device_id: Option = None; - loop { - let command = UsbIpCommand::read_from_socket(&mut socket).await; - if let Err(err) = command { - if let Some(dev_id) = current_import_device_id { - let mut used_devices = server.used_devices.write().await; - let mut available_devices = server.available_devices.write().await; - match used_devices.remove(&dev_id) { - Some(dev) => available_devices.push(dev), - None => unreachable!(), - } - } - - if err.kind() == ErrorKind::UnexpectedEof { - info!("Remote closed the connection"); - return Ok(()); - } else { - return Err(err); - } - } - - let used_devices = server.used_devices.read().await; - let mut current_import_device = current_import_device_id - .clone() - .and_then(|ref id| used_devices.get(id)); - - match command.unwrap() { - UsbIpCommand::OpReqDevlist { .. } => { - trace!("Got OP_REQ_DEVLIST"); - let devices = server.available_devices.read().await; - - // OP_REP_DEVLIST - UsbIpResponse::op_rep_devlist(&devices) - .write_to_socket(socket) - .await?; - trace!("Sent OP_REP_DEVLIST"); - } - UsbIpCommand::OpReqImport { busid, .. } => { - trace!("Got OP_REQ_IMPORT"); - - current_import_device_id = None; - current_import_device = None; - std::mem::drop(used_devices); - - let mut used_devices = server.used_devices.write().await; - let mut available_devices = server.available_devices.write().await; - for (i, dev) in available_devices.iter().enumerate() { - let mut expected = dev.bus_id.as_bytes().to_vec(); - expected.resize(32, 0); - if expected.as_slice() == busid { - let dev = available_devices.remove(i); - let dev_id = dev.bus_id.clone(); - used_devices.insert(dev.bus_id.clone(), dev); - current_import_device_id = dev_id.clone().into(); - current_import_device = Some(used_devices.get(&dev_id).unwrap()); - break; - } - } - - let res = if let Some(dev) = current_import_device { - UsbIpResponse::op_rep_import_success(dev) - } else { - UsbIpResponse::op_rep_import_fail() - }; - res.write_to_socket(socket).await?; - trace!("Sent OP_REP_IMPORT"); - } - UsbIpCommand::UsbIpCmdSubmit { - mut header, - transfer_buffer_length, - setup, - data, - .. - } => { - trace!("Got USBIP_CMD_SUBMIT"); - let device = current_import_device.unwrap(); - - let out = header.direction == 0; - let real_ep = if out { header.ep } else { header.ep | 0x80 }; - - header.command = USBIP_RET_SUBMIT.into(); - - let res = match device.find_ep(real_ep as u8) { - None => { - warn!("Endpoint {:02x?} not found", real_ep); - UsbIpResponse::usbip_ret_submit_fail(&header) - } - Some((ep, intf)) => { - trace!("->Endpoint {:02x?}", ep); - trace!("->Setup {:02x?}", setup); - trace!("->Request {:02x?}", data); - let resp = device - .handle_urb( - ep, - intf, - transfer_buffer_length, - SetupPacket::parse(&setup), - &data, - ) - .await?; - - if out { - trace!("<-Wrote {}", data.len()); - } else { - trace!("<-Resp {:02x?}", resp); - } - - UsbIpResponse::usbip_ret_submit_success(&header, 0, 0, resp, vec![]) - } - }; - res.write_to_socket(socket).await?; - trace!("Sent USBIP_RET_SUBMIT"); - } - UsbIpCommand::UsbIpCmdUnlink { - mut header, - unlink_seqnum: _, - } => { - trace!("Got USBIP_CMD_UNLINK"); - - std::mem::drop(used_devices); - - let mut used_devices = server.used_devices.write().await; - let mut available_devices = server.available_devices.write().await; - - let dev = current_import_device_id - .clone() - .and_then(|ref k| used_devices.remove(k)); - - header.command = USBIP_RET_UNLINK.into(); - - let res = match dev { - Some(dev) => { - available_devices.push(dev); - current_import_device_id = None; - UsbIpResponse::usbip_ret_unlink_success(&header) - } - None => { - warn!("Device not found"); - UsbIpResponse::usbip_ret_unlink_fail(&header) - } - }; - res.write_to_socket(socket).await?; - trace!("Sent USBIP_RET_UNLINK"); - } - } - } -} - -/// Spawn a USB/IP server at `addr` using [TcpListener] -pub async fn server(addr: SocketAddr, server: Arc) { - let listener = TcpListener::bind(addr).await.expect("bind to addr"); - - let server = async move { - loop { - match listener.accept().await { - Ok((mut socket, _addr)) => { - info!("Got connection from {:?}", socket.peer_addr()); - let new_server = server.clone(); - tokio::spawn(async move { - let res = handler(&mut socket, new_server).await; - info!("Handler ended with {:?}", res); - }); - } - Err(err) => { - warn!("Got error {:?}", err); - } - } - } - }; - - server.await -} - -#[cfg(test)] -mod tests { - use tokio::{net::TcpStream, task::JoinSet}; - - use super::*; - use crate::{ - usbip_protocol::{UsbIpHeaderBasic, USBIP_CMD_SUBMIT, USBIP_CMD_UNLINK}, - util::tests::*, - }; - - const SINGLE_DEVICE_BUSID: &str = "0-0-0"; - - fn new_server_with_single_device() -> UsbIpServer { - UsbIpServer::new_simulated(vec![UsbDevice::new(0).with_interface( - ClassCode::CDC as u8, - cdc::CDC_ACM_SUBCLASS, - 0x00, - "Test CDC ACM", - cdc::UsbCdcAcmHandler::endpoints(), - Arc::new(Mutex::new( - Box::new(cdc::UsbCdcAcmHandler::new()) as Box - )), - )]) - } - - fn op_req_import(busid: &str) -> Vec { - let mut busid = busid.to_string().as_bytes().to_vec(); - busid.resize(32, 0); - UsbIpCommand::OpReqImport { - status: 0, - busid: busid.try_into().unwrap(), - } - .to_bytes() - } - - async fn attach_device(connection: &mut TcpStream, busid: &str) -> u32 { - let req = op_req_import(busid); - connection.write_all(req.as_slice()).await.unwrap(); - connection.read_u32().await.unwrap(); - let result = connection.read_u32().await.unwrap(); - if result == 0 { - connection.read_exact(&mut vec![0; 0x138]).await.unwrap(); - } - result - } - - #[tokio::test] - async fn req_empty_devlist() { - setup_test_logger(); - let server = UsbIpServer::new_simulated(vec![]); - let req = UsbIpCommand::OpReqDevlist { status: 0 }; - - let mut mock_socket = MockSocket::new(req.to_bytes()); - handler(&mut mock_socket, Arc::new(server)).await.ok(); - - assert_eq!( - mock_socket.output, - UsbIpResponse::op_rep_devlist(&[]).to_bytes(), - ); - } - - #[tokio::test] - async fn req_sample_devlist() { - setup_test_logger(); - let server = new_server_with_single_device(); - let req = UsbIpCommand::OpReqDevlist { status: 0 }; - - let mut mock_socket = MockSocket::new(req.to_bytes()); - handler(&mut mock_socket, Arc::new(server)).await.ok(); - - // OP_REP_DEVLIST - // header: 0xC - // device: 0x138 - // interface: 4 * 0x1 - assert_eq!(mock_socket.output.len(), 0xC + 0x138 + 4); - } - - #[tokio::test] - async fn req_import() { - setup_test_logger(); - let server = new_server_with_single_device(); - - // OP_REQ_IMPORT - let req = op_req_import(SINGLE_DEVICE_BUSID); - let mut mock_socket = MockSocket::new(req); - handler(&mut mock_socket, Arc::new(server)).await.ok(); - // OP_REQ_IMPORT - assert_eq!(mock_socket.output.len(), 0x140); - } - - #[tokio::test] - async fn add_and_remove_10_devices() { - setup_test_logger(); - let server_ = Arc::new(UsbIpServer::new_simulated(vec![])); - let addr = get_free_address().await; - tokio::spawn(server(addr, server_.clone())); - - let mut join_set = JoinSet::new(); - let devices = (0..10).map(UsbDevice::new).collect::>(); - - for device in devices.iter() { - let new_server = server_.clone(); - let new_device = device.clone(); - join_set.spawn(async move { - new_server.add_device(new_device).await; - }); - } - - for device in devices.iter() { - let new_server = server_.clone(); - let new_device = device.clone(); - join_set.spawn(async move { - new_server.remove_device(&new_device.bus_id).await.unwrap(); - }); - } - - while join_set.join_next().await.is_some() {} - - let device_len = server_.clone().available_devices.read().await.len(); - - assert_eq!(device_len, 0); - } - - #[tokio::test] - async fn send_usb_traffic_while_adding_and_removing_devices() { - setup_test_logger(); - let server_ = Arc::new(new_server_with_single_device()); - - let addr = get_free_address().await; - tokio::spawn(server(addr, server_.clone())); - - let cmd_loop_handle = tokio::spawn(async move { - let mut connection = poll_connect(addr).await; - let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; - assert_eq!(result, 0); - - let cdc_loopback_bulk_cmd = UsbIpCommand::UsbIpCmdSubmit { - header: usbip_protocol::UsbIpHeaderBasic { - command: USBIP_CMD_SUBMIT.into(), - seqnum: 1, - devid: 0, - direction: 0, // OUT - ep: 2, - }, - transfer_flags: 0, - transfer_buffer_length: 8, - start_frame: 0, - number_of_packets: 0, - interval: 0, - setup: [0; 8], - data: vec![1, 2, 3, 4, 5, 6, 7, 8], - iso_packet_descriptor: vec![], - }; - - loop { - connection - .write_all(cdc_loopback_bulk_cmd.to_bytes().as_slice()) - .await - .unwrap(); - let mut result = vec![0; 4 * 12]; - connection.read_exact(&mut result).await.unwrap(); - } - }); - - let add_and_remove_device_handle = tokio::spawn(async move { - let mut join_set = JoinSet::new(); - let devices = (1..4).map(UsbDevice::new).collect::>(); - - loop { - for device in devices.iter() { - let new_server = server_.clone(); - let new_device = device.clone(); - join_set.spawn(async move { - new_server.add_device(new_device).await; - }); - } - - for device in devices.iter() { - let new_server = server_.clone(); - let new_device = device.clone(); - join_set.spawn(async move { - new_server.remove_device(&new_device.bus_id).await.unwrap(); - }); - } - while join_set.join_next().await.is_some() {} - tokio::time::sleep(tokio::time::Duration::from_millis(20)).await; - } - }); - - tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; - cmd_loop_handle.abort(); - add_and_remove_device_handle.abort(); - } - - #[tokio::test] - async fn only_single_connection_allowed_to_device() { - setup_test_logger(); - let server_ = Arc::new(new_server_with_single_device()); - - let addr = get_free_address().await; - tokio::spawn(server(addr, server_.clone())); - - let mut first_connection = poll_connect(addr).await; - let mut second_connection = TcpStream::connect(addr).await.unwrap(); - - let result = attach_device(&mut first_connection, SINGLE_DEVICE_BUSID).await; - assert_eq!(result, 0); - - let result = attach_device(&mut second_connection, SINGLE_DEVICE_BUSID).await; - assert_eq!(result, 1); - } - - #[tokio::test] - async fn device_gets_released_on_cmd_unlink() { - setup_test_logger(); - let server_ = Arc::new(new_server_with_single_device()); - - let addr = get_free_address().await; - tokio::spawn(server(addr, server_.clone())); - - let mut connection = poll_connect(addr).await; - - let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; - assert_eq!(result, 0); - - let unlink_req = UsbIpCommand::UsbIpCmdUnlink { - header: UsbIpHeaderBasic { - command: USBIP_CMD_UNLINK.into(), - seqnum: 1, - devid: 0, - direction: 0, - ep: 0, - }, - unlink_seqnum: 0, - } - .to_bytes(); - - connection.write_all(unlink_req.as_slice()).await.unwrap(); - connection.read_exact(&mut [0; 4 * 5]).await.unwrap(); - let result = connection.read_u32().await.unwrap(); - connection.read_exact(&mut [0; 4 * 6]).await.unwrap(); - assert_eq!(result, 0); - - let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; - assert_eq!(result, 0); - } - - #[tokio::test] - async fn device_gets_released_on_closed_socket() { - setup_test_logger(); - let server_ = Arc::new(new_server_with_single_device()); - - let addr = get_free_address().await; - tokio::spawn(server(addr, server_.clone())); - - let mut connection = poll_connect(addr).await; - let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; - assert_eq!(result, 0); - - std::mem::drop(connection); - - let mut connection = TcpStream::connect(addr).await.unwrap(); - let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; - assert_eq!(result, 0); - } - - #[tokio::test] - async fn req_import_get_device_desc() { - setup_test_logger(); - let server = new_server_with_single_device(); - - let mut req = op_req_import(SINGLE_DEVICE_BUSID); - req.extend( - UsbIpCommand::UsbIpCmdSubmit { - header: UsbIpHeaderBasic { - command: USBIP_CMD_SUBMIT.into(), - seqnum: 1, - devid: 0, - direction: 1, // IN - ep: 0, - }, - transfer_flags: 0, - transfer_buffer_length: 0, - start_frame: 0, - number_of_packets: 0, - interval: 0, - // GetDescriptor to Device - setup: [0x80, 0x06, 0x00, 0x01, 0x00, 0x00, 0x40, 0x00], - data: vec![], - iso_packet_descriptor: vec![], - } - .to_bytes(), - ); - - let mut mock_socket = MockSocket::new(req); - handler(&mut mock_socket, Arc::new(server)).await.ok(); - // OP_REQ_IMPORT + USBIP_CMD_SUBMIT + Device Descriptor - assert_eq!(mock_socket.output.len(), 0x140 + 0x30 + 0x12); - } -} diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..9241487 --- /dev/null +++ b/src/server.rs @@ -0,0 +1,222 @@ +mod async_server; +mod sync_server; + +use crate::{ + usbip_protocol::{UsbIpCommand, UsbIpResponse, USBIP_RET_SUBMIT, USBIP_RET_UNLINK}, + EndpointAttributes, SetupPacket, UsbDevice, UsbEndpoint, UsbHostDeviceHandler, + UsbHostInterfaceHandler, UsbInterface, UsbInterfaceHandler, +}; + +use async_trait::async_trait; +use log::{info, trace, warn}; +use rusb::{Device, GlobalContext}; + +use std::{ + collections::HashMap, + io::{ErrorKind, Result}, + net::SocketAddr, + sync::{Arc, Mutex}, +}; +use tokio::{io::AsyncReadExt, io::AsyncWriteExt, net::TcpListener, sync::RwLock}; + +pub use async_server::*; +pub use sync_server::*; + +/// A USB/IP server. +/// +/// A server that can host USB devices and expose them to clients, +/// using the USB/IP protocol. +#[async_trait] +pub trait UsbIpServer: Send + Clone { + /// Create a [UsbIpServer] with simulated devices + fn new_simulated(devices: Vec) -> Self; + + /// Create a [UsbIpServer] exposing devices in the host, and redirect all USB transfers to them using libusb + fn new_from_host() -> Self { + Self::new_from_host_with_filter(|_| true) + } + + // TODO: This is actually a default impl if Default is implemented for + // UsbIpServer. However, RFC 1210 is not stable yet. When it is, + // it would be nice to provide a specialized impl here. + + /// Create a [UsbIpServer] exposing a filtered set of devices in the host, and redirect all USB transfers to them using libusb + /// + /// Similar to [new_from_host], but only devices that pass the filter will be exposed. + fn new_from_host_with_filter(filter: F) -> Self + where + F: FnMut(&Device) -> bool; + + /// Add a [UsbDevice] to the servers available devices + async fn add_device(&self, device: UsbDevice); + + /// Remove a [UsbDevice] from the server + /// + /// This function will return an error if the device cannot + /// be identified by its `bus_id`, or if the device is currently + /// attached to a client. + async fn remove_device(&self, bus_id: &str) -> Result<()>; + + /// Internal per-device server loop. + /// + /// This is usually called by [serve] for every new connection-device pair, + /// and is responsible for forwarding USB packets between the socket and device. + async fn handler(self, socket: S) -> Result<()> + where + S: AsyncReadExt + AsyncWriteExt + Unpin + Send; + + // Spawn a USB/IP server at `addr` using [TcpListener] + // + // This will host a USB/IP endpoint at the `addr`, + // and spawn a `handler` for every new connection. + async fn serve(self, addr: SocketAddr); +} + +fn get_list_of_real_devices(device_list: Vec>) -> Vec { + let mut devices = vec![]; + + for dev in device_list { + let open_device = match dev.open() { + Ok(dev) => dev, + Err(err) => { + warn!("Impossible to share {:?}: {}, ignoring device", dev, err); + continue; + } + }; + let desc = match dev.device_descriptor() { + Ok(desc) => desc, + Err(err) => { + warn!( + "Impossible to get device descriptor for {:?}: {}, ignoring device", + dev, err + ); + continue; + } + }; + let cfg = match dev.active_config_descriptor() { + Ok(desc) => desc, + Err(err) => { + warn!( + "Impossible to get config descriptor for {:?}: {}, ignoring device", + dev, err + ); + continue; + } + }; + + let handle = Arc::new(Mutex::new(open_device)); + let mut interfaces = vec![]; + handle + .lock() + .unwrap() + .set_auto_detach_kernel_driver(true) + .ok(); + for intf in cfg.interfaces() { + // ignore alternate settings + let intf_desc = intf.descriptors().next().unwrap(); + handle + .lock() + .unwrap() + .set_auto_detach_kernel_driver(true) + .ok(); + let mut endpoints = vec![]; + + for ep_desc in intf_desc.endpoint_descriptors() { + endpoints.push(UsbEndpoint { + address: ep_desc.address(), + attributes: ep_desc.transfer_type() as u8, + max_packet_size: ep_desc.max_packet_size(), + interval: ep_desc.interval(), + }); + } + + let handler = Arc::new(Mutex::new( + Box::new(UsbHostInterfaceHandler::new(handle.clone())) + as Box, + )); + interfaces.push(UsbInterface { + interface_class: intf_desc.class_code(), + interface_subclass: intf_desc.sub_class_code(), + interface_protocol: intf_desc.protocol_code(), + endpoints, + string_interface: intf_desc.description_string_index().unwrap_or(0), + class_specific_descriptor: Vec::from(intf_desc.extra()), + handler, + }); + } + let mut device = UsbDevice { + path: format!( + "/sys/bus/{}/{}/{}", + dev.bus_number(), + dev.address(), + dev.port_number() + ), + bus_id: format!( + "{}-{}-{}", + dev.bus_number(), + dev.address(), + dev.port_number() + ), + bus_num: dev.bus_number() as u32, + dev_num: dev.port_number() as u32, + speed: dev.speed() as u32, + vendor_id: desc.vendor_id(), + product_id: desc.product_id(), + device_class: desc.class_code(), + device_subclass: desc.sub_class_code(), + device_protocol: desc.protocol_code(), + device_bcd: desc.device_version().into(), + configuration_value: cfg.number(), + num_configurations: desc.num_configurations(), + ep0_in: UsbEndpoint { + address: 0x80, + attributes: EndpointAttributes::Control as u8, + max_packet_size: desc.max_packet_size() as u16, + interval: 0, + }, + ep0_out: UsbEndpoint { + address: 0x00, + attributes: EndpointAttributes::Control as u8, + max_packet_size: desc.max_packet_size() as u16, + interval: 0, + }, + interfaces, + device_handler: Some(Arc::new(Mutex::new(Box::new(UsbHostDeviceHandler::new( + handle.clone(), + ))))), + usb_version: desc.usb_version().into(), + ..UsbDevice::default() + }; + + // set strings + if let Some(index) = desc.manufacturer_string_index() { + device.string_manufacturer = device.new_string( + &handle + .lock() + .unwrap() + .read_string_descriptor_ascii(index) + .unwrap(), + ) + } + if let Some(index) = desc.product_string_index() { + device.string_product = device.new_string( + &handle + .lock() + .unwrap() + .read_string_descriptor_ascii(index) + .unwrap(), + ) + } + if let Some(index) = desc.serial_number_string_index() { + device.string_serial = device.new_string( + &handle + .lock() + .unwrap() + .read_string_descriptor_ascii(index) + .unwrap(), + ) + } + devices.push(device); + } + devices +} \ No newline at end of file diff --git a/src/server/async_server.rs b/src/server/async_server.rs new file mode 100644 index 0000000..dff3956 --- /dev/null +++ b/src/server/async_server.rs @@ -0,0 +1,608 @@ +use std::vec; + +use tokio::{ + io::{split, AsyncReadExt, AsyncWriteExt, ErrorKind, Result}, + sync::mpsc, +}; + +use super::*; + +#[derive(Default, Clone)] +pub struct AsyncUsbIpServer { + available_devices: Arc>>, + used_devices: Arc>>, +} + +#[async_trait] +impl UsbIpServer for AsyncUsbIpServer { + fn new_simulated(devices: Vec) -> Self { + Self { + available_devices: Arc::new(RwLock::new(devices)), + used_devices: Arc::new(RwLock::new(HashMap::new())), + } + } + + fn new_from_host_with_filter(filter: F) -> Self + where + F: FnMut(&Device) -> bool, + { + match rusb::devices() { + Ok(list) => { + let mut devs = vec![]; + for d in list.iter().filter(filter) { + devs.push(d) + } + Self::new_simulated(get_list_of_real_devices(devs)) + } + Err(_) => Default::default(), + } + } + + async fn add_device(&self, device: UsbDevice) { + trace!("Adding device {:?}", device.bus_id); + self.available_devices.write().await.push(device); + } + + async fn remove_device(&self, bus_id: &str) -> Result<()> { + let mut available_devices = self.available_devices.write().await; + + if let Some(device) = available_devices.iter().position(|d| d.bus_id == bus_id) { + trace!("Removing device {:?}", bus_id); + available_devices.remove(device); + Ok(()) + } else if let Some(device) = self + .used_devices + .read() + .await + .values() + .find(|d| d.bus_id == bus_id) + { + warn!("Device {} is in use", device.bus_id); + Err(std::io::Error::new( + ErrorKind::Other, + format!("Device {} is in use", device.bus_id), + )) + } else { + warn!("Device {} not found", bus_id); + Err(std::io::Error::new( + ErrorKind::NotFound, + format!("Device {} not found", bus_id), + )) + } + } + + async fn handler( + self, + socket: T, + ) -> Result<()> { + let _ = async_scoped::TokioScope::scope_and_block::, _>(|scope| { + let (internal_tx, mut internal_rx) = mpsc::unbounded_channel::(); + let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel::<()>(); + let (mut sock_rx, mut sock_tx) = split(socket); + + // RX Thread + scope.spawn(async move { + let mut current_import_device_id: Option = None; + let internal_tx = Arc::new(Mutex::new(internal_tx)); + + loop { + let command = UsbIpCommand::read_from_socket(&mut sock_rx).await; + if let Err(err) = command { + if let Some(dev_id) = current_import_device_id { + let mut used_devices = self.used_devices.write().await; + let mut available_devices = self.available_devices.write().await; + match used_devices.remove(&dev_id) { + Some(dev) => available_devices.push(dev), + None => unreachable!(), + } + } + + shutdown_tx.send(()).unwrap(); + if err.kind() == ErrorKind::UnexpectedEof { + info!("[RX] Remote closed the connection"); + } else { + warn!("[RX] Exiting due to broken socket"); + warn!("{:?}", err); + } + return; + } + + let cmd = command.unwrap(); + trace!("[RX] Got command: {:?}", cmd); + + let used_devices = self.used_devices.read().await; + let mut current_import_device = current_import_device_id + .clone() + .and_then(|ref id| used_devices.get(id)); + + match cmd { + UsbIpCommand::OpReqDevlist { .. } => { + let device_list = self.available_devices.read().await; + internal_tx + .clone() + .lock() + .unwrap() + .send(UsbIpResponse::op_rep_devlist(&device_list)) + .unwrap(); + } + UsbIpCommand::OpReqImport { busid, .. } => { + current_import_device_id = None; + current_import_device = None; + std::mem::drop(used_devices); + + let mut used_devices = self.used_devices.write().await; + let mut available_devices = self.available_devices.write().await; + + for (i, dev) in available_devices.iter().enumerate() { + let mut expected = dev.bus_id.as_bytes().to_vec(); + expected.resize(32, 0); + if expected.as_slice() == busid { + let dev = available_devices.remove(i); + let dev_id = dev.bus_id.clone(); + used_devices.insert(dev.bus_id.clone(), dev); + current_import_device_id = dev_id.clone().into(); + current_import_device = + Some(used_devices.get(&dev_id).unwrap()); + break; + } + } + + let res = if let Some(dev) = current_import_device { + UsbIpResponse::op_rep_import_success(dev) + } else { + UsbIpResponse::op_rep_import_fail() + }; + + internal_tx.clone().lock().unwrap().send(res).unwrap(); + } + UsbIpCommand::UsbIpCmdUnlink { header, .. } => { + std::mem::drop(used_devices); + let mut used_devices = self.used_devices.write().await; + let mut available_devices = self.available_devices.write().await; + + let dev = current_import_device_id + .clone() + .and_then(|ref k| used_devices.remove(k)); + + let res = match dev { + Some(dev) => { + available_devices.push(dev); + current_import_device_id = None; + UsbIpResponse::usbip_ret_unlink_success(&header) + } + None => { + warn!("Device not found"); + UsbIpResponse::usbip_ret_unlink_fail(&header) + } + }; + internal_tx.clone().lock().unwrap().send(res).unwrap(); + } + + UsbIpCommand::UsbIpCmdSubmit { + mut header, + transfer_buffer_length, + setup, + data, + .. + } => { + let device = current_import_device.unwrap().clone(); + let internal_tx = internal_tx.clone(); + + tokio::spawn(async move { + let out = header.direction == 0; + let real_ep = if out { header.ep } else { header.ep | 0x80 }; + + header.command = USBIP_RET_SUBMIT.into(); + + let res = match device.find_ep(real_ep as u8) { + None => { + warn!("Endpoint {:02x?} not found", real_ep); + UsbIpResponse::usbip_ret_submit_fail(&header) + } + Some((ep, intf)) => { + trace!("->Endpoint {:02x?}", ep); + trace!("->Setup {:02x?}", setup); + trace!("->Request {:02x?}", data); + let resp = device + .handle_urb( + ep, + intf, + transfer_buffer_length, + SetupPacket::parse(&setup), + &data, + ) + .await + .unwrap(); + + if out { + trace!("<-Wrote {}", data.len()); + } else { + trace!("<-Resp {:02x?}", resp); + } + + UsbIpResponse::usbip_ret_submit_success( + &header, + 0, + 0, + resp, + vec![], + ) + } + }; + internal_tx.clone().lock().unwrap().send(res).unwrap(); + }); + } + } + } + }); + + // TX thread + scope.spawn(async move { + loop { + tokio::select! { + Some(res) = internal_rx.recv() => { + trace!("[TX] Staging response: {:?}", res.to_bytes()); + sock_tx.write_all(res.to_bytes().as_slice()).await.unwrap(); + trace!("[TX] Sent response"); + }, + _ = shutdown_rx.recv() => { + warn!("[TX] Exiting due to RX shutdown signal"); + return; + } + } + } + }); + Ok(()) + }); + Ok(()) + } + + async fn serve(self, addr: SocketAddr) { + trace!("Trying to listen on {:?}", addr); + let listener = TcpListener::bind(addr) + .await + .unwrap_or_else(|_| panic!("Could not bind to {}", addr)); + trace!("Listening on {:?}", addr); + + let server = async move { + loop { + match listener.accept().await { + Ok((mut socket, _addr)) => { + info!("Got connection from {:?}", socket.peer_addr()); + let new_server = self.clone(); + tokio::spawn(async move { + let res = new_server.handler(&mut socket).await; + info!("Handler ended with {:?}", res); + }); + } + Err(err) => { + warn!("Got error {:?}", err); + } + } + } + }; + + server.await + } +} + +#[cfg(test)] +mod tests { + use tokio::{net::TcpStream, task::JoinSet}; + + use super::*; + use crate::{ + cdc, + usbip_protocol::{self, UsbIpHeaderBasic, USBIP_CMD_SUBMIT, USBIP_CMD_UNLINK}, + util::tests::*, + ClassCode, UsbDevice, UsbInterfaceHandler, + }; + + const SINGLE_DEVICE_BUSID: &str = "0-0-0"; + + fn new_server_with_single_device() -> AsyncUsbIpServer { + AsyncUsbIpServer::new_simulated(vec![UsbDevice::new(0).with_interface( + ClassCode::CDC as u8, + cdc::CDC_ACM_SUBCLASS, + 0x00, + "Test CDC ACM", + cdc::UsbCdcAcmHandler::endpoints(), + Arc::new(Mutex::new( + Box::new(cdc::UsbCdcAcmHandler::new()) as Box + )), + )]) + } + + fn op_req_import(busid: &str) -> Vec { + let mut busid = busid.to_string().as_bytes().to_vec(); + busid.resize(32, 0); + UsbIpCommand::OpReqImport { + status: 0, + busid: busid.try_into().unwrap(), + } + .to_bytes() + } + + async fn attach_device(connection: &mut TcpStream, busid: &str) -> u32 { + let req = op_req_import(busid); + connection.write_all(req.as_slice()).await.unwrap(); + connection.read_u32().await.unwrap(); + let result = connection.read_u32().await.unwrap(); + if result == 0 { + connection.read_exact(&mut vec![0; 0x138]).await.unwrap(); + } + result + } + + #[tokio::test(flavor = "multi_thread")] + async fn req_empty_devlist() { + setup_test_logger(); + let server = AsyncUsbIpServer::new_simulated(vec![]); + let req = UsbIpCommand::OpReqDevlist { status: 0 }; + + let mut mock_socket = MockSocket::new(req.to_bytes()); + server.handler(&mut mock_socket).await.ok(); + + assert_eq!( + mock_socket.output, + UsbIpResponse::op_rep_devlist(&[]).to_bytes(), + ); + } + + #[tokio::test(flavor = "multi_thread")] + async fn req_sample_devlist() { + setup_test_logger(); + let server = new_server_with_single_device(); + let req = UsbIpCommand::OpReqDevlist { status: 0 }; + + let mut mock_socket = MockSocket::new(req.to_bytes()); + server.handler(&mut mock_socket).await.ok(); + + // OP_REP_DEVLIST + // header: 0xC + // device: 0x138 + // interface: 4 * 0x1 + assert_eq!(mock_socket.output.len(), 0xC + 0x138 + 4); + } + + #[tokio::test(flavor = "multi_thread")] + async fn req_import() { + setup_test_logger(); + let server = new_server_with_single_device(); + + // OP_REQ_IMPORT + let req = op_req_import(SINGLE_DEVICE_BUSID); + let mut mock_socket = MockSocket::new(req); + server.handler(&mut mock_socket).await.ok(); + // OP_REQ_IMPORT + assert_eq!(mock_socket.output.len(), 0x140); + } + + #[tokio::test(flavor = "multi_thread")] + async fn add_and_remove_10_devices() { + setup_test_logger(); + let server = AsyncUsbIpServer::new_simulated(vec![]); + let server_ = server.clone(); + let addr = get_free_address().await; + let server_thread = tokio::spawn(server_.serve(addr)); + + let mut join_set = JoinSet::new(); + let devices = (0..10) + .map(|i| { + let mut device = UsbDevice::new(i); + device.bus_id = format!("0-0-{}", i); + device + }) + .collect::>(); + + for device in devices.iter() { + let new_server = server.clone(); + let new_device = device.clone(); + join_set.spawn(async move { + new_server.add_device(new_device).await; + }); + } + + for device in devices.iter() { + let new_server = server.clone(); + let new_device = device.clone(); + join_set.spawn(async move { + new_server.remove_device(&new_device.bus_id).await.unwrap(); + }); + } + + while join_set.join_next().await.is_some() {} + + let device_len = server.clone().available_devices.read().await.len(); + + assert_eq!(device_len, 0); + server_thread.abort(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn send_usb_traffic_while_adding_and_removing_devices() { + setup_test_logger(); + let server = new_server_with_single_device(); + let server_ = server.clone(); + + let addr = get_free_address().await; + let server_thread = tokio::spawn(server_.serve(addr)); + + let cmd_loop_handle = tokio::spawn(async move { + let mut connection = poll_connect(addr).await; + let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 0); + + let cdc_loopback_bulk_cmd = UsbIpCommand::UsbIpCmdSubmit { + header: usbip_protocol::UsbIpHeaderBasic { + command: USBIP_CMD_SUBMIT.into(), + seqnum: 1, + devid: 0, + direction: 0, // OUT + ep: 2, + }, + transfer_flags: 0, + transfer_buffer_length: 8, + start_frame: 0, + number_of_packets: 0, + interval: 0, + setup: [0; 8], + data: vec![1, 2, 3, 4, 5, 6, 7, 8], + iso_packet_descriptor: vec![], + }; + + loop { + connection + .write_all(cdc_loopback_bulk_cmd.to_bytes().as_slice()) + .await + .unwrap(); + let mut result = vec![0; 4 * 12]; + connection.read_exact(&mut result).await.unwrap(); + } + }); + + let add_and_remove_device_handle = tokio::spawn(async move { + let mut join_set = JoinSet::new(); + let devices = (1..4) + .map(|i| { + let mut device = UsbDevice::new(i); + device.bus_id = format!("0-0-{}", i); + device + }) + .collect::>(); + + loop { + for device in devices.iter() { + let new_server = server.clone(); + let new_device = device.clone(); + join_set.spawn(async move { + new_server.add_device(new_device).await; + }); + } + + for device in devices.iter() { + let new_server = server.clone(); + let new_device = device.clone(); + join_set.spawn(async move { + new_server.remove_device(&new_device.bus_id).await.unwrap(); + }); + } + while join_set.join_next().await.is_some() {} + tokio::time::sleep(tokio::time::Duration::from_millis(20)).await; + } + }); + + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + cmd_loop_handle.abort(); + add_and_remove_device_handle.abort(); + server_thread.abort(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn only_single_connection_allowed_to_device() { + setup_test_logger(); + let server = new_server_with_single_device(); + + let addr = get_free_address().await; + let server_thread = tokio::spawn(server.serve(addr)); + + let mut first_connection = poll_connect(addr).await; + let mut second_connection = TcpStream::connect(addr).await.unwrap(); + + let result = attach_device(&mut first_connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 0); + + let result = attach_device(&mut second_connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 1); + server_thread.abort(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn device_gets_released_on_cmd_unlink() { + setup_test_logger(); + let server = new_server_with_single_device(); + + let addr = get_free_address().await; + let server_thread = tokio::spawn(server.serve(addr)); + + let mut connection = poll_connect(addr).await; + + let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 0); + + let unlink_req = UsbIpCommand::UsbIpCmdUnlink { + header: UsbIpHeaderBasic { + command: USBIP_CMD_UNLINK.into(), + seqnum: 1, + devid: 0, + direction: 0, + ep: 0, + }, + unlink_seqnum: 0, + } + .to_bytes(); + + connection.write_all(unlink_req.as_slice()).await.unwrap(); + connection.read_exact(&mut [0; 4 * 5]).await.unwrap(); + let result = connection.read_u32().await.unwrap(); + connection.read_exact(&mut [0; 4 * 6]).await.unwrap(); + assert_eq!(result, 0); + + let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 0); + server_thread.abort(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn device_gets_released_on_closed_socket() { + setup_test_logger(); + let server = new_server_with_single_device(); + + let addr = get_free_address().await; + tokio::spawn(server.serve(addr)); + + let mut connection = poll_connect(addr).await; + let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 0); + + std::mem::drop(connection); + + let mut connection = TcpStream::connect(addr).await.unwrap(); + let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 0); + } + + #[tokio::test(flavor = "multi_thread")] + async fn req_import_get_device_desc() { + setup_test_logger(); + let server = new_server_with_single_device(); + + let mut req = op_req_import(SINGLE_DEVICE_BUSID); + req.extend( + UsbIpCommand::UsbIpCmdSubmit { + header: UsbIpHeaderBasic { + command: USBIP_CMD_SUBMIT.into(), + seqnum: 1, + devid: 0, + direction: 1, // IN + ep: 0, + }, + transfer_flags: 0, + transfer_buffer_length: 0, + start_frame: 0, + number_of_packets: 0, + interval: 0, + // GetDescriptor to Device + setup: [0x80, 0x06, 0x00, 0x01, 0x00, 0x00, 0x40, 0x00], + data: vec![], + iso_packet_descriptor: vec![], + } + .to_bytes(), + ); + + let mut mock_socket = MockSocket::new(req); + server.handler(&mut mock_socket).await.ok(); + // OP_REQ_IMPORT + USBIP_CMD_SUBMIT + Device Descriptor + assert_eq!(mock_socket.output.len(), 0x140 + 0x30 + 0x12); + } +} diff --git a/src/server/sync_server.rs b/src/server/sync_server.rs new file mode 100644 index 0000000..b7fab43 --- /dev/null +++ b/src/server/sync_server.rs @@ -0,0 +1,555 @@ +use super::*; + +#[derive(Default, Clone)] +pub struct SyncUsbIpServer { + available_devices: Arc>>, + used_devices: Arc>>, +} + +#[async_trait] +impl UsbIpServer for SyncUsbIpServer { + fn new_simulated(devices: Vec) -> Self { + Self { + available_devices: Arc::new(RwLock::new(devices)), + used_devices: Arc::new(RwLock::new(HashMap::new())), + } + } + + fn new_from_host_with_filter(filter: F) -> Self + where + F: FnMut(&Device) -> bool, + { + match rusb::devices() { + Ok(list) => { + let mut devs = vec![]; + for d in list.iter().filter(filter) { + devs.push(d) + } + Self::new_simulated(get_list_of_real_devices(devs)) + } + Err(_) => Default::default(), + } + } + + async fn add_device(&self, device: UsbDevice) { + self.available_devices.write().await.push(device); + } + + async fn remove_device(&self, bus_id: &str) -> Result<()> { + let mut available_devices = self.available_devices.write().await; + + if let Some(device) = available_devices.iter().position(|d| d.bus_id == bus_id) { + available_devices.remove(device); + Ok(()) + } else if let Some(device) = self + .used_devices + .read() + .await + .values() + .find(|d| d.bus_id == bus_id) + { + Err(std::io::Error::new( + ErrorKind::Other, + format!("Device {} is in use", device.bus_id), + )) + } else { + Err(std::io::Error::new( + ErrorKind::NotFound, + format!("Device {} not found", bus_id), + )) + } + } + + async fn handler( + self, + mut socket: T, + ) -> Result<()> { + let mut current_import_device_id: Option = None; + loop { + let command = UsbIpCommand::read_from_socket(&mut socket).await; + if let Err(err) = command { + if let Some(dev_id) = current_import_device_id { + let mut used_devices = self.used_devices.write().await; + let mut available_devices = self.available_devices.write().await; + match used_devices.remove(&dev_id) { + Some(dev) => available_devices.push(dev), + None => unreachable!(), + } + } + + if err.kind() == ErrorKind::UnexpectedEof { + info!("Remote closed the connection"); + return Ok(()); + } else { + return Err(err); + } + } + + let used_devices = self.used_devices.read().await; + let mut current_import_device = current_import_device_id + .clone() + .and_then(|ref id| used_devices.get(id)); + + match command.unwrap() { + UsbIpCommand::OpReqDevlist { .. } => { + trace!("Got OP_REQ_DEVLIST"); + let devices = self.available_devices.read().await; + + UsbIpResponse::op_rep_devlist(&devices) + .write_to_socket(&mut socket) + .await?; + trace!("Sent OP_REP_DEVLIST"); + } + UsbIpCommand::OpReqImport { busid, .. } => { + trace!("Got OP_REQ_IMPORT"); + + current_import_device_id = None; + current_import_device = None; + std::mem::drop(used_devices); + + let mut used_devices = self.used_devices.write().await; + let mut available_devices = self.available_devices.write().await; + for (i, dev) in available_devices.iter().enumerate() { + let mut expected = dev.bus_id.as_bytes().to_vec(); + expected.resize(32, 0); + if expected.as_slice() == busid { + let dev = available_devices.remove(i); + let dev_id = dev.bus_id.clone(); + used_devices.insert(dev.bus_id.clone(), dev); + current_import_device_id = dev_id.clone().into(); + current_import_device = Some(used_devices.get(&dev_id).unwrap()); + break; + } + } + + let res = if let Some(dev) = current_import_device { + UsbIpResponse::op_rep_import_success(dev) + } else { + UsbIpResponse::op_rep_import_fail() + }; + res.write_to_socket(&mut socket).await?; + trace!("Sent OP_REP_IMPORT"); + } + UsbIpCommand::UsbIpCmdSubmit { + mut header, + transfer_buffer_length, + setup, + data, + .. + } => { + trace!("Got USBIP_CMD_SUBMIT"); + let device = current_import_device.unwrap(); + + let out = header.direction == 0; + let real_ep = if out { header.ep } else { header.ep | 0x80 }; + + header.command = USBIP_RET_SUBMIT.into(); + + let res = match device.find_ep(real_ep as u8) { + None => { + warn!("Endpoint {:02x?} not found", real_ep); + UsbIpResponse::usbip_ret_submit_fail(&header) + } + Some((ep, intf)) => { + trace!("->Endpoint {:02x?}", ep); + trace!("->Setup {:02x?}", setup); + trace!("->Request {:02x?}", data); + let resp = device + .handle_urb( + ep, + intf, + transfer_buffer_length, + SetupPacket::parse(&setup), + &data, + ) + .await?; + + if out { + trace!("<-Wrote {}", data.len()); + } else { + trace!("<-Resp {:02x?}", resp); + } + + UsbIpResponse::usbip_ret_submit_success(&header, 0, 0, resp, vec![]) + } + }; + res.write_to_socket(&mut socket).await?; + trace!("Sent USBIP_RET_SUBMIT"); + } + UsbIpCommand::UsbIpCmdUnlink { + mut header, + unlink_seqnum: _, + } => { + trace!("Got USBIP_CMD_UNLINK"); + + std::mem::drop(used_devices); + + let mut used_devices = self.used_devices.write().await; + let mut available_devices = self.available_devices.write().await; + + let dev = current_import_device_id + .clone() + .and_then(|ref k| used_devices.remove(k)); + + header.command = USBIP_RET_UNLINK.into(); + + let res = match dev { + Some(dev) => { + available_devices.push(dev); + current_import_device_id = None; + UsbIpResponse::usbip_ret_unlink_success(&header) + } + None => { + warn!("Device not found"); + UsbIpResponse::usbip_ret_unlink_fail(&header) + } + }; + res.write_to_socket(&mut socket).await?; + trace!("Sent USBIP_RET_UNLINK"); + } + } + } + } + + async fn serve(self, addr: SocketAddr) { + let listener = TcpListener::bind(addr) + .await + .unwrap_or_else(|_| panic!("Could not bind to {}", addr)); + + trace!("Listening on {:?}", addr); + + let server_loop = async move { + loop { + match listener.accept().await { + Ok((socket, addr)) => { + info!("Got connection from {:?}", addr); + let server = self.clone(); + tokio::spawn(async move { + match server.handler(socket).await { + Ok(()) => { + trace!("Handler for {} ended successfully", addr); + } + Err(err) => { + warn!("Handler for {} ended with error {:?}", addr, err); + } + } + }); + } + Err(err) => { + warn!( + "Could not establish socket for incoming connection, {:?}", + err + ); + } + } + } + }; + + server_loop.await + } +} + +#[cfg(test)] +mod tests { + use tokio::{net::TcpStream, task::JoinSet}; + + use super::*; + use crate::{ + cdc, + usbip_protocol::{self, UsbIpHeaderBasic, USBIP_CMD_SUBMIT, USBIP_CMD_UNLINK}, + util::tests::*, + ClassCode, UsbDevice, UsbInterfaceHandler, + }; + + const SINGLE_DEVICE_BUSID: &str = "0-0-0"; + + fn new_server_with_single_device() -> SyncUsbIpServer { + SyncUsbIpServer::new_simulated(vec![UsbDevice::new(0).with_interface( + ClassCode::CDC as u8, + cdc::CDC_ACM_SUBCLASS, + 0x00, + "Test CDC ACM", + cdc::UsbCdcAcmHandler::endpoints(), + Arc::new(Mutex::new( + Box::new(cdc::UsbCdcAcmHandler::new()) as Box + )), + )]) + } + + fn op_req_import(busid: &str) -> Vec { + let mut busid = busid.to_string().as_bytes().to_vec(); + busid.resize(32, 0); + UsbIpCommand::OpReqImport { + status: 0, + busid: busid.try_into().unwrap(), + } + .to_bytes() + } + + async fn attach_device(connection: &mut TcpStream, busid: &str) -> u32 { + let req = op_req_import(busid); + connection.write_all(req.as_slice()).await.unwrap(); + connection.read_u32().await.unwrap(); + let result = connection.read_u32().await.unwrap(); + if result == 0 { + connection.read_exact(&mut vec![0; 0x138]).await.unwrap(); + } + result + } + + #[tokio::test] + async fn req_empty_devlist() { + setup_test_logger(); + let server = SyncUsbIpServer::new_simulated(vec![]); + let req = UsbIpCommand::OpReqDevlist { status: 0 }; + + let mut mock_socket = MockSocket::new(req.to_bytes()); + server.handler(&mut mock_socket).await.ok(); + + assert_eq!( + mock_socket.output, + UsbIpResponse::op_rep_devlist(&[]).to_bytes(), + ); + } + + #[tokio::test] + async fn req_sample_devlist() { + setup_test_logger(); + let server = new_server_with_single_device(); + let req = UsbIpCommand::OpReqDevlist { status: 0 }; + + let mut mock_socket = MockSocket::new(req.to_bytes()); + server.handler(&mut mock_socket).await.ok(); + + // OP_REP_DEVLIST + // header: 0xC + // device: 0x138 + // interface: 4 * 0x1 + assert_eq!(mock_socket.output.len(), 0xC + 0x138 + 4); + } + + #[tokio::test] + async fn req_import() { + setup_test_logger(); + let server = new_server_with_single_device(); + + // OP_REQ_IMPORT + let req = op_req_import(SINGLE_DEVICE_BUSID); + let mut mock_socket = MockSocket::new(req); + server.handler(&mut mock_socket).await.ok(); + // OP_REQ_IMPORT + assert_eq!(mock_socket.output.len(), 0x140); + } + + #[tokio::test] + async fn add_and_remove_10_devices() { + setup_test_logger(); + let server = SyncUsbIpServer::new_simulated(vec![]); + let server_ = server.clone(); + let addr = get_free_address().await; + tokio::spawn(server_.serve(addr)); + + let mut join_set = JoinSet::new(); + let devices = (0..10).map(UsbDevice::new).collect::>(); + + for device in devices.iter() { + let new_server = server.clone(); + let new_device = device.clone(); + join_set.spawn(async move { + new_server.add_device(new_device).await; + }); + } + + for device in devices.iter() { + let new_server = server.clone(); + let new_device = device.clone(); + join_set.spawn(async move { + new_server.remove_device(&new_device.bus_id).await.unwrap(); + }); + } + + while join_set.join_next().await.is_some() {} + + let device_len = server.clone().available_devices.read().await.len(); + + assert_eq!(device_len, 0); + } + + #[tokio::test] + async fn send_usb_traffic_while_adding_and_removing_devices() { + setup_test_logger(); + let server = new_server_with_single_device(); + let server_ = server.clone(); + + let addr = get_free_address().await; + tokio::spawn(server_.serve(addr)); + + let cmd_loop_handle = tokio::spawn(async move { + let mut connection = poll_connect(addr).await; + let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 0); + + let cdc_loopback_bulk_cmd = UsbIpCommand::UsbIpCmdSubmit { + header: usbip_protocol::UsbIpHeaderBasic { + command: USBIP_CMD_SUBMIT.into(), + seqnum: 1, + devid: 0, + direction: 0, // OUT + ep: 2, + }, + transfer_flags: 0, + transfer_buffer_length: 8, + start_frame: 0, + number_of_packets: 0, + interval: 0, + setup: [0; 8], + data: vec![1, 2, 3, 4, 5, 6, 7, 8], + iso_packet_descriptor: vec![], + }; + + loop { + connection + .write_all(cdc_loopback_bulk_cmd.to_bytes().as_slice()) + .await + .unwrap(); + let mut result = vec![0; 4 * 12]; + connection.read_exact(&mut result).await.unwrap(); + } + }); + + let add_and_remove_device_handle = tokio::spawn(async move { + let mut join_set = JoinSet::new(); + let devices = (1..4).map(UsbDevice::new).collect::>(); + + loop { + for device in devices.iter() { + let new_server = server.clone(); + let new_device = device.clone(); + join_set.spawn(async move { + new_server.add_device(new_device).await; + }); + } + + for device in devices.iter() { + let new_server = server.clone(); + let new_device = device.clone(); + join_set.spawn(async move { + new_server.remove_device(&new_device.bus_id).await.unwrap(); + }); + } + while join_set.join_next().await.is_some() {} + tokio::time::sleep(tokio::time::Duration::from_millis(20)).await; + } + }); + + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + cmd_loop_handle.abort(); + add_and_remove_device_handle.abort(); + } + + #[tokio::test] + async fn only_single_connection_allowed_to_device() { + setup_test_logger(); + let server = new_server_with_single_device(); + + let addr = get_free_address().await; + tokio::spawn(server.serve(addr)); + + let mut first_connection = poll_connect(addr).await; + let mut second_connection = TcpStream::connect(addr).await.unwrap(); + + let result = attach_device(&mut first_connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 0); + + let result = attach_device(&mut second_connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 1); + } + + #[tokio::test] + async fn device_gets_released_on_cmd_unlink() { + setup_test_logger(); + let server = new_server_with_single_device(); + + let addr = get_free_address().await; + tokio::spawn(server.serve(addr)); + + let mut connection = poll_connect(addr).await; + + let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 0); + + let unlink_req = UsbIpCommand::UsbIpCmdUnlink { + header: UsbIpHeaderBasic { + command: USBIP_CMD_UNLINK.into(), + seqnum: 1, + devid: 0, + direction: 0, + ep: 0, + }, + unlink_seqnum: 0, + } + .to_bytes(); + + connection.write_all(unlink_req.as_slice()).await.unwrap(); + connection.read_exact(&mut [0; 4 * 5]).await.unwrap(); + let result = connection.read_u32().await.unwrap(); + connection.read_exact(&mut [0; 4 * 6]).await.unwrap(); + assert_eq!(result, 0); + + let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 0); + } + + #[tokio::test] + async fn device_gets_released_on_closed_socket() { + setup_test_logger(); + let server = new_server_with_single_device(); + + let addr = get_free_address().await; + tokio::spawn(server.serve(addr)); + + let mut connection = poll_connect(addr).await; + let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 0); + + std::mem::drop(connection); + + let mut connection = TcpStream::connect(addr).await.unwrap(); + let result = attach_device(&mut connection, SINGLE_DEVICE_BUSID).await; + assert_eq!(result, 0); + } + + #[tokio::test] + async fn req_import_get_device_desc() { + setup_test_logger(); + let server = new_server_with_single_device(); + + let mut req = op_req_import(SINGLE_DEVICE_BUSID); + req.extend( + UsbIpCommand::UsbIpCmdSubmit { + header: UsbIpHeaderBasic { + command: USBIP_CMD_SUBMIT.into(), + seqnum: 1, + devid: 0, + direction: 1, // IN + ep: 0, + }, + transfer_flags: 0, + transfer_buffer_length: 0, + start_frame: 0, + number_of_packets: 0, + interval: 0, + // GetDescriptor to Device + setup: [0x80, 0x06, 0x00, 0x01, 0x00, 0x00, 0x40, 0x00], + data: vec![], + iso_packet_descriptor: vec![], + } + .to_bytes(), + ); + + let mut mock_socket = MockSocket::new(req); + server.handler(&mut mock_socket).await.ok(); + // OP_REQ_IMPORT + USBIP_CMD_SUBMIT + Device Descriptor + assert_eq!(mock_socket.output.len(), 0x140 + 0x30 + 0x12); + } +} diff --git a/src/usbip_protocol.rs b/src/usbip_protocol.rs index e2bf18f..084d4a0 100644 --- a/src/usbip_protocol.rs +++ b/src/usbip_protocol.rs @@ -346,7 +346,7 @@ impl UsbIpResponse { let mut result = Vec::with_capacity(48 + transfer_buffer.len() + iso_packet_descriptor.len()); - debug_assert!(header.command == USBIP_RET_SUBMIT.into()); + debug_assert!(header.command == >::into(USBIP_RET_SUBMIT.into())); result.extend_from_slice(&header.to_bytes()); result.extend_from_slice(&status.to_be_bytes()); @@ -362,7 +362,7 @@ impl UsbIpResponse { Self::UsbIpRetUnlink { ref header, status } => { let mut result = Vec::with_capacity(48); - debug_assert!(header.command == USBIP_RET_UNLINK.into()); + debug_assert!(header.command == >::into(USBIP_RET_UNLINK.into())); result.extend_from_slice(&header.to_bytes()); result.extend_from_slice(&status.to_be_bytes()); diff --git a/src/util.rs b/src/util.rs index 977c39a..d819ce4 100644 --- a/src/util.rs +++ b/src/util.rs @@ -15,9 +15,10 @@ pub(crate) mod tests { pin::Pin, task::{Context, Poll}, }; + use log::{info, trace}; use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, - net::{TcpListener, TcpStream}, + net::{TcpListener, TcpStream}, time, }; pub(crate) struct MockSocket { @@ -71,14 +72,21 @@ pub(crate) mod tests { } pub(crate) async fn poll_connect(addr: SocketAddr) -> TcpStream { - loop { + for _ in 0..1000 { + trace!("Trying to connect to {:?}", addr); + if let Ok(stream) = TcpStream::connect(addr).await { return stream; } + + trace!("Failed to connect to {:?}, retrying in 10ms", addr); + time::sleep(time::Duration::from_millis(10)).await; } + panic!("Could not connect to socket in 1000 tries, assuming test failure") } pub(crate) fn setup_test_logger() { let _ = env_logger::builder().is_test(true).try_init(); + info!("Successfully initialized test env logger"); } }