diff --git a/crates/libs/core/src/com_object.rs b/crates/libs/core/src/com_object.rs index eb11a38f60..2249569ac3 100644 --- a/crates/libs/core/src/com_object.rs +++ b/crates/libs/core/src/com_object.rs @@ -330,3 +330,68 @@ impl Borrow for ComObject { self.get() } } + +/// Enables applications to define COM objects using static storage. This is useful for factory +/// objects, stateless objects, or objects which use need to contain or use mutable global state. +/// +/// COM objects that are defined using `StaticComObject` have their storage placed directly in +/// static storage; they are not stored in the heap. +/// +/// COM objects defined using `StaticComObject` do have a reference count and this reference +/// count is adjusted when owned COM interface references (e.g. `IFoo` and `IUnknown`) are created +/// for the object. The reference count is initialized to 1. +/// +/// # Example +/// +/// ```rust,ignore +/// #[implement(IFoo)] +/// struct MyApp { +/// // ... +/// } +/// +/// static MY_STATIC_APP: StaticComObject = MyApp { ... }.into_static(); +/// +/// fn get_my_static_ifoo() -> IFoo { +/// MY_STATIC_APP.to_interface() +/// } +/// ``` +pub struct StaticComObject +where + T: ComObjectInner, +{ + outer: T::Outer, +} + +// IMPORTANT: Do not expose any methods that return mutable access to the contents of StaticComObject. +// Doing so would violate our safety invariants. For example, we provide a Deref impl but it would +// be unsound to provide a DerefMut impl. +impl StaticComObject +where + T: ComObjectInner, +{ + /// Wraps `outer` in a `StaticComObject`. + pub const fn from_outer(outer: T::Outer) -> Self { + Self { outer } + } +} + +impl StaticComObject +where + T: ComObjectInner, +{ + /// Gets access to the contained value. + pub const fn get(&'static self) -> &'static T::Outer { + &self.outer + } +} + +impl core::ops::Deref for StaticComObject +where + T: ComObjectInner, +{ + type Target = T::Outer; + + fn deref(&self) -> &Self::Target { + &self.outer + } +} diff --git a/crates/libs/core/src/imp/ref_count.rs b/crates/libs/core/src/imp/ref_count.rs index 1dc3a97578..97e1f2a670 100644 --- a/crates/libs/core/src/imp/ref_count.rs +++ b/crates/libs/core/src/imp/ref_count.rs @@ -6,7 +6,7 @@ pub struct RefCount(pub(crate) AtomicI32); impl RefCount { /// Creates a new `RefCount` with an initial value of `1`. - pub fn new(count: u32) -> Self { + pub const fn new(count: u32) -> Self { Self(AtomicI32::new(count as i32)) } diff --git a/crates/libs/core/src/imp/weak_ref_count.rs b/crates/libs/core/src/imp/weak_ref_count.rs index 6346eb733b..d4ebfe21c8 100644 --- a/crates/libs/core/src/imp/weak_ref_count.rs +++ b/crates/libs/core/src/imp/weak_ref_count.rs @@ -10,7 +10,7 @@ use core::sync::atomic::{AtomicIsize, Ordering}; pub struct WeakRefCount(AtomicIsize); impl WeakRefCount { - pub fn new() -> Self { + pub const fn new() -> Self { Self(AtomicIsize::new(1)) } diff --git a/crates/libs/core/src/weak.rs b/crates/libs/core/src/weak.rs index be34d02868..f89a23156e 100644 --- a/crates/libs/core/src/weak.rs +++ b/crates/libs/core/src/weak.rs @@ -7,7 +7,7 @@ pub struct Weak(Option, PhantomData); impl Weak { /// Creates a new `Weak` object without any backing object. - pub fn new() -> Self { + pub const fn new() -> Self { Self(None, PhantomData) } diff --git a/crates/libs/implement/src/lib.rs b/crates/libs/implement/src/lib.rs index 5964da41b0..024263085b 100644 --- a/crates/libs/implement/src/lib.rs +++ b/crates/libs/implement/src/lib.rs @@ -179,6 +179,38 @@ pub fn implement( const IDENTITY: ::windows_core::IInspectable_Vtbl = ::windows_core::IInspectable_Vtbl::new::(); } + impl #generics #original_ident::#generics where #constraints { + /// This converts a partially-constructed COM object (in the sense that it contains + /// application state but does not yet have vtable and reference count constructed) + /// into a `StaticComObject`. This allows the COM object to be stored in static + /// (global) variables. + pub const fn into_static(self) -> ::windows_core::StaticComObject { + ::windows_core::StaticComObject::from_outer(self.into_outer()) + } + + // This constructs an "outer" object. This should only be used by the implementation + // of the outer object, never by application code. + // + // The callers of this function (`into_static` and `into_object`) are both responsible + // for maintaining one of our invariants: Application code never has an owned instance + // of the outer (implementation) type. into_static() maintains this invariant by + // returning a wrapped StaticComObject value, which owns its contents but never gives + // application code a way to mutably access its contents. This prevents the refcount + // shearing problem. + // + // TODO: Make it impossible for app code to call this function, by placing it in a + // module and marking this as private to the module. + #[inline(always)] + const fn into_outer(self) -> #impl_ident::#generics { + #impl_ident::#generics { + identity: &#impl_ident::#generics::IDENTITY, + vtables: (#(&#impl_ident::#generics::VTABLES.#offset,)*), + this: self, + count: ::windows_core::imp::WeakRefCount::new(), + } + } + } + impl #generics ::windows_core::ComObjectInner for #original_ident::#generics where #constraints { type Outer = #impl_ident::#generics; @@ -191,12 +223,7 @@ pub fn implement( // This is why this function returns ComObject instead of returning #impl_ident. fn into_object(self) -> ::windows_core::ComObject { - let boxed = ::windows_core::imp::Box::new(#impl_ident::#generics { - identity: &#impl_ident::#generics::IDENTITY, - vtables: (#(&#impl_ident::#generics::VTABLES.#offset,)*), - this: self, - count: ::windows_core::imp::WeakRefCount::new(), - }); + let boxed = ::windows_core::imp::Box::<#impl_ident::#generics>::new(self.into_outer()); unsafe { let ptr = ::windows_core::imp::Box::into_raw(boxed); ::windows_core::ComObject::from_raw( diff --git a/crates/tests/implement_core/src/lib.rs b/crates/tests/implement_core/src/lib.rs index aa8f3bec53..5b3ca9b210 100644 --- a/crates/tests/implement_core/src/lib.rs +++ b/crates/tests/implement_core/src/lib.rs @@ -5,3 +5,4 @@ mod com_chain; mod com_object; +mod static_com_object; diff --git a/crates/tests/implement_core/src/static_com_object.rs b/crates/tests/implement_core/src/static_com_object.rs new file mode 100644 index 0000000000..68f568fc83 --- /dev/null +++ b/crates/tests/implement_core/src/static_com_object.rs @@ -0,0 +1,77 @@ +//! Unit tests for `windows_core::StaticComObject` + +use std::sync::atomic::{AtomicU32, Ordering::SeqCst}; +use windows_core::{ + implement, interface, ComObject, IUnknown, IUnknownImpl, IUnknown_Vtbl, InterfaceRef, + StaticComObject, +}; + +#[interface("818f2fd1-d479-4398-b286-a93c4c7904d1")] +unsafe trait INumberFactory: IUnknown { + fn next(&self) -> u32; + + fn add(&self, x: u32, y: u32) -> u32; +} + +#[implement(INumberFactory)] +struct MyFactory { + x: AtomicU32, +} + +impl INumberFactory_Impl for MyFactory_Impl { + unsafe fn next(&self) -> u32 { + self.x.fetch_add(1, SeqCst) + } + + unsafe fn add(&self, x: u32, y: u32) -> u32 { + x + y + } +} + +static NUMBER_FACTORY_INSTANCE: StaticComObject = MyFactory { + x: AtomicU32::new(100), +} +.into_static(); + +#[test] +fn as_interface() { + let factory_outer: &MyFactory_Impl = NUMBER_FACTORY_INSTANCE.get(); + let ifactory: InterfaceRef = factory_outer.as_interface::(); + + // Produce the next number. We don't verify the value since tests are multi-threaded. + // This just demonstrates that you can have shared state with interior mutability (such as + // atomics) in a static COM object. + let n = unsafe { ifactory.next() }; + println!("n = {n:?}"); + + assert_eq!(unsafe { ifactory.add(333, 444) }, 777); +} + +// This tests that we can safely AddRef/Release a StaticComObject. +#[test] +fn to_interface() { + let factory_outer: &MyFactory_Impl = NUMBER_FACTORY_INSTANCE.get(); + let ifactory: INumberFactory = factory_outer.to_interface::(); + assert_eq!(unsafe { ifactory.add(333, 444) }, 777); + drop(ifactory); +} + +#[test] +fn to_object() { + let factory_outer: &MyFactory_Impl = NUMBER_FACTORY_INSTANCE.get(); + let factory_object: ComObject = factory_outer.to_object(); + assert_eq!(unsafe { factory_object.add(333, 444) }, 777); +} + +// This tests the behavior when dropping a StaticComObject. Since static variables are never +// dropped, this isn't relevant to normal usage. However, if app code constructs a StaticComObject +// in local variables (not statics) and them drops them, then we still need well-defined behavior. +// Basically, we are testing that the refererence-count field does not panic when being dropped +// with a non-zero reference count. +#[test] +fn drop_half_constructed() { + let _static_com_object: StaticComObject = MyFactory { + x: AtomicU32::new(0), + } + .into_static(); +}