Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vtable entry for covariant IEnumerable<object> for WinRT types too. #601

Merged
merged 4 commits into from
Nov 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions TestComponentCSharp/Class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,34 @@ namespace winrt::TestComponentCSharp::implementation
{
_objectIterableChanged.remove(token);
}
IIterable<IIterable<WF::Point>> Class::IterableOfPointIterablesProperty()
{
return _pointIterableIterable;
}
void Class::IterableOfPointIterablesProperty(IIterable<IIterable<WF::Point>> const& value)
{
for (auto points : value)
{
for (auto point : points)
{
}
}
_pointIterableIterable = value;
}
IIterable<IIterable<WF::IInspectable>> Class::IterableOfObjectIterablesProperty()
{
return _objectIterableIterable;
}
void Class::IterableOfObjectIterablesProperty(IIterable<IIterable<WF::IInspectable>> const& value)
{
for (auto objects : value)
{
for (auto object : objects)
{
}
}
_objectIterableIterable = value;
}
Uri Class::UriProperty()
{
return _uri;
Expand Down
6 changes: 6 additions & 0 deletions TestComponentCSharp/Class.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ namespace winrt::TestComponentCSharp::implementation
Windows::Foundation::IInspectable _object;
winrt::event<Windows::Foundation::EventHandler<Windows::Foundation::IInspectable>> _objectChanged;
Windows::Foundation::Collections::IIterable<Windows::Foundation::IInspectable> _objectIterable;
Windows::Foundation::Collections::IIterable<Windows::Foundation::Collections::IIterable<Windows::Foundation::Point>> _pointIterableIterable;
Windows::Foundation::Collections::IIterable<Windows::Foundation::Collections::IIterable<Windows::Foundation::IInspectable>> _objectIterableIterable;
winrt::event<Windows::Foundation::EventHandler<Windows::Foundation::Collections::IIterable<Windows::Foundation::IInspectable>>> _objectIterableChanged;
Windows::Foundation::Uri _uri;
winrt::event<Windows::Foundation::EventHandler<Windows::Foundation::Uri>> _uriChanged;
Expand Down Expand Up @@ -153,6 +155,10 @@ namespace winrt::TestComponentCSharp::implementation
void CallForObjectIterable(TestComponentCSharp::ProvideObjectIterable const& provideObjectIterable);
winrt::event_token ObjectIterablePropertyChanged(Windows::Foundation::EventHandler<Windows::Foundation::Collections::IIterable<Windows::Foundation::IInspectable>> const& handler);
void ObjectIterablePropertyChanged(winrt::event_token const& token) noexcept;
Windows::Foundation::Collections::IIterable<Windows::Foundation::Collections::IIterable<Windows::Foundation::Point>> IterableOfPointIterablesProperty();
void IterableOfPointIterablesProperty(Windows::Foundation::Collections::IIterable<Windows::Foundation::Collections::IIterable<Windows::Foundation::Point>> const& value);
Windows::Foundation::Collections::IIterable<Windows::Foundation::Collections::IIterable<Windows::Foundation::IInspectable>> IterableOfObjectIterablesProperty();
void IterableOfObjectIterablesProperty(Windows::Foundation::Collections::IIterable<Windows::Foundation::Collections::IIterable<Windows::Foundation::IInspectable>> const& value);
Windows::Foundation::Uri UriProperty();
void UriProperty(Windows::Foundation::Uri const& value);
void RaiseUriChanged();
Expand Down
2 changes: 2 additions & 0 deletions TestComponentCSharp/TestComponentCSharp.idl
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ namespace TestComponentCSharp
void RaiseObjectIterableChanged();
void CallForObjectIterable(ProvideObjectIterable provideObjectIterable);
event Windows.Foundation.EventHandler<Windows.Foundation.Collections.IIterable<Object> > ObjectIterablePropertyChanged;
Windows.Foundation.Collections.IIterable<Windows.Foundation.Collections.IIterable<Windows.Foundation.Point> > IterableOfPointIterablesProperty;
Windows.Foundation.Collections.IIterable<Windows.Foundation.Collections.IIterable<Object> > IterableOfObjectIterablesProperty;

Windows.Foundation.Uri UriProperty;
void RaiseUriChanged();
Expand Down
41 changes: 40 additions & 1 deletion UnitTest/TestComponentCSharp_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,26 @@ public void TestObjectCasting()

var objects = new List<ManagedType>() { new ManagedType(), new ManagedType() };
var query = from item in objects select item;
TestObject.ObjectIterableProperty = query;
TestObject.ObjectIterableProperty = query;

TestObject.ObjectProperty = "test";
Assert.Equal("test", TestObject.ObjectProperty);

var objectArray = new ManagedType[] { new ManagedType(), new ManagedType() };
TestObject.ObjectIterableProperty = objectArray;
Assert.True(TestObject.ObjectIterableProperty.SequenceEqual(objectArray));

var strArray = new string[] { "str1", "str2", "str3" };
TestObject.ObjectIterableProperty = strArray;
Assert.True(TestObject.ObjectIterableProperty.SequenceEqual(strArray));

var uriArray = new Uri[] { new Uri("http://aka.ms/cswinrt"), new Uri("http://github.com") };
TestObject.ObjectIterableProperty = uriArray;
Assert.True(TestObject.ObjectIterableProperty.SequenceEqual(uriArray));

var objectUriArray = new object[] { new Uri("http://github.com") };
TestObject.ObjectIterableProperty = objectUriArray;
Assert.True(TestObject.ObjectIterableProperty.SequenceEqual(objectUriArray));
}

[Fact]
Expand Down Expand Up @@ -2250,6 +2269,26 @@ public void TestIBindableVector()
{
CustomBindableVectorTest vector = new CustomBindableVectorTest();
Assert.NotNull(vector);
}

[Fact]
public void TestCovariance()
{
var listOfListOfPoints = new List<List<Point>>() {
new List<Point>{ new Point(1, 1), new Point(1, 2), new Point(1, 3) },
new List<Point>{ new Point(2, 1), new Point(2, 2), new Point(2, 3) },
new List<Point>{ new Point(3, 1), new Point(3, 2), new Point(3, 3) }
};
TestObject.IterableOfPointIterablesProperty = listOfListOfPoints;
Assert.True(TestObject.IterableOfPointIterablesProperty.SequenceEqual(listOfListOfPoints));

var listOfListOfUris = new List<List<Uri>>() {
new List<Uri>{ new Uri("http://aka.ms/cswinrt"), new Uri("http://github.com") },
new List<Uri>{ new Uri("http://aka.ms/cswinrt") },
new List<Uri>{ new Uri("http://aka.ms/cswinrt"), new Uri("http://microsoft.com") }
};
TestObject.IterableOfObjectIterablesProperty = listOfListOfUris;
Assert.True(TestObject.IterableOfObjectIterablesProperty.SequenceEqual(listOfListOfUris));
}
}
}
15 changes: 9 additions & 6 deletions WinRT.Runtime/ComWrappersSupport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,17 @@ internal static List<ComInterfaceEntry> GetInterfaceTableEntries(object obj)
}

if (iface.IsConstructedGenericType
&& Projections.TryGetCompatibleWindowsRuntimeTypeForVariantType(iface, out var compatibleIface))
&& Projections.TryGetCompatibleWindowsRuntimeTypesForVariantType(iface, out var compatibleIfaces))
{
var compatibleIfaceAbiType = compatibleIface.FindHelperType();
entries.Add(new ComInterfaceEntry
foreach (var compatibleIface in compatibleIfaces)
{
IID = GuidGenerator.GetIID(compatibleIfaceAbiType),
Vtable = (IntPtr)compatibleIfaceAbiType.GetAbiToProjectionVftblPtr()
});
var compatibleIfaceAbiType = compatibleIface.FindHelperType();
entries.Add(new ComInterfaceEntry
{
IID = GuidGenerator.GetIID(compatibleIfaceAbiType),
Vtable = (IntPtr)compatibleIfaceAbiType.GetAbiToProjectionVftblPtr()
});
}
}
}

Expand Down
114 changes: 114 additions & 0 deletions WinRT.Runtime/Projections.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Collections.Generic;
using System.Collections.Specialized;
using System.ComponentModel;
using System.Linq;
using System.Numerics;
using System.Reflection;
using System.Threading;
Expand Down Expand Up @@ -206,6 +207,7 @@ private static bool IsTypeWindowsRuntimeTypeNoArray(Type type)
|| type.GetCustomAttribute<WindowsRuntimeTypeAttribute>() is object;
}

// Use TryGetCompatibleWindowsRuntimeTypesForVariantType instead.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the comment a TODO?

Copy link
Member Author

@manodasanW manodasanW Nov 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the issue is this API was marked public even though we only used it internally. Since we don't expect anyone else to use it or for it to be used in cross module scenarios, we can accept the breaking change and remove the old version of the API. But before I did that, I wanted to make sure there weren't concerns with that, so I instead put a comment.

public static bool TryGetCompatibleWindowsRuntimeTypeForVariantType(Type type, out Type compatibleType)
{
compatibleType = null;
Expand Down Expand Up @@ -247,6 +249,118 @@ public static bool TryGetCompatibleWindowsRuntimeTypeForVariantType(Type type, o
return true;
}

private static HashSet<Type> GetCompatibleTypes(Type type)
{
HashSet<Type> compatibleTypes = new HashSet<Type>();

foreach (var iface in type.GetInterfaces())
{
if (IsTypeWindowsRuntimeTypeNoArray(iface))
{
compatibleTypes.Add(iface);
}

if (iface.IsConstructedGenericType
&& TryGetCompatibleWindowsRuntimeTypesForVariantType(iface, out var compatibleIfaces))
{
compatibleTypes.UnionWith(compatibleIfaces);
}
}

Type baseType = type.BaseType;
while (baseType != null)
{
if (IsTypeWindowsRuntimeTypeNoArray(baseType))
{
compatibleTypes.Add(baseType);
}
baseType = baseType.BaseType;
}

return compatibleTypes;
}

internal static IEnumerable<Type> GetAllPossibleTypeCombinations(IEnumerable<IEnumerable<Type>> compatibleTypesPerGeneric, Type definition)
{
// Implementation adapted from https://stackoverflow.com/a/4424005
var accum = new List<Type>();
var compatibleTypesPerGenericArray = compatibleTypesPerGeneric.ToArray();
if (compatibleTypesPerGenericArray.Length > 0)
{
GetAllPossibleTypeCombinationsCore(
accum,
new Stack<Type>(),
compatibleTypesPerGenericArray,
compatibleTypesPerGenericArray.Length - 1);
}
return accum;

void GetAllPossibleTypeCombinationsCore(List<Type> accum, Stack<Type> stack, IEnumerable<Type>[] compatibleTypes, int index)
{
foreach (var type in compatibleTypes[index])
{
stack.Push(type);
if (index == 0)
{
// IEnumerable on a System.Collections.Generic.Stack
// enumerates in order of removal (last to first).
// As a result, we get the correct ordering here.
accum.Add(definition.MakeGenericType(stack.ToArray()));
}
else
{
GetAllPossibleTypeCombinationsCore(accum, stack, compatibleTypes, index - 1);
}
stack.Pop();
}
}
}

internal static bool TryGetCompatibleWindowsRuntimeTypesForVariantType(Type type, out IEnumerable<Type> compatibleTypes)
{
compatibleTypes = null;
if (!type.IsConstructedGenericType)
{
throw new ArgumentException(nameof(type));
}

var definition = type.GetGenericTypeDefinition();

if (!IsTypeWindowsRuntimeTypeNoArray(definition))
{
return false;
}

var genericConstraints = definition.GetGenericArguments();
var genericArguments = type.GetGenericArguments();
List<List<Type>> compatibleTypesPerGeneric = new List<List<Type>>();
for (int i = 0; i < genericArguments.Length; i++)
{
List<Type> compatibleTypesForGeneric = new List<Type>();
bool argumentCovariantObject = (genericConstraints[i].GenericParameterAttributes & GenericParameterAttributes.VarianceMask) == GenericParameterAttributes.Covariant
&& !genericArguments[i].IsValueType;

if (IsTypeWindowsRuntimeTypeNoArray(genericArguments[i]))
{
compatibleTypesForGeneric.Add(genericArguments[i]);
}
else if (!argumentCovariantObject)
{
return false;
}

if (argumentCovariantObject)
{
compatibleTypesForGeneric.AddRange(GetCompatibleTypes(genericArguments[i]));
}

compatibleTypesPerGeneric.Add(compatibleTypesForGeneric);
}

compatibleTypes = GetAllPossibleTypeCombinations(compatibleTypesPerGeneric, definition);
return true;
}

internal static bool TryGetDefaultInterfaceTypeForRuntimeClassType(Type runtimeClass, out Type defaultInterface)
{
runtimeClass = runtimeClass.GetRuntimeClassCCWType() ?? runtimeClass;
Expand Down