Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use winrt-rs instead of winrt-rust #56

Merged
merged 1 commit into from
Jul 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ bitflags = "1.2"
memalloc = "0.1.0"
jack-sys = { version = "0.1.0", optional = true }
libc = { version = "0.2.21", optional = true }
winrt = { version = "0.6.0", features = ["windows-devices", "windows-storage"], optional = true}
winrt = { version = "0.7.0", optional = true}

[target.'cfg(target_os = "linux")'.dependencies]
alsa = "0.2"
Expand Down
156 changes: 110 additions & 46 deletions src/backend/winrt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,26 @@ extern crate winrt;

use std::sync::{Arc, Mutex};

use self::winrt::{ComPtr, HString, RtAsyncOperation, RtDefaultConstructible, IMemoryBufferByteAccess};
use self::winrt::windows::foundation::*;
use self::winrt::windows::devices::enumeration::*;
use self::winrt::windows::devices::midi::*;
use self::winrt::windows::storage::streams::*;

use ::errors::*;
use ::Ignore;

use self::winrt::{AbiTransferable, HString, TryInto};

winrt::import!(
dependencies
os
types
windows::foundation::*
windows::devices::midi::*
windows::devices::enumeration::DeviceInformation
windows::storage::streams::{Buffer, DataWriter}
);

use self::windows::foundation::*;
use self::windows::devices::midi::*;
use self::windows::devices::enumeration::DeviceInformation;
use self::windows::storage::streams::{Buffer, DataWriter};

#[derive(Clone)]
pub struct MidiInputPort {
id: HString
Expand All @@ -23,6 +34,59 @@ pub struct MidiInput {
ignore_flags: Ignore
}

#[repr(C)]
pub struct abi_IMemoryBufferByteAccess {
__base: [usize; 3],
get_buffer: extern "system" fn(
winrt::NonNullRawComPtr<IMemoryBufferByteAccess>,
value: *mut *mut u8,
capacity: *mut u32,
) -> winrt::ErrorCode,
}

unsafe impl winrt::ComInterface for IMemoryBufferByteAccess {
type VTable = abi_IMemoryBufferByteAccess;
fn iid() -> winrt::Guid {
winrt::Guid::from_values(0x5b0d3235, 0x4dba, 0x4d44, [0x86, 0x5e, 0x8f, 0x1d, 0x0e, 0x4f, 0xd0, 0x4d])
}
}

unsafe impl AbiTransferable for IMemoryBufferByteAccess {
type Abi = winrt::RawComPtr<Self>;

fn get_abi(&self) -> Self::Abi {
self.ptr.get_abi()
}

fn set_abi(&mut self) -> *mut Self::Abi {
self.ptr.set_abi()
}
}

#[repr(transparent)]
#[derive(Default, Clone)]
pub struct IMemoryBufferByteAccess {
ptr: winrt::ComPtr<IMemoryBufferByteAccess>,
}

impl IMemoryBufferByteAccess {
pub unsafe fn get_buffer(&self) -> winrt::Result<&[u8]> {
match self.get_abi() {
None => panic!("The `this` pointer was null when calling method"),
Some(ptr) => {
let mut bufptr = std::ptr::null_mut();
let mut capacity: u32 = 0;
(ptr.vtable().get_buffer)(ptr, &mut bufptr, &mut capacity).ok()?;
if capacity == 0 {
bufptr = 1 as *mut u8; // null pointer is not allowed
}
Ok(std::slice::from_raw_parts(bufptr, capacity as usize))
}
}
}
}


unsafe impl Send for MidiInput {} // because HString doesn't ...

impl MidiInput {
Expand All @@ -36,12 +100,11 @@ impl MidiInput {
}

pub(crate) fn ports_internal(&self) -> Vec<::common::MidiInputPort> {
let device_collection = DeviceInformation::find_all_async_aqs_filter(&self.selector.make_reference()).unwrap().blocking_get().expect("find_all_async failed").expect("find_all_async returned null");
let count = device_collection.get_size().expect("get_size failed") as usize;
let device_collection = DeviceInformation::find_all_async_aqs_filter(&self.selector).unwrap().get().expect("find_all_async failed");
let count = device_collection.size().expect("get_size failed") as usize;
let mut result = Vec::with_capacity(count as usize);
for device_info in device_collection.into_iter() {
let device_info = device_info.expect("device_info was null");
let device_id = device_info.get_id().expect("get_id failed");
let device_id = device_info.id().expect("get_id failed");
result.push(::common::MidiInputPort {
imp: MidiInputPort { id: device_id }
});
Expand All @@ -50,29 +113,29 @@ impl MidiInput {
}

pub fn port_count(&self) -> usize {
let device_collection = DeviceInformation::find_all_async_aqs_filter(&self.selector.make_reference()).unwrap().blocking_get().expect("find_all_async failed").expect("find_all_async returned null");
device_collection.get_size().expect("get_size failed") as usize
let device_collection = DeviceInformation::find_all_async_aqs_filter(&self.selector).unwrap().get().expect("find_all_async failed");
device_collection.size().expect("get_size failed") as usize
}

pub fn port_name(&self, port: &MidiInputPort) -> Result<String, PortInfoError> {
let device_info_async = DeviceInformation::create_from_id_async(&port.id.make_reference()).map_err(|_| PortInfoError::InvalidPort)?;
let device_info = device_info_async.blocking_get().map_err(|_| PortInfoError::InvalidPort)?.expect("device_info was null");
let device_name = device_info.get_name().map_err(|_| PortInfoError::CannotRetrievePortName)?;
let device_info_async = DeviceInformation::create_from_id_async(&port.id).map_err(|_| PortInfoError::InvalidPort)?;
let device_info = device_info_async.get().map_err(|_| PortInfoError::InvalidPort)?;
let device_name = device_info.name().map_err(|_| PortInfoError::CannotRetrievePortName)?;
Ok(device_name.to_string())
}

fn handle_input<T>(args: &MidiMessageReceivedEventArgs, handler_data: &mut HandlerData<T>) {
let ignore = handler_data.ignore_flags;
let data = &mut handler_data.user_data.as_mut().unwrap();
let timestamp;
let byte_access;
let byte_access: IMemoryBufferByteAccess;
let message_bytes;
let message = args.get_message().expect("get_message failed").expect("get_message returned null");
timestamp = message.get_timestamp().expect("get_timestamp failed").Duration as u64 / 10;
let buffer = message.get_raw_data().expect("get_raw_data failed").expect("get_raw_data returned null");
let membuffer = Buffer::create_memory_buffer_over_ibuffer(&buffer).expect("create_memory_buffer_over_ibuffer failed").expect("create_memory_buffer_over_ibuffer returned null");
byte_access = membuffer.create_reference().expect("create_reference failed").expect("create_reference returned null").query_interface::<IMemoryBufferByteAccess>().unwrap();
message_bytes = unsafe { byte_access.get_buffer() }; // TODO: somehow make sure that the buffer is not invalidated while we're reading from it ...
let message = args.message().expect("get_message failed");
timestamp = message.timestamp().expect("get_timestamp failed").duration as u64 / 10;
let buffer = message.raw_data().expect("get_raw_data failed");
let membuffer = Buffer::create_memory_buffer_over_ibuffer(&buffer).expect("create_memory_buffer_over_ibuffer failed");
byte_access = membuffer.create_reference().expect("create_reference failed").try_into().unwrap();
message_bytes = unsafe { byte_access.get_buffer().expect("get_buffer failed") }; // TODO: somehow make sure that the buffer is not invalidated while we're reading from it ...

// The first byte in the message is the status
let status = message_bytes[0];
Expand All @@ -91,9 +154,9 @@ impl MidiInput {
) -> Result<MidiInputConnection<T>, ConnectError<MidiInput>>
where F: FnMut(u64, &[u8], &mut T) + Send + 'static {

let in_port = match MidiInPort::from_id_async(&port.id.make_reference()) {
Ok(port_async) => match port_async.blocking_get() {
Ok(Some(port)) => port,
let in_port = match MidiInPort::from_id_async(&port.id) {
Ok(port_async) => match port_async.get() {
Ok(port) => port,
_ => return Err(ConnectError::new(ConnectErrorKind::InvalidPort, self))
}
Err(_) => return Err(ConnectError::new(ConnectErrorKind::InvalidPort, self))
Expand All @@ -106,18 +169,18 @@ impl MidiInput {
}));
let handler_data2 = handler_data.clone();

let handler = TypedEventHandler::new(move |_sender, args: *mut MidiMessageReceivedEventArgs| {
unsafe { MidiInput::handle_input(&*args, &mut *handler_data2.lock().unwrap()) };
let handler = TypedEventHandler::new(move |_sender, args| {
MidiInput::handle_input(args, &mut *handler_data2.lock().unwrap());
Ok(())
});

let event_token = in_port.add_message_received(&handler).expect("add_message_received failed");
let event_token = in_port.message_received(&handler).expect("add_message_received failed");

Ok(MidiInputConnection { port: RtMidiInPort(in_port), event_token: event_token, handler_data: handler_data })
}
}

struct RtMidiInPort(ComPtr<MidiInPort>);
struct RtMidiInPort(MidiInPort);
unsafe impl Send for RtMidiInPort {}

pub struct MidiInputConnection<T> {
Expand All @@ -134,7 +197,8 @@ pub struct MidiInputConnection<T> {
impl<T> MidiInputConnection<T> {
pub fn close(self) -> (MidiInput, T) {
let _ = self.port.0.remove_message_received(self.event_token);
let _ = self.port.0.query_interface::<IClosable>().unwrap().close();
let closable: IClosable = self.port.0.try_into().unwrap();
let _ = closable.close();
let device_selector = MidiInPort::get_device_selector().expect("get_device_selector failed"); // probably won't ever fail here, because it worked previously
let mut handler_data_locked = self.handler_data.lock().unwrap();
(MidiInput {
Expand Down Expand Up @@ -175,12 +239,11 @@ impl MidiOutput {
}

pub(crate) fn ports_internal(&self) -> Vec<::common::MidiOutputPort> {
let device_collection = DeviceInformation::find_all_async_aqs_filter(&self.selector.make_reference()).unwrap().blocking_get().expect("find_all_async failed").expect("find_all_async returned null");
let count = device_collection.get_size().expect("get_size failed") as usize;
let device_collection = DeviceInformation::find_all_async_aqs_filter(&self.selector).unwrap().get().expect("find_all_async failed");
let count = device_collection.size().expect("get_size failed") as usize;
let mut result = Vec::with_capacity(count as usize);
for device_info in device_collection.into_iter() {
let device_info = device_info.expect("device_info was null");
let device_id = device_info.get_id().expect("get_id failed");
let device_id = device_info.id().expect("get_id failed");
result.push(::common::MidiOutputPort {
imp: MidiOutputPort { id: device_id }
});
Expand All @@ -189,21 +252,21 @@ impl MidiOutput {
}

pub fn port_count(&self) -> usize {
let device_collection = DeviceInformation::find_all_async_aqs_filter(&self.selector.make_reference()).unwrap().blocking_get().expect("find_all_async failed").expect("find_all_async returned null");
device_collection.get_size().expect("get_size failed") as usize
let device_collection = DeviceInformation::find_all_async_aqs_filter(&self.selector).unwrap().get().expect("find_all_async failed");
device_collection.size().expect("get_size failed") as usize
}

pub fn port_name(&self, port: &MidiOutputPort) -> Result<String, PortInfoError> {
let device_info_async = DeviceInformation::create_from_id_async(&port.id.make_reference()).map_err(|_| PortInfoError::InvalidPort)?;
let device_info = device_info_async.blocking_get().map_err(|_| PortInfoError::InvalidPort)?.expect("device_info_async was null");
let device_name = device_info.get_name().map_err(|_| PortInfoError::CannotRetrievePortName)?;
let device_info_async = DeviceInformation::create_from_id_async(&port.id).map_err(|_| PortInfoError::InvalidPort)?;
let device_info = device_info_async.get().map_err(|_| PortInfoError::InvalidPort)?;
let device_name = device_info.name().map_err(|_| PortInfoError::CannotRetrievePortName)?;
Ok(device_name.to_string())
}

pub fn connect(self, port: &MidiOutputPort, _port_name: &str) -> Result<MidiOutputConnection, ConnectError<MidiOutput>> {
let out_port = match MidiOutPort::from_id_async(&port.id.make_reference()) {
Ok(port_async) => match port_async.blocking_get() {
Ok(Some(port)) => port,
let out_port = match MidiOutPort::from_id_async(&port.id) {
Ok(port_async) => match port_async.get() {
Ok(port) => port,
_ => return Err(ConnectError::new(ConnectErrorKind::InvalidPort, self))
}
Err(_) => return Err(ConnectError::new(ConnectErrorKind::InvalidPort, self))
Expand All @@ -213,22 +276,23 @@ impl MidiOutput {
}

pub struct MidiOutputConnection {
port: ComPtr<IMidiOutPort>
port: IMidiOutPort
}

unsafe impl Send for MidiOutputConnection {}

impl MidiOutputConnection {
pub fn close(self) -> MidiOutput {
let _ = self.port.query_interface::<IClosable>().unwrap().close();
let closable: IClosable = self.port.try_into().unwrap();
let _ = closable.close();
let device_selector = MidiOutPort::get_device_selector().expect("get_device_selector failed"); // probably won't ever fail here, because it worked previously
MidiOutput { selector: device_selector }
}

pub fn send(&mut self, message: &[u8]) -> Result<(), SendError> {
let data_writer: ComPtr<DataWriter> = DataWriter::new();
let data_writer = DataWriter::new().unwrap();
data_writer.write_bytes(message).map_err(|_| SendError::Other("write_bytes failed"))?;
let buffer = data_writer.detach_buffer().map_err(|_| SendError::Other("detach_buffer failed"))?.expect("detach buffer returned null");
let buffer = data_writer.detach_buffer().map_err(|_| SendError::Other("detach_buffer failed"))?;
self.port.send_buffer(&buffer).map_err(|_| SendError::Other("send_buffer failed"))?;
Ok(())
}
Expand Down