diff --git a/crates/libs/implement/src/lib.rs b/crates/libs/implement/src/lib.rs index 33c95ee52d..3434dfb92a 100644 --- a/crates/libs/implement/src/lib.rs +++ b/crates/libs/implement/src/lib.rs @@ -8,6 +8,7 @@ use tokens::*; pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro::TokenStream) -> proc_macro::TokenStream { let attributes = syn::parse_macro_input!(attributes as ImplementAttributes); let generics = attributes.generics(); + let interfaces_len = Literal::usize_unsuffixed(attributes.implement.len()); let constraints = quote! { #(#generics: ::windows::core::RuntimeType + 'static,)* @@ -144,6 +145,7 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: } } impl <#constraints> #original_ident::<#(#generics,)*> { + /// Box and pin `self` and then try to cast it as the supplied interface fn alloc(self) -> ::windows::core::Result { let this = #impl_ident::<#(#generics,)*>::new(self); let boxed = ::core::mem::ManuallyDrop::new(::std::boxed::Box::new(this)); @@ -157,6 +159,18 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: } result } + + /// Try casting as the provided interface + /// + /// # Safety + /// + /// This function can only be safely called if `self` has been heap allocated and pinned using + /// the mechanisms provided by `implement` macro. + unsafe fn cast(&self) -> ::windows::core::Result { + let boxed = (self as *const _ as *const ::windows::core::RawPtr).sub(2 + #interfaces_len) as *mut #impl_ident::<#(#generics,)*>; + let mut result = None; + <#impl_ident::<#(#generics,)*> as ::windows::core::IUnknownImpl>::QueryInterface(&*boxed, &I::IID, &mut result as *mut _ as _).and_some(result) + } } impl <#constraints> ::windows::core::Compose for #original_ident::<#(#generics,)*> { unsafe fn compose<'a>(implementation: Self) -> (::windows::core::IInspectable, &'a mut ::core::option::Option<::windows::core::IInspectable>) { diff --git a/crates/tests/nightly_implement/tests/cast_self.rs b/crates/tests/nightly_implement/tests/cast_self.rs index 3810bf3d32..057a53d600 100644 --- a/crates/tests/nightly_implement/tests/cast_self.rs +++ b/crates/tests/nightly_implement/tests/cast_self.rs @@ -10,7 +10,7 @@ struct App; #[allow(non_snake_case)] impl IApplicationOverrides_Impl for App { fn OnLaunched(&self, _: &Option) -> Result<()> { - let app: Application = self.cast()?; + let app: Application = unsafe { self.cast()? }; assert!(app.FocusVisualKind()? == FocusVisualKind::DottedLine); Ok(()) } diff --git a/crates/tests/nightly_implement/tests/into_impl.rs b/crates/tests/nightly_implement/tests/into_impl.rs index c829cc96d5..2acf424a8b 100644 --- a/crates/tests/nightly_implement/tests/into_impl.rs +++ b/crates/tests/nightly_implement/tests/into_impl.rs @@ -54,7 +54,7 @@ where #[allow(non_snake_case)] impl IIterable_Impl for Iterable { fn First(&self) -> Result> { - Ok(Iterator::((self.cast()?, 0).into()).into()) + Ok(Iterator::((unsafe { self.cast()? }, 0).into()).into()) } } diff --git a/crates/tests/nightly_vector/tests/test.rs b/crates/tests/nightly_vector/tests/test.rs index 2649be362c..4af7412d5a 100644 --- a/crates/tests/nightly_vector/tests/test.rs +++ b/crates/tests/nightly_vector/tests/test.rs @@ -61,7 +61,7 @@ impl IVector_Impl for Vector { self.Size() } fn GetView(&self) -> Result> { - self.cast() + unsafe { self.cast() } } fn IndexOf(&self, value: &T::DefaultType, result: &mut u32) -> Result { self.IndexOf(value, result)