diff --git a/TestComponentCSharp/Class.cpp b/TestComponentCSharp/Class.cpp index ac6248a24..b14a02f30 100644 --- a/TestComponentCSharp/Class.cpp +++ b/TestComponentCSharp/Class.cpp @@ -481,6 +481,34 @@ namespace winrt::TestComponentCSharp::implementation { _objectIterableChanged.remove(token); } + IIterable> Class::IterableOfPointIterablesProperty() + { + return _pointIterableIterable; + } + void Class::IterableOfPointIterablesProperty(IIterable> const& value) + { + for (auto points : value) + { + for (auto point : points) + { + } + } + _pointIterableIterable = value; + } + IIterable> Class::IterableOfObjectIterablesProperty() + { + return _objectIterableIterable; + } + void Class::IterableOfObjectIterablesProperty(IIterable> const& value) + { + for (auto objects : value) + { + for (auto object : objects) + { + } + } + _objectIterableIterable = value; + } Uri Class::UriProperty() { return _uri; diff --git a/TestComponentCSharp/Class.h b/TestComponentCSharp/Class.h index 7eae0b447..c2039f632 100644 --- a/TestComponentCSharp/Class.h +++ b/TestComponentCSharp/Class.h @@ -37,6 +37,8 @@ namespace winrt::TestComponentCSharp::implementation Windows::Foundation::IInspectable _object; winrt::event> _objectChanged; Windows::Foundation::Collections::IIterable _objectIterable; + Windows::Foundation::Collections::IIterable> _pointIterableIterable; + Windows::Foundation::Collections::IIterable> _objectIterableIterable; winrt::event>> _objectIterableChanged; Windows::Foundation::Uri _uri; winrt::event> _uriChanged; @@ -153,6 +155,10 @@ namespace winrt::TestComponentCSharp::implementation void CallForObjectIterable(TestComponentCSharp::ProvideObjectIterable const& provideObjectIterable); winrt::event_token ObjectIterablePropertyChanged(Windows::Foundation::EventHandler> const& handler); void ObjectIterablePropertyChanged(winrt::event_token const& token) noexcept; + Windows::Foundation::Collections::IIterable> IterableOfPointIterablesProperty(); + void IterableOfPointIterablesProperty(Windows::Foundation::Collections::IIterable> const& value); + Windows::Foundation::Collections::IIterable> IterableOfObjectIterablesProperty(); + void IterableOfObjectIterablesProperty(Windows::Foundation::Collections::IIterable> const& value); Windows::Foundation::Uri UriProperty(); void UriProperty(Windows::Foundation::Uri const& value); void RaiseUriChanged(); diff --git a/TestComponentCSharp/TestComponentCSharp.idl b/TestComponentCSharp/TestComponentCSharp.idl index 55815cdae..4d1148cdc 100644 --- a/TestComponentCSharp/TestComponentCSharp.idl +++ b/TestComponentCSharp/TestComponentCSharp.idl @@ -176,6 +176,8 @@ namespace TestComponentCSharp void RaiseObjectIterableChanged(); void CallForObjectIterable(ProvideObjectIterable provideObjectIterable); event Windows.Foundation.EventHandler > ObjectIterablePropertyChanged; + Windows.Foundation.Collections.IIterable > IterableOfPointIterablesProperty; + Windows.Foundation.Collections.IIterable > IterableOfObjectIterablesProperty; Windows.Foundation.Uri UriProperty; void RaiseUriChanged(); diff --git a/UnitTest/TestComponentCSharp_Tests.cs b/UnitTest/TestComponentCSharp_Tests.cs index 69cd6b8d0..474e44bf6 100644 --- a/UnitTest/TestComponentCSharp_Tests.cs +++ b/UnitTest/TestComponentCSharp_Tests.cs @@ -687,7 +687,26 @@ public void TestObjectCasting() var objects = new List() { new ManagedType(), new ManagedType() }; var query = from item in objects select item; - TestObject.ObjectIterableProperty = query; + TestObject.ObjectIterableProperty = query; + + TestObject.ObjectProperty = "test"; + Assert.Equal("test", TestObject.ObjectProperty); + + var objectArray = new ManagedType[] { new ManagedType(), new ManagedType() }; + TestObject.ObjectIterableProperty = objectArray; + Assert.True(TestObject.ObjectIterableProperty.SequenceEqual(objectArray)); + + var strArray = new string[] { "str1", "str2", "str3" }; + TestObject.ObjectIterableProperty = strArray; + Assert.True(TestObject.ObjectIterableProperty.SequenceEqual(strArray)); + + var uriArray = new Uri[] { new Uri("http://aka.ms/cswinrt"), new Uri("http://github.com") }; + TestObject.ObjectIterableProperty = uriArray; + Assert.True(TestObject.ObjectIterableProperty.SequenceEqual(uriArray)); + + var objectUriArray = new object[] { new Uri("http://github.com") }; + TestObject.ObjectIterableProperty = objectUriArray; + Assert.True(TestObject.ObjectIterableProperty.SequenceEqual(objectUriArray)); } [Fact] @@ -2250,6 +2269,26 @@ public void TestIBindableVector() { CustomBindableVectorTest vector = new CustomBindableVectorTest(); Assert.NotNull(vector); + } + + [Fact] + public void TestCovariance() + { + var listOfListOfPoints = new List>() { + new List{ new Point(1, 1), new Point(1, 2), new Point(1, 3) }, + new List{ new Point(2, 1), new Point(2, 2), new Point(2, 3) }, + new List{ new Point(3, 1), new Point(3, 2), new Point(3, 3) } + }; + TestObject.IterableOfPointIterablesProperty = listOfListOfPoints; + Assert.True(TestObject.IterableOfPointIterablesProperty.SequenceEqual(listOfListOfPoints)); + + var listOfListOfUris = new List>() { + new List{ new Uri("http://aka.ms/cswinrt"), new Uri("http://github.com") }, + new List{ new Uri("http://aka.ms/cswinrt") }, + new List{ new Uri("http://aka.ms/cswinrt"), new Uri("http://microsoft.com") } + }; + TestObject.IterableOfObjectIterablesProperty = listOfListOfUris; + Assert.True(TestObject.IterableOfObjectIterablesProperty.SequenceEqual(listOfListOfUris)); } } } \ No newline at end of file diff --git a/WinRT.Runtime/ComWrappersSupport.cs b/WinRT.Runtime/ComWrappersSupport.cs index 347a09247..0f5e8f329 100644 --- a/WinRT.Runtime/ComWrappersSupport.cs +++ b/WinRT.Runtime/ComWrappersSupport.cs @@ -107,14 +107,17 @@ internal static List GetInterfaceTableEntries(object obj) } if (iface.IsConstructedGenericType - && Projections.TryGetCompatibleWindowsRuntimeTypeForVariantType(iface, out var compatibleIface)) + && Projections.TryGetCompatibleWindowsRuntimeTypesForVariantType(iface, out var compatibleIfaces)) { - var compatibleIfaceAbiType = compatibleIface.FindHelperType(); - entries.Add(new ComInterfaceEntry + foreach (var compatibleIface in compatibleIfaces) { - IID = GuidGenerator.GetIID(compatibleIfaceAbiType), - Vtable = (IntPtr)compatibleIfaceAbiType.GetAbiToProjectionVftblPtr() - }); + var compatibleIfaceAbiType = compatibleIface.FindHelperType(); + entries.Add(new ComInterfaceEntry + { + IID = GuidGenerator.GetIID(compatibleIfaceAbiType), + Vtable = (IntPtr)compatibleIfaceAbiType.GetAbiToProjectionVftblPtr() + }); + } } } diff --git a/WinRT.Runtime/Projections.cs b/WinRT.Runtime/Projections.cs index 06b06c9ca..6c345c53d 100644 --- a/WinRT.Runtime/Projections.cs +++ b/WinRT.Runtime/Projections.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Collections.Specialized; using System.ComponentModel; +using System.Linq; using System.Numerics; using System.Reflection; using System.Threading; @@ -206,6 +207,7 @@ private static bool IsTypeWindowsRuntimeTypeNoArray(Type type) || type.GetCustomAttribute() is object; } + // Use TryGetCompatibleWindowsRuntimeTypesForVariantType instead. public static bool TryGetCompatibleWindowsRuntimeTypeForVariantType(Type type, out Type compatibleType) { compatibleType = null; @@ -247,6 +249,118 @@ public static bool TryGetCompatibleWindowsRuntimeTypeForVariantType(Type type, o return true; } + private static HashSet GetCompatibleTypes(Type type) + { + HashSet compatibleTypes = new HashSet(); + + foreach (var iface in type.GetInterfaces()) + { + if (IsTypeWindowsRuntimeTypeNoArray(iface)) + { + compatibleTypes.Add(iface); + } + + if (iface.IsConstructedGenericType + && TryGetCompatibleWindowsRuntimeTypesForVariantType(iface, out var compatibleIfaces)) + { + compatibleTypes.UnionWith(compatibleIfaces); + } + } + + Type baseType = type.BaseType; + while (baseType != null) + { + if (IsTypeWindowsRuntimeTypeNoArray(baseType)) + { + compatibleTypes.Add(baseType); + } + baseType = baseType.BaseType; + } + + return compatibleTypes; + } + + internal static IEnumerable GetAllPossibleTypeCombinations(IEnumerable> compatibleTypesPerGeneric, Type definition) + { + // Implementation adapted from https://stackoverflow.com/a/4424005 + var accum = new List(); + var compatibleTypesPerGenericArray = compatibleTypesPerGeneric.ToArray(); + if (compatibleTypesPerGenericArray.Length > 0) + { + GetAllPossibleTypeCombinationsCore( + accum, + new Stack(), + compatibleTypesPerGenericArray, + compatibleTypesPerGenericArray.Length - 1); + } + return accum; + + void GetAllPossibleTypeCombinationsCore(List accum, Stack stack, IEnumerable[] compatibleTypes, int index) + { + foreach (var type in compatibleTypes[index]) + { + stack.Push(type); + if (index == 0) + { + // IEnumerable on a System.Collections.Generic.Stack + // enumerates in order of removal (last to first). + // As a result, we get the correct ordering here. + accum.Add(definition.MakeGenericType(stack.ToArray())); + } + else + { + GetAllPossibleTypeCombinationsCore(accum, stack, compatibleTypes, index - 1); + } + stack.Pop(); + } + } + } + + internal static bool TryGetCompatibleWindowsRuntimeTypesForVariantType(Type type, out IEnumerable compatibleTypes) + { + compatibleTypes = null; + if (!type.IsConstructedGenericType) + { + throw new ArgumentException(nameof(type)); + } + + var definition = type.GetGenericTypeDefinition(); + + if (!IsTypeWindowsRuntimeTypeNoArray(definition)) + { + return false; + } + + var genericConstraints = definition.GetGenericArguments(); + var genericArguments = type.GetGenericArguments(); + List> compatibleTypesPerGeneric = new List>(); + for (int i = 0; i < genericArguments.Length; i++) + { + List compatibleTypesForGeneric = new List(); + bool argumentCovariantObject = (genericConstraints[i].GenericParameterAttributes & GenericParameterAttributes.VarianceMask) == GenericParameterAttributes.Covariant + && !genericArguments[i].IsValueType; + + if (IsTypeWindowsRuntimeTypeNoArray(genericArguments[i])) + { + compatibleTypesForGeneric.Add(genericArguments[i]); + } + else if (!argumentCovariantObject) + { + return false; + } + + if (argumentCovariantObject) + { + compatibleTypesForGeneric.AddRange(GetCompatibleTypes(genericArguments[i])); + } + + compatibleTypesPerGeneric.Add(compatibleTypesForGeneric); + } + + compatibleTypes = GetAllPossibleTypeCombinations(compatibleTypesPerGeneric, definition); + return true; + } + internal static bool TryGetDefaultInterfaceTypeForRuntimeClassType(Type runtimeClass, out Type defaultInterface) { runtimeClass = runtimeClass.GetRuntimeClassCCWType() ?? runtimeClass;