From ef3e473247fe3e1f920bf3c64d715dd896c0ea4d Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Mon, 22 Jul 2024 22:18:22 +0200 Subject: [PATCH 01/18] add draft of a vsock driver --- Cargo.toml | 3 +- src/arch/aarch64/mm/paging.rs | 2 +- src/arch/riscv64/mm/paging.rs | 2 +- src/arch/x86_64/mm/paging.rs | 2 +- src/drivers/mod.rs | 17 +++- src/drivers/pci.rs | 19 +++- src/drivers/virtio/mod.rs | 18 ++++ src/drivers/virtio/transport/pci.rs | 20 ++++ src/drivers/vsock/mod.rs | 152 ++++++++++++++++++++++++++++ src/drivers/vsock/pci.rs | 124 +++++++++++++++++++++++ 10 files changed, 346 insertions(+), 13 deletions(-) create mode 100644 src/drivers/vsock/mod.rs create mode 100644 src/drivers/vsock/pci.rs diff --git a/Cargo.toml b/Cargo.toml index d0c1b16aae..0c8f605125 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ name = "measure_startup_time" harness = false [features] -default = ["pci", "pci-ids", "acpi", "fsgsbase", "smp", "tcp", "dhcpv4", "fuse"] +default = ["pci", "pci-ids", "acpi", "fsgsbase", "smp", "tcp", "dhcpv4", "fuse", "vsock"] acpi = [] dhcpv4 = [ "smoltcp", @@ -54,6 +54,7 @@ dhcpv4 = [ ] fs = ["fuse"] fuse = ["pci", "dep:fuse-abi", "fuse-abi/num_enum"] +vsock = [] fsgsbase = [] gem-net = ["tcp", "dep:tock-registers"] newlib = [] diff --git a/src/arch/aarch64/mm/paging.rs b/src/arch/aarch64/mm/paging.rs index 5e03396a35..e9598eca46 100644 --- a/src/arch/aarch64/mm/paging.rs +++ b/src/arch/aarch64/mm/paging.rs @@ -578,7 +578,7 @@ pub fn virtual_to_physical(virtual_address: VirtAddr) -> Option { get_physical_address::(virtual_address) } -#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "fuse", feature = "vsock", feature = "tcp", feature = "udp"))] pub fn virt_to_phys(virtual_address: VirtAddr) -> PhysAddr { virtual_to_physical(virtual_address).unwrap() } diff --git a/src/arch/riscv64/mm/paging.rs b/src/arch/riscv64/mm/paging.rs index f44c3256cb..3aaf7c0008 100644 --- a/src/arch/riscv64/mm/paging.rs +++ b/src/arch/riscv64/mm/paging.rs @@ -584,7 +584,7 @@ pub fn virtual_to_physical(virtual_address: VirtAddr) -> Option { panic!("virtual_to_physical should never reach this point"); } -#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "fuse", feature = "vsock", feature = "tcp", feature = "udp"))] pub fn virt_to_phys(virtual_address: VirtAddr) -> PhysAddr { virtual_to_physical(virtual_address).unwrap() } diff --git a/src/arch/x86_64/mm/paging.rs b/src/arch/x86_64/mm/paging.rs index bc7cd99f57..b55ba41464 100644 --- a/src/arch/x86_64/mm/paging.rs +++ b/src/arch/x86_64/mm/paging.rs @@ -118,7 +118,7 @@ pub fn virtual_to_physical(virtual_address: VirtAddr) -> Option { } } -#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "fuse", feature = "vsock", feature = "tcp", feature = "udp"))] pub fn virt_to_phys(virtual_address: VirtAddr) -> PhysAddr { virtual_to_physical(virtual_address).unwrap() } diff --git a/src/drivers/mod.rs b/src/drivers/mod.rs index 9d31c91c5c..580171c434 100644 --- a/src/drivers/mod.rs +++ b/src/drivers/mod.rs @@ -10,9 +10,12 @@ pub mod net; pub mod pci; #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] pub mod virtio; +#[cfg(feature = "vsock")] +pub mod vsock; /// A common error module for drivers. /// [DriverError](error::DriverError) values will be @@ -26,7 +29,8 @@ pub mod error { use crate::drivers::net::rtl8139::RTL8139Error; #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] use crate::drivers::virtio::error::VirtioError; @@ -34,7 +38,8 @@ pub mod error { pub enum DriverError { #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] InitVirtioDevFail(VirtioError), #[cfg(feature = "rtl8139")] @@ -45,7 +50,8 @@ pub mod error { #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] impl From for DriverError { fn from(err: VirtioError) -> Self { @@ -73,7 +79,8 @@ pub mod error { match *self { #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] DriverError::InitVirtioDevFail(ref err) => { write!(f, "Virtio driver failed: {err:?}") diff --git a/src/drivers/pci.rs b/src/drivers/pci.rs index ee6239db4a..7add681a3f 100644 --- a/src/drivers/pci.rs +++ b/src/drivers/pci.rs @@ -4,7 +4,7 @@ use alloc::vec::Vec; use core::fmt; use hermit_sync::without_interrupts; -#[cfg(any(feature = "tcp", feature = "udp", feature = "fuse"))] +#[cfg(any(feature = "tcp", feature = "udp", feature = "fuse", feature = "vsock"))] use hermit_sync::InterruptTicketMutex; use pci_types::capability::CapabilityIterator; use pci_types::{ @@ -22,14 +22,18 @@ use crate::drivers::net::rtl8139::{self, RTL8139Driver}; use crate::drivers::net::virtio::VirtioNetDriver; #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] use crate::drivers::virtio::transport::pci as pci_virtio; #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] use crate::drivers::virtio::transport::pci::VirtioDriver; +#[cfg(feature = "vsock")] +use crate::drivers::vsock::VirtioVsockDriver; pub(crate) static mut PCI_DEVICES: Vec> = Vec::new(); static mut PCI_DRIVERS: Vec = Vec::new(); @@ -297,6 +301,8 @@ pub(crate) fn print_information() { pub(crate) enum PciDriver { #[cfg(feature = "fuse")] VirtioFs(InterruptTicketMutex), + #[cfg(feature = "vsock")] + VirtioVsock(InterruptTicketMutex), #[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] VirtioNet(InterruptTicketMutex), #[cfg(all(feature = "rtl8139", any(feature = "tcp", feature = "udp")))] @@ -373,13 +379,18 @@ pub(crate) fn init_drivers() { #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] match pci_virtio::init_device(adapter) { #[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] Ok(VirtioDriver::Network(drv)) => { register_driver(PciDriver::VirtioNet(InterruptTicketMutex::new(drv))) } + #[cfg(feature = "vsock")] + Ok(VirtioDriver::Vsock(drv)) => { + register_driver(PciDriver::VirtioVsock(InterruptTicketMutex::new(drv))) + } #[cfg(feature = "fuse")] Ok(VirtioDriver::FileSystem(drv)) => { register_driver(PciDriver::VirtioFs(InterruptTicketMutex::new(drv))) diff --git a/src/drivers/virtio/mod.rs b/src/drivers/virtio/mod.rs index cbd2719cc4..e32cd5f375 100644 --- a/src/drivers/virtio/mod.rs +++ b/src/drivers/virtio/mod.rs @@ -14,6 +14,8 @@ pub mod error { pub use crate::drivers::net::virtio::error::VirtioNetError; #[cfg(feature = "pci")] use crate::drivers::pci::error::PciError; + #[cfg(feature = "vsock")] + pub use crate::drivers::vsock::error::VirtioVsockError; #[allow(dead_code)] #[derive(Debug)] @@ -25,6 +27,8 @@ pub mod error { NetDriver(VirtioNetError), #[cfg(feature = "fuse")] FsDriver(VirtioFsError), + #[cfg(feature = "vsock")] + VsockDriver(VirtioVsockError), #[cfg(not(feature = "pci"))] Unknown, } @@ -71,6 +75,20 @@ pub mod error { VirtioFsError::IncompatibleFeatureSets(driver_features, device_features) => write!(f, "Feature set: {driver_features:?} , is incompatible with the device features: {device_features:?}", ), VirtioFsError::Unknown => write!(f, "Virtio filesystem failed, driver failed due unknown reason!"), }, + #[cfg(feature = "vsock")] + VirtioError::VsockDriver(vsock_error) => match vsock_error { + #[cfg(feature = "pci")] + VirtioVsockError::NoDevCfg(id) => write!(f, "Virtio socket device driver failed, for device {id:x}, due to a missing or malformed device config!"), + #[cfg(feature = "pci")] + VirtioVsockError::NoComCfg(id) => write!(f, "Virtio socket device driver failed, for device {id:x}, due to a missing or malformed common config!"), + #[cfg(feature = "pci")] + VirtioVsockError::NoIsrCfg(id) => write!(f, "Virtio socket device driver failed, for device {id:x}, due to a missing or malformed ISR status config!"), + #[cfg(feature = "pci")] + VirtioVsockError::NoNotifCfg(id) => write!(f, "Virtio socket device driver failed, for device {id:x}, due to a missing or malformed notification config!"), + VirtioVsockError::FailFeatureNeg(id) => write!(f, "Virtio socket device driver failed, for device {id:x}, device did not acknowledge negotiated feature set!"), + VirtioVsockError::FeatureRequirementsNotMet(features) => write!(f, "Virtio socket driver tried to set feature bit without setting dependency feature. Feat set: {features:?}"), + VirtioVsockError::IncompatibleFeatureSets(driver_features, device_features) => write!(f, "Feature set: {driver_features:?} , is incompatible with the device features: {device_features:?}"), + }, } } } diff --git a/src/drivers/virtio/transport/pci.rs b/src/drivers/virtio/transport/pci.rs index 7dc5e6770e..b2f119eb16 100644 --- a/src/drivers/virtio/transport/pci.rs +++ b/src/drivers/virtio/transport/pci.rs @@ -31,6 +31,8 @@ use crate::drivers::net::virtio::VirtioNetDriver; use crate::drivers::pci::error::PciError; use crate::drivers::pci::PciDevice; use crate::drivers::virtio::error::VirtioError; +#[cfg(feature = "vsock")] +use crate::drivers::vsock::VirtioVsockDriver; /// Maps a given device specific pci configuration structure and /// returns a static reference to it. @@ -915,6 +917,20 @@ pub(crate) fn init_device( Err(DriverError::InitVirtioDevFail(virtio_error)) } }, + #[cfg(feature = "vsock")] + virtio::Id::Vsock => match VirtioVsockDriver::init(device) { + Ok(virt_sock_drv) => { + info!("Virtio sock driver initialized."); + Ok(VirtioDriver::Vsock(virt_sock_drv)) + } + Err(virtio_error) => { + error!( + "Virtio sock driver could not be initialized with device: {:x}", + device_id + ); + Err(DriverError::InitVirtioDevFail(virtio_error)) + } + }, #[cfg(feature = "fuse")] virtio::Id::Fs => { // TODO: check subclass @@ -956,6 +972,8 @@ pub(crate) fn init_device( Ok(drv) } + #[cfg(feature = "vsock")] + VirtioDriver::Vsock(_) => Ok(drv), #[cfg(feature = "fuse")] VirtioDriver::FileSystem(_) => Ok(drv), } @@ -967,6 +985,8 @@ pub(crate) fn init_device( pub(crate) enum VirtioDriver { #[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] Network(VirtioNetDriver), + #[cfg(feature = "vsock")] + Vsock(VirtioVsockDriver), #[cfg(feature = "fuse")] FileSystem(VirtioFsDriver), } diff --git a/src/drivers/vsock/mod.rs b/src/drivers/vsock/mod.rs new file mode 100644 index 0000000000..decd3b125e --- /dev/null +++ b/src/drivers/vsock/mod.rs @@ -0,0 +1,152 @@ +#[cfg(feature = "pci")] +pub mod pci; + +use alloc::rc::Rc; +use alloc::vec::Vec; + +use virtio::FeatureBits; + +use crate::config::VIRTIO_MAX_QUEUE_SIZE; +use crate::drivers::virtio::error::VirtioVsockError; +#[cfg(feature = "pci")] +use crate::drivers::virtio::transport::pci::{ComCfg, IsrStatus, NotifCfg}; +use crate::drivers::virtio::virtqueue::split::SplitVq; +use crate::drivers::virtio::virtqueue::{Virtq, VqIndex, VqSize}; +#[cfg(feature = "pci")] +use crate::drivers::vsock::pci::VsockDevCfgRaw; + +/// A wrapper struct for the raw configuration structure. +/// Handling the right access to fields, as some are read-only +/// for the driver. +pub(crate) struct VsockDevCfg { + pub raw: &'static VsockDevCfgRaw, + pub dev_id: u16, + pub features: virtio::net::F, +} + +pub(crate) struct VirtioVsockDriver { + pub(super) dev_cfg: VsockDevCfg, + pub(super) com_cfg: ComCfg, + pub(super) isr_stat: IsrStatus, + pub(super) notif_cfg: NotifCfg, + pub(super) vqueues: Vec>, +} + +impl VirtioVsockDriver { + #[cfg(feature = "pci")] + pub fn get_dev_id(&self) -> u16 { + self.dev_cfg.dev_id + } + + #[cfg(feature = "pci")] + pub fn set_failed(&mut self) { + self.com_cfg.set_failed(); + } + + pub fn disable_interrupts(&self) { + // For send and receive queues? + // Only for receive? Because send is off anyway? + self.vqueues[0].disable_notifs(); + } + + pub fn enable_interrupts(&self) { + // For send and receive queues? + // Only for receive? Because send is off anyway? + self.vqueues[0].enable_notifs(); + } + + /// Negotiates a subset of features, understood and wanted by both the OS + /// and the device. + fn negotiate_features( + &mut self, + driver_features: virtio::vsock::F, + ) -> Result<(), VirtioVsockError> { + let device_features = virtio::vsock::F::from(self.com_cfg.dev_features()); + + if device_features.requirements_satisfied() { + info!("Feature set wanted by vsock driver are in conformance with specification."); + } else { + return Err(VirtioVsockError::FeatureRequirementsNotMet(device_features)); + } + + if device_features.contains(driver_features) { + // If device supports subset of features write feature set to common config + self.com_cfg.set_drv_features(driver_features.into()); + Ok(()) + } else { + Err(VirtioVsockError::IncompatibleFeatureSets( + driver_features, + device_features, + )) + } + } + + /// Initializes the device in adherence to specification. Returns Some(VirtioVsockError) + /// upon failure and None in case everything worked as expected. + /// + /// See Virtio specification v1.1. - 3.1.1. + /// and v1.1. - 5.10.6 + pub(crate) fn init_dev(&mut self) -> Result<(), VirtioVsockError> { + // Reset + self.com_cfg.reset_dev(); + + // Indiacte device, that OS noticed it + self.com_cfg.ack_dev(); + + // Indicate device, that driver is able to handle it + self.com_cfg.set_drv(); + + let features = virtio::vsock::F::VERSION_1; + self.negotiate_features(features)?; + + // Indicates the device, that the current feature set is final for the driver + // and will not be changed. + self.com_cfg.features_ok(); + + // Checks if the device has accepted final set. This finishes feature negotiation. + if self.com_cfg.check_features() { + info!( + "Features have been negotiated between virtio socket device {:x} and driver.", + self.dev_cfg.dev_id + ); + // Set feature set in device config fur future use. + self.dev_cfg.features = features; + } else { + return Err(VirtioVsockError::FailFeatureNeg(self.dev_cfg.dev_id)); + } + + // create the queues and tell device about them + for i in 0..3u16 { + let vq = SplitVq::new( + &mut self.com_cfg, + &self.notif_cfg, + VqSize::from(VIRTIO_MAX_QUEUE_SIZE), + VqIndex::from(i), + self.dev_cfg.features.into(), + ) + .unwrap(); + self.vqueues.push(Rc::new(vq)); + } + + Ok(()) + } +} + +/// Error module of virtio socket device driver. +pub mod error { + /// Virtio socket device error enum. + #[derive(Debug, Copy, Clone)] + pub enum VirtioVsockError { + NoDevCfg(u16), + NoComCfg(u16), + NoIsrCfg(u16), + NoNotifCfg(u16), + FailFeatureNeg(u16), + /// Set of features does not adhere to the requirements of features + /// indicated by the specification + FeatureRequirementsNotMet(virtio::net::F), + /// The first u64 contains the feature bits wanted by the driver. + /// but which are incompatible with the device feature set, second u64. + IncompatibleFeatureSets(virtio::net::F, virtio::net::F), + } +} diff --git a/src/drivers/vsock/pci.rs b/src/drivers/vsock/pci.rs new file mode 100644 index 0000000000..1b968a39ec --- /dev/null +++ b/src/drivers/vsock/pci.rs @@ -0,0 +1,124 @@ +use alloc::vec::Vec; + +use crate::arch::pci::PciConfigRegion; +use crate::drivers::pci::PciDevice; +use crate::drivers::virtio::error::{self, VirtioError}; +use crate::drivers::virtio::transport::pci; +use crate::drivers::virtio::transport::pci::{PciCap, UniCapsColl}; +use crate::drivers::vsock::{VirtioVsockDriver, VsockDevCfg}; + +/// Virtio's socket device configuration structure. +/// See specification v1.1. - 5.11.4 +/// +#[derive(Debug, Copy, Clone)] +#[repr(C)] +pub(crate) struct VsockDevCfgRaw { + /// The guest_cid field contains the guest’s context ID, which uniquely identifies the device + /// for its lifetime. The upper 32 bits of the CID are reserved and zeroed. + guest_cid: u64, +} + +impl VirtioVsockDriver { + fn map_cfg(cap: &PciCap) -> Option { + let dev_cfg: &'static VsockDevCfgRaw = match pci::map_dev_cfg::(cap) { + Some(cfg) => cfg, + None => return None, + }; + + Some(VsockDevCfg { + raw: dev_cfg, + dev_id: cap.dev_id(), + features: virtio::net::F::empty(), + }) + } + + /// Instantiates a new VirtioVsockDriver struct, by checking the available + /// configuration structures and moving them into the struct. + pub fn new( + mut caps_coll: UniCapsColl, + device: &PciDevice, + ) -> Result { + let device_id = device.device_id(); + + let com_cfg = match caps_coll.get_com_cfg() { + Some(com_cfg) => com_cfg, + None => { + error!("No common config. Aborting!"); + return Err(error::VirtioVsockError::NoComCfg(device_id)); + } + }; + + let isr_stat = match caps_coll.get_isr_cfg() { + Some(isr_stat) => isr_stat, + None => { + error!("No ISR status config. Aborting!"); + return Err(error::VirtioVsockError::NoIsrCfg(device_id)); + } + }; + + let notif_cfg = match caps_coll.get_notif_cfg() { + Some(notif_cfg) => notif_cfg, + None => { + error!("No notif config. Aborting!"); + return Err(error::VirtioVsockError::NoNotifCfg(device_id)); + } + }; + + let dev_cfg = loop { + match caps_coll.get_dev_cfg() { + Some(cfg) => { + if let Some(dev_cfg) = VirtioVsockDriver::map_cfg(&cfg) { + break dev_cfg; + } + } + None => { + error!("No dev config. Aborting!"); + return Err(error::VirtioVsockError::NoDevCfg(device_id)); + } + } + }; + + Ok(VirtioVsockDriver { + dev_cfg, + com_cfg, + isr_stat, + notif_cfg, + vqueues: Vec::new(), + }) + } + + /// Initializes virtio socket device + /// + /// Returns a driver instance of VirtioVsockDriver. + pub(crate) fn init( + device: &PciDevice, + ) -> Result { + let mut drv = match pci::map_caps(device) { + Ok(caps) => match VirtioVsockDriver::new(caps, device) { + Ok(driver) => driver, + Err(vsock_err) => { + error!("Initializing new virtio socket device driver failed. Aborting!"); + return Err(VirtioError::VsockDriver(vsock_err)); + } + }, + Err(pci_error) => { + error!("Mapping capabilities failed. Aborting!"); + return Err(VirtioError::FromPci(pci_error)); + } + }; + + match drv.init_dev() { + Ok(_) => { + info!( + "Socket device with cid {:x}, has been initialized by driver!", + drv.dev_cfg.raw.guest_cid + ); + Ok(drv) + } + Err(fs_err) => { + drv.set_failed(); + Err(VirtioError::VsockDriver(fs_err)) + } + } + } +} From db20f73ee1974974c53df308058b56223754048f Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Mon, 22 Jul 2024 22:21:52 +0200 Subject: [PATCH 02/18] allow dead code in vsock driver --- src/drivers/vsock/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/drivers/vsock/mod.rs b/src/drivers/vsock/mod.rs index decd3b125e..61dc4a04c7 100644 --- a/src/drivers/vsock/mod.rs +++ b/src/drivers/vsock/mod.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + #[cfg(feature = "pci")] pub mod pci; From 12474e266570b1e7bc63455b18c2e1abdd1d6177 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Mon, 22 Jul 2024 22:28:45 +0200 Subject: [PATCH 03/18] allow enum_variant_names to avoid clippy warnings --- src/drivers/pci.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/drivers/pci.rs b/src/drivers/pci.rs index 7add681a3f..1b379ee107 100644 --- a/src/drivers/pci.rs +++ b/src/drivers/pci.rs @@ -298,6 +298,7 @@ pub(crate) fn print_information() { } #[allow(clippy::large_enum_variant)] +#[allow(clippy::enum_variant_names)] pub(crate) enum PciDriver { #[cfg(feature = "fuse")] VirtioFs(InterruptTicketMutex), From 055c6276f325fa8b25cffd55e64fedfbd681c866 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Mon, 22 Jul 2024 22:34:10 +0200 Subject: [PATCH 04/18] define vsock driver as pci driver --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 0c8f605125..68c073d86d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,7 @@ dhcpv4 = [ ] fs = ["fuse"] fuse = ["pci", "dep:fuse-abi", "fuse-abi/num_enum"] -vsock = [] +vsock = ["pci"] fsgsbase = [] gem-net = ["tcp", "dep:tock-registers"] newlib = [] From 1ef24e88061c295fcaaa4db56065728a3c19fc24 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Tue, 23 Jul 2024 21:45:46 +0200 Subject: [PATCH 05/18] extend socket interface to support Vsocks - fix typos and use virtio::vsock::F instead of virtio::net::F - use sa_family to check type of the socket address - revise driver to share interrupt handles between virtio devices - remove compiler warnings --- src/arch/x86_64/kernel/interrupts.rs | 4 +- src/config.rs | 2 +- src/drivers/net/gem.rs | 16 +- src/drivers/net/mod.rs | 40 +-- src/drivers/net/rtl8139.rs | 17 +- src/drivers/virtio/transport/mmio.rs | 25 +- src/drivers/virtio/transport/mod.rs | 49 ++++ src/drivers/virtio/transport/pci.rs | 22 +- src/drivers/vsock/mod.rs | 4 +- src/drivers/vsock/pci.rs | 1 + src/fd/mod.rs | 75 +++-- src/fd/socket/mod.rs | 2 + src/fd/socket/tcp.rs | 71 +++-- src/fd/socket/udp.rs | 59 ++-- src/fd/socket/vsock.rs | 96 +++++++ src/syscalls/mod.rs | 2 +- src/syscalls/socket.rs | 415 +++++++++++++++++++-------- 17 files changed, 665 insertions(+), 235 deletions(-) create mode 100644 src/fd/socket/vsock.rs diff --git a/src/arch/x86_64/kernel/interrupts.rs b/src/arch/x86_64/kernel/interrupts.rs index c5709dbd3b..6a3c854755 100644 --- a/src/arch/x86_64/kernel/interrupts.rs +++ b/src/arch/x86_64/kernel/interrupts.rs @@ -9,7 +9,7 @@ use hermit_sync::{InterruptSpinMutex, InterruptTicketMutex}; use x86_64::instructions::interrupts::enable_and_hlt; pub use x86_64::instructions::interrupts::{disable, enable}; use x86_64::set_general_handler; -#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp", feature = "vsock"))] use x86_64::structures::idt; use x86_64::structures::idt::InterruptDescriptorTable; pub use x86_64::structures::idt::InterruptStackFrame as ExceptionStackFrame; @@ -155,7 +155,7 @@ pub(crate) fn install() { IRQ_NAMES.lock().insert(7, "FPU"); } -#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp", feature = "vsock"))] pub fn irq_install_handler(irq_number: u8, handler: idt::HandlerFunc) { debug!("Install handler for interrupt {}", irq_number); diff --git a/src/config.rs b/src/config.rs index 8b32b25ca8..2768e37384 100644 --- a/src/config.rs +++ b/src/config.rs @@ -12,5 +12,5 @@ pub(crate) const VIRTIO_MAX_QUEUE_SIZE: u16 = 2048; pub(crate) const VIRTIO_MAX_QUEUE_SIZE: u16 = 1024; /// Default keep alive interval in milliseconds -#[cfg(any(feature = "tcp", feature = "udp"))] +#[cfg(feature = "tcp")] pub(crate) const DEFAULT_KEEP_ALIVE_INTERVAL: u64 = 75000; diff --git a/src/drivers/net/gem.rs b/src/drivers/net/gem.rs index 89bb76e362..d9e3ed0a9a 100644 --- a/src/drivers/net/gem.rs +++ b/src/drivers/net/gem.rs @@ -197,6 +197,18 @@ pub enum GEMError { Unknown, } +fn gem_irqhandler() { + use crate::scheduler::PerCoreSchedulerExt; + + debug!("Receive network interrupt"); + + // PLIC end of interrupt + crate::arch::kernel::interrupts::external_eoi(); + let _ = network_irqhandler(); + + core_scheduler().reschedule(); +} + /// GEM network driver struct. /// /// Struct allows to control device queus as also @@ -674,9 +686,9 @@ pub fn init_device( // Configure Interrupts debug!( "Install interrupt handler for GEM at {:x}", - network_irqhandler as usize + gem_irqhandler as usize ); - irq_install_handler(irq, network_irqhandler); + irq_install_handler(irq, gem_irqhandler); (*gem).int_enable.write(Interrupts::FRAMERX::SET); // + Interrupts::TXCOMPL::SET // Enable the Controller (again?) diff --git a/src/drivers/net/mod.rs b/src/drivers/net/mod.rs index 810db6a447..0927f5469a 100644 --- a/src/drivers/net/mod.rs +++ b/src/drivers/net/mod.rs @@ -7,16 +7,10 @@ pub mod virtio; use smoltcp::phy::ChecksumCapabilities; -#[cfg(target_arch = "x86_64")] -use crate::arch::kernel::apic; #[allow(unused_imports)] use crate::arch::kernel::core_local::*; -#[cfg(target_arch = "x86_64")] -use crate::arch::kernel::interrupts::ExceptionStackFrame; #[cfg(not(feature = "pci"))] use crate::arch::kernel::mmio as hardware; -#[cfg(target_arch = "aarch64")] -use crate::arch::scheduler::State; #[cfg(feature = "pci")] use crate::drivers::pci as hardware; use crate::executor::device::{RxToken, TxToken}; @@ -47,7 +41,7 @@ pub(crate) trait NetworkDriver { } #[inline] -fn _irqhandler() -> bool { +pub(crate) fn network_irqhandler() -> bool { let result = if let Some(driver) = hardware::get_network_driver() { driver.lock().handle_interrupt() } else { @@ -60,35 +54,3 @@ fn _irqhandler() -> bool { result } - -#[cfg(target_arch = "aarch64")] -pub(crate) fn network_irqhandler(_state: &State) -> bool { - debug!("Receive network interrupt"); - _irqhandler() -} - -#[cfg(target_arch = "x86_64")] -pub(crate) extern "x86-interrupt" fn network_irqhandler(stack_frame: ExceptionStackFrame) { - crate::arch::x86_64::swapgs(&stack_frame); - use crate::scheduler::PerCoreSchedulerExt; - - debug!("Receive network interrupt"); - apic::eoi(); - let _ = _irqhandler(); - - core_scheduler().reschedule(); - crate::arch::x86_64::swapgs(&stack_frame); -} - -#[cfg(target_arch = "riscv64")] -pub fn network_irqhandler() { - use crate::scheduler::PerCoreSchedulerExt; - - debug!("Receive network interrupt"); - - // PLIC end of interrupt - crate::arch::kernel::interrupts::external_eoi(); - let _ = _irqhandler(); - - core_scheduler().reschedule(); -} diff --git a/src/drivers/net/rtl8139.rs b/src/drivers/net/rtl8139.rs index cc0dffede0..3366c1739b 100644 --- a/src/drivers/net/rtl8139.rs +++ b/src/drivers/net/rtl8139.rs @@ -9,6 +9,8 @@ use pci_types::{Bar, CommandRegister, InterruptLine, MAX_BARS}; use x86::io::*; use crate::arch::kernel::core_local::increment_irq_counter; +#[cfg(target_arch = "x86_64")] +use crate::arch::kernel::interrupts::ExceptionStackFrame; use crate::arch::kernel::interrupts::*; use crate::arch::mm::paging::virt_to_phys; use crate::arch::mm::VirtAddr; @@ -419,6 +421,19 @@ impl Drop for RTL8139Driver { } } +extern "x86-interrupt" fn rtl8139_irqhandler(stack_frame: ExceptionStackFrame) { + crate::arch::x86_64::swapgs(&stack_frame); + use crate::arch::kernel::core_local::core_scheduler; + use crate::scheduler::PerCoreSchedulerExt; + + debug!("Receive network interrupt"); + crate::arch::x86_64::kernel::apic::eoi(); + let _ = network_irqhandler(); + + core_scheduler().reschedule(); + crate::arch::x86_64::swapgs(&stack_frame); +} + pub(crate) fn init_device( device: &PciDevice, ) -> Result { @@ -573,7 +588,7 @@ pub(crate) fn init_device( // Install interrupt handler for RTL8139 debug!("Install interrupt handler for RTL8139 at {}", irq); - irq_install_handler(irq, network_irqhandler); + irq_install_handler(irq, rtl8139_irqhandler); add_irq_name(irq, "rtl8139_net"); Ok(RTL8139Driver { diff --git a/src/drivers/virtio/transport/mmio.rs b/src/drivers/virtio/transport/mmio.rs index da826ec3b0..ce29e3bf1d 100644 --- a/src/drivers/virtio/transport/mmio.rs +++ b/src/drivers/virtio/transport/mmio.rs @@ -18,10 +18,9 @@ use crate::arch::kernel::interrupts::*; use crate::arch::mm::PhysAddr; use crate::drivers::error::DriverError; #[cfg(any(feature = "tcp", feature = "udp"))] -use crate::drivers::net::network_irqhandler; -#[cfg(any(feature = "tcp", feature = "udp"))] use crate::drivers::net::virtio::VirtioNetDriver; use crate::drivers::virtio::error::VirtioError; +use crate::drivers::virtio::transport::virtio_irqhandler; pub struct VqCfgHandler<'a> { vq_index: u16, @@ -386,9 +385,9 @@ pub(crate) fn init_device( Ok(virt_net_drv) => { info!("Virtio network driver initialized."); // Install interrupt handler - irq_install_handler(irq_no, network_irqhandler); + irq_install_handler(irq_no, virtio_irqhandler); #[cfg(not(target_arch = "riscv64"))] - add_irq_name(irq_no, "virtio_net"); + add_irq_name(irq_no, "virtio"); Ok(VirtioDriver::Network(virt_net_drv)) } @@ -398,6 +397,24 @@ pub(crate) fn init_device( } } } + #[cfg(feature = "vsock")] + virtio::Id::Vsock => { + match VirtioVsockDriver::init(dev_id, registers, irq_no) { + Ok(virt_net_drv) => { + info!("Virtio sock driver initialized."); + // Install interrupt handler + irq_install_handler(irq_no, virtio_irqhandler); + #[cfg(not(target_arch = "riscv64"))] + add_irq_name(irq_no, "virtio"); + + Ok(VirtioDriver::Vsock(virt_vsock_drv)) + } + Err(virtio_error) => { + error!("Virtio sock driver could not be initialized with device"); + Err(DriverError::InitVirtioDevFail(virtio_error)) + } + } + } device_id => { error!("Device with id {device_id:?} is currently not supported!"); // Return Driver error inidacting device is not supported diff --git a/src/drivers/virtio/transport/mod.rs b/src/drivers/virtio/transport/mod.rs index 50df4b30c9..9082e7d7b4 100644 --- a/src/drivers/virtio/transport/mod.rs +++ b/src/drivers/virtio/transport/mod.rs @@ -8,3 +8,52 @@ pub mod mmio; #[cfg(feature = "pci")] pub mod pci; + +#[cfg(target_arch = "x86_64")] +use crate::arch::kernel::interrupts::ExceptionStackFrame; +#[cfg(target_arch = "aarch64")] +use crate::arch::scheduler::State; +#[cfg(any(feature = "tcp", feature = "udp"))] +use crate::drivers::net::network_irqhandler; + +#[cfg(target_arch = "aarch64")] +pub(crate) fn virtio_irqhandler(_state: &State) -> bool { + debug!("Receive virtio interrupt"); + cfg_if::cfg_if! { + if #[cfg(any(feature = "tcp", feature = "udp"))] { + network_irqhandler() + } else { + false + } + } +} + +#[cfg(target_arch = "x86_64")] +pub(crate) extern "x86-interrupt" fn virtio_irqhandler(stack_frame: ExceptionStackFrame) { + crate::arch::x86_64::swapgs(&stack_frame); + use crate::arch::kernel::core_local::core_scheduler; + use crate::scheduler::PerCoreSchedulerExt; + + info!("Receive virtio interrupt"); + crate::kernel::apic::eoi(); + #[cfg(any(feature = "tcp", feature = "udp"))] + let _ = network_irqhandler(); + + core_scheduler().reschedule(); + crate::arch::x86_64::swapgs(&stack_frame); +} + +#[cfg(target_arch = "riscv64")] +pub(crate) fn virtio_irqhandler() { + use crate::arch::kernel::core_local::core_scheduler; + use crate::scheduler::PerCoreSchedulerExt; + + debug!("Receive virtio interrupt"); + + // PLIC end of interrupt + crate::arch::kernel::interrupts::external_eoi(); + #[cfg(any(feature = "tcp", feature = "udp"))] + let _ = network_irqhandler(); + + core_scheduler().reschedule(); +} diff --git a/src/drivers/virtio/transport/pci.rs b/src/drivers/virtio/transport/pci.rs index b2f119eb16..adeaa22c0e 100644 --- a/src/drivers/virtio/transport/pci.rs +++ b/src/drivers/virtio/transport/pci.rs @@ -16,7 +16,10 @@ use virtio::{le16, le32, DeviceStatus}; use volatile::access::ReadOnly; use volatile::{VolatilePtr, VolatileRef}; -#[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] +#[cfg(all( + not(feature = "rtl8139"), + any(feature = "tcp", feature = "udp", feature = "vsock") +))] use crate::arch::kernel::interrupts::*; use crate::arch::memory_barrier; use crate::arch::mm::PhysAddr; @@ -25,12 +28,11 @@ use crate::drivers::error::DriverError; #[cfg(feature = "fuse")] use crate::drivers::fs::virtio_fs::VirtioFsDriver; #[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] -use crate::drivers::net::network_irqhandler; -#[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] use crate::drivers::net::virtio::VirtioNetDriver; use crate::drivers::pci::error::PciError; use crate::drivers::pci::PciDevice; use crate::drivers::virtio::error::VirtioError; +use crate::drivers::virtio::transport::virtio_irqhandler; #[cfg(feature = "vsock")] use crate::drivers::vsock::VirtioVsockDriver; @@ -967,13 +969,21 @@ pub(crate) fn init_device( let irq = device.get_irq().unwrap(); info!("Install virtio interrupt handler at line {}", irq); // Install interrupt handler - irq_install_handler(irq, network_irqhandler); - add_irq_name(irq, "virtio_net"); + irq_install_handler(irq, virtio_irqhandler); + add_irq_name(irq, "virtio"); Ok(drv) } #[cfg(feature = "vsock")] - VirtioDriver::Vsock(_) => Ok(drv), + VirtioDriver::Vsock(_) => { + let irq = device.get_irq().unwrap(); + info!("Install virtio interrupt handler at line {}", irq); + // Install interrupt handler + irq_install_handler(irq, virtio_irqhandler); + add_irq_name(irq, "virtio"); + + Ok(drv) + } #[cfg(feature = "fuse")] VirtioDriver::FileSystem(_) => Ok(drv), } diff --git a/src/drivers/vsock/mod.rs b/src/drivers/vsock/mod.rs index 61dc4a04c7..78d4b86c3c 100644 --- a/src/drivers/vsock/mod.rs +++ b/src/drivers/vsock/mod.rs @@ -6,6 +6,7 @@ pub mod pci; use alloc::rc::Rc; use alloc::vec::Vec; +use pci_types::InterruptLine; use virtio::FeatureBits; use crate::config::VIRTIO_MAX_QUEUE_SIZE; @@ -23,7 +24,7 @@ use crate::drivers::vsock::pci::VsockDevCfgRaw; pub(crate) struct VsockDevCfg { pub raw: &'static VsockDevCfgRaw, pub dev_id: u16, - pub features: virtio::net::F, + pub features: virtio::vsock::F, } pub(crate) struct VirtioVsockDriver { @@ -31,6 +32,7 @@ pub(crate) struct VirtioVsockDriver { pub(super) com_cfg: ComCfg, pub(super) isr_stat: IsrStatus, pub(super) notif_cfg: NotifCfg, + pub(super) irq: InterruptLine, pub(super) vqueues: Vec>, } diff --git a/src/drivers/vsock/pci.rs b/src/drivers/vsock/pci.rs index 1b968a39ec..14104e6b87 100644 --- a/src/drivers/vsock/pci.rs +++ b/src/drivers/vsock/pci.rs @@ -83,6 +83,7 @@ impl VirtioVsockDriver { com_cfg, isr_stat, notif_cfg, + irq: device.get_irq().unwrap(), vqueues: Vec::new(), }) } diff --git a/src/fd/mod.rs b/src/fd/mod.rs index c2018b2324..c6c893f650 100644 --- a/src/fd/mod.rs +++ b/src/fd/mod.rs @@ -16,7 +16,7 @@ use crate::fs::{DirectoryEntry, FileAttr, SeekWhence}; use crate::io; mod eventfd; -#[cfg(any(feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] pub(crate) mod socket; pub(crate) mod stdio; @@ -24,6 +24,36 @@ pub(crate) const STDIN_FILENO: FileDescriptor = 0; pub(crate) const STDOUT_FILENO: FileDescriptor = 1; pub(crate) const STDERR_FILENO: FileDescriptor = 2; +#[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] +#[allow(clippy::upper_case_acronyms, dead_code)] +#[derive(Debug, Clone, Copy)] +pub(crate) enum AddressFamily { + #[cfg(any(feature = "tcp", feature = "udp"))] + INET, + #[cfg(any(feature = "tcp", feature = "udp"))] + INET6, + #[cfg(feature = "vsock")] + VSOCK, +} + +#[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] +#[derive(Debug)] +pub(crate) enum Endpoint { + #[cfg(any(feature = "tcp", feature = "udp"))] + Ip(IpEndpoint), + #[cfg(feature = "vsock")] + Vsock(()), +} + +#[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] +#[derive(Debug)] +pub(crate) enum ListenEndpoint { + #[cfg(any(feature = "tcp", feature = "udp"))] + Ip(IpListenEndpoint), + #[cfg(feature = "vsock")] + Vsock(socket::vsock::VsockListenEndpoint), +} + #[allow(dead_code)] #[derive(Debug, PartialEq)] pub(crate) enum SocketOption { @@ -186,57 +216,57 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug + DynClone { } /// `accept` a connection on a socket - #[cfg(any(feature = "tcp", feature = "udp"))] - fn accept(&self) -> io::Result { + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] + fn accept(&self) -> io::Result { Err(io::Error::EINVAL) } /// initiate a connection on a socket - #[cfg(any(feature = "tcp", feature = "udp"))] - fn connect(&self, _endpoint: IpEndpoint) -> io::Result<()> { + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] + fn connect(&self, _endpoint: Endpoint) -> io::Result<()> { Err(io::Error::EINVAL) } /// `bind` a name to a socket - #[cfg(any(feature = "tcp", feature = "udp"))] - fn bind(&self, _name: IpListenEndpoint) -> io::Result<()> { + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] + fn bind(&self, _name: ListenEndpoint) -> io::Result<()> { Err(io::Error::EINVAL) } /// `listen` for connections on a socket - #[cfg(any(feature = "tcp", feature = "udp"))] + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] fn listen(&self, _backlog: i32) -> io::Result<()> { Err(io::Error::EINVAL) } /// `setsockopt` sets options on sockets - #[cfg(any(feature = "tcp", feature = "udp"))] + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] fn setsockopt(&self, _opt: SocketOption, _optval: bool) -> io::Result<()> { Err(io::Error::EINVAL) } /// `getsockopt` gets options on sockets - #[cfg(any(feature = "tcp", feature = "udp"))] + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] fn getsockopt(&self, _opt: SocketOption) -> io::Result { Err(io::Error::EINVAL) } /// `getsockname` gets socket name - #[cfg(any(feature = "tcp", feature = "udp"))] - fn getsockname(&self) -> Option { + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] + fn getsockname(&self) -> Option { None } /// `getpeername` get address of connected peer - #[cfg(any(feature = "tcp", feature = "udp"))] + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] #[allow(dead_code)] - fn getpeername(&self) -> Option { + fn getpeername(&self) -> Option { None } /// receive a message from a socket - #[cfg(any(feature = "tcp", feature = "udp"))] - fn recvfrom(&self, _buffer: &mut [u8]) -> io::Result<(usize, IpEndpoint)> { + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] + fn recvfrom(&self, _buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> { Err(io::Error::ENOSYS) } @@ -247,13 +277,13 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug + DynClone { /// If a peer address has been prespecified, either the message shall /// be sent to the address specified by dest_addr (overriding the pre-specified peer /// address). - #[cfg(any(feature = "tcp", feature = "udp"))] - fn sendto(&self, _buffer: &[u8], _endpoint: IpEndpoint) -> io::Result { + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] + fn sendto(&self, _buffer: &[u8], _endpoint: Endpoint) -> io::Result { Err(io::Error::ENOSYS) } /// shut down part of a full-duplex connection - #[cfg(any(feature = "tcp", feature = "udp"))] + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] fn shutdown(&self, _how: i32) -> io::Result<()> { Err(io::Error::ENOSYS) } @@ -263,6 +293,13 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug + DynClone { fn ioctl(&self, _cmd: IoCtl, _value: bool) -> io::Result<()> { Err(io::Error::ENOSYS) } + + /// Sockets returns the supported address family + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] + #[allow(dead_code)] + fn get_address_family(&self) -> Option { + None + } } pub(crate) fn read(fd: FileDescriptor, buf: &mut [u8]) -> io::Result { diff --git a/src/fd/socket/mod.rs b/src/fd/socket/mod.rs index 7a41790273..1ccbb1f1bc 100644 --- a/src/fd/socket/mod.rs +++ b/src/fd/socket/mod.rs @@ -2,3 +2,5 @@ pub(crate) mod tcp; #[cfg(feature = "udp")] pub(crate) mod udp; +#[cfg(feature = "vsock")] +pub(crate) mod vsock; diff --git a/src/fd/socket/tcp.rs b/src/fd/socket/tcp.rs index 79c03b82f4..87c3837ce1 100644 --- a/src/fd/socket/tcp.rs +++ b/src/fd/socket/tcp.rs @@ -8,11 +8,13 @@ use async_trait::async_trait; use smoltcp::iface; use smoltcp::socket::tcp; use smoltcp::time::Duration; -use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; +use smoltcp::wire::{IpEndpoint, IpVersion}; use crate::executor::block_on; use crate::executor::network::{now, Handle, NetworkState, NIC}; -use crate::fd::{IoCtl, ObjectInterface, PollEvent, SocketOption}; +use crate::fd::{ + AddressFamily, Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent, SocketOption, +}; use crate::{io, DEFAULT_KEEP_ALIVE_INTERVAL}; /// further receives will be disallowed @@ -304,45 +306,59 @@ impl ObjectInterface for Socket { Ok(pos) } - fn bind(&self, endpoint: IpListenEndpoint) -> io::Result<()> { - self.port.store(endpoint.port, Ordering::Release); - Ok(()) + fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { + #[allow(irrefutable_let_patterns)] + if let ListenEndpoint::Ip(endpoint) = endpoint { + self.port.store(endpoint.port, Ordering::Release); + Ok(()) + } else { + Err(io::Error::EIO) + } } - fn connect(&self, endpoint: IpEndpoint) -> io::Result<()> { - if self.nonblocking.load(Ordering::Acquire) { - block_on(self.async_connect(endpoint), Some(Duration::ZERO.into())).map_err(|x| { - if x == io::Error::ETIME { - io::Error::EAGAIN - } else { - x - } - }) + fn connect(&self, endpoint: Endpoint) -> io::Result<()> { + #[allow(irrefutable_let_patterns)] + if let Endpoint::Ip(endpoint) = endpoint { + if self.nonblocking.load(Ordering::Acquire) { + block_on(self.async_connect(endpoint), Some(Duration::ZERO.into())).map_err(|x| { + if x == io::Error::ETIME { + io::Error::EAGAIN + } else { + x + } + }) + } else { + block_on(self.async_connect(endpoint), None) + } } else { - block_on(self.async_connect(endpoint), None) + Err(io::Error::EIO) } } - fn accept(&self) -> io::Result { - if self.is_nonblocking() { + fn accept(&self) -> io::Result { + let endpoint = if self.is_nonblocking() { block_on(self.async_accept(), Some(Duration::ZERO.into())).map_err(|x| { if x == io::Error::ETIME { io::Error::EAGAIN } else { x } - }) + })? } else { - block_on(self.async_accept(), None) - } + block_on(self.async_accept(), None)? + }; + + Ok(Endpoint::Ip(endpoint)) } - fn getpeername(&self) -> Option { + fn getpeername(&self) -> Option { self.with(|socket| socket.remote_endpoint()) + .map(Endpoint::Ip) } - fn getsockname(&self) -> Option { + fn getsockname(&self) -> Option { self.with(|socket| socket.local_endpoint()) + .map(Endpoint::Ip) } fn is_nonblocking(&self) -> bool { @@ -411,6 +427,17 @@ impl ObjectInterface for Socket { Err(io::Error::EINVAL) } } + + fn get_address_family(&self) -> Option { + self.with(|socket| { + socket + .local_endpoint() + .map(|endpoint| match endpoint.addr.version() { + IpVersion::Ipv4 => AddressFamily::INET, + IpVersion::Ipv6 => AddressFamily::INET6, + }) + }) + } } impl Clone for Socket { diff --git a/src/fd/socket/udp.rs b/src/fd/socket/udp.rs index 59e62c04a6..fe5d834991 100644 --- a/src/fd/socket/udp.rs +++ b/src/fd/socket/udp.rs @@ -9,11 +9,11 @@ use crossbeam_utils::atomic::AtomicCell; use smoltcp::socket::udp; use smoltcp::socket::udp::UdpMetadata; use smoltcp::time::Duration; -use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; +use smoltcp::wire::{IpEndpoint, IpVersion}; use crate::executor::network::{now, Handle, NetworkState, NIC}; use crate::executor::{block_on, poll_on}; -use crate::fd::{IoCtl, ObjectInterface, PollEvent}; +use crate::fd::{AddressFamily, Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent}; use crate::io; #[derive(Debug)] @@ -51,7 +51,7 @@ impl Socket { .await } - async fn async_recvfrom(&self, buffer: &mut [u8]) -> io::Result<(usize, IpEndpoint)> { + async fn async_recvfrom(&self, buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> { future::poll_fn(|cx| { self.with(|socket| { if socket.is_open() { @@ -81,6 +81,7 @@ impl Socket { }) }) .await + .map(|(len, endpoint)| (len, Endpoint::Ip(endpoint))) } async fn async_write_with_meta(&self, buffer: &[u8], meta: &UdpMetadata) -> io::Result { @@ -154,29 +155,44 @@ impl ObjectInterface for Socket { .await } - fn bind(&self, endpoint: IpListenEndpoint) -> io::Result<()> { - self.with(|socket| socket.bind(endpoint).map_err(|_| io::Error::EADDRINUSE)) + fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { + #[allow(irrefutable_let_patterns)] + if let ListenEndpoint::Ip(endpoint) = endpoint { + self.with(|socket| socket.bind(endpoint).map_err(|_| io::Error::EADDRINUSE)) + } else { + Err(io::Error::EIO) + } } - fn connect(&self, endpoint: IpEndpoint) -> io::Result<()> { - self.endpoint.store(Some(endpoint)); - Ok(()) + fn connect(&self, endpoint: Endpoint) -> io::Result<()> { + #[allow(irrefutable_let_patterns)] + if let Endpoint::Ip(endpoint) = endpoint { + self.endpoint.store(Some(endpoint)); + Ok(()) + } else { + Err(io::Error::EIO) + } } - fn sendto(&self, buf: &[u8], endpoint: IpEndpoint) -> io::Result { - let meta = UdpMetadata::from(endpoint); + fn sendto(&self, buf: &[u8], endpoint: Endpoint) -> io::Result { + #[allow(irrefutable_let_patterns)] + if let Endpoint::Ip(endpoint) = endpoint { + let meta = UdpMetadata::from(endpoint); - if self.nonblocking.load(Ordering::Acquire) { - poll_on( - self.async_write_with_meta(buf, &meta), - Some(Duration::ZERO.into()), - ) + if self.nonblocking.load(Ordering::Acquire) { + poll_on( + self.async_write_with_meta(buf, &meta), + Some(Duration::ZERO.into()), + ) + } else { + poll_on(self.async_write_with_meta(buf, &meta), None) + } } else { - poll_on(self.async_write_with_meta(buf, &meta), None) + Err(io::Error::EIO) } } - fn recvfrom(&self, buf: &mut [u8]) -> io::Result<(usize, IpEndpoint)> { + fn recvfrom(&self, buf: &mut [u8]) -> io::Result<(usize, Endpoint)> { if self.nonblocking.load(Ordering::Acquire) { poll_on(self.async_recvfrom(buf), Some(Duration::ZERO.into())).map_err(|x| { if x == io::Error::ETIME { @@ -253,6 +269,15 @@ impl ObjectInterface for Socket { Err(io::Error::EINVAL) } } + + fn get_address_family(&self) -> Option { + self.endpoint + .load() + .map(|endpoint| match endpoint.addr.version() { + IpVersion::Ipv4 => AddressFamily::INET, + IpVersion::Ipv6 => AddressFamily::INET6, + }) + } } impl Clone for Socket { diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs new file mode 100644 index 0000000000..8e05f70fda --- /dev/null +++ b/src/fd/socket/vsock.rs @@ -0,0 +1,96 @@ +use core::sync::atomic::{AtomicBool, AtomicU32, Ordering}; + +use async_trait::async_trait; + +use crate::fd::{AddressFamily, Endpoint, IoCtl, ListenEndpoint, ObjectInterface}; +use crate::io; + +#[derive(Debug)] +pub(crate) struct VsockListenEndpoint { + port: u32, + #[allow(dead_code)] + cid: u32, +} + +impl VsockListenEndpoint { + pub const fn new(port: u32, cid: u32) -> Self { + Self { port, cid } + } +} + +#[derive(Debug)] +pub struct Socket { + port: AtomicU32, + nonblocking: AtomicBool, + listen: AtomicBool, +} + +impl Socket { + pub fn new() -> Self { + Self { + port: AtomicU32::new(0), + nonblocking: AtomicBool::new(false), + listen: AtomicBool::new(false), + } + } +} + +#[async_trait] +impl ObjectInterface for Socket { + fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { + info!("bind {:?}", endpoint); + match endpoint { + ListenEndpoint::Vsock(ep) => { + self.port.store(ep.port, Ordering::Release); + Ok(()) + } + #[cfg(any(feature = "tcp", feature = "udp"))] + _ => Err(io::Error::EINVAL), + } + } + + fn is_nonblocking(&self) -> bool { + self.nonblocking.load(Ordering::Acquire) + } + + fn listen(&self, _backlog: i32) -> io::Result<()> { + info!("listen"); + self.listen.store(true, Ordering::Relaxed); + Ok(()) + } + + fn accept(&self) -> io::Result { + info!("accept"); + Err(io::Error::EINVAL) + } + + fn ioctl(&self, cmd: IoCtl, value: bool) -> io::Result<()> { + if cmd == IoCtl::NonBlocking { + if value { + trace!("set vsock device to nonblocking mode"); + self.nonblocking.store(true, Ordering::Release); + } else { + trace!("set vsock device to blocking mode"); + self.nonblocking.store(false, Ordering::Release); + } + + Ok(()) + } else { + Err(io::Error::EINVAL) + } + } + + fn get_address_family(&self) -> Option { + Some(AddressFamily::VSOCK) + } +} + +impl Clone for Socket { + fn clone(&self) -> Self { + Self { + port: AtomicU32::new(self.port.load(Ordering::Acquire)), + nonblocking: AtomicBool::new(self.nonblocking.load(Ordering::Acquire)), + listen: AtomicBool::new(false), + } + } +} diff --git a/src/syscalls/mod.rs b/src/syscalls/mod.rs index 7c8b662eb0..ed50da6025 100644 --- a/src/syscalls/mod.rs +++ b/src/syscalls/mod.rs @@ -40,7 +40,7 @@ mod processor; #[cfg(feature = "newlib")] mod recmutex; mod semaphore; -#[cfg(any(feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] pub mod socket; mod spinlock; mod system; diff --git a/src/syscalls/socket.rs b/src/syscalls/socket.rs index 3a2a118724..8fe88b806c 100644 --- a/src/syscalls/socket.rs +++ b/src/syscalls/socket.rs @@ -3,22 +3,31 @@ use alloc::sync::Arc; use core::ffi::{c_char, c_void}; use core::mem::size_of; +#[allow(unused_imports)] use core::ops::DerefMut; +use cfg_if::cfg_if; #[cfg(any(feature = "tcp", feature = "udp"))] use smoltcp::wire::{IpAddress, IpEndpoint, IpListenEndpoint}; use crate::errno::*; +#[cfg(any(feature = "tcp", feature = "udp"))] use crate::executor::network::{NetworkState, NIC}; #[cfg(feature = "tcp")] use crate::fd::socket::tcp; #[cfg(feature = "udp")] use crate::fd::socket::udp; -use crate::fd::{get_object, insert_object, replace_object, ObjectInterface, SocketOption}; +#[cfg(feature = "vsock")] +use crate::fd::socket::vsock::{self, VsockListenEndpoint}; +use crate::fd::{ + get_object, insert_object, replace_object, Endpoint, ListenEndpoint, ObjectInterface, + SocketOption, +}; use crate::syscalls::IoCtl; pub const AF_INET: i32 = 0; pub const AF_INET6: i32 = 1; +pub const AF_VSOCK: i32 = 2; pub const IPPROTO_IP: i32 = 0; pub const IPPROTO_IPV6: i32 = 41; pub const IPPROTO_TCP: i32 = 6; @@ -92,6 +101,28 @@ pub struct sockaddr { pub sa_data: [c_char; 14], } +#[cfg(feature = "vsock")] +#[repr(C)] +#[derive(Debug, Copy, Clone, Default)] +pub struct sockaddr_vm { + pub svm_len: u8, + pub svm_family: sa_family_t, + pub svm_reserved1: u16, + pub svm_port: u32, + pub svm_cid: u32, + pub svm_zero: [u8; 4], +} + +#[cfg(feature = "vsock")] +impl From for VsockListenEndpoint { + fn from(addr: sockaddr_vm) -> VsockListenEndpoint { + let port = addr.svm_port; + let cid = addr.svm_cid; + + VsockListenEndpoint::new(port, cid) + } +} + #[repr(C)] #[derive(Debug, Default, Copy, Clone)] pub struct sockaddr_in { @@ -352,25 +383,21 @@ pub unsafe extern "C" fn sys_getaddrbyname( #[hermit_macro::system] #[no_mangle] pub extern "C" fn sys_socket(domain: i32, type_: SockType, protocol: i32) -> i32 { - debug!( + info!( "sys_socket: domain {}, type {:?}, protocol {}", domain, type_, protocol ); - if (domain != AF_INET && domain != AF_INET6) + if (domain != AF_INET && domain != AF_INET6 && domain != AF_VSOCK) || !type_.intersects(SockType::SOCK_STREAM | SockType::SOCK_DGRAM) || protocol != 0 { -EINVAL } else { - let mut guard = NIC.lock(); - - if let NetworkState::Initialized(nic) = guard.deref_mut() { - #[cfg(feature = "udp")] - if type_.contains(SockType::SOCK_DGRAM) { - let handle = nic.create_udp_handle().unwrap(); - drop(guard); - let socket = udp::Socket::new(handle); + #[cfg(feature = "vsock")] + { + if type_.contains(SockType::SOCK_STREAM) { + let socket = vsock::Socket::new(); if type_.contains(SockType::SOCK_NONBLOCK) { socket.ioctl(IoCtl::NonBlocking, true).unwrap(); @@ -380,26 +407,45 @@ pub extern "C" fn sys_socket(domain: i32, type_: SockType, protocol: i32) -> i32 return fd; } + } + #[cfg(any(feature = "tcp", feature = "udp"))] + { + let mut guard = NIC.lock(); + + if let NetworkState::Initialized(nic) = guard.deref_mut() { + #[cfg(feature = "udp")] + if type_.contains(SockType::SOCK_DGRAM) { + let handle = nic.create_udp_handle().unwrap(); + drop(guard); + let socket = udp::Socket::new(handle); + + if type_.contains(SockType::SOCK_NONBLOCK) { + socket.ioctl(IoCtl::NonBlocking, true).unwrap(); + } - #[cfg(feature = "tcp")] - if type_.contains(SockType::SOCK_STREAM) { - let handle = nic.create_tcp_handle().unwrap(); - drop(guard); - let socket = tcp::Socket::new(handle); + let fd = insert_object(Arc::new(socket)).expect("FD is already used"); - if type_.contains(SockType::SOCK_NONBLOCK) { - socket.ioctl(IoCtl::NonBlocking, true).unwrap(); + return fd; } - let fd = insert_object(Arc::new(socket)).expect("FD is already used"); + #[cfg(feature = "tcp")] + if type_.contains(SockType::SOCK_STREAM) { + let handle = nic.create_tcp_handle().unwrap(); + drop(guard); + let socket = tcp::Socket::new(handle); - return fd; - } + if type_.contains(SockType::SOCK_NONBLOCK) { + socket.ioctl(IoCtl::NonBlocking, true).unwrap(); + } - -EINVAL - } else { - -EINVAL + let fd = insert_object(Arc::new(socket)).expect("FD is already used"); + + return fd; + } + } } + + -EINVAL } } @@ -412,33 +458,55 @@ pub unsafe extern "C" fn sys_accept(fd: i32, addr: *mut sockaddr, addrlen: *mut |v| { (*v).accept().map_or_else( |e| -num::ToPrimitive::to_i32(&e).unwrap(), - |endpoint| { - let new_obj = dyn_clone::clone_box(&*v); - replace_object(fd, Arc::from(new_obj)).unwrap(); - let new_fd = insert_object(v).unwrap(); - - if !addr.is_null() && !addrlen.is_null() { - let addrlen = unsafe { &mut *addrlen }; - - match endpoint.addr { - IpAddress::Ipv4(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; - *addr = sockaddr_in::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); + |endpoint| match endpoint { + #[cfg(any(feature = "tcp", feature = "udp"))] + Endpoint::Ip(endpoint) => { + let new_obj = dyn_clone::clone_box(&*v); + replace_object(fd, Arc::from(new_obj)).unwrap(); + let new_fd = insert_object(v).unwrap(); + + if !addr.is_null() && !addrlen.is_null() { + let addrlen = unsafe { &mut *addrlen }; + + match endpoint.addr { + IpAddress::Ipv4(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; + *addr = sockaddr_in::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } } - } - IpAddress::Ipv6(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; - *addr = sockaddr_in6::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); + IpAddress::Ipv6(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; + *addr = sockaddr_in6::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } } } } + + new_fd } + #[cfg(feature = "vsock")] + Endpoint::Vsock(_) => { + let new_obj = dyn_clone::clone_box(&*v); + replace_object(fd, Arc::from(new_obj)).unwrap(); + let new_fd = insert_object(v).unwrap(); + + if !addr.is_null() && !addrlen.is_null() { + let addrlen = unsafe { &*addrlen }; + + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_vm) }; + *addr = sockaddr_vm::default(); + } + } - new_fd + warn!("unsupported device"); + + new_fd + } }, ) }, @@ -461,20 +529,44 @@ pub extern "C" fn sys_listen(fd: i32, backlog: i32) -> i32 { #[hermit_macro::system] #[no_mangle] pub unsafe extern "C" fn sys_bind(fd: i32, name: *const sockaddr, namelen: socklen_t) -> i32 { - let endpoint = if namelen == size_of::().try_into().unwrap() { - IpListenEndpoint::from(unsafe { *(name as *const sockaddr_in) }) - } else if namelen == size_of::().try_into().unwrap() { - IpListenEndpoint::from(unsafe { *(name as *const sockaddr_in6) }) - } else { + if name.is_null() { return -crate::errno::EINVAL; - }; + } + + let family: i32 = unsafe { (*name).sa_family.into() }; let obj = get_object(fd); obj.map_or_else( |e| -num::ToPrimitive::to_i32(&e).unwrap(), - |v| { - (*v).bind(endpoint) - .map_or_else(|e| -num::ToPrimitive::to_i32(&e).unwrap(), |_| 0) + |v| match family { + #[cfg(any(feature = "tcp", feature = "udp"))] + AF_INET => { + if namelen < size_of::().try_into().unwrap() { + return -crate::errno::EINVAL; + } + let endpoint = IpListenEndpoint::from(unsafe { *(name as *const sockaddr_in) }); + (*v).bind(ListenEndpoint::Ip(endpoint)) + .map_or_else(|e| -num::ToPrimitive::to_i32(&e).unwrap(), |_| 0) + } + #[cfg(any(feature = "tcp", feature = "udp"))] + AF_INET6 => { + if namelen < size_of::().try_into().unwrap() { + return -crate::errno::EINVAL; + } + let endpoint = IpListenEndpoint::from(unsafe { *(name as *const sockaddr_in6) }); + (*v).bind(ListenEndpoint::Ip(endpoint)) + .map_or_else(|e| -num::ToPrimitive::to_i32(&e).unwrap(), |_| 0) + } + #[cfg(feature = "vsock")] + AF_VSOCK => { + if namelen < size_of::().try_into().unwrap() { + return -crate::errno::EINVAL; + } + let endpoint = VsockListenEndpoint::from(unsafe { *(name as *const sockaddr_vm) }); + (*v).bind(ListenEndpoint::Vsock(endpoint)) + .map_or_else(|e| -num::ToPrimitive::to_i32(&e).unwrap(), |_| 0) + } + _ => -crate::errno::EINVAL, }, ) } @@ -482,12 +574,37 @@ pub unsafe extern "C" fn sys_bind(fd: i32, name: *const sockaddr, namelen: sockl #[hermit_macro::system] #[no_mangle] pub unsafe extern "C" fn sys_connect(fd: i32, name: *const sockaddr, namelen: socklen_t) -> i32 { - let endpoint = if namelen == size_of::().try_into().unwrap() { - IpEndpoint::from(unsafe { *(name as *const sockaddr_in) }) - } else if namelen == size_of::().try_into().unwrap() { - IpEndpoint::from(unsafe { *(name as *const sockaddr_in6) }) - } else { + if name.is_null() { return -crate::errno::EINVAL; + } + + let sa_family = unsafe { (*name).sa_family as i32 }; + + let endpoint = match sa_family { + #[cfg(any(feature = "tcp", feature = "udp"))] + AF_INET => { + if namelen < size_of::().try_into().unwrap() { + return -crate::errno::EINVAL; + } + Endpoint::Ip(IpEndpoint::from(unsafe { *(name as *const sockaddr_in) })) + } + #[cfg(any(feature = "tcp", feature = "udp"))] + AF_INET6 => { + if namelen < size_of::().try_into().unwrap() { + return -crate::errno::EINVAL; + } + Endpoint::Ip(IpEndpoint::from(unsafe { *(name as *const sockaddr_in6) })) + } + #[cfg(feature = "vsock")] + AF_VSOCK => { + if namelen < size_of::().try_into().unwrap() { + return -crate::errno::EINVAL; + } + Endpoint::Vsock(()) + } + _ => { + return -crate::errno::EINVAL; + } }; let obj = get_object(fd); @@ -515,21 +632,33 @@ pub unsafe extern "C" fn sys_getsockname( if !addr.is_null() && !addrlen.is_null() { let addrlen = unsafe { &mut *addrlen }; - match endpoint.addr { - IpAddress::Ipv4(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; - *addr = sockaddr_in::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); - } else { - return -crate::errno::EINVAL; + match endpoint { + #[cfg(any(feature = "tcp", feature = "udp"))] + Endpoint::Ip(endpoint) => match endpoint.addr { + IpAddress::Ipv4(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; + *addr = sockaddr_in::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return -crate::errno::EINVAL; + } } - } - IpAddress::Ipv6(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; - *addr = sockaddr_in6::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); + #[cfg(any(feature = "tcp", feature = "udp"))] + IpAddress::Ipv6(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; + *addr = sockaddr_in6::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return -crate::errno::EINVAL; + } + } + }, + #[cfg(feature = "vsock")] + Endpoint::Vsock(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + warn!("unsupported device"); } else { return -crate::errno::EINVAL; } @@ -643,21 +772,32 @@ pub unsafe extern "C" fn sys_getpeername( if !addr.is_null() && !addrlen.is_null() { let addrlen = unsafe { &mut *addrlen }; - match endpoint.addr { - IpAddress::Ipv4(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; - *addr = sockaddr_in::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); - } else { - return -crate::errno::EINVAL; + match endpoint { + #[cfg(any(feature = "tcp", feature = "udp"))] + Endpoint::Ip(endpoint) => match endpoint.addr { + IpAddress::Ipv4(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; + *addr = sockaddr_in::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return -crate::errno::EINVAL; + } } - } - IpAddress::Ipv6(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; - *addr = sockaddr_in6::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); + IpAddress::Ipv6(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; + *addr = sockaddr_in6::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return -crate::errno::EINVAL; + } + } + }, + #[cfg(feature = "vsock")] + Endpoint::Vsock(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + warn!("unsupported device"); } else { return -crate::errno::EINVAL; } @@ -741,25 +881,52 @@ pub unsafe extern "C" fn sys_sendto( addr: *const sockaddr, addr_len: socklen_t, ) -> isize { - let endpoint = if addr_len == size_of::().try_into().unwrap() { - IpEndpoint::from(unsafe { *(addr as *const sockaddr_in) }) - } else if addr_len == size_of::().try_into().unwrap() { - IpEndpoint::from(unsafe { *(addr as *const sockaddr_in6) }) - } else { + let endpoint; + + if addr.is_null() || addr_len == 0 { return (-crate::errno::EINVAL).try_into().unwrap(); - }; - let slice = unsafe { core::slice::from_raw_parts(buf, len) }; - let obj = get_object(fd); + } - obj.map_or_else( - |e| -num::ToPrimitive::to_isize(&e).unwrap(), - |v| { - (*v).sendto(slice, endpoint).map_or_else( - |e| -num::ToPrimitive::to_isize(&e).unwrap(), - |v| v.try_into().unwrap(), - ) - }, - ) + cfg_if! { + if #[cfg(any(feature = "tcp", feature = "udp"))] { + let sa_family = unsafe { (*addr).sa_family as i32 }; + + if sa_family == AF_INET { + if addr_len < size_of::().try_into().unwrap() { + return (-crate::errno::EINVAL).try_into().unwrap(); + } + + endpoint = Some(Endpoint::Ip(IpEndpoint::from(unsafe {*(addr as *const sockaddr_in)}))); + } else if sa_family == AF_INET6 { + if addr_len < size_of::().try_into().unwrap() { + return (-crate::errno::EINVAL).try_into().unwrap(); + } + + endpoint = Some(Endpoint::Ip(IpEndpoint::from(unsafe { *(addr as *const sockaddr_in6) }))); + } else { + endpoint = None; + } + } else { + endpoint = None; + } + } + + if let Some(endpoint) = endpoint { + let slice = unsafe { core::slice::from_raw_parts(buf, len) }; + let obj = get_object(fd); + + obj.map_or_else( + |e| -num::ToPrimitive::to_isize(&e).unwrap(), + |v| { + (*v).sendto(slice, endpoint).map_or_else( + |e| -num::ToPrimitive::to_isize(&e).unwrap(), + |v| v.try_into().unwrap(), + ) + }, + ) + } else { + (-crate::errno::EINVAL).try_into().unwrap() + } } #[hermit_macro::system] @@ -781,26 +948,34 @@ pub unsafe extern "C" fn sys_recvfrom( |e| -num::ToPrimitive::to_isize(&e).unwrap(), |(len, endpoint)| { if !addr.is_null() && !addrlen.is_null() { + #[allow(unused_variables)] let addrlen = unsafe { &mut *addrlen }; - match endpoint.addr { - IpAddress::Ipv4(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; - *addr = sockaddr_in::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); - } else { - return (-crate::errno::EINVAL).try_into().unwrap(); + match endpoint { + #[cfg(any(feature = "tcp", feature = "udp"))] + Endpoint::Ip(endpoint) => match endpoint.addr { + IpAddress::Ipv4(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; + *addr = sockaddr_in::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return (-crate::errno::EINVAL).try_into().unwrap(); + } } - } - IpAddress::Ipv6(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; - *addr = sockaddr_in6::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); - } else { - return (-crate::errno::EINVAL).try_into().unwrap(); + IpAddress::Ipv6(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; + *addr = sockaddr_in6::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return (-crate::errno::EINVAL).try_into().unwrap(); + } } + }, + #[cfg(feature = "vsock")] + _ => { + return (-crate::errno::EINVAL).try_into().unwrap(); } } } From b95b137d7cb163584bc6b8f814a93017a02a51c2 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Mon, 5 Aug 2024 23:52:42 +0200 Subject: [PATCH 06/18] add dependency to endian-num - vsock depends on endian-num --- Cargo.lock | 1 + Cargo.toml | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 0687394029..2cfedc8fac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -628,6 +628,7 @@ dependencies = [ "cfg-if", "crossbeam-utils", "dyn-clone", + "endian-num", "fdt", "float-cmp", "free-list", diff --git a/Cargo.toml b/Cargo.toml index 68c073d86d..516f5b6193 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,7 @@ dhcpv4 = [ ] fs = ["fuse"] fuse = ["pci", "dep:fuse-abi", "fuse-abi/num_enum"] -vsock = ["pci"] +vsock = ["pci", "endian-num"] fsgsbase = [] gem-net = ["tcp", "dep:tock-registers"] newlib = [] @@ -88,6 +88,7 @@ build-time = "0.1.3" cfg-if = "1" crossbeam-utils = { version = "0.8", default-features = false } dyn-clone = "1.0" +endian-num = { version = "0.1", optional = true } fdt = { version = "0.1", features = ["pretty-printing"] } free-list = "0.3" fuse-abi = { version = "0.1", features = ["zerocopy"], optional = true } From fecf8dbf716a84b3362ac367e44baf17420d5544 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Mon, 5 Aug 2024 23:55:51 +0200 Subject: [PATCH 07/18] move interrupt handler to the device drivers --- src/drivers/net/gem.rs | 8 ++++++-- src/drivers/net/mod.rs | 21 +-------------------- src/drivers/net/rtl8139.rs | 10 ++++++++-- src/drivers/net/virtio/mmio.rs | 5 +---- src/drivers/net/virtio/mod.rs | 21 +++++---------------- src/drivers/net/virtio/pci.rs | 1 - 6 files changed, 21 insertions(+), 45 deletions(-) diff --git a/src/drivers/net/gem.rs b/src/drivers/net/gem.rs index d9e3ed0a9a..19ab8486fc 100644 --- a/src/drivers/net/gem.rs +++ b/src/drivers/net/gem.rs @@ -19,7 +19,7 @@ use crate::arch::kernel::interrupts::*; use crate::arch::mm::paging::virt_to_phys; use crate::arch::mm::VirtAddr; use crate::drivers::error::DriverError; -use crate::drivers::net::{network_irqhandler, NetworkDriver}; +use crate::drivers::net::NetworkDriver; use crate::executor::device::{RxToken, TxToken}; //Base address of the control registers @@ -202,9 +202,13 @@ fn gem_irqhandler() { debug!("Receive network interrupt"); + crate::executor::run(); + // PLIC end of interrupt crate::arch::kernel::interrupts::external_eoi(); - let _ = network_irqhandler(); + if let Some(driver) = hardware::get_network_driver() { + driver.lock().handle_interrupt() + } core_scheduler().reschedule(); } diff --git a/src/drivers/net/mod.rs b/src/drivers/net/mod.rs index 0927f5469a..d4f72652a0 100644 --- a/src/drivers/net/mod.rs +++ b/src/drivers/net/mod.rs @@ -9,10 +9,6 @@ use smoltcp::phy::ChecksumCapabilities; #[allow(unused_imports)] use crate::arch::kernel::core_local::*; -#[cfg(not(feature = "pci"))] -use crate::arch::kernel::mmio as hardware; -#[cfg(feature = "pci")] -use crate::drivers::pci as hardware; use crate::executor::device::{RxToken, TxToken}; /// A trait for accessing the network interface @@ -37,20 +33,5 @@ pub(crate) trait NetworkDriver { /// Enable / disable the polling mode of the network interface fn set_polling_mode(&mut self, value: bool); /// Handle interrupt and check if a packet is available - fn handle_interrupt(&mut self) -> bool; -} - -#[inline] -pub(crate) fn network_irqhandler() -> bool { - let result = if let Some(driver) = hardware::get_network_driver() { - driver.lock().handle_interrupt() - } else { - debug!("Unable to handle interrupt!"); - false - }; - - // TODO: do we need it? - crate::executor::run(); - - result + fn handle_interrupt(&mut self); } diff --git a/src/drivers/net/rtl8139.rs b/src/drivers/net/rtl8139.rs index 3366c1739b..c48637b494 100644 --- a/src/drivers/net/rtl8139.rs +++ b/src/drivers/net/rtl8139.rs @@ -16,7 +16,7 @@ use crate::arch::mm::paging::virt_to_phys; use crate::arch::mm::VirtAddr; use crate::arch::pci::PciConfigRegion; use crate::drivers::error::DriverError; -use crate::drivers::net::{network_irqhandler, NetworkDriver}; +use crate::drivers::net::NetworkDriver; use crate::drivers::pci::PciDevice; use crate::executor::device::{RxToken, TxToken}; @@ -342,6 +342,8 @@ impl NetworkDriver for RTL8139Driver { let ret = (isr_contents & ISR_ROK) == ISR_ROK; + crate::executor::run(); + unsafe { outw( self.iobase + ISR, @@ -428,7 +430,11 @@ extern "x86-interrupt" fn rtl8139_irqhandler(stack_frame: ExceptionStackFrame) { debug!("Receive network interrupt"); crate::arch::x86_64::kernel::apic::eoi(); - let _ = network_irqhandler(); + if let Some(driver) = hardware::get_network_driver() { + driver.lock().handle_interrupt() + } else { + debug!("Unable to handle interrupt!"); + } core_scheduler().reschedule(); crate::arch::x86_64::swapgs(&stack_frame); diff --git a/src/drivers/net/virtio/mmio.rs b/src/drivers/net/virtio/mmio.rs index 8e86c2b796..ff73ef3b8e 100644 --- a/src/drivers/net/virtio/mmio.rs +++ b/src/drivers/net/virtio/mmio.rs @@ -20,7 +20,6 @@ impl VirtioNetDriver { pub fn new( dev_id: u16, mut registers: VolatileRef<'static, DeviceRegisters>, - irq: u8, ) -> Result { let dev_cfg_raw: &'static virtio::net::Config = unsafe { &*registers @@ -58,7 +57,6 @@ impl VirtioNetDriver { recv_vqs, send_vqs, num_vqs: 0, - irq, mtu, checksums: ChecksumCapabilities::default(), }) @@ -79,9 +77,8 @@ impl VirtioNetDriver { pub fn init( dev_id: u16, registers: VolatileRef<'static, DeviceRegisters>, - irq_no: u8, ) -> Result { - if let Ok(mut drv) = VirtioNetDriver::new(dev_id, registers, irq_no) { + if let Ok(mut drv) = VirtioNetDriver::new(dev_id, registers) { match drv.init_dev() { Err(error_code) => Err(VirtioError::NetDriver(error_code)), _ => { diff --git a/src/drivers/net/virtio/mod.rs b/src/drivers/net/virtio/mod.rs index 83cdf07fbb..6dbbcff019 100644 --- a/src/drivers/net/virtio/mod.rs +++ b/src/drivers/net/virtio/mod.rs @@ -14,7 +14,7 @@ use alloc::boxed::Box; use alloc::vec::Vec; use core::mem::MaybeUninit; -use pci_types::InterruptLine; +use align_address::Align; use smoltcp::phy::{Checksum, ChecksumCapabilities}; use smoltcp::wire::{EthernetFrame, Ipv4Packet, Ipv6Packet, ETHERNET_HEADER_LEN}; use virtio::net::{ConfigVolatileFieldAccess, Hdr, HdrF}; @@ -24,8 +24,6 @@ use volatile::VolatileRef; use self::constants::MAX_NUM_VQ; use self::error::VirtioNetError; -#[cfg(not(target_arch = "riscv64"))] -use crate::arch::kernel::core_local::increment_irq_counter; use crate::config::VIRTIO_MAX_QUEUE_SIZE; use crate::drivers::net::NetworkDriver; #[cfg(not(feature = "pci"))] @@ -249,8 +247,6 @@ pub(crate) struct VirtioNetDriver { pub(super) send_vqs: TxQueues, pub(super) num_vqs: u16, - #[cfg_attr(target_arch = "riscv64", allow(dead_code))] - pub(super) irq: InterruptLine, pub(super) mtu: u16, pub(super) checksums: ChecksumCapabilities, } @@ -394,22 +390,15 @@ impl NetworkDriver for VirtioNetDriver { } } - fn handle_interrupt(&mut self) -> bool { - #[cfg(not(target_arch = "riscv64"))] - increment_irq_counter(32 + self.irq); + fn handle_interrupt(&mut self) { + let _ = self.isr_stat.is_interrupt(); - let result = if self.isr_stat.is_interrupt() { - true - } else if self.isr_stat.is_cfg_change() { + if self.isr_stat.is_cfg_change() { info!("Configuration changes are not possible! Aborting"); todo!("Implement possibility to change config on the fly...") - } else { - false - }; + } self.isr_stat.acknowledge(); - - result } } diff --git a/src/drivers/net/virtio/pci.rs b/src/drivers/net/virtio/pci.rs index bcd68cfab7..a477b6509e 100644 --- a/src/drivers/net/virtio/pci.rs +++ b/src/drivers/net/virtio/pci.rs @@ -97,7 +97,6 @@ impl VirtioNetDriver { recv_vqs, send_vqs, num_vqs: 0, - irq: device.get_irq().unwrap(), mtu, checksums: ChecksumCapabilities::default(), }) From 2472726ac34b1befb25d111420bf5e4eb809a5d3 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Mon, 5 Aug 2024 23:59:20 +0200 Subject: [PATCH 08/18] extend executor to support vsock streams - rewrite virtio driver to support the executor --- src/drivers/pci.rs | 20 +- src/drivers/virtio/transport/mod.rs | 59 +++- src/drivers/virtio/transport/pci.rs | 8 + src/drivers/vsock/mod.rs | 434 +++++++++++++++++++++++++++- src/drivers/vsock/pci.rs | 11 +- src/executor/mod.rs | 4 + src/executor/vsock.rs | 174 +++++++++++ src/fd/mod.rs | 21 +- src/fd/socket/tcp.rs | 17 +- src/fd/socket/udp.rs | 11 +- src/fd/socket/vsock.rs | 172 ++++++++++- src/lib.rs | 1 + src/syscalls/socket.rs | 46 ++- 13 files changed, 881 insertions(+), 97 deletions(-) create mode 100644 src/executor/vsock.rs diff --git a/src/drivers/pci.rs b/src/drivers/pci.rs index 1b379ee107..a00f4a8637 100644 --- a/src/drivers/pci.rs +++ b/src/drivers/pci.rs @@ -135,6 +135,10 @@ impl PciDevice { } }; + if address == 0 { + return None; + } + debug!( "Mapping bar {} at {:#x} with length {:#x}", index, address, size @@ -329,6 +333,15 @@ impl PciDriver { } } + #[cfg(feature = "vsock")] + fn get_vsock_driver(&self) -> Option<&InterruptTicketMutex> { + #[allow(unreachable_patterns)] + match self { + Self::VirtioVsock(drv) => Some(drv), + _ => None, + } + } + #[cfg(feature = "fuse")] fn get_filesystem_driver(&self) -> Option<&InterruptTicketMutex> { match self { @@ -355,6 +368,11 @@ pub(crate) fn get_network_driver() -> Option<&'static InterruptTicketMutex Option<&'static InterruptTicketMutex> { + unsafe { PCI_DRIVERS.iter().find_map(|drv| drv.get_vsock_driver()) } +} + #[cfg(feature = "fuse")] pub(crate) fn get_filesystem_driver() -> Option<&'static InterruptTicketMutex> { unsafe { @@ -374,7 +392,7 @@ pub(crate) fn init_drivers() { }) } { info!( - "Found virtio network device with device id {:#x}", + "Found virtio device with device id {:#x}", adapter.device_id() ); diff --git a/src/drivers/virtio/transport/mod.rs b/src/drivers/virtio/transport/mod.rs index 9082e7d7b4..8484037400 100644 --- a/src/drivers/virtio/transport/mod.rs +++ b/src/drivers/virtio/transport/mod.rs @@ -9,22 +9,38 @@ pub mod mmio; #[cfg(feature = "pci")] pub mod pci; +use hermit_sync::OnceCell; + +#[cfg(not(target_arch = "riscv64"))] +use crate::arch::kernel::core_local::increment_irq_counter; #[cfg(target_arch = "x86_64")] use crate::arch::kernel::interrupts::ExceptionStackFrame; +#[cfg(all(feature = "vsock", not(feature = "pci")))] +use crate::arch::kernel::mmio as hardware; #[cfg(target_arch = "aarch64")] use crate::arch::scheduler::State; #[cfg(any(feature = "tcp", feature = "udp"))] -use crate::drivers::net::network_irqhandler; +use crate::drivers::net::NetworkDriver; +#[cfg(all(feature = "vsock", feature = "pci"))] +use crate::drivers::pci as hardware; + +/// All virtio devices share the interrupt number `VIRTIO_IRQ` +static VIRTIO_IRQ: OnceCell = OnceCell::new(); #[cfg(target_arch = "aarch64")] pub(crate) fn virtio_irqhandler(_state: &State) -> bool { debug!("Receive virtio interrupt"); - cfg_if::cfg_if! { - if #[cfg(any(feature = "tcp", feature = "udp"))] { - network_irqhandler() - } else { - false - } + + crate::executor::run(); + + #[cfg(any(feature = "tcp", feature = "udp"))] + if let Some(driver) = hardware::get_network_driver() { + driver.lock().handle_interrupt() + } + + #[cfg(feature = "vsock")] + if let Some(driver) = hardware::get_vsock_driver() { + driver.lock().handle_interrupt(); } } @@ -34,10 +50,22 @@ pub(crate) extern "x86-interrupt" fn virtio_irqhandler(stack_frame: ExceptionSta use crate::arch::kernel::core_local::core_scheduler; use crate::scheduler::PerCoreSchedulerExt; - info!("Receive virtio interrupt"); + debug!("Receive virtio interrupt"); + + increment_irq_counter(32 + VIRTIO_IRQ.get().unwrap()); + + crate::executor::run(); crate::kernel::apic::eoi(); + #[cfg(any(feature = "tcp", feature = "udp"))] - let _ = network_irqhandler(); + if let Some(driver) = hardware::get_network_driver() { + driver.lock().handle_interrupt() + } + + #[cfg(feature = "vsock")] + if let Some(driver) = hardware::get_vsock_driver() { + driver.lock().handle_interrupt(); + } core_scheduler().reschedule(); crate::arch::x86_64::swapgs(&stack_frame); @@ -50,10 +78,21 @@ pub(crate) fn virtio_irqhandler() { debug!("Receive virtio interrupt"); + increment_irq_counter(32 + VIRTIO_IRQ.get().unwrap()); + + crate::executor::run(); + // PLIC end of interrupt crate::arch::kernel::interrupts::external_eoi(); #[cfg(any(feature = "tcp", feature = "udp"))] - let _ = network_irqhandler(); + if let Some(driver) = hardware::get_network_driver() { + driver.lock().handle_interrupt() + } + + #[cfg(feature = "vsock")] + if let Some(driver) = hardware::get_vsock_driver() { + driver.lock().handle_interrupt(); + } core_scheduler().reschedule(); } diff --git a/src/drivers/virtio/transport/pci.rs b/src/drivers/virtio/transport/pci.rs index adeaa22c0e..f725b06499 100644 --- a/src/drivers/virtio/transport/pci.rs +++ b/src/drivers/virtio/transport/pci.rs @@ -966,7 +966,11 @@ pub(crate) fn init_device( match &drv { #[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] VirtioDriver::Network(_) => { + use crate::drivers::virtio::transport::VIRTIO_IRQ; + let irq = device.get_irq().unwrap(); + let _ = VIRTIO_IRQ.try_insert(irq); + info!("Install virtio interrupt handler at line {}", irq); // Install interrupt handler irq_install_handler(irq, virtio_irqhandler); @@ -976,7 +980,11 @@ pub(crate) fn init_device( } #[cfg(feature = "vsock")] VirtioDriver::Vsock(_) => { + use crate::drivers::virtio::transport::VIRTIO_IRQ; + let irq = device.get_irq().unwrap(); + let _ = VIRTIO_IRQ.try_insert(irq); + info!("Install virtio interrupt handler at line {}", irq); // Install interrupt handler irq_install_handler(irq, virtio_irqhandler); diff --git a/src/drivers/vsock/mod.rs b/src/drivers/vsock/mod.rs index 78d4b86c3c..1abb000d86 100644 --- a/src/drivers/vsock/mod.rs +++ b/src/drivers/vsock/mod.rs @@ -5,8 +5,12 @@ pub mod pci; use alloc::rc::Rc; use alloc::vec::Vec; +use core::cmp::Ordering; +use core::mem; +use align_address::Align; use pci_types::InterruptLine; +use virtio::vsock::{Event, Hdr}; use virtio::FeatureBits; use crate::config::VIRTIO_MAX_QUEUE_SIZE; @@ -14,10 +18,349 @@ use crate::drivers::virtio::error::VirtioVsockError; #[cfg(feature = "pci")] use crate::drivers::virtio::transport::pci::{ComCfg, IsrStatus, NotifCfg}; use crate::drivers::virtio::virtqueue::split::SplitVq; -use crate::drivers::virtio::virtqueue::{Virtq, VqIndex, VqSize}; +use crate::drivers::virtio::virtqueue::{ + BuffSpec, BufferToken, BufferType, Bytes, Virtq, VqIndex, VqSize, +}; #[cfg(feature = "pci")] use crate::drivers::vsock::pci::VsockDevCfgRaw; +const MTU: usize = 65536; + +pub(crate) struct RxQueue { + vq: Option>, + poll_sender: async_channel::Sender, + poll_receiver: async_channel::Receiver, +} + +impl RxQueue { + pub fn new() -> Self { + let (poll_sender, poll_receiver) = async_channel::unbounded(); + + Self { + vq: None, + poll_sender, + poll_receiver, + } + } + + pub fn add(&mut self, vq: Rc) { + let num_buff: u16 = vq.size().into(); + let rx_size = (MTU + mem::size_of::()) + .align_up(core::mem::size_of::>()); + let spec = BuffSpec::Single(Bytes::new(rx_size).unwrap()); + + for _ in 0..num_buff { + let buff_tkn = match BufferToken::new(None, Some(spec.clone())) { + Ok(tkn) => tkn, + Err(_vq_err) => { + error!("Setup of vsock queue failed, which should not happen!"); + panic!("setup of vsock queue failed!"); + } + }; + + // BufferTokens are directly provided to the queue + // TransferTokens are directly dispatched + // Transfers will be awaited at the queue + match vq.dispatch_await( + buff_tkn, + self.poll_sender.clone(), + false, + BufferType::Direct, + ) { + Ok(_) => (), + Err(_) => { + error!("Descriptor IDs were exhausted earlier than expected."); + break; + } + } + } + + self.vq = Some(vq); + } + + pub fn enable_notifs(&self) { + if let Some(ref vq) = self.vq { + vq.enable_notifs(); + } + } + + pub fn disable_notifs(&self) { + if let Some(ref vq) = self.vq { + vq.disable_notifs(); + } + } + + fn get_next(&mut self) -> Option { + let transfer = self.poll_receiver.try_recv(); + + transfer + .or_else(|_| { + // Check if any not yet provided transfers are in the queue. + self.poll(); + + self.poll_receiver.try_recv() + }) + .ok() + } + + fn poll(&self) { + if let Some(ref vq) = self.vq { + vq.poll(); + } + } + + pub fn process_packet(&mut self, mut f: F) + where + F: FnMut(&Hdr, &[u8]), + { + const HEADER_SIZE: usize = mem::size_of::(); + + while let Some(mut buffer_tkn) = self.get_next() { + let (_, recv_data_opt) = buffer_tkn.as_slices().unwrap(); + let mut recv_data = recv_data_opt.unwrap(); + + if recv_data.len() == 1 { + let packet = recv_data.pop().unwrap(); + + // drop packets with invalid packet size + if packet.len() < HEADER_SIZE { + panic!("Invalid packet size!"); + } + + if let Some(ref vq) = self.vq { + let header = unsafe { + core::mem::transmute::<[u8; HEADER_SIZE], Hdr>( + packet[..HEADER_SIZE].try_into().unwrap(), + ) + }; + + f(&header, &packet[HEADER_SIZE..]); + + buffer_tkn.reset(); + vq.dispatch_await( + buffer_tkn, + self.poll_sender.clone(), + false, + BufferType::Direct, + ) + .unwrap(); + } else { + panic!("Unable to get receive queue"); + } + } else { + panic!("Invalid length of receive queue"); + } + } + } +} + +pub(crate) struct TxQueue { + vq: Option>, + poll_sender: async_channel::Sender, + poll_receiver: async_channel::Receiver, + ready_queue: Vec, +} + +impl TxQueue { + pub fn new() -> Self { + let (poll_sender, poll_receiver) = async_channel::unbounded(); + + Self { + vq: None, + poll_sender, + poll_receiver, + ready_queue: Vec::new(), + } + } + + pub fn add(&mut self, vq: Rc) { + let tx_size = (1514 + mem::size_of::()) + .align_up(core::mem::size_of::>()); + let buff_def = Bytes::new(tx_size).unwrap(); + let spec = BuffSpec::Single(buff_def); + let num_buff: u16 = vq.size().into(); + + for _ in 0..num_buff { + let mut buffer_tkn = BufferToken::new(Some(spec.clone()), None).unwrap(); + buffer_tkn + .write_seq(Some(&Hdr::default()), None::<&Hdr>) + .unwrap(); + self.ready_queue.push(buffer_tkn) + } + + self.vq = Some(vq); + } + + pub fn enable_notifs(&self) { + if let Some(ref vq) = self.vq { + vq.enable_notifs(); + } + } + + pub fn disable_notifs(&self) { + if let Some(ref vq) = self.vq { + vq.disable_notifs(); + } + } + + fn poll(&self) { + if let Some(ref vq) = self.vq { + vq.poll(); + } + } + + /// Returns either a BufferToken and the corresponding index of the + /// virtqueue it is coming from. (Index in the TxQueues.vqs vector) + /// + /// OR returns None, if no BufferToken could be generated + fn get_tkn(&mut self, len: usize) -> Option<(BufferToken, usize)> { + // Check all ready token, for correct size. + // Drop token if not so + // + // All Tokens inside the ready_queue are coming from the main queue with index 0. + while let Some(mut tkn) = self.ready_queue.pop() { + let (send_len, _) = tkn.len(); + + match send_len.cmp(&len) { + Ordering::Less => {} + Ordering::Equal => return Some((tkn, 0)), + Ordering::Greater => { + tkn.restr_size(Some(len), None).unwrap(); + return Some((tkn, 0)); + } + } + } + + if self.poll_receiver.is_empty() { + self.poll(); + } + + while let Ok(mut buffer_token) = self.poll_receiver.try_recv() { + buffer_token.reset(); + let (send_len, _) = buffer_token.len(); + + match send_len.cmp(&len) { + Ordering::Less => {} + Ordering::Equal => return Some((buffer_token, 0)), + Ordering::Greater => { + buffer_token.restr_size(Some(len), None).unwrap(); + return Some((buffer_token, 0)); + } + } + } + + // As usize is currently safe as the minimal usize is defined as 16bit in rust. + let spec = BuffSpec::Single(Bytes::new(len).unwrap()); + + match BufferToken::new(Some(spec), None) { + Ok(tkn) => Some((tkn, 0)), + Err(_) => { + // Here it is possible if multiple queues are enabled to get another buffertoken from them! + // Info the queues are disabled upon initialization and should be enabled somehow! + None + } + } + } + + /// Provides a slice to copy the packet and transfer the packet + /// to the send queue. The caller has to creatde the header + /// for the vsock interface. + pub fn send_packet(&mut self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + if let Some((mut buff_tkn, _vq_index)) = self.get_tkn(len) { + let (send_ptrs, _) = buff_tkn.raw_ptrs(); + let (buff_ptr, _) = send_ptrs.unwrap()[0]; + + let buf_slice: &'static mut [u8] = + unsafe { core::slice::from_raw_parts_mut(buff_ptr, len) }; + let result = f(buf_slice); + + if let Some(ref vq) = self.vq { + vq.dispatch_await( + buff_tkn, + self.poll_sender.clone(), + false, + BufferType::Direct, + ) + .unwrap(); + + result + } else { + panic!("Unable to get token for send queue"); + } + } else { + panic!("Unable to get send queue"); + } + } +} + +pub(crate) struct EventQueue { + vq: Option>, + poll_sender: async_channel::Sender, + poll_receiver: async_channel::Receiver, +} + +impl EventQueue { + pub fn new() -> Self { + let (poll_sender, poll_receiver) = async_channel::unbounded(); + + Self { + vq: None, + poll_sender, + poll_receiver, + } + } + + pub fn add(&mut self, vq: Rc) { + let num_buff: u16 = vq.size().into(); + let event_size = mem::size_of::() + .align_up(core::mem::size_of::>()); + let spec = BuffSpec::Single(Bytes::new(event_size).unwrap()); + + for _ in 0..num_buff { + let buff_tkn = match BufferToken::new(None, Some(spec.clone())) { + Ok(tkn) => tkn, + Err(_vq_err) => { + error!("Setup of vsock queue failed, which should not happen!"); + panic!("setup of vsock queue failed!"); + } + }; + + // BufferTokens are directly provided to the queue + // TransferTokens are directly dispatched + // Transfers will be awaited at the queue + match vq.dispatch_await( + buff_tkn, + self.poll_sender.clone(), + false, + BufferType::Direct, + ) { + Ok(_) => (), + Err(_) => { + error!("Descriptor IDs were exhausted earlier than expected."); + break; + } + } + } + + self.vq = Some(vq); + } + + pub fn enable_notifs(&self) { + if let Some(ref vq) = self.vq { + vq.enable_notifs(); + } + } + + pub fn disable_notifs(&self) { + if let Some(ref vq) = self.vq { + vq.disable_notifs(); + } + } +} + /// A wrapper struct for the raw configuration structure. /// Handling the right access to fields, as some are read-only /// for the driver. @@ -33,7 +376,10 @@ pub(crate) struct VirtioVsockDriver { pub(super) isr_stat: IsrStatus, pub(super) notif_cfg: NotifCfg, pub(super) irq: InterruptLine, - pub(super) vqueues: Vec>, + + pub(super) event_vq: EventQueue, + pub(super) recv_vq: RxQueue, + pub(super) send_vq: TxQueue, } impl VirtioVsockDriver { @@ -42,6 +388,11 @@ impl VirtioVsockDriver { self.dev_cfg.dev_id } + #[inline] + pub fn get_cid(&self) -> u64 { + self.dev_cfg.raw.guest_cid + } + #[cfg(feature = "pci")] pub fn set_failed(&mut self) { self.com_cfg.set_failed(); @@ -50,13 +401,24 @@ impl VirtioVsockDriver { pub fn disable_interrupts(&self) { // For send and receive queues? // Only for receive? Because send is off anyway? - self.vqueues[0].disable_notifs(); + self.recv_vq.disable_notifs(); } pub fn enable_interrupts(&self) { // For send and receive queues? // Only for receive? Because send is off anyway? - self.vqueues[0].enable_notifs(); + self.recv_vq.enable_notifs(); + } + + pub fn handle_interrupt(&mut self) { + let _ = self.isr_stat.is_interrupt(); + + if self.isr_stat.is_cfg_change() { + info!("Configuration changes are not possible! Aborting"); + todo!("Implement possibility to change config on the fly...") + } + + self.isr_stat.acknowledge(); } /// Negotiates a subset of features, understood and wanted by both the OS @@ -90,7 +452,7 @@ impl VirtioVsockDriver { /// /// See Virtio specification v1.1. - 3.1.1. /// and v1.1. - 5.10.6 - pub(crate) fn init_dev(&mut self) -> Result<(), VirtioVsockError> { + pub fn init_dev(&mut self) -> Result<(), VirtioVsockError> { // Reset self.com_cfg.reset_dev(); @@ -120,20 +482,70 @@ impl VirtioVsockDriver { } // create the queues and tell device about them - for i in 0..3u16 { - let vq = SplitVq::new( + self.recv_vq.add(Rc::new( + SplitVq::new( &mut self.com_cfg, &self.notif_cfg, VqSize::from(VIRTIO_MAX_QUEUE_SIZE), - VqIndex::from(i), + VqIndex::from(0u16), self.dev_cfg.features.into(), ) - .unwrap(); - self.vqueues.push(Rc::new(vq)); - } + .unwrap(), + )); + // Interrupt for receiving packets is wanted + self.recv_vq.enable_notifs(); + + self.send_vq.add(Rc::new( + SplitVq::new( + &mut self.com_cfg, + &self.notif_cfg, + VqSize::from(VIRTIO_MAX_QUEUE_SIZE), + VqIndex::from(1u16), + self.dev_cfg.features.into(), + ) + .unwrap(), + )); + // Interrupt for communicating that a sended packet left, is not needed + self.send_vq.disable_notifs(); + + // create the queues and tell device about them + self.event_vq.add(Rc::new( + SplitVq::new( + &mut self.com_cfg, + &self.notif_cfg, + VqSize::from(VIRTIO_MAX_QUEUE_SIZE), + VqIndex::from(2u16), + self.dev_cfg.features.into(), + ) + .unwrap(), + )); + // Interrupt for event packets is wanted + self.event_vq.enable_notifs(); + + // At this point the device is "live" + self.com_cfg.drv_ok(); Ok(()) } + + #[inline] + pub fn process_packet(&mut self, f: F) + where + F: FnMut(&Hdr, &[u8]), + { + self.recv_vq.process_packet(f) + } + + /// Provides a slice to copy the packet and transfer the packet + /// to the send queue. The caller has to creatde the header + /// for the vsock interface. + #[inline] + pub fn send_packet(&mut self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + self.send_vq.send_packet(len, f) + } } /// Error module of virtio socket device driver. diff --git a/src/drivers/vsock/pci.rs b/src/drivers/vsock/pci.rs index 14104e6b87..dda855241b 100644 --- a/src/drivers/vsock/pci.rs +++ b/src/drivers/vsock/pci.rs @@ -1,11 +1,9 @@ -use alloc::vec::Vec; - use crate::arch::pci::PciConfigRegion; use crate::drivers::pci::PciDevice; use crate::drivers::virtio::error::{self, VirtioError}; use crate::drivers::virtio::transport::pci; use crate::drivers::virtio::transport::pci::{PciCap, UniCapsColl}; -use crate::drivers::vsock::{VirtioVsockDriver, VsockDevCfg}; +use crate::drivers::vsock::{EventQueue, RxQueue, TxQueue, VirtioVsockDriver, VsockDevCfg}; /// Virtio's socket device configuration structure. /// See specification v1.1. - 5.11.4 @@ -15,7 +13,7 @@ use crate::drivers::vsock::{VirtioVsockDriver, VsockDevCfg}; pub(crate) struct VsockDevCfgRaw { /// The guest_cid field contains the guest’s context ID, which uniquely identifies the device /// for its lifetime. The upper 32 bits of the CID are reserved and zeroed. - guest_cid: u64, + pub guest_cid: u64, } impl VirtioVsockDriver { @@ -84,7 +82,9 @@ impl VirtioVsockDriver { isr_stat, notif_cfg, irq: device.get_irq().unwrap(), - vqueues: Vec::new(), + event_vq: EventQueue::new(), + recv_vq: RxQueue::new(), + send_vq: TxQueue::new(), }) } @@ -114,6 +114,7 @@ impl VirtioVsockDriver { "Socket device with cid {:x}, has been initialized by driver!", drv.dev_cfg.raw.guest_cid ); + Ok(drv) } Err(fs_err) => { diff --git a/src/executor/mod.rs b/src/executor/mod.rs index 232a72224f..5cb48939eb 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -5,6 +5,8 @@ pub(crate) mod device; #[cfg(any(feature = "tcp", feature = "udp"))] pub(crate) mod network; pub(crate) mod task; +#[cfg(feature = "vsock")] +pub(crate) mod vsock; use alloc::sync::Arc; use alloc::task::Wake; @@ -91,6 +93,8 @@ where pub fn init() { #[cfg(any(feature = "tcp", feature = "udp"))] crate::executor::network::init(); + #[cfg(feature = "vsock")] + crate::executor::vsock::init(); } #[inline] diff --git a/src/executor/vsock.rs b/src/executor/vsock.rs new file mode 100644 index 0000000000..124254b603 --- /dev/null +++ b/src/executor/vsock.rs @@ -0,0 +1,174 @@ +use alloc::collections::BTreeMap; +use alloc::vec::Vec; +use core::future; +use core::task::{Poll, Waker}; + +use endian_num::{le16, le32}; +use hermit_sync::InterruptTicketMutex; +use virtio::vsock::{Hdr, Op, Type}; + +#[cfg(not(feature = "pci"))] +use crate::arch::kernel::mmio as hardware; +#[cfg(feature = "pci")] +use crate::drivers::pci as hardware; +use crate::executor::spawn; +use crate::io; +use crate::io::Error::EADDRINUSE; + +pub(crate) static VSOCK_MAP: InterruptTicketMutex = + InterruptTicketMutex::new(VsockMap::new()); + +#[derive(Debug, Copy, Clone, PartialEq)] +pub(crate) enum VsockState { + Listen, + ReceiveRequest, + Connected, + Connecting, + Shutdown, +} + +/// WakerRegistration is derived from smoltcp's +/// implementation. +#[derive(Debug)] +pub(crate) struct WakerRegistration { + waker: Option, +} + +impl WakerRegistration { + pub const fn new() -> Self { + Self { waker: None } + } + + /// Register a waker. Overwrites the previous waker, if any. + pub fn register(&mut self, w: &Waker) { + match self.waker { + // Optimization: If both the old and new Wakers wake the same task, we can simply + // keep the old waker, skipping the clone. + Some(ref w2) if (w2.will_wake(w)) => {} + // In all other cases + // - we have no waker registered + // - we have a waker registered but it's for a different task. + // then clone the new waker and store it + _ => self.waker = Some(w.clone()), + } + } + + /// Wake the registered waker, if any. + pub fn wake(&mut self) { + self.waker.take().map(|w| w.wake()); + } +} + +pub(crate) const RAW_SOCKET_BUFFER_SIZE: usize = 256 * 1024; + +#[derive(Debug)] +pub(crate) struct RawSocket { + pub remote_cid: u32, + pub remote_port: u32, + pub state: VsockState, + pub waker: WakerRegistration, + pub buffer: Vec, +} + +impl RawSocket { + pub fn new(state: VsockState) -> Self { + Self { + remote_cid: 0, + remote_port: 0, + state, + waker: WakerRegistration::new(), + buffer: Vec::with_capacity(RAW_SOCKET_BUFFER_SIZE), + } + } +} + +async fn vsock_run() { + future::poll_fn(|_cx| { + if let Some(driver) = hardware::get_vsock_driver() { + const HEADER_SIZE: usize = core::mem::size_of::(); + let mut driver_guard = driver.lock(); + let mut hdr: Option = None; + + driver_guard.process_packet(|header, data| { + let op = Op::try_from(header.op.to_ne()).unwrap(); + let port = header.dst_port.to_ne(); + let type_ = Type::try_from(header.type_.to_ne()).unwrap(); + let mut vsock_guard = VSOCK_MAP.lock(); + + if let Some(raw) = vsock_guard.get_mut_socket(port) { + if op == Op::Request && raw.state == VsockState::Listen && type_ == Type::Stream + { + raw.state = VsockState::ReceiveRequest; + raw.remote_cid = header.src_cid.to_ne().try_into().unwrap(); + raw.remote_port = header.src_port.to_ne().try_into().unwrap(); + raw.waker.wake(); + } else if (raw.state == VsockState::Connected + || raw.state == VsockState::Shutdown) + && type_ == Type::Stream + { + raw.buffer.extend_from_slice(data); + raw.waker.wake(); + } else { + hdr = Some(*header); + } + } + }); + + if let Some(hdr) = hdr { + // reset connection + driver_guard.send_packet(HEADER_SIZE, |buffer| { + let response = unsafe { &mut *(buffer.as_mut_ptr() as *mut Hdr) }; + + response.src_cid = hdr.dst_cid; + response.dst_cid = hdr.src_cid; + response.src_port = hdr.dst_port; + response.dst_port = hdr.src_port; + response.len = le32::from_ne(0); + response.type_ = hdr.type_; + response.op = le16::from_ne(Op::Rst.into()); + response.flags = le32::from_ne(0); + response.buf_alloc = le32::from_ne(RAW_SOCKET_BUFFER_SIZE as u32); + response.fwd_cnt = le32::from_ne(0); + }); + } + + Poll::Pending + } else { + Poll::Ready(()) + } + }) + .await +} + +pub(crate) struct VsockMap { + port_map: BTreeMap, +} + +impl VsockMap { + pub const fn new() -> Self { + Self { + port_map: BTreeMap::new(), + } + } + + pub fn bind(&mut self, port: u32) -> io::Result<()> { + self.port_map + .try_insert(port, RawSocket::new(VsockState::Listen)) + .map_err(|_| EADDRINUSE)?; + Ok(()) + } + + pub fn get_socket(&self, port: u32) -> Option<&RawSocket> { + self.port_map.get(&port) + } + + pub fn get_mut_socket(&mut self, port: u32) -> Option<&mut RawSocket> { + self.port_map.get_mut(&port) + } +} + +pub(crate) fn init() { + info!("Try to initialize vsock interface!"); + + spawn(vsock_run()); +} diff --git a/src/fd/mod.rs b/src/fd/mod.rs index c6c893f650..1f0c2f652d 100644 --- a/src/fd/mod.rs +++ b/src/fd/mod.rs @@ -24,25 +24,13 @@ pub(crate) const STDIN_FILENO: FileDescriptor = 0; pub(crate) const STDOUT_FILENO: FileDescriptor = 1; pub(crate) const STDERR_FILENO: FileDescriptor = 2; -#[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] -#[allow(clippy::upper_case_acronyms, dead_code)] -#[derive(Debug, Clone, Copy)] -pub(crate) enum AddressFamily { - #[cfg(any(feature = "tcp", feature = "udp"))] - INET, - #[cfg(any(feature = "tcp", feature = "udp"))] - INET6, - #[cfg(feature = "vsock")] - VSOCK, -} - #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] #[derive(Debug)] pub(crate) enum Endpoint { #[cfg(any(feature = "tcp", feature = "udp"))] Ip(IpEndpoint), #[cfg(feature = "vsock")] - Vsock(()), + Vsock(socket::vsock::VsockEndpoint), } #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] @@ -293,13 +281,6 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug + DynClone { fn ioctl(&self, _cmd: IoCtl, _value: bool) -> io::Result<()> { Err(io::Error::ENOSYS) } - - /// Sockets returns the supported address family - #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] - #[allow(dead_code)] - fn get_address_family(&self) -> Option { - None - } } pub(crate) fn read(fd: FileDescriptor, buf: &mut [u8]) -> io::Result { diff --git a/src/fd/socket/tcp.rs b/src/fd/socket/tcp.rs index 87c3837ce1..9aedf492ea 100644 --- a/src/fd/socket/tcp.rs +++ b/src/fd/socket/tcp.rs @@ -8,13 +8,11 @@ use async_trait::async_trait; use smoltcp::iface; use smoltcp::socket::tcp; use smoltcp::time::Duration; -use smoltcp::wire::{IpEndpoint, IpVersion}; +use smoltcp::wire::IpEndpoint; use crate::executor::block_on; use crate::executor::network::{now, Handle, NetworkState, NIC}; -use crate::fd::{ - AddressFamily, Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent, SocketOption, -}; +use crate::fd::{Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent, SocketOption}; use crate::{io, DEFAULT_KEEP_ALIVE_INTERVAL}; /// further receives will be disallowed @@ -427,17 +425,6 @@ impl ObjectInterface for Socket { Err(io::Error::EINVAL) } } - - fn get_address_family(&self) -> Option { - self.with(|socket| { - socket - .local_endpoint() - .map(|endpoint| match endpoint.addr.version() { - IpVersion::Ipv4 => AddressFamily::INET, - IpVersion::Ipv6 => AddressFamily::INET6, - }) - }) - } } impl Clone for Socket { diff --git a/src/fd/socket/udp.rs b/src/fd/socket/udp.rs index fe5d834991..8004dd68a2 100644 --- a/src/fd/socket/udp.rs +++ b/src/fd/socket/udp.rs @@ -13,7 +13,7 @@ use smoltcp::wire::{IpEndpoint, IpVersion}; use crate::executor::network::{now, Handle, NetworkState, NIC}; use crate::executor::{block_on, poll_on}; -use crate::fd::{AddressFamily, Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent}; +use crate::fd::{Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent}; use crate::io; #[derive(Debug)] @@ -269,15 +269,6 @@ impl ObjectInterface for Socket { Err(io::Error::EINVAL) } } - - fn get_address_family(&self) -> Option { - self.endpoint - .load() - .map(|endpoint| match endpoint.addr.version() { - IpVersion::Ipv4 => AddressFamily::INET, - IpVersion::Ipv6 => AddressFamily::INET6, - }) - } } impl Clone for Socket { diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs index 8e05f70fda..0e8fb46704 100644 --- a/src/fd/socket/vsock.rs +++ b/src/fd/socket/vsock.rs @@ -1,18 +1,40 @@ +use alloc::boxed::Box; +use alloc::vec::Vec; +use core::future; use core::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use core::task::Poll; use async_trait::async_trait; +use endian_num::{le16, le32, le64}; +use virtio::vsock::{Hdr, Op, Type}; -use crate::fd::{AddressFamily, Endpoint, IoCtl, ListenEndpoint, ObjectInterface}; -use crate::io; +#[cfg(not(feature = "pci"))] +use crate::arch::kernel::mmio as hardware; +#[cfg(feature = "pci")] +use crate::drivers::pci as hardware; +use crate::executor::vsock::{VsockState, VSOCK_MAP}; +use crate::fd::{block_on, Endpoint, IoCtl, ListenEndpoint, ObjectInterface}; +use crate::io::{self, Error}; #[derive(Debug)] pub(crate) struct VsockListenEndpoint { - port: u32, - #[allow(dead_code)] - cid: u32, + pub port: u32, + pub cid: Option, } impl VsockListenEndpoint { + pub const fn new(port: u32, cid: Option) -> Self { + Self { port, cid } + } +} + +#[derive(Debug)] +pub(crate) struct VsockEndpoint { + pub port: u32, + pub cid: u32, +} + +impl VsockEndpoint { pub const fn new(port: u32, cid: u32) -> Self { Self { port, cid } } @@ -21,16 +43,16 @@ impl VsockListenEndpoint { #[derive(Debug)] pub struct Socket { port: AtomicU32, + cid: AtomicU32, nonblocking: AtomicBool, - listen: AtomicBool, } impl Socket { pub fn new() -> Self { Self { port: AtomicU32::new(0), + cid: AtomicU32::new(u32::MAX), nonblocking: AtomicBool::new(false), - listen: AtomicBool::new(false), } } } @@ -38,11 +60,15 @@ impl Socket { #[async_trait] impl ObjectInterface for Socket { fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { - info!("bind {:?}", endpoint); match endpoint { ListenEndpoint::Vsock(ep) => { self.port.store(ep.port, Ordering::Release); - Ok(()) + if let Some(cid) = ep.cid { + self.cid.store(cid, Ordering::Release); + } else { + self.cid.store(u32::MAX, Ordering::Release); + } + VSOCK_MAP.lock().bind(ep.port) } #[cfg(any(feature = "tcp", feature = "udp"))] _ => Err(io::Error::EINVAL), @@ -54,14 +80,68 @@ impl ObjectInterface for Socket { } fn listen(&self, _backlog: i32) -> io::Result<()> { - info!("listen"); - self.listen.store(true, Ordering::Relaxed); Ok(()) } fn accept(&self) -> io::Result { - info!("accept"); - Err(io::Error::EINVAL) + let port = self.port.load(Ordering::Acquire); + let cid = self.cid.load(Ordering::Acquire); + + let endpoint = block_on( + async { + future::poll_fn(|cx| { + let mut guard = VSOCK_MAP.lock(); + let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?; + + match raw.state { + VsockState::Listen => { + raw.waker.register(cx.waker()); + Poll::Pending + } + VsockState::ReceiveRequest => { + let result = { + const HEADER_SIZE: usize = core::mem::size_of::(); + let mut driver_guard = hardware::get_vsock_driver().unwrap().lock(); + let local_cid = driver_guard.get_cid(); + + driver_guard.send_packet(HEADER_SIZE, |buffer| { + let response = + unsafe { &mut *(buffer.as_mut_ptr() as *mut Hdr) }; + + response.src_cid = le64::from_ne(local_cid); + response.dst_cid = le64::from_ne(raw.remote_cid as u64); + response.src_port = le32::from_ne(port); + response.dst_port = le32::from_ne(raw.remote_port); + response.len = le32::from_ne(0); + response.type_ = le16::from_ne(Type::Stream.into()); + if local_cid != cid.into() && cid != u32::MAX { + response.op = le16::from_ne(Op::Rst.into()) + } else { + response.op = le16::from_ne(Op::Response.into()); + } + response.flags = le32::from_ne(0); + response.buf_alloc = le32::from_ne( + crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32, + ); + response.fwd_cnt = le32::from_ne(0); + }); + + raw.state = VsockState::Connected; + + Ok(VsockEndpoint::new(raw.remote_port, raw.remote_cid)) + }; + + Poll::Ready(result) + } + _ => Poll::Ready(Err(Error::EBADF)), + } + }) + .await + }, + None, + )?; + + Ok(Endpoint::Vsock(endpoint)) } fn ioctl(&self, cmd: IoCtl, value: bool) -> io::Result<()> { @@ -80,8 +160,68 @@ impl ObjectInterface for Socket { } } - fn get_address_family(&self) -> Option { - Some(AddressFamily::VSOCK) + // TODO: Remove allow once fixed: + // https://github.com/rust-lang/rust-clippy/issues/11380 + #[allow(clippy::needless_pass_by_ref_mut)] + async fn async_read(&self, buffer: &mut [u8]) -> io::Result { + let port = self.port.load(Ordering::Acquire); + future::poll_fn(|cx| { + let mut guard = VSOCK_MAP.lock(); + let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?; + + match raw.state { + VsockState::Connected => { + let len = core::cmp::min(buffer.len(), raw.buffer.len()); + + if len == 0 { + raw.waker.register(cx.waker()); + Poll::Pending + } else { + let tmp: Vec<_> = raw.buffer.drain(..len).collect(); + buffer[..len].copy_from_slice(tmp.as_slice()); + + Poll::Ready(Ok(len)) + } + } + _ => Poll::Ready(Err(Error::EIO)), + } + }) + .await + } + + async fn async_write(&self, buffer: &[u8]) -> io::Result { + let port = self.port.load(Ordering::Acquire); + let guard = VSOCK_MAP.lock(); + let raw = guard.get_socket(port).ok_or(Error::EINVAL)?; + + match raw.state { + VsockState::Connected => { + const HEADER_SIZE: usize = core::mem::size_of::(); + let mut driver_guard = hardware::get_vsock_driver().unwrap().lock(); + let local_cid = driver_guard.get_cid(); + + driver_guard.send_packet(HEADER_SIZE + buffer.len(), |virtio_buffer| { + let response = unsafe { &mut *(virtio_buffer.as_mut_ptr() as *mut Hdr) }; + + response.src_cid = le64::from_ne(local_cid); + response.dst_cid = le64::from_ne(raw.remote_cid as u64); + response.src_port = le32::from_ne(port); + response.dst_port = le32::from_ne(raw.remote_port); + response.len = le32::from_ne(buffer.len().try_into().unwrap()); + response.type_ = le16::from_ne(Type::Stream.into()); + response.op = le16::from_ne(Op::Rw.into()); + response.flags = le32::from_ne(0); + response.buf_alloc = + le32::from_ne(crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32); + response.fwd_cnt = le32::from_ne(0); + + virtio_buffer[HEADER_SIZE..].copy_from_slice(buffer); + }); + + Ok(buffer.len()) + } + _ => Err(Error::EIO), + } } } @@ -89,8 +229,8 @@ impl Clone for Socket { fn clone(&self) -> Self { Self { port: AtomicU32::new(self.port.load(Ordering::Acquire)), + cid: AtomicU32::new(self.cid.load(Ordering::Acquire)), nonblocking: AtomicBool::new(self.nonblocking.load(Ordering::Acquire)), - listen: AtomicBool::new(false), } } } diff --git a/src/lib.rs b/src/lib.rs index 68d84da879..89eb303463 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ #![feature(asm_const)] #![feature(exposed_provenance)] #![feature(linked_list_cursors)] +#![feature(map_try_insert)] #![feature(maybe_uninit_as_bytes)] #![feature(maybe_uninit_slice)] #![feature(naked_functions)] diff --git a/src/syscalls/socket.rs b/src/syscalls/socket.rs index 8fe88b806c..d49cfb0efa 100644 --- a/src/syscalls/socket.rs +++ b/src/syscalls/socket.rs @@ -18,7 +18,7 @@ use crate::fd::socket::tcp; #[cfg(feature = "udp")] use crate::fd::socket::udp; #[cfg(feature = "vsock")] -use crate::fd::socket::vsock::{self, VsockListenEndpoint}; +use crate::fd::socket::vsock::{self, VsockEndpoint, VsockListenEndpoint}; use crate::fd::{ get_object, insert_object, replace_object, Endpoint, ListenEndpoint, ObjectInterface, SocketOption, @@ -117,12 +117,39 @@ pub struct sockaddr_vm { impl From for VsockListenEndpoint { fn from(addr: sockaddr_vm) -> VsockListenEndpoint { let port = addr.svm_port; - let cid = addr.svm_cid; + let cid = if addr.svm_cid < u32::MAX { + Some(addr.svm_cid) + } else { + None + }; VsockListenEndpoint::new(port, cid) } } +#[cfg(feature = "vsock")] +impl From for VsockEndpoint { + fn from(addr: sockaddr_vm) -> VsockEndpoint { + let port = addr.svm_port; + let cid = addr.svm_cid; + + VsockEndpoint::new(port, cid) + } +} + +#[cfg(feature = "vsock")] +impl From for sockaddr_vm { + fn from(endpoint: VsockEndpoint) -> Self { + Self { + svm_len: core::mem::size_of::().try_into().unwrap(), + svm_family: AF_VSOCK.try_into().unwrap(), + svm_port: endpoint.port, + svm_cid: endpoint.cid, + ..Default::default() + } + } +} + #[repr(C)] #[derive(Debug, Default, Copy, Clone)] pub struct sockaddr_in { @@ -383,7 +410,7 @@ pub unsafe extern "C" fn sys_getaddrbyname( #[hermit_macro::system] #[no_mangle] pub extern "C" fn sys_socket(domain: i32, type_: SockType, protocol: i32) -> i32 { - info!( + debug!( "sys_socket: domain {}, type {:?}, protocol {}", domain, type_, protocol ); @@ -489,22 +516,21 @@ pub unsafe extern "C" fn sys_accept(fd: i32, addr: *mut sockaddr, addrlen: *mut new_fd } #[cfg(feature = "vsock")] - Endpoint::Vsock(_) => { + Endpoint::Vsock(endpoint) => { let new_obj = dyn_clone::clone_box(&*v); replace_object(fd, Arc::from(new_obj)).unwrap(); let new_fd = insert_object(v).unwrap(); if !addr.is_null() && !addrlen.is_null() { - let addrlen = unsafe { &*addrlen }; + let addrlen = unsafe { &mut *addrlen }; if *addrlen >= size_of::().try_into().unwrap() { let addr = unsafe { &mut *(addr as *mut sockaddr_vm) }; - *addr = sockaddr_vm::default(); + *addr = sockaddr_vm::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); } } - warn!("unsupported device"); - new_fd } }, @@ -600,7 +626,9 @@ pub unsafe extern "C" fn sys_connect(fd: i32, name: *const sockaddr, namelen: so if namelen < size_of::().try_into().unwrap() { return -crate::errno::EINVAL; } - Endpoint::Vsock(()) + Endpoint::Vsock(VsockEndpoint::from(unsafe { + *(name as *const sockaddr_vm) + })) } _ => { return -crate::errno::EINVAL; From 7c24305c8ebb9f229fc0cb5a3e82dab230007986 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Tue, 6 Aug 2024 15:41:34 +0200 Subject: [PATCH 09/18] revise vsock driver to support the new virtio interface --- src/drivers/net/gem.rs | 14 +- src/drivers/net/rtl8139.rs | 11 +- src/drivers/net/virtio/mod.rs | 1 - src/drivers/virtio/transport/mmio.rs | 10 +- src/drivers/virtio/transport/mod.rs | 25 +- src/drivers/vsock/mod.rs | 369 ++++++++++----------------- src/drivers/vsock/pci.rs | 2 +- src/executor/vsock.rs | 22 +- 8 files changed, 195 insertions(+), 259 deletions(-) diff --git a/src/drivers/net/gem.rs b/src/drivers/net/gem.rs index 19ab8486fc..f2060230e4 100644 --- a/src/drivers/net/gem.rs +++ b/src/drivers/net/gem.rs @@ -16,10 +16,14 @@ use tock_registers::{register_bitfields, register_structs}; use crate::arch::kernel::core_local::core_scheduler; use crate::arch::kernel::interrupts::*; +#[cfg(all(any(feature = "tcp", feature = "udp"), not(feature = "pci")))] +use crate::arch::kernel::mmio as hardware; use crate::arch::mm::paging::virt_to_phys; use crate::arch::mm::VirtAddr; use crate::drivers::error::DriverError; use crate::drivers::net::NetworkDriver; +#[cfg(all(any(feature = "tcp", feature = "udp"), feature = "pci"))] +use crate::drivers::pci as hardware; use crate::executor::device::{RxToken, TxToken}; //Base address of the control registers @@ -202,14 +206,14 @@ fn gem_irqhandler() { debug!("Receive network interrupt"); - crate::executor::run(); - // PLIC end of interrupt crate::arch::kernel::interrupts::external_eoi(); if let Some(driver) = hardware::get_network_driver() { driver.lock().handle_interrupt() } + crate::executor::run(); + core_scheduler().reschedule(); } @@ -365,7 +369,7 @@ impl NetworkDriver for GEMDriver { } } - fn handle_interrupt(&mut self) -> bool { + fn handle_interrupt(&mut self) { let int_status = unsafe { (*self.gem).int_status.extract() }; let receive_status = unsafe { (*self.gem).receive_status.extract() }; @@ -409,8 +413,8 @@ impl NetworkDriver for GEMDriver { // handle incoming packets todo!(); } - // increment_irq_counter((32 + self.irq).into()); - ret + + //increment_irq_counter((32 + self.irq).into()); } } diff --git a/src/drivers/net/rtl8139.rs b/src/drivers/net/rtl8139.rs index c48637b494..dddedce8e8 100644 --- a/src/drivers/net/rtl8139.rs +++ b/src/drivers/net/rtl8139.rs @@ -17,6 +17,7 @@ use crate::arch::mm::VirtAddr; use crate::arch::pci::PciConfigRegion; use crate::drivers::error::DriverError; use crate::drivers::net::NetworkDriver; +use crate::drivers::pci as hardware; use crate::drivers::pci::PciDevice; use crate::executor::device::{RxToken, TxToken}; @@ -319,7 +320,7 @@ impl NetworkDriver for RTL8139Driver { } } - fn handle_interrupt(&mut self) -> bool { + fn handle_interrupt(&mut self) { increment_irq_counter(32 + self.irq); let isr_contents = unsafe { inw(self.iobase + ISR) }; @@ -340,18 +341,12 @@ impl NetworkDriver for RTL8139Driver { trace!("RTL88139: RX overflow detected!\n"); } - let ret = (isr_contents & ISR_ROK) == ISR_ROK; - - crate::executor::run(); - unsafe { outw( self.iobase + ISR, isr_contents & (ISR_RXOVW | ISR_TER | ISR_RER | ISR_TOK | ISR_ROK), ); } - - ret } } @@ -436,6 +431,8 @@ extern "x86-interrupt" fn rtl8139_irqhandler(stack_frame: ExceptionStackFrame) { debug!("Unable to handle interrupt!"); } + crate::executor::run(); + core_scheduler().reschedule(); crate::arch::x86_64::swapgs(&stack_frame); } diff --git a/src/drivers/net/virtio/mod.rs b/src/drivers/net/virtio/mod.rs index 6dbbcff019..ca88d50681 100644 --- a/src/drivers/net/virtio/mod.rs +++ b/src/drivers/net/virtio/mod.rs @@ -14,7 +14,6 @@ use alloc::boxed::Box; use alloc::vec::Vec; use core::mem::MaybeUninit; -use align_address::Align; use smoltcp::phy::{Checksum, ChecksumCapabilities}; use smoltcp::wire::{EthernetFrame, Ipv4Packet, Ipv6Packet, ETHERNET_HEADER_LEN}; use virtio::net::{ConfigVolatileFieldAccess, Hdr, HdrF}; diff --git a/src/drivers/virtio/transport/mmio.rs b/src/drivers/virtio/transport/mmio.rs index ce29e3bf1d..cf8b7f3114 100644 --- a/src/drivers/virtio/transport/mmio.rs +++ b/src/drivers/virtio/transport/mmio.rs @@ -381,13 +381,16 @@ pub(crate) fn init_device( match registers.as_ptr().device_id().read() { #[cfg(any(feature = "tcp", feature = "udp"))] virtio::Id::Net => { - match VirtioNetDriver::init(dev_id, registers, irq_no) { + match VirtioNetDriver::init(dev_id, registers) { Ok(virt_net_drv) => { + use crate::drivers::virtio::transport::VIRTIO_IRQ; + info!("Virtio network driver initialized."); // Install interrupt handler irq_install_handler(irq_no, virtio_irqhandler); #[cfg(not(target_arch = "riscv64"))] add_irq_name(irq_no, "virtio"); + let _ = VIRTIO_IRQ.try_insert(irq_no); Ok(VirtioDriver::Network(virt_net_drv)) } @@ -399,13 +402,16 @@ pub(crate) fn init_device( } #[cfg(feature = "vsock")] virtio::Id::Vsock => { - match VirtioVsockDriver::init(dev_id, registers, irq_no) { + match VirtioVsockDriver::init(dev_id, registers) { Ok(virt_net_drv) => { + use crate::drivers::virtio::transport::VIRTIO_IRQ; + info!("Virtio sock driver initialized."); // Install interrupt handler irq_install_handler(irq_no, virtio_irqhandler); #[cfg(not(target_arch = "riscv64"))] add_irq_name(irq_no, "virtio"); + let _ = VIRTIO_IRQ.try_insert(irq_no); Ok(VirtioDriver::Vsock(virt_vsock_drv)) } diff --git a/src/drivers/virtio/transport/mod.rs b/src/drivers/virtio/transport/mod.rs index 8484037400..26df517264 100644 --- a/src/drivers/virtio/transport/mod.rs +++ b/src/drivers/virtio/transport/mod.rs @@ -15,13 +15,19 @@ use hermit_sync::OnceCell; use crate::arch::kernel::core_local::increment_irq_counter; #[cfg(target_arch = "x86_64")] use crate::arch::kernel::interrupts::ExceptionStackFrame; -#[cfg(all(feature = "vsock", not(feature = "pci")))] +#[cfg(all( + any(feature = "vsock", feature = "tcp", feature = "udp"), + not(feature = "pci") +))] use crate::arch::kernel::mmio as hardware; #[cfg(target_arch = "aarch64")] use crate::arch::scheduler::State; #[cfg(any(feature = "tcp", feature = "udp"))] use crate::drivers::net::NetworkDriver; -#[cfg(all(feature = "vsock", feature = "pci"))] +#[cfg(all( + any(feature = "vsock", feature = "tcp", feature = "udp"), + feature = "pci" +))] use crate::drivers::pci as hardware; /// All virtio devices share the interrupt number `VIRTIO_IRQ` @@ -31,7 +37,7 @@ static VIRTIO_IRQ: OnceCell = OnceCell::new(); pub(crate) fn virtio_irqhandler(_state: &State) -> bool { debug!("Receive virtio interrupt"); - crate::executor::run(); + increment_irq_counter(32 + VIRTIO_IRQ.get().unwrap()); #[cfg(any(feature = "tcp", feature = "udp"))] if let Some(driver) = hardware::get_network_driver() { @@ -42,6 +48,10 @@ pub(crate) fn virtio_irqhandler(_state: &State) -> bool { if let Some(driver) = hardware::get_vsock_driver() { driver.lock().handle_interrupt(); } + + crate::executor::run(); + + true } #[cfg(target_arch = "x86_64")] @@ -54,7 +64,6 @@ pub(crate) extern "x86-interrupt" fn virtio_irqhandler(stack_frame: ExceptionSta increment_irq_counter(32 + VIRTIO_IRQ.get().unwrap()); - crate::executor::run(); crate::kernel::apic::eoi(); #[cfg(any(feature = "tcp", feature = "udp"))] @@ -67,6 +76,8 @@ pub(crate) extern "x86-interrupt" fn virtio_irqhandler(stack_frame: ExceptionSta driver.lock().handle_interrupt(); } + crate::executor::run(); + core_scheduler().reschedule(); crate::arch::x86_64::swapgs(&stack_frame); } @@ -78,10 +89,6 @@ pub(crate) fn virtio_irqhandler() { debug!("Receive virtio interrupt"); - increment_irq_counter(32 + VIRTIO_IRQ.get().unwrap()); - - crate::executor::run(); - // PLIC end of interrupt crate::arch::kernel::interrupts::external_eoi(); #[cfg(any(feature = "tcp", feature = "udp"))] @@ -94,5 +101,7 @@ pub(crate) fn virtio_irqhandler() { driver.lock().handle_interrupt(); } + crate::executor::run(); + core_scheduler().reschedule(); } diff --git a/src/drivers/vsock/mod.rs b/src/drivers/vsock/mod.rs index 1abb000d86..f1e2f6d37a 100644 --- a/src/drivers/vsock/mod.rs +++ b/src/drivers/vsock/mod.rs @@ -3,14 +3,13 @@ #[cfg(feature = "pci")] pub mod pci; -use alloc::rc::Rc; +use alloc::boxed::Box; use alloc::vec::Vec; -use core::cmp::Ordering; use core::mem; +use core::mem::MaybeUninit; -use align_address::Align; use pci_types::InterruptLine; -use virtio::vsock::{Event, Hdr}; +use virtio::vsock::Hdr; use virtio::FeatureBits; use crate::config::VIRTIO_MAX_QUEUE_SIZE; @@ -19,17 +18,59 @@ use crate::drivers::virtio::error::VirtioVsockError; use crate::drivers::virtio::transport::pci::{ComCfg, IsrStatus, NotifCfg}; use crate::drivers::virtio::virtqueue::split::SplitVq; use crate::drivers::virtio::virtqueue::{ - BuffSpec, BufferToken, BufferType, Bytes, Virtq, VqIndex, VqSize, + AvailBufferToken, BufferElem, BufferType, UsedBufferToken, Virtq, VqIndex, VqSize, }; #[cfg(feature = "pci")] use crate::drivers::vsock::pci::VsockDevCfgRaw; - -const MTU: usize = 65536; +use crate::mm::device_alloc::DeviceAlloc; + +fn fill_queue( + vq: &mut dyn Virtq, + num_packets: u16, + packet_size: u32, + poll_sender: async_channel::Sender, +) { + for _ in 0..num_packets { + let buff_tkn = match AvailBufferToken::new( + vec![], + vec![ + BufferElem::Sized(Box::::new_uninit_in(DeviceAlloc)), + BufferElem::Vector(Vec::with_capacity_in( + packet_size.try_into().unwrap(), + DeviceAlloc, + )), + ], + ) { + Ok(tkn) => tkn, + Err(_vq_err) => { + error!("Setup of network queue failed, which should not happen!"); + panic!("setup of network queue failed!"); + } + }; + + // BufferTokens are directly provided to the queue + // TransferTokens are directly dispatched + // Transfers will be awaited at the queue + match vq.dispatch( + buff_tkn, + Some(poll_sender.clone()), + false, + BufferType::Direct, + ) { + Ok(_) => (), + Err(err) => { + error!("{:#?}", err); + break; + } + } + } +} pub(crate) struct RxQueue { - vq: Option>, - poll_sender: async_channel::Sender, - poll_receiver: async_channel::Receiver, + vq: Option>, + poll_sender: async_channel::Sender, + poll_receiver: async_channel::Receiver, + packet_size: u32, } impl RxQueue { @@ -40,57 +81,37 @@ impl RxQueue { vq: None, poll_sender, poll_receiver, + packet_size: 8192u32 + mem::size_of::() as u32, } } - pub fn add(&mut self, vq: Rc) { - let num_buff: u16 = vq.size().into(); - let rx_size = (MTU + mem::size_of::()) - .align_up(core::mem::size_of::>()); - let spec = BuffSpec::Single(Bytes::new(rx_size).unwrap()); - - for _ in 0..num_buff { - let buff_tkn = match BufferToken::new(None, Some(spec.clone())) { - Ok(tkn) => tkn, - Err(_vq_err) => { - error!("Setup of vsock queue failed, which should not happen!"); - panic!("setup of vsock queue failed!"); - } - }; - - // BufferTokens are directly provided to the queue - // TransferTokens are directly dispatched - // Transfers will be awaited at the queue - match vq.dispatch_await( - buff_tkn, - self.poll_sender.clone(), - false, - BufferType::Direct, - ) { - Ok(_) => (), - Err(_) => { - error!("Descriptor IDs were exhausted earlier than expected."); - break; - } - } - } + pub fn add(&mut self, mut vq: Box) { + const BUFF_PER_PACKET: u16 = 2; + let num_packets: u16 = u16::from(vq.size()) / BUFF_PER_PACKET; + info!("num_packets {}", num_packets); + fill_queue( + vq.as_mut(), + num_packets, + self.packet_size, + self.poll_sender.clone(), + ); self.vq = Some(vq); } - pub fn enable_notifs(&self) { - if let Some(ref vq) = self.vq { + pub fn enable_notifs(&mut self) { + if let Some(ref mut vq) = self.vq { vq.enable_notifs(); } } - pub fn disable_notifs(&self) { - if let Some(ref vq) = self.vq { + pub fn disable_notifs(&mut self) { + if let Some(ref mut vq) = self.vq { vq.disable_notifs(); } } - fn get_next(&mut self) -> Option { + fn get_next(&mut self) -> Option { let transfer = self.poll_receiver.try_recv(); transfer @@ -103,8 +124,8 @@ impl RxQueue { .ok() } - fn poll(&self) { - if let Some(ref vq) = self.vq { + fn poll(&mut self) { + if let Some(ref mut vq) = self.vq { vq.poll(); } } @@ -113,40 +134,17 @@ impl RxQueue { where F: FnMut(&Hdr, &[u8]), { - const HEADER_SIZE: usize = mem::size_of::(); - while let Some(mut buffer_tkn) = self.get_next() { - let (_, recv_data_opt) = buffer_tkn.as_slices().unwrap(); - let mut recv_data = recv_data_opt.unwrap(); - - if recv_data.len() == 1 { - let packet = recv_data.pop().unwrap(); - - // drop packets with invalid packet size - if packet.len() < HEADER_SIZE { - panic!("Invalid packet size!"); - } - - if let Some(ref vq) = self.vq { - let header = unsafe { - core::mem::transmute::<[u8; HEADER_SIZE], Hdr>( - packet[..HEADER_SIZE].try_into().unwrap(), - ) - }; - - f(&header, &packet[HEADER_SIZE..]); - - buffer_tkn.reset(); - vq.dispatch_await( - buffer_tkn, - self.poll_sender.clone(), - false, - BufferType::Direct, - ) - .unwrap(); - } else { - panic!("Unable to get receive queue"); - } + let header = buffer_tkn + .used_recv_buff + .pop_front_downcast::() + .unwrap(); + let packet = buffer_tkn.used_recv_buff.pop_front_vec().unwrap(); + + if let Some(ref mut vq) = self.vq { + f(&header, &packet[..]); + + fill_queue(vq.as_mut(), 1, self.packet_size, self.poll_sender.clone()); } else { panic!("Invalid length of receive queue"); } @@ -155,141 +153,70 @@ impl RxQueue { } pub(crate) struct TxQueue { - vq: Option>, - poll_sender: async_channel::Sender, - poll_receiver: async_channel::Receiver, - ready_queue: Vec, + vq: Option>, + /// Indicates, whether the Driver/Device are using multiple + /// queues for communication. + packet_length: u32, } impl TxQueue { pub fn new() -> Self { - let (poll_sender, poll_receiver) = async_channel::unbounded(); - Self { vq: None, - poll_sender, - poll_receiver, - ready_queue: Vec::new(), + packet_length: 8192u32 + mem::size_of::() as u32, } } - pub fn add(&mut self, vq: Rc) { - let tx_size = (1514 + mem::size_of::()) - .align_up(core::mem::size_of::>()); - let buff_def = Bytes::new(tx_size).unwrap(); - let spec = BuffSpec::Single(buff_def); - let num_buff: u16 = vq.size().into(); - - for _ in 0..num_buff { - let mut buffer_tkn = BufferToken::new(Some(spec.clone()), None).unwrap(); - buffer_tkn - .write_seq(Some(&Hdr::default()), None::<&Hdr>) - .unwrap(); - self.ready_queue.push(buffer_tkn) - } - + pub fn add(&mut self, vq: Box) { self.vq = Some(vq); } - pub fn enable_notifs(&self) { - if let Some(ref vq) = self.vq { + pub fn enable_notifs(&mut self) { + if let Some(ref mut vq) = self.vq { vq.enable_notifs(); } } - pub fn disable_notifs(&self) { - if let Some(ref vq) = self.vq { + pub fn disable_notifs(&mut self) { + if let Some(ref mut vq) = self.vq { vq.disable_notifs(); } } - fn poll(&self) { - if let Some(ref vq) = self.vq { + fn poll(&mut self) { + if let Some(ref mut vq) = self.vq { vq.poll(); } } - /// Returns either a BufferToken and the corresponding index of the - /// virtqueue it is coming from. (Index in the TxQueues.vqs vector) - /// - /// OR returns None, if no BufferToken could be generated - fn get_tkn(&mut self, len: usize) -> Option<(BufferToken, usize)> { - // Check all ready token, for correct size. - // Drop token if not so - // - // All Tokens inside the ready_queue are coming from the main queue with index 0. - while let Some(mut tkn) = self.ready_queue.pop() { - let (send_len, _) = tkn.len(); - - match send_len.cmp(&len) { - Ordering::Less => {} - Ordering::Equal => return Some((tkn, 0)), - Ordering::Greater => { - tkn.restr_size(Some(len), None).unwrap(); - return Some((tkn, 0)); - } - } - } - - if self.poll_receiver.is_empty() { - self.poll(); - } - - while let Ok(mut buffer_token) = self.poll_receiver.try_recv() { - buffer_token.reset(); - let (send_len, _) = buffer_token.len(); - - match send_len.cmp(&len) { - Ordering::Less => {} - Ordering::Equal => return Some((buffer_token, 0)), - Ordering::Greater => { - buffer_token.restr_size(Some(len), None).unwrap(); - return Some((buffer_token, 0)); - } - } - } - - // As usize is currently safe as the minimal usize is defined as 16bit in rust. - let spec = BuffSpec::Single(Bytes::new(len).unwrap()); - - match BufferToken::new(Some(spec), None) { - Ok(tkn) => Some((tkn, 0)), - Err(_) => { - // Here it is possible if multiple queues are enabled to get another buffertoken from them! - // Info the queues are disabled upon initialization and should be enabled somehow! - None - } - } - } - /// Provides a slice to copy the packet and transfer the packet - /// to the send queue. The caller has to creatde the header + /// to the send queue. The caller has to create the header /// for the vsock interface. pub fn send_packet(&mut self, len: usize, f: F) -> R where F: FnOnce(&mut [u8]) -> R, { - if let Some((mut buff_tkn, _vq_index)) = self.get_tkn(len) { - let (send_ptrs, _) = buff_tkn.raw_ptrs(); - let (buff_ptr, _) = send_ptrs.unwrap()[0]; - - let buf_slice: &'static mut [u8] = - unsafe { core::slice::from_raw_parts_mut(buff_ptr, len) }; - let result = f(buf_slice); - - if let Some(ref vq) = self.vq { - vq.dispatch_await( - buff_tkn, - self.poll_sender.clone(), - false, - BufferType::Direct, - ) - .unwrap(); + // We need to poll to get the queue to remove elements from the table and make space for + // what we are about to add + if let Some(ref mut vq) = self.vq { + vq.poll(); + assert!(len < usize::try_from(self.packet_length).unwrap()); + let mut packet = Vec::with_capacity_in(len, DeviceAlloc); + let result = unsafe { + let result = f(MaybeUninit::slice_assume_init_mut( + packet.spare_capacity_mut(), + )); + packet.set_len(len); result - } else { - panic!("Unable to get token for send queue"); - } + }; + + let buff_tkn = AvailBufferToken::new(vec![BufferElem::Vector(packet)], vec![]).unwrap(); + + vq.dispatch(buff_tkn, None, false, BufferType::Direct) + .unwrap(); + + result } else { panic!("Unable to get send queue"); } @@ -297,9 +224,10 @@ impl TxQueue { } pub(crate) struct EventQueue { - vq: Option>, - poll_sender: async_channel::Sender, - poll_receiver: async_channel::Receiver, + vq: Option>, + poll_sender: async_channel::Sender, + poll_receiver: async_channel::Receiver, + packet_size: u32, } impl EventQueue { @@ -310,52 +238,33 @@ impl EventQueue { vq: None, poll_sender, poll_receiver, + packet_size: 1024u32, } } - pub fn add(&mut self, vq: Rc) { - let num_buff: u16 = vq.size().into(); - let event_size = mem::size_of::() - .align_up(core::mem::size_of::>()); - let spec = BuffSpec::Single(Bytes::new(event_size).unwrap()); - - for _ in 0..num_buff { - let buff_tkn = match BufferToken::new(None, Some(spec.clone())) { - Ok(tkn) => tkn, - Err(_vq_err) => { - error!("Setup of vsock queue failed, which should not happen!"); - panic!("setup of vsock queue failed!"); - } - }; - - // BufferTokens are directly provided to the queue - // TransferTokens are directly dispatched - // Transfers will be awaited at the queue - match vq.dispatch_await( - buff_tkn, - self.poll_sender.clone(), - false, - BufferType::Direct, - ) { - Ok(_) => (), - Err(_) => { - error!("Descriptor IDs were exhausted earlier than expected."); - break; - } - } - } - + /// Adds a given queue to the underlying vector and populates the queue with RecvBuffers. + /// + /// Queues are all populated according to Virtio specification v1.1. - 5.1.6.3.1 + fn add(&mut self, mut vq: Box) { + const BUFF_PER_PACKET: u16 = 2; + let num_packets: u16 = u16::from(vq.size()) / BUFF_PER_PACKET; + fill_queue( + vq.as_mut(), + num_packets, + self.packet_size, + self.poll_sender.clone(), + ); self.vq = Some(vq); } - pub fn enable_notifs(&self) { - if let Some(ref vq) = self.vq { + pub fn enable_notifs(&mut self) { + if let Some(ref mut vq) = self.vq { vq.enable_notifs(); } } - pub fn disable_notifs(&self) { - if let Some(ref vq) = self.vq { + pub fn disable_notifs(&mut self) { + if let Some(ref mut vq) = self.vq { vq.disable_notifs(); } } @@ -398,13 +307,13 @@ impl VirtioVsockDriver { self.com_cfg.set_failed(); } - pub fn disable_interrupts(&self) { + pub fn disable_interrupts(&mut self) { // For send and receive queues? // Only for receive? Because send is off anyway? self.recv_vq.disable_notifs(); } - pub fn enable_interrupts(&self) { + pub fn enable_interrupts(&mut self) { // For send and receive queues? // Only for receive? Because send is off anyway? self.recv_vq.enable_notifs(); @@ -482,7 +391,7 @@ impl VirtioVsockDriver { } // create the queues and tell device about them - self.recv_vq.add(Rc::new( + self.recv_vq.add(Box::new( SplitVq::new( &mut self.com_cfg, &self.notif_cfg, @@ -495,7 +404,7 @@ impl VirtioVsockDriver { // Interrupt for receiving packets is wanted self.recv_vq.enable_notifs(); - self.send_vq.add(Rc::new( + self.send_vq.add(Box::new( SplitVq::new( &mut self.com_cfg, &self.notif_cfg, @@ -509,7 +418,7 @@ impl VirtioVsockDriver { self.send_vq.disable_notifs(); // create the queues and tell device about them - self.event_vq.add(Rc::new( + self.event_vq.add(Box::new( SplitVq::new( &mut self.com_cfg, &self.notif_cfg, @@ -560,9 +469,9 @@ pub mod error { FailFeatureNeg(u16), /// Set of features does not adhere to the requirements of features /// indicated by the specification - FeatureRequirementsNotMet(virtio::net::F), + FeatureRequirementsNotMet(virtio::vsock::F), /// The first u64 contains the feature bits wanted by the driver. /// but which are incompatible with the device feature set, second u64. - IncompatibleFeatureSets(virtio::net::F, virtio::net::F), + IncompatibleFeatureSets(virtio::vsock::F, virtio::vsock::F), } } diff --git a/src/drivers/vsock/pci.rs b/src/drivers/vsock/pci.rs index dda855241b..536e4700de 100644 --- a/src/drivers/vsock/pci.rs +++ b/src/drivers/vsock/pci.rs @@ -26,7 +26,7 @@ impl VirtioVsockDriver { Some(VsockDevCfg { raw: dev_cfg, dev_id: cap.dev_id(), - features: virtio::net::F::empty(), + features: virtio::vsock::F::empty(), }) } diff --git a/src/executor/vsock.rs b/src/executor/vsock.rs index 124254b603..5036709645 100644 --- a/src/executor/vsock.rs +++ b/src/executor/vsock.rs @@ -55,7 +55,9 @@ impl WakerRegistration { /// Wake the registered waker, if any. pub fn wake(&mut self) { - self.waker.take().map(|w| w.wake()); + if let Some(w) = self.waker.take() { + w.wake() + } } } @@ -88,6 +90,7 @@ async fn vsock_run() { const HEADER_SIZE: usize = core::mem::size_of::(); let mut driver_guard = driver.lock(); let mut hdr: Option = None; + let mut fwd_cnt: Option = None; driver_guard.process_packet(|header, data| { let op = Op::try_from(header.op.to_ne()).unwrap(); @@ -100,7 +103,7 @@ async fn vsock_run() { { raw.state = VsockState::ReceiveRequest; raw.remote_cid = header.src_cid.to_ne().try_into().unwrap(); - raw.remote_port = header.src_port.to_ne().try_into().unwrap(); + raw.remote_port = header.src_port.to_ne(); raw.waker.wake(); } else if (raw.state == VsockState::Connected || raw.state == VsockState::Shutdown) @@ -110,12 +113,14 @@ async fn vsock_run() { raw.waker.wake(); } else { hdr = Some(*header); + if op == Op::CreditRequest { + fwd_cnt = Some(raw.buffer.len().try_into().unwrap()); + } } } }); if let Some(hdr) = hdr { - // reset connection driver_guard.send_packet(HEADER_SIZE, |buffer| { let response = unsafe { &mut *(buffer.as_mut_ptr() as *mut Hdr) }; @@ -125,10 +130,17 @@ async fn vsock_run() { response.dst_port = hdr.src_port; response.len = le32::from_ne(0); response.type_ = hdr.type_; - response.op = le16::from_ne(Op::Rst.into()); + if let Some(fwd_cnt) = fwd_cnt { + // update fwd_cnt + response.op = le16::from_ne(Op::CreditUpdate.into()); + response.fwd_cnt = le32::from_ne(fwd_cnt); + } else { + // reset connection + response.op = le16::from_ne(Op::Rst.into()); + response.fwd_cnt = le32::from_ne(0); + } response.flags = le32::from_ne(0); response.buf_alloc = le32::from_ne(RAW_SOCKET_BUFFER_SIZE as u32); - response.fwd_cnt = le32::from_ne(0); }); } From d82fba86e138a0cc11f47797f6e4e3cacd386a5b Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Wed, 7 Aug 2024 00:03:15 +0200 Subject: [PATCH 10/18] add support of Vsock shutdown and credit update command --- src/executor/vsock.rs | 9 +++++++++ src/fd/socket/vsock.rs | 22 +++++++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/executor/vsock.rs b/src/executor/vsock.rs index 5036709645..a3f215a474 100644 --- a/src/executor/vsock.rs +++ b/src/executor/vsock.rs @@ -108,9 +108,14 @@ async fn vsock_run() { } else if (raw.state == VsockState::Connected || raw.state == VsockState::Shutdown) && type_ == Type::Stream + && op == Op::Rw { raw.buffer.extend_from_slice(data); raw.waker.wake(); + } else if op == Op::CreditUpdate { + debug!("CrediteUpdate currently not supported: {:?}", header); + } else if op == Op::Shutdown { + raw.state = VsockState::Shutdown; } else { hdr = Some(*header); if op == Op::CreditRequest { @@ -177,6 +182,10 @@ impl VsockMap { pub fn get_mut_socket(&mut self, port: u32) -> Option<&mut RawSocket> { self.port_map.get_mut(&port) } + + pub fn remove_socket(&mut self, port: u32) { + let _ = self.port_map.remove(&port); + } } pub(crate) fn init() { diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs index 0e8fb46704..afd3cf3698 100644 --- a/src/fd/socket/vsock.rs +++ b/src/fd/socket/vsock.rs @@ -183,6 +183,18 @@ impl ObjectInterface for Socket { Poll::Ready(Ok(len)) } } + VsockState::Shutdown => { + let len = core::cmp::min(buffer.len(), raw.buffer.len()); + + if len == 0 { + Poll::Ready(Ok(0)) + } else { + let tmp: Vec<_> = raw.buffer.drain(..len).collect(); + buffer[..len].copy_from_slice(tmp.as_slice()); + + Poll::Ready(Ok(len)) + } + } _ => Poll::Ready(Err(Error::EIO)), } }) @@ -213,7 +225,7 @@ impl ObjectInterface for Socket { response.flags = le32::from_ne(0); response.buf_alloc = le32::from_ne(crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32); - response.fwd_cnt = le32::from_ne(0); + response.fwd_cnt = le32::from_ne(raw.buffer.len().try_into().unwrap()); virtio_buffer[HEADER_SIZE..].copy_from_slice(buffer); }); @@ -234,3 +246,11 @@ impl Clone for Socket { } } } + +impl Drop for Socket { + fn drop(&mut self) { + let port = self.port.load(Ordering::Acquire); + let mut guard = VSOCK_MAP.lock(); + guard.remove_socket(port); + } +} From 7190b40d4c0ee790022893d0bb5f37a42cdedb92 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Wed, 7 Aug 2024 15:03:45 +0200 Subject: [PATCH 11/18] revise buffer management compliant with the virtio standard --- src/config.rs | 3 ++ src/drivers/vsock/mod.rs | 6 +-- src/executor/vsock.rs | 40 ++++++++++++------- src/fd/socket/vsock.rs | 83 ++++++++++++++++++++++++---------------- 4 files changed, 83 insertions(+), 49 deletions(-) diff --git a/src/config.rs b/src/config.rs index 2768e37384..805acb44b6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -14,3 +14,6 @@ pub(crate) const VIRTIO_MAX_QUEUE_SIZE: u16 = 1024; /// Default keep alive interval in milliseconds #[cfg(feature = "tcp")] pub(crate) const DEFAULT_KEEP_ALIVE_INTERVAL: u64 = 75000; + +#[cfg(feature = "vsock")] +pub(crate) const VSOCK_PACKET_SIZE: u32 = 8192; diff --git a/src/drivers/vsock/mod.rs b/src/drivers/vsock/mod.rs index f1e2f6d37a..e906220671 100644 --- a/src/drivers/vsock/mod.rs +++ b/src/drivers/vsock/mod.rs @@ -81,7 +81,7 @@ impl RxQueue { vq: None, poll_sender, poll_receiver, - packet_size: 8192u32 + mem::size_of::() as u32, + packet_size: crate::VSOCK_PACKET_SIZE + mem::size_of::() as u32, } } @@ -163,7 +163,7 @@ impl TxQueue { pub fn new() -> Self { Self { vq: None, - packet_length: 8192u32 + mem::size_of::() as u32, + packet_length: crate::VSOCK_PACKET_SIZE + mem::size_of::() as u32, } } @@ -238,7 +238,7 @@ impl EventQueue { vq: None, poll_sender, poll_receiver, - packet_size: 1024u32, + packet_size: 128u32, } } diff --git a/src/executor/vsock.rs b/src/executor/vsock.rs index a3f215a474..e1d429af4e 100644 --- a/src/executor/vsock.rs +++ b/src/executor/vsock.rs @@ -67,8 +67,13 @@ pub(crate) const RAW_SOCKET_BUFFER_SIZE: usize = 256 * 1024; pub(crate) struct RawSocket { pub remote_cid: u32, pub remote_port: u32, + pub fwd_cnt: u32, + pub peer_fwd_cnt: u32, + pub peer_buf_alloc: u32, + pub tx_cnt: u32, pub state: VsockState, - pub waker: WakerRegistration, + pub rx_waker: WakerRegistration, + pub tx_waker: WakerRegistration, pub buffer: Vec, } @@ -77,8 +82,13 @@ impl RawSocket { Self { remote_cid: 0, remote_port: 0, + fwd_cnt: 0, + peer_fwd_cnt: 0, + peer_buf_alloc: 0, + tx_cnt: 0, state, - waker: WakerRegistration::new(), + rx_waker: WakerRegistration::new(), + tx_waker: WakerRegistration::new(), buffer: Vec::with_capacity(RAW_SOCKET_BUFFER_SIZE), } } @@ -90,7 +100,7 @@ async fn vsock_run() { const HEADER_SIZE: usize = core::mem::size_of::(); let mut driver_guard = driver.lock(); let mut hdr: Option = None; - let mut fwd_cnt: Option = None; + let mut fwd_cnt: u32 = 0; driver_guard.process_packet(|header, data| { let op = Op::try_from(header.op.to_ne()).unwrap(); @@ -104,23 +114,28 @@ async fn vsock_run() { raw.state = VsockState::ReceiveRequest; raw.remote_cid = header.src_cid.to_ne().try_into().unwrap(); raw.remote_port = header.src_port.to_ne(); - raw.waker.wake(); + raw.peer_buf_alloc = header.buf_alloc.to_ne(); + raw.rx_waker.wake(); } else if (raw.state == VsockState::Connected || raw.state == VsockState::Shutdown) && type_ == Type::Stream && op == Op::Rw { raw.buffer.extend_from_slice(data); - raw.waker.wake(); + raw.fwd_cnt = raw.fwd_cnt.wrapping_add(u32::try_from(data.len()).unwrap()); + raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); + raw.tx_waker.wake(); + raw.rx_waker.wake(); + hdr = Some(*header); + fwd_cnt = raw.fwd_cnt; } else if op == Op::CreditUpdate { - debug!("CrediteUpdate currently not supported: {:?}", header); + raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); + raw.tx_waker.wake(); } else if op == Op::Shutdown { raw.state = VsockState::Shutdown; } else { hdr = Some(*header); - if op == Op::CreditRequest { - fwd_cnt = Some(raw.buffer.len().try_into().unwrap()); - } + fwd_cnt = raw.fwd_cnt; } } }); @@ -135,17 +150,16 @@ async fn vsock_run() { response.dst_port = hdr.src_port; response.len = le32::from_ne(0); response.type_ = hdr.type_; - if let Some(fwd_cnt) = fwd_cnt { - // update fwd_cnt + if hdr.op.to_ne() == Op::CreditRequest.into() || hdr.op.to_ne() == Op::Rw.into() + { response.op = le16::from_ne(Op::CreditUpdate.into()); - response.fwd_cnt = le32::from_ne(fwd_cnt); } else { // reset connection response.op = le16::from_ne(Op::Rst.into()); - response.fwd_cnt = le32::from_ne(0); } response.flags = le32::from_ne(0); response.buf_alloc = le32::from_ne(RAW_SOCKET_BUFFER_SIZE as u32); + response.fwd_cnt = le32::from_ne(fwd_cnt); }); } diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs index afd3cf3698..e1f56a8e49 100644 --- a/src/fd/socket/vsock.rs +++ b/src/fd/socket/vsock.rs @@ -95,7 +95,7 @@ impl ObjectInterface for Socket { match raw.state { VsockState::Listen => { - raw.waker.register(cx.waker()); + raw.rx_waker.register(cx.waker()); Poll::Pending } VsockState::ReceiveRequest => { @@ -123,7 +123,7 @@ impl ObjectInterface for Socket { response.buf_alloc = le32::from_ne( crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32, ); - response.fwd_cnt = le32::from_ne(0); + response.fwd_cnt = le32::from_ne(raw.fwd_cnt); }); raw.state = VsockState::Connected; @@ -174,7 +174,7 @@ impl ObjectInterface for Socket { let len = core::cmp::min(buffer.len(), raw.buffer.len()); if len == 0 { - raw.waker.register(cx.waker()); + raw.rx_waker.register(cx.waker()); Poll::Pending } else { let tmp: Vec<_> = raw.buffer.drain(..len).collect(); @@ -203,37 +203,54 @@ impl ObjectInterface for Socket { async fn async_write(&self, buffer: &[u8]) -> io::Result { let port = self.port.load(Ordering::Acquire); - let guard = VSOCK_MAP.lock(); - let raw = guard.get_socket(port).ok_or(Error::EINVAL)?; - - match raw.state { - VsockState::Connected => { - const HEADER_SIZE: usize = core::mem::size_of::(); - let mut driver_guard = hardware::get_vsock_driver().unwrap().lock(); - let local_cid = driver_guard.get_cid(); - - driver_guard.send_packet(HEADER_SIZE + buffer.len(), |virtio_buffer| { - let response = unsafe { &mut *(virtio_buffer.as_mut_ptr() as *mut Hdr) }; - - response.src_cid = le64::from_ne(local_cid); - response.dst_cid = le64::from_ne(raw.remote_cid as u64); - response.src_port = le32::from_ne(port); - response.dst_port = le32::from_ne(raw.remote_port); - response.len = le32::from_ne(buffer.len().try_into().unwrap()); - response.type_ = le16::from_ne(Type::Stream.into()); - response.op = le16::from_ne(Op::Rw.into()); - response.flags = le32::from_ne(0); - response.buf_alloc = - le32::from_ne(crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32); - response.fwd_cnt = le32::from_ne(raw.buffer.len().try_into().unwrap()); - - virtio_buffer[HEADER_SIZE..].copy_from_slice(buffer); - }); - - Ok(buffer.len()) + future::poll_fn(|cx| { + let mut guard = VSOCK_MAP.lock(); + let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?; + let diff = raw.tx_cnt.abs_diff(raw.peer_fwd_cnt); + + match raw.state { + VsockState::Connected => { + if diff >= raw.peer_buf_alloc { + raw.tx_waker.register(cx.waker()); + Poll::Pending + } else { + const HEADER_SIZE: usize = core::mem::size_of::(); + let mut driver_guard = hardware::get_vsock_driver().unwrap().lock(); + let local_cid = driver_guard.get_cid(); + let len = core::cmp::min( + buffer.len(), + usize::try_from(raw.peer_buf_alloc - diff).unwrap(), + ); + + driver_guard.send_packet(HEADER_SIZE + len, |virtio_buffer| { + let response = + unsafe { &mut *(virtio_buffer.as_mut_ptr() as *mut Hdr) }; + + raw.tx_cnt = raw.tx_cnt.wrapping_add(len.try_into().unwrap()); + response.src_cid = le64::from_ne(local_cid); + response.dst_cid = le64::from_ne(raw.remote_cid as u64); + response.src_port = le32::from_ne(port); + response.dst_port = le32::from_ne(raw.remote_port); + response.len = le32::from_ne(len.try_into().unwrap()); + response.type_ = le16::from_ne(Type::Stream.into()); + response.op = le16::from_ne(Op::Rw.into()); + response.flags = le32::from_ne(0); + response.buf_alloc = le32::from_ne( + crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32, + ); + response.fwd_cnt = le32::from_ne(raw.fwd_cnt); + + virtio_buffer[HEADER_SIZE..HEADER_SIZE + len] + .copy_from_slice(&buffer[..len]); + }); + + Poll::Ready(Ok(len)) + } + } + _ => Poll::Ready(Err(Error::EIO)), } - _ => Err(Error::EIO), - } + }) + .await } } From 36375c61fae7a3c9484847bbc1e23d4973e01209 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Wed, 7 Aug 2024 23:14:58 +0200 Subject: [PATCH 12/18] add support of the system call `poll` for the vsock interface --- src/fd/socket/vsock.rs | 77 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs index e1f56a8e49..b52320b55e 100644 --- a/src/fd/socket/vsock.rs +++ b/src/fd/socket/vsock.rs @@ -13,7 +13,7 @@ use crate::arch::kernel::mmio as hardware; #[cfg(feature = "pci")] use crate::drivers::pci as hardware; use crate::executor::vsock::{VsockState, VSOCK_MAP}; -use crate::fd::{block_on, Endpoint, IoCtl, ListenEndpoint, ObjectInterface}; +use crate::fd::{block_on, Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent}; use crate::io::{self, Error}; #[derive(Debug)] @@ -59,6 +59,77 @@ impl Socket { #[async_trait] impl ObjectInterface for Socket { + async fn poll(&self, event: PollEvent) -> io::Result { + let port = self.port.load(Ordering::Acquire); + + future::poll_fn(|cx| { + let mut guard = VSOCK_MAP.lock(); + let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?; + + match raw.state { + VsockState::Shutdown | VsockState::ReceiveRequest => { + let available = PollEvent::POLLOUT + | PollEvent::POLLWRNORM + | PollEvent::POLLWRBAND + | PollEvent::POLLIN + | PollEvent::POLLRDNORM + | PollEvent::POLLRDBAND; + + let ret = event & available; + + if ret.is_empty() { + Poll::Ready(Ok(PollEvent::POLLHUP)) + } else { + Poll::Ready(Ok(ret)) + } + } + VsockState::Listen | VsockState::Connecting => { + raw.rx_waker.register(cx.waker()); + raw.tx_waker.register(cx.waker()); + Poll::Pending + } + VsockState::Connected => { + let mut available = PollEvent::empty(); + + if !raw.buffer.is_empty() { + // In case, we just establish a fresh connection in non-blocking mode, we try to read data. + available.insert( + PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND, + ); + } + + let diff = raw.tx_cnt.abs_diff(raw.peer_fwd_cnt); + if diff < raw.peer_buf_alloc { + available.insert( + PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND, + ); + } + + let ret = event & available; + + if ret.is_empty() { + if event.intersects( + PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND, + ) { + raw.rx_waker.register(cx.waker()); + } + + if event.intersects( + PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND, + ) { + raw.tx_waker.register(cx.waker()); + } + + Poll::Pending + } else { + Poll::Ready(Ok(ret)) + } + } + } + }) + .await + } + fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { match endpoint { ListenEndpoint::Vsock(ep) => { @@ -144,6 +215,10 @@ impl ObjectInterface for Socket { Ok(Endpoint::Vsock(endpoint)) } + fn shutdown(&self, _how: i32) -> io::Result<()> { + Ok(()) + } + fn ioctl(&self, cmd: IoCtl, value: bool) -> io::Result<()> { if cmd == IoCtl::NonBlocking { if value { From ebf94b7cda267d47a4ccc67c41c57676f1e02629 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Thu, 8 Aug 2024 11:30:16 +0200 Subject: [PATCH 13/18] add support of getpeername and getsockname --- src/fd/socket/vsock.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs index b52320b55e..336a52c00c 100644 --- a/src/fd/socket/vsock.rs +++ b/src/fd/socket/vsock.rs @@ -146,6 +146,26 @@ impl ObjectInterface for Socket { } } + fn getpeername(&self) -> Option { + let port = self.port.load(Ordering::Acquire); + let guard = VSOCK_MAP.lock(); + let raw = guard.get_socket(port)?; + + Some(Endpoint::Vsock(VsockEndpoint::new( + raw.remote_port, + raw.remote_cid, + ))) + } + + fn getsockname(&self) -> Option { + let local_cid = hardware::get_vsock_driver().unwrap().lock().get_cid(); + + Some(Endpoint::Vsock(VsockEndpoint::new( + self.port.load(Ordering::Acquire), + local_cid.try_into().unwrap(), + ))) + } + fn is_nonblocking(&self) -> bool { self.nonblocking.load(Ordering::Acquire) } From 8b80a6d1b9bada487a6e52606fb3b276e0d80fb0 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Thu, 8 Aug 2024 18:19:28 +0200 Subject: [PATCH 14/18] In case of vsock, cloning of the socket is invalid After accepting a connection, the new socket has to use the same port like the listening stream. => Accepting of new connection is not possible. --- src/syscalls/socket.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/syscalls/socket.rs b/src/syscalls/socket.rs index d49cfb0efa..5137e56590 100644 --- a/src/syscalls/socket.rs +++ b/src/syscalls/socket.rs @@ -517,9 +517,9 @@ pub unsafe extern "C" fn sys_accept(fd: i32, addr: *mut sockaddr, addrlen: *mut } #[cfg(feature = "vsock")] Endpoint::Vsock(endpoint) => { - let new_obj = dyn_clone::clone_box(&*v); - replace_object(fd, Arc::from(new_obj)).unwrap(); - let new_fd = insert_object(v).unwrap(); + //let new_obj = dyn_clone::clone_box(&*v); + //replace_object(fd, Arc::from(new_obj)).unwrap(); + let new_fd = insert_object(v.clone()).unwrap(); if !addr.is_null() && !addrlen.is_null() { let addrlen = unsafe { &mut *addrlen }; From 27d5c7ae24ca64fec4acce9c5222e23a6031a912 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Sun, 11 Aug 2024 19:26:58 +0200 Subject: [PATCH 15/18] check the source of the received message --- src/executor/vsock.rs | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/executor/vsock.rs b/src/executor/vsock.rs index e1d429af4e..4178f7e9d4 100644 --- a/src/executor/vsock.rs +++ b/src/executor/vsock.rs @@ -107,12 +107,13 @@ async fn vsock_run() { let port = header.dst_port.to_ne(); let type_ = Type::try_from(header.type_.to_ne()).unwrap(); let mut vsock_guard = VSOCK_MAP.lock(); + let header_cid: u32 = header.src_cid.to_ne().try_into().unwrap(); if let Some(raw) = vsock_guard.get_mut_socket(port) { if op == Op::Request && raw.state == VsockState::Listen && type_ == Type::Stream { raw.state = VsockState::ReceiveRequest; - raw.remote_cid = header.src_cid.to_ne().try_into().unwrap(); + raw.remote_cid = header_cid; raw.remote_port = header.src_port.to_ne(); raw.peer_buf_alloc = header.buf_alloc.to_ne(); raw.rx_waker.wake(); @@ -121,19 +122,32 @@ async fn vsock_run() { && type_ == Type::Stream && op == Op::Rw { - raw.buffer.extend_from_slice(data); - raw.fwd_cnt = raw.fwd_cnt.wrapping_add(u32::try_from(data.len()).unwrap()); - raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); - raw.tx_waker.wake(); - raw.rx_waker.wake(); - hdr = Some(*header); - fwd_cnt = raw.fwd_cnt; + if raw.remote_cid == header_cid { + raw.buffer.extend_from_slice(data); + raw.fwd_cnt = + raw.fwd_cnt.wrapping_add(u32::try_from(data.len()).unwrap()); + raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); + raw.tx_waker.wake(); + raw.rx_waker.wake(); + hdr = Some(*header); + fwd_cnt = raw.fwd_cnt; + } else { + trace!("Receive message from invalid source {}", header_cid); + } } else if op == Op::CreditUpdate { - raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); - raw.tx_waker.wake(); + if raw.remote_cid == header_cid { + raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); + raw.tx_waker.wake(); + } else { + trace!("Receive message from invalid source {}", header_cid); + } } else if op == Op::Shutdown { - raw.state = VsockState::Shutdown; - } else { + if raw.remote_cid == header_cid { + raw.state = VsockState::Shutdown; + } else { + trace!("Receive message from invalid source {}", header_cid); + } + } else if raw.remote_cid == header_cid { hdr = Some(*header); fwd_cnt = raw.fwd_cnt; } From 8aee66c79b6468512f3a572bd33af58ce68c9b1f Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Mon, 12 Aug 2024 13:17:15 +0200 Subject: [PATCH 16/18] add support of the system call connect --- src/executor/vsock.rs | 18 ++++++++++++++ src/fd/socket/vsock.rs | 55 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/src/executor/vsock.rs b/src/executor/vsock.rs index 4178f7e9d4..902cebef6c 100644 --- a/src/executor/vsock.rs +++ b/src/executor/vsock.rs @@ -147,6 +147,10 @@ async fn vsock_run() { } else { trace!("Receive message from invalid source {}", header_cid); } + } else if op == Op::Response && type_ == Type::Stream { + if raw.remote_cid == header_cid && raw.state == VsockState::Connecting { + raw.state = VsockState::Connected; + } } else if raw.remote_cid == header_cid { hdr = Some(*header); fwd_cnt = raw.fwd_cnt; @@ -203,6 +207,20 @@ impl VsockMap { Ok(()) } + pub fn connect(&mut self, port: u32, cid: u32) -> io::Result { + for i in u32::MAX / 4..u32::MAX { + let mut raw = RawSocket::new(VsockState::Connecting); + raw.remote_cid = cid; + raw.remote_port = port; + + if self.port_map.try_insert(i, raw).is_ok() { + return Ok(i); + } + } + + Err(io::Error::EBADF) + } + pub fn get_socket(&self, port: u32) -> Option<&RawSocket> { self.port_map.get(&port) } diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs index 336a52c00c..19718887f9 100644 --- a/src/fd/socket/vsock.rs +++ b/src/fd/socket/vsock.rs @@ -3,6 +3,7 @@ use alloc::vec::Vec; use core::future; use core::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use core::task::Poll; +use core::time::Duration; use async_trait::async_trait; use endian_num::{le16, le32, le64}; @@ -146,6 +147,60 @@ impl ObjectInterface for Socket { } } + fn connect(&self, endpoint: Endpoint) -> io::Result<()> { + match endpoint { + Endpoint::Vsock(ep) => { + const HEADER_SIZE: usize = core::mem::size_of::(); + let port = VSOCK_MAP.lock().connect(ep.port, ep.cid)?; + self.port.store(port, Ordering::Release); + self.port.store(ep.cid, Ordering::Release); + + let mut driver_guard = hardware::get_vsock_driver().unwrap().lock(); + let local_cid = driver_guard.get_cid(); + + driver_guard.send_packet(HEADER_SIZE, |buffer| { + let response = unsafe { &mut *(buffer.as_mut_ptr() as *mut Hdr) }; + + response.src_cid = le64::from_ne(local_cid); + response.dst_cid = le64::from_ne(ep.cid as u64); + response.src_port = le32::from_ne(port); + response.dst_port = le32::from_ne(ep.port); + response.len = le32::from_ne(0); + response.type_ = le16::from_ne(Type::Stream.into()); + response.op = le16::from_ne(Op::Request.into()); + response.flags = le32::from_ne(0); + response.buf_alloc = + le32::from_ne(crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32); + response.fwd_cnt = le32::from_ne(0); + }); + + drop(driver_guard); + + block_on( + async { + future::poll_fn(|cx| { + let mut guard = VSOCK_MAP.lock(); + let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?; + + match raw.state { + VsockState::Connected => Poll::Ready(Ok(())), + VsockState::Connecting => { + raw.rx_waker.register(cx.waker()); + Poll::Pending + } + _ => Poll::Ready(Err(io::Error::EBADF)), + } + }) + .await + }, + Some(Duration::from_millis(1000)), + ) + } + #[cfg(any(feature = "tcp", feature = "udp"))] + _ => Err(io::Error::EINVAL), + } + } + fn getpeername(&self) -> Option { let port = self.port.load(Ordering::Acquire); let guard = VSOCK_MAP.lock(); From ae48d4e4250e55014c4982e41faf9aad07f1c198 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Tue, 13 Aug 2024 08:40:03 +0200 Subject: [PATCH 17/18] use busy waiting in case of connect --- src/fd/socket/vsock.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs index 19718887f9..d41b24d09d 100644 --- a/src/fd/socket/vsock.rs +++ b/src/fd/socket/vsock.rs @@ -14,7 +14,7 @@ use crate::arch::kernel::mmio as hardware; #[cfg(feature = "pci")] use crate::drivers::pci as hardware; use crate::executor::vsock::{VsockState, VSOCK_MAP}; -use crate::fd::{block_on, Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent}; +use crate::fd::{block_on, poll_on, Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent}; use crate::io::{self, Error}; #[derive(Debug)] @@ -176,7 +176,7 @@ impl ObjectInterface for Socket { drop(driver_guard); - block_on( + poll_on( async { future::poll_fn(|cx| { let mut guard = VSOCK_MAP.lock(); From affd88cb833eac2d7697c2fc3f4faae57dad0bd0 Mon Sep 17 00:00:00 2001 From: Stefan Lankes Date: Mon, 19 Aug 2024 20:06:02 +0200 Subject: [PATCH 18/18] remove dependency to the crate endian-num --- Cargo.lock | 1 - Cargo.toml | 3 +-- src/executor/vsock.rs | 2 +- src/fd/socket/vsock.rs | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2cfedc8fac..0687394029 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -628,7 +628,6 @@ dependencies = [ "cfg-if", "crossbeam-utils", "dyn-clone", - "endian-num", "fdt", "float-cmp", "free-list", diff --git a/Cargo.toml b/Cargo.toml index 516f5b6193..68c073d86d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,7 @@ dhcpv4 = [ ] fs = ["fuse"] fuse = ["pci", "dep:fuse-abi", "fuse-abi/num_enum"] -vsock = ["pci", "endian-num"] +vsock = ["pci"] fsgsbase = [] gem-net = ["tcp", "dep:tock-registers"] newlib = [] @@ -88,7 +88,6 @@ build-time = "0.1.3" cfg-if = "1" crossbeam-utils = { version = "0.8", default-features = false } dyn-clone = "1.0" -endian-num = { version = "0.1", optional = true } fdt = { version = "0.1", features = ["pretty-printing"] } free-list = "0.3" fuse-abi = { version = "0.1", features = ["zerocopy"], optional = true } diff --git a/src/executor/vsock.rs b/src/executor/vsock.rs index 902cebef6c..4cd43ea4d3 100644 --- a/src/executor/vsock.rs +++ b/src/executor/vsock.rs @@ -3,9 +3,9 @@ use alloc::vec::Vec; use core::future; use core::task::{Poll, Waker}; -use endian_num::{le16, le32}; use hermit_sync::InterruptTicketMutex; use virtio::vsock::{Hdr, Op, Type}; +use virtio::{le16, le32}; #[cfg(not(feature = "pci"))] use crate::arch::kernel::mmio as hardware; diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs index d41b24d09d..402c57d130 100644 --- a/src/fd/socket/vsock.rs +++ b/src/fd/socket/vsock.rs @@ -6,8 +6,8 @@ use core::task::Poll; use core::time::Duration; use async_trait::async_trait; -use endian_num::{le16, le32, le64}; use virtio::vsock::{Hdr, Op, Type}; +use virtio::{le16, le32, le64}; #[cfg(not(feature = "pci"))] use crate::arch::kernel::mmio as hardware;