diff --git a/crates/libs/implement/src/lib.rs b/crates/libs/implement/src/lib.rs index 92f801ab6d..f97f5713c4 100644 --- a/crates/libs/implement/src/lib.rs +++ b/crates/libs/implement/src/lib.rs @@ -54,10 +54,10 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: impl <#constraints> ::core::convert::From<#original_ident::<#(#generics,)*>> for #interface_ident { fn from(this: #original_ident::<#(#generics,)*>) -> Self { let this = #impl_ident::<#(#generics,)*>::new(this); - let mut this = ::std::boxed::Box::new(this); - let vtable_ptr = &mut this.vtables.#offset as *mut *const <#interface_ident as ::windows::core::Interface>::Vtable; - let _ = ::std::boxed::Box::leak(this); - unsafe { ::core::mem::transmute_copy(&vtable_ptr) } + let mut this = ::core::mem::ManuallyDrop::new(::std::boxed::Box::new(this)); + let vtable_ptr = &this.vtables.#offset; + // SAFETY: interfaces are in-memory equivalent to pointers to their vtables. + unsafe { ::core::mem::transmute(vtable_ptr) } } } impl <#constraints> ::windows::core::AsImpl<#original_ident::<#(#generics,)*>> for #interface_ident { @@ -145,12 +145,16 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: } } impl <#constraints> #original_ident::<#(#generics,)*> { - fn cast(&self) -> ::windows::core::Result { - unsafe { - let boxed = (self as *const #original_ident::<#(#generics,)*> as *mut #original_ident::<#(#generics,)*> as *mut ::windows::core::RawPtr).sub(2 + #interfaces_len) as *mut #impl_ident::<#(#generics,)*>; - let mut result = None; - <#impl_ident::<#(#generics,)*> as ::windows::core::IUnknownImpl>::QueryInterface(&*boxed, &ResultType::IID, &mut result as *mut _ as _).and_some(result) - } + /// Try casting as the provided interface + /// + /// # Safety + /// + /// This function can only be safely called if `self` has been heap allocated and pinned using + /// the mechanisms provided by `implement` macro. + unsafe fn cast(&self) -> ::windows::core::Result { + let boxed = (self as *const _ as *const ::windows::core::RawPtr).sub(2 + #interfaces_len) as *mut #impl_ident::<#(#generics,)*>; + let mut result = None; + <#impl_ident::<#(#generics,)*> as ::windows::core::IUnknownImpl>::QueryInterface(&*boxed, &I::IID, &mut result as *mut _ as _).and_some(result) } } impl <#constraints> ::windows::core::Compose for #original_ident::<#(#generics,)*> { @@ -163,23 +167,19 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: } impl <#constraints> ::core::convert::From<#original_ident::<#(#generics,)*>> for ::windows::core::IUnknown { fn from(this: #original_ident::<#(#generics,)*>) -> Self { + let this = #impl_ident::<#(#generics,)*>::new(this); + let boxed = ::core::mem::ManuallyDrop::new(::std::boxed::Box::new(this)); unsafe { - let this = #impl_ident::<#(#generics,)*>::new(this); - let ptr = ::std::boxed::Box::into_raw(::std::boxed::Box::new(this)); - ::core::mem::transmute_copy(&::core::ptr::NonNull::new_unchecked( - &mut (*ptr).identity as *mut _ as _ - )) + ::core::mem::transmute(&boxed.identity) } } } impl <#constraints> ::core::convert::From<#original_ident::<#(#generics,)*>> for ::windows::core::IInspectable { fn from(this: #original_ident::<#(#generics,)*>) -> Self { + let this = #impl_ident::<#(#generics,)*>::new(this); + let boxed = ::core::mem::ManuallyDrop::new(::std::boxed::Box::new(this)); unsafe { - let this = #impl_ident::<#(#generics,)*>::new(this); - let ptr = ::std::boxed::Box::into_raw(::std::boxed::Box::new(this)); - ::core::mem::transmute_copy(&::core::ptr::NonNull::new_unchecked( - &mut (*ptr).identity as *mut _ as _ - )) + ::core::mem::transmute(&boxed.identity) } } } diff --git a/crates/tests/nightly_implement/tests/cast_self.rs b/crates/tests/nightly_implement/tests/cast_self.rs index 7e5bf13bc0..057a53d600 100644 --- a/crates/tests/nightly_implement/tests/cast_self.rs +++ b/crates/tests/nightly_implement/tests/cast_self.rs @@ -5,12 +5,12 @@ use windows::UI::Xaml::*; // TODO: This is a compile-only test for now until #81 is further along and can provide composable test classes. #[implement(IApplicationOverrides)] -struct App(); +struct App; #[allow(non_snake_case)] impl IApplicationOverrides_Impl for App { fn OnLaunched(&self, _: &Option) -> Result<()> { - let app: Application = self.cast()?; + let app: Application = unsafe { self.cast()? }; assert!(app.FocusVisualKind()? == FocusVisualKind::DottedLine); Ok(()) } diff --git a/crates/tests/nightly_implement/tests/com.rs b/crates/tests/nightly_implement/tests/com.rs index d2bfcf02b3..c50e63fc0b 100644 --- a/crates/tests/nightly_implement/tests/com.rs +++ b/crates/tests/nightly_implement/tests/com.rs @@ -7,7 +7,7 @@ use windows::Win32::System::WinRT::Composition::*; use windows::Win32::System::WinRT::Display::*; #[implement(windows::Foundation::IStringable, windows::Win32::System::WinRT::Composition::ISwapChainInterop, windows::Win32::System::WinRT::Display::IDisplayPathInterop)] -struct Mix(); +struct Mix; impl IStringable_Impl for Mix { fn ToString(&self) -> Result { @@ -32,13 +32,13 @@ impl IDisplayPathInterop_Impl for Mix { #[test] fn mix() -> Result<()> { - let a: ISwapChainInterop = Mix().into(); + let a: ISwapChainInterop = Mix.into(); unsafe { a.SetSwapChain(None)? }; let b: IStringable = a.cast()?; assert!(b.ToString()? == "Mix"); - let c: IStringable = Mix().into(); + let c: IStringable = Mix.into(); assert!(c.ToString()? == "Mix"); let d: ISwapChainInterop = c.cast()?; diff --git a/crates/tests/nightly_implement/tests/into_impl.rs b/crates/tests/nightly_implement/tests/into_impl.rs index c829cc96d5..2acf424a8b 100644 --- a/crates/tests/nightly_implement/tests/into_impl.rs +++ b/crates/tests/nightly_implement/tests/into_impl.rs @@ -54,7 +54,7 @@ where #[allow(non_snake_case)] impl IIterable_Impl for Iterable { fn First(&self) -> Result> { - Ok(Iterator::((self.cast()?, 0).into()).into()) + Ok(Iterator::((unsafe { self.cast()? }, 0).into()).into()) } } diff --git a/crates/tests/nightly_vector/tests/test.rs b/crates/tests/nightly_vector/tests/test.rs index 2649be362c..4af7412d5a 100644 --- a/crates/tests/nightly_vector/tests/test.rs +++ b/crates/tests/nightly_vector/tests/test.rs @@ -61,7 +61,7 @@ impl IVector_Impl for Vector { self.Size() } fn GetView(&self) -> Result> { - self.cast() + unsafe { self.cast() } } fn IndexOf(&self, value: &T::DefaultType, result: &mut u32) -> Result { self.IndexOf(value, result)