Skip to content

Commit

Permalink
sound: add QueueIdx enum for virtio queue indices
Browse files Browse the repository at this point in the history
Add type safe enum to use instead of raw u16 values, which we have to
validate every time we use them.

Signed-off-by: Manos Pitsidianakis <manos.pitsidianakis@linaro.org>
  • Loading branch information
epilys committed Dec 13, 2023
1 parent af87658 commit 985d9c4
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 27 deletions.
51 changes: 26 additions & 25 deletions staging/vhost-device-sound/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ use crate::{
audio_backends::{alloc_audio_backend, AudioBackend},
stream::{Buffer, Error as StreamError, Stream},
virtio_sound::*,
ControlMessageKind, Direction, Error, IOMessage, Result, SoundConfig,
ControlMessageKind, Direction, Error, IOMessage, QueueIdx, Result, SoundConfig,
};

pub struct VhostUserSoundThread {
mem: Option<GuestMemoryAtomic<GuestMemoryMmap>>,
event_idx: bool,
chmaps: Arc<RwLock<Vec<VirtioSoundChmapInfo>>>,
jacks: Arc<RwLock<Vec<VirtioSoundJackInfo>>>,
queue_indexes: Vec<u16>,
queue_indexes: Vec<QueueIdx>,
streams: Arc<RwLock<Vec<Stream>>>,
streams_no: usize,
}
Expand All @@ -49,11 +49,11 @@ impl VhostUserSoundThread {
pub fn new(
chmaps: Arc<RwLock<Vec<VirtioSoundChmapInfo>>>,
jacks: Arc<RwLock<Vec<VirtioSoundJackInfo>>>,
mut queue_indexes: Vec<u16>,
mut queue_indexes: Vec<QueueIdx>,
streams: Arc<RwLock<Vec<Stream>>>,
streams_no: usize,
) -> Result<Self> {
queue_indexes.sort();
queue_indexes.sort_by_key(|idx| *idx as u16);

Ok(Self {
event_idx: false,
Expand All @@ -70,7 +70,7 @@ impl VhostUserSoundThread {
let mut queues_per_thread = 0u64;

for idx in self.queue_indexes.iter() {
queues_per_thread |= 1u64 << idx
queues_per_thread |= 1u64 << *idx as u16
}

queues_per_thread
Expand All @@ -94,7 +94,10 @@ impl VhostUserSoundThread {
let vring = &vrings
.get(device_event as usize)
.ok_or_else(|| Error::HandleUnknownEvent(device_event))?;
let queue_idx = self.queue_indexes[device_event as usize];
let queue_idx = self
.queue_indexes
.get(device_event as usize)
.ok_or_else(|| Error::HandleUnknownEvent(device_event))?;
if self.event_idx {
// vm-virtio's Queue implementation only checks avail_index
// once, so to properly support EVENT_IDX we need to keep
Expand All @@ -103,11 +106,10 @@ impl VhostUserSoundThread {
loop {
vring.disable_notification().unwrap();
match queue_idx {
CONTROL_QUEUE_IDX => self.process_control(vring, audio_backend),
EVENT_QUEUE_IDX => self.process_event(vring),
TX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Output),
RX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Input),
_ => Err(Error::HandleUnknownEvent(queue_idx).into()),
QueueIdx::Control => self.process_control(vring, audio_backend),
QueueIdx::Event => self.process_event(vring),
QueueIdx::Tx => self.process_io(vring, audio_backend, Direction::Output),
QueueIdx::Rx => self.process_io(vring, audio_backend, Direction::Input),
}?;
if !vring.enable_notification().unwrap() {
break;
Expand All @@ -116,11 +118,10 @@ impl VhostUserSoundThread {
} else {
// Without EVENT_IDX, a single call is enough.
match queue_idx {
CONTROL_QUEUE_IDX => self.process_control(vring, audio_backend),
EVENT_QUEUE_IDX => self.process_event(vring),
TX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Output),
RX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Input),
_ => Err(Error::HandleUnknownEvent(queue_idx).into()),
QueueIdx::Control => self.process_control(vring, audio_backend),
QueueIdx::Event => self.process_event(vring),
QueueIdx::Tx => self.process_io(vring, audio_backend, Direction::Output),
QueueIdx::Rx => self.process_io(vring, audio_backend, Direction::Input),
}?;
}
Ok(())
Expand Down Expand Up @@ -635,21 +636,21 @@ impl VhostUserSoundBackend {
RwLock::new(VhostUserSoundThread::new(
chmaps.clone(),
jacks.clone(),
vec![CONTROL_QUEUE_IDX, EVENT_QUEUE_IDX],
vec![QueueIdx::Control, QueueIdx::Event],
streams.clone(),
streams_no,
)?),
RwLock::new(VhostUserSoundThread::new(
chmaps.clone(),
jacks.clone(),
vec![TX_QUEUE_IDX],
vec![QueueIdx::Tx],
streams.clone(),
streams_no,
)?),
RwLock::new(VhostUserSoundThread::new(
chmaps,
jacks,
vec![RX_QUEUE_IDX],
vec![QueueIdx::Rx],
streams.clone(),
streams_no,
)?),
Expand All @@ -659,10 +660,10 @@ impl VhostUserSoundBackend {
chmaps,
jacks,
vec![
CONTROL_QUEUE_IDX,
EVENT_QUEUE_IDX,
TX_QUEUE_IDX,
RX_QUEUE_IDX,
QueueIdx::Control,
QueueIdx::Event,
QueueIdx::Tx,
QueueIdx::Rx,
],
streams.clone(),
streams_no,
Expand Down Expand Up @@ -832,7 +833,7 @@ mod tests {

let chmaps = Arc::new(RwLock::new(vec![]));
let jacks = Arc::new(RwLock::new(vec![]));
let queue_indexes = vec![1, 2, 3];
let queue_indexes = vec![QueueIdx::Event, QueueIdx::Tx, QueueIdx::Rx];
let streams = vec![Stream::default()];
let streams_no = streams.len();
let streams = Arc::new(RwLock::new(streams));
Expand Down Expand Up @@ -927,7 +928,7 @@ mod tests {

let chmaps = Arc::new(RwLock::new(vec![]));
let jacks = Arc::new(RwLock::new(vec![]));
let queue_indexes = vec![1, 2, 3];
let queue_indexes = vec![QueueIdx::Event, QueueIdx::Tx, QueueIdx::Rx];
let streams = Arc::new(RwLock::new(vec![]));
let streams_no = 0;
let thread =
Expand Down
59 changes: 57 additions & 2 deletions staging/vhost-device-sound/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,43 @@ impl TryFrom<u8> for Direction {
Ok(match val {
virtio_sound::VIRTIO_SND_D_OUTPUT => Self::Output,
virtio_sound::VIRTIO_SND_D_INPUT => Self::Input,
other => return Err(Error::InvalidMessageValue(stringify!(Direction), other)),
other => {
return Err(Error::InvalidMessageValue(
stringify!(Direction),
other.into(),
))
}
})
}
}

/// Queue index.
///
/// Type safe enum for CONTROL_QUEUE_IDX, EVENT_QUEUE_IDX, TX_QUEUE_IDX,
/// RX_QUEUE_IDX.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u16)]
pub enum QueueIdx {
#[doc(alias = "CONTROL_QUEUE_IDX")]
Control = virtio_sound::CONTROL_QUEUE_IDX,
#[doc(alias = "EVENT_QUEUE_IDX")]
Event = virtio_sound::EVENT_QUEUE_IDX,
#[doc(alias = "TX_QUEUE_IDX")]
Tx = virtio_sound::TX_QUEUE_IDX,
#[doc(alias = "RX_QUEUE_IDX")]
Rx = virtio_sound::RX_QUEUE_IDX,
}

impl TryFrom<u16> for QueueIdx {
type Error = Error;

fn try_from(val: u16) -> std::result::Result<Self, Self::Error> {
Ok(match val {
virtio_sound::CONTROL_QUEUE_IDX => Self::Control,
virtio_sound::EVENT_QUEUE_IDX => Self::Event,
virtio_sound::TX_QUEUE_IDX => Self::Tx,
virtio_sound::RX_QUEUE_IDX => Self::Rx,
other => return Err(Error::InvalidMessageValue(stringify!(QueueIdx), other)),
})
}
}
Expand All @@ -117,7 +153,7 @@ pub enum Error {
#[error("Invalid control message code {0}")]
InvalidControlMessage(u32),
#[error("Invalid value in {0}: {1}")]
InvalidMessageValue(&'static str, u8),
InvalidMessageValue(&'static str, u16),
#[error("Failed to create a new EventFd")]
EventFdCreate(IoError),
#[error("Request missing data buffer")]
Expand Down Expand Up @@ -389,6 +425,25 @@ mod tests {

let val = 42;
Direction::try_from(val).unwrap_err();

assert_eq!(
QueueIdx::try_from(virtio_sound::CONTROL_QUEUE_IDX).unwrap(),
QueueIdx::Control
);
assert_eq!(
QueueIdx::try_from(virtio_sound::EVENT_QUEUE_IDX).unwrap(),
QueueIdx::Event
);
assert_eq!(
QueueIdx::try_from(virtio_sound::TX_QUEUE_IDX).unwrap(),
QueueIdx::Tx
);
assert_eq!(
QueueIdx::try_from(virtio_sound::RX_QUEUE_IDX).unwrap(),
QueueIdx::Rx
);
let val = virtio_sound::NUM_QUEUES;
QueueIdx::try_from(val).unwrap_err();
}

#[test]
Expand Down

0 comments on commit 985d9c4

Please sign in to comment.