From 54eaf066543555a97e331aaf77ee072ace0e13f4 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Thu, 13 Apr 2023 23:17:15 -0700 Subject: [PATCH 01/31] wip --- .../ComInterfaceGenerator.cs | 167 ++++++++++++++++-- .../ComInterfaceGeneratorHelpers.cs | 8 + .../ComInterfaceGenerator.Tests.csproj | 5 +- 3 files changed, 162 insertions(+), 18 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 4ca18b10a5e5f..bf883999d746f 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -5,9 +5,11 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; using System.Threading; +using System.Xml.Schema; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -44,6 +46,56 @@ public static class StepNames public const string GenerateInterfaceInformation = nameof(GenerateInterfaceInformation); public const string GenerateIUnknownDerivedAttribute = nameof(GenerateIUnknownDerivedAttribute); } + public record struct DerivedComInterfacesContext( + int StartingOffset, + ManagedTypeInfo? BaseInterface, + ImmutableArray ShadowingMethodsToGenerate, + bool IsMarkerInterface); + + internal struct BaseInterfacesInfo + { + internal BaseInterfacesInfo(ImmutableArray methodsToShadow, ManagedTypeInfo? baseInterface, bool isMarkerInterface) + { + MethodsToShadow = methodsToShadow; + BaseInterface = baseInterface; + IsMarkerInterface = isMarkerInterface; + } + + public int StartingOffset => MethodsToShadow.Length; + + public ImmutableArray MethodsToShadow { get; } + + public ManagedTypeInfo? BaseInterface { get; } + + public bool IsMarkerInterface { get; } + } + public record ShadowingMethodContext( + int Offset, + ManagedTypeInfo BaseInterfaceType, + string MethodName, + MethodDeclarationSyntax Syntax, + ManagedTypeInfo ReturnType, + ImmutableArray<(ManagedTypeInfo Type, string Name, ParameterSyntax Syntax)> Parameters) + { + public MethodDeclarationSyntax Generate() + { + // DeclarationCopiedFromBaseDeclaration() + // { + // return (()this).(); + // } + return Syntax.WithBody( + Block( + ReturnStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + CastExpression(BaseInterfaceType.Syntax, IdentifierName(Token(SyntaxKind.ThisKeyword))), + IdentifierName(MethodName)), + ArgumentList( + SeparatedList(Parameters.Select(p => Argument(IdentifierName(p.Name))))))))); + } + + } public void Initialize(IncrementalGeneratorInitializationContext context) { @@ -70,19 +122,84 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var interfaceBaseInfo = interfacesToGenerate.Collect().SelectMany((data, ct) => { - ImmutableArray<(int StartingOffset, ManagedTypeInfo? BaseInterface, bool IsMarkerInterface)>.Builder baseInterfaceInfo = ImmutableArray.CreateBuilder<(int, ManagedTypeInfo?, bool)>(data.Length); + // Interface to BaseInterfaceInformation cache + var methodsToShadowCache = new Dictionary>(SymbolEqualityComparer.Default); + foreach (var iface in data) + { + AddBaseInfo(iface.Symbol, methodsToShadowCache, ct); + } + return data.Select(iface => methodsToShadowCache[iface.Symbol]); + + static void AddBaseInfo(INamedTypeSymbol iface, Dictionary> cache, CancellationToken ct) + { + if (TryGetBaseComInterface(iface, out var baseComIface)) + { + if (!cache.TryGetValue(baseComIface, out var methodsToShadow)) + { + AddBaseInfo(baseComIface, cache, ct); + methodsToShadow = cache[baseComIface]; + } + cache.Add(iface, GetRelevantInformationAboutBaseTypes(iface, methodsToShadow.Length, ct).ToImmutableArray()); + } + else + { + cache.Add(iface, GetRelevantInformationAboutBaseTypes(iface, 3, ct).ToImmutableArray()); + } + + } + + static bool TryGetBaseComInterface(INamedTypeSymbol comIface, [NotNullWhen(true)] out INamedTypeSymbol? baseComIface) + { + baseComIface = null; + foreach (var implemented in comIface.Interfaces) + { + if (implemented.GetAttributes().Any(static attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)) + { + // We'll filter out cases where there's multiple matching interfaces when determining + // if this is a valid candidate for generation. + Debug.Assert(baseComIface is null); + baseComIface = implemented; + } + } + return baseComIface is not null; + } + + ImmutableArray.Builder baseInterfaceInfo = ImmutableArray.CreateBuilder(data.Length); // Track the calculated last offsets of the interfaces. // If a type has invalid methods, we'll count them and issue an error when generating code for that // interface. - Dictionary derivedNextOffset = new(SymbolEqualityComparer.Default); + Dictionary derivedNextOffsets = new(SymbolEqualityComparer.Default); foreach (var iface in data) { - var (starting, baseType, derivedStarting) = CalculateOffsetsForInterface(iface.Symbol, derivedNextOffset); - baseInterfaceInfo.Add((starting, baseType is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(baseType) : null, starting == derivedStarting)); + var (starting, baseType, derivedStarting) = CalculateOffsetsForInterface(iface.Symbol, derivedNextOffsets); + var baseTypes = ImmutableArray.CreateBuilder(); + // TODO: We can cache the list of base types for each interface so we don't have to recalculate them. + while (baseType is not null) + { + (var baseStarting, var nextBaseType, derivedStarting) = CalculateOffsetsForInterface(baseType, derivedNextOffsets); + baseTypes.AddRange(GetRelevantInformationAboutBaseTypes(baseType, baseStarting, ct)); + baseType = nextBaseType; + } + baseInterfaceInfo.Add(new(starting, baseTypes.FirstOrDefault()?.BaseInterfaceType, baseTypes.ToImmutableArray(), starting == derivedStarting)); } - return baseInterfaceInfo.MoveToImmutable(); + //return baseInterfaceInfo.MoveToImmutable(); - static (int Starting, INamedTypeSymbol? BaseType, int DerivedStarting) CalculateOffsetsForInterface(INamedTypeSymbol iface, Dictionary derivedNextOffsetCache) + static IEnumerable GetRelevantInformationAboutBaseTypes(INamedTypeSymbol iface, int starting, CancellationToken ct) + { + foreach (var method in iface.GetMembers().Where(m => m is IMethodSymbol)) + { + var m = (IMethodSymbol)method; + yield return new ShadowingMethodContext( + starting++, + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(iface), + iface.Name, + (MethodDeclarationSyntax)m.DeclaringSyntaxReferences[0].GetSyntax(ct), + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(m.ReturnType), + m.Parameters.Select(p => (ManagedTypeInfo.CreateTypeInfoForTypeSymbol(p.Type), p.Name, (ParameterSyntax)p.DeclaringSyntaxReferences[0].GetSyntax(ct))).ToImmutableArray()); + + } + } + static (int Starting, INamedTypeSymbol BaseInterface, int DerivedStarting) CalculateOffsetsForInterface(INamedTypeSymbol iface, Dictionary baseInterfaceInfoCache) { INamedTypeSymbol? baseInterface = null; foreach (var implemented in iface.Interfaces) @@ -97,32 +214,35 @@ public void Initialize(IncrementalGeneratorInitializationContext context) } // Cache the starting offsets for each base interface so we don't have to recalculate them. - int startingOffset = 3; + (int StartingOffset, INamedTypeSymbol? BaseInterface) pair = (3, baseInterface); if (baseInterface is not null) { - if (!derivedNextOffsetCache.TryGetValue(baseInterface, out int offset)) + if (!baseInterfaceInfoCache.TryGetValue(baseInterface, out pair)) { - offset = CalculateOffsetsForInterface(baseInterface, derivedNextOffsetCache).DerivedStarting; + var baseInfo = CalculateOffsetsForInterface(baseInterface, baseInterfaceInfoCache); + pair = (baseInfo.DerivedStarting, baseInterface); } - - startingOffset = offset; } // This calculation isn't strictly accurate. This will count methods that aren't in the same declaring syntax as the attribute on the interface, // but we'll emit an error later if that's a problem. We also can't detect this error if the base type is in metadata. - int ifaceDerivedNextOffset = startingOffset + iface.GetMembers().Where(static m => m is IMethodSymbol { IsStatic: false }).Count(); - derivedNextOffsetCache[iface] = ifaceDerivedNextOffset; + int ifaceDerivedNextOffset = pair.StartingOffset + iface.GetMembers().Where(static m => m is IMethodSymbol { IsStatic: false }).Count(); + baseInterfaceInfoCache[iface] = pair; - return (startingOffset, baseInterface, ifaceDerivedNextOffset); + return (pair.StartingOffset, baseInterface, ifaceDerivedNextOffset); } }); + var shadowingMethods = interfaceBaseInfo.Zip(interfacesToGenerate).Select((data, ct) => + { + return (data.Right, data.Left.Select(s => s.Generate())); + }); // Zip the interface base information back with the symbols and syntax for the interface // to calculate the interface context. // The generator infrastructure preserves ordering of the tables once Select statements are in use, // so we can rely on the order matching here. var interfaceContexts = interfacesToGenerate - .Zip(interfaceBaseInfo.Select((data, ct) => data.StartingOffset)) + .Zip(interfaceBaseInfo.Select((data, ct) => data.Length)) .Select((data, ct) => { var (iface, startingOffset) = data; @@ -141,6 +261,13 @@ public void Initialize(IncrementalGeneratorInitializationContext context) startingOffset, guid ?? Guid.Empty); }); + //var shadowingMethods = interfacesToGenerate.Zip(interfaceBaseInfo).Select((data, ct) => + //{ + // var interfaceInfo = data.Left; + // var interfaceBaseInfo = data.Right; + // interfaceBaseInfo.BaseInterface. + + //}) context.RegisterDiagnostics(invalidTypeDiagnostics.Select((data, ct) => data.Diagnostic)); @@ -152,7 +279,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { var (interfaceData, interfaceContext) = data; Location interfaceLocation = interfaceData.Syntax.GetLocation(); - var methods = ImmutableArray.CreateBuilder<(MethodDeclarationSyntax Syntax, IMethodSymbol Symbol, int Index, Diagnostic? Diagnostic)>(); + var methods = ImmutableArray.CreateBuilder(); int methodVtableOffset = interfaceContext.MethodStartIndex; foreach (var member in interfaceData.Symbol.GetMembers()) { @@ -644,7 +771,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(Im private static readonly MethodDeclarationSyntax CreateManagedVirtualFunctionTableMethodTemplate = MethodDeclaration(VoidStarStarSyntax, CreateManagedVirtualFunctionTableMethodName) .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword)); - private static InterfaceDeclarationSyntax GenerateImplementationVTable(ImmutableArray interfaceMethodStubs, (int StartingOffset, ManagedTypeInfo? BaseInterface, bool) baseInterfaceTypeInfo) + private static InterfaceDeclarationSyntax GenerateImplementationVTable(ImmutableArray interfaceMethodStubs, DerivedComInterfacesContext baseInterfaceTypeInfo) { const string vtableLocalName = "vtable"; var interfaceType = interfaceMethodStubs[0].OriginalDefiningType; @@ -883,4 +1010,10 @@ static ExpressionSyntax CreateEmbeddedDataBlobCreationStatement(ReadOnlySpan (value.Syntax, value.Symbol, value.Index, value.Diagnostic); + public static implicit operator ComInterfaceMethodContext((MethodDeclarationSyntax Syntax, IMethodSymbol Symbol, int Index, Diagnostic? Diagnostic) value) => new ComInterfaceMethodContext(value.Syntax, value.Symbol, value.Index, value.Diagnostic); + } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs index e192fcd02d8fe..8595023a7a30b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Text; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; namespace Microsoft.Interop { @@ -64,5 +65,12 @@ internal static class ComInterfaceGeneratorHelpers return MarshallingGeneratorFactoryKey.Create((env.TargetFramework, env.TargetFrameworkVersion), generatorFactory); } + + //public static IsComMethod(IMethodSymbol method) + //{ + // method.ContainingSymbol is INamedTypeSymbol interfaceSymbol + // && method.DeclaringSyntaxReferences.First().GetSyntax().Parent is InterfaceDeclarationSyntax interfaceDeclaration&& interfaceDeclaration.AttributeLists.Any(attLiToString() == "GeneratedComInterfaceAttribute")) + // && interfaceSymbol.GetAttributes().Any(att=> att.AttributeClass.IsOfType(TypeNames.GeneratedComInterfaceAttribute)) + //} } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ComInterfaceGenerator.Tests.csproj b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ComInterfaceGenerator.Tests.csproj index b8947960debca..46e001c90ae61 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ComInterfaceGenerator.Tests.csproj +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ComInterfaceGenerator.Tests.csproj @@ -1,4 +1,4 @@ - + $(NetCoreAppCurrent) false @@ -9,10 +9,13 @@ false None + true + $(MSBuildThisFileDirectory)Generated + From 153f52a18edf0e8e4211ea96841ce4b1cc308cbc Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Mon, 17 Apr 2023 16:58:40 -0700 Subject: [PATCH 02/31] wip --- .../ComInterfaceGenerator.cs | 320 +++++++++++------- .../ComInterfaceGeneratorHelpers.cs | 33 +- 2 files changed, 221 insertions(+), 132 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index bf883999d746f..8f18cd4cc9ecb 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.ComponentModel.Design; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; @@ -20,13 +21,6 @@ namespace Microsoft.Interop [Generator] public sealed class ComInterfaceGenerator : IIncrementalGenerator { - private sealed record ComInterfaceContext( - ManagedTypeInfo InterfaceType, - ContainingSyntaxContext TypeDefinitionContext, - ContainingSyntax InterfaceTypeSyntax, - int MethodStartIndex, - Guid InterfaceId); - private sealed record class GeneratedStubCodeContext( ManagedTypeInfo OriginalDefiningType, ContainingSyntaxContext ContainingSyntaxContext, @@ -46,11 +40,6 @@ public static class StepNames public const string GenerateInterfaceInformation = nameof(GenerateInterfaceInformation); public const string GenerateIUnknownDerivedAttribute = nameof(GenerateIUnknownDerivedAttribute); } - public record struct DerivedComInterfacesContext( - int StartingOffset, - ManagedTypeInfo? BaseInterface, - ImmutableArray ShadowingMethodsToGenerate, - bool IsMarkerInterface); internal struct BaseInterfacesInfo { @@ -94,7 +83,22 @@ public MethodDeclarationSyntax Generate() ArgumentList( SeparatedList(Parameters.Select(p => Argument(IdentifierName(p.Name))))))))); } + } + private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, [NotNullWhen(true)] out INamedTypeSymbol? baseComIface) + { + baseComIface = null; + foreach (var implemented in comIface.Interfaces) + { + if (implemented.GetAttributes().Any(static attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)) + { + // We'll filter out cases where there's multiple matching interfaces when determining + // if this is a valid candidate for generation. + Debug.Assert(baseComIface is null); + baseComIface = implemented; + } + } + return baseComIface is not null; } public void Initialize(IncrementalGeneratorInitializationContext context) @@ -110,139 +114,84 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .Where( static modelData => modelData is not null); - var interfacesWithDiagnostics = attributedInterfaces.Select(static (data, ct) => + + var interfacesAndDiagnostics = attributedInterfaces.Select(static (data, ct) => { - Diagnostic? diagnostic = GetDiagnosticIfInvalidTypeForGeneration(data.Syntax, data.Symbol); - return new { data.Syntax, data.Symbol, Diagnostic = diagnostic }; + Diagnostic? Diagnostic = GetDiagnosticIfInvalidTypeForGeneration(data.Syntax, data.Symbol); + INamedTypeSymbol? BaseInterfaceSymbol = TryGetBaseComInterface(data.Symbol, out var baseComInterface) ? baseComInterface : null; + ComInterfaceContext Context = ComInterfaceContext.From(data.Symbol, data.Syntax); + return new { data.Syntax, data.Symbol, Context, Diagnostic, BaseInterfaceSymbol }; }); // Split the types we want to generate and the ones we don't into two separate groups. - var interfacesToGenerate = interfacesWithDiagnostics.Where(static data => data.Diagnostic is null); - var invalidTypeDiagnostics = interfacesWithDiagnostics.Where(static data => data.Diagnostic is not null); + var interfacesToGenerate = interfacesAndDiagnostics.Where(static data => data.Diagnostic is null); + var invalidTypeDiagnostics = interfacesAndDiagnostics.Where(static data => data.Diagnostic is not null); - var interfaceBaseInfo = interfacesToGenerate.Collect().SelectMany((data, ct) => + // Get the methods for each interface, without the indices + var interfaceMethods = interfacesToGenerate.Select((data, ct) => { - // Interface to BaseInterfaceInformation cache - var methodsToShadowCache = new Dictionary>(SymbolEqualityComparer.Default); - foreach (var iface in data) + INamedTypeSymbol iface = data.Symbol; + List comMethods = new(); + foreach (var member in iface.GetMembers()) { - AddBaseInfo(iface.Symbol, methodsToShadowCache, ct); - } - return data.Select(iface => methodsToShadowCache[iface.Symbol]); - - static void AddBaseInfo(INamedTypeSymbol iface, Dictionary> cache, CancellationToken ct) - { - if (TryGetBaseComInterface(iface, out var baseComIface)) - { - if (!cache.TryGetValue(baseComIface, out var methodsToShadow)) - { - AddBaseInfo(baseComIface, cache, ct); - methodsToShadow = cache[baseComIface]; - } - cache.Add(iface, GetRelevantInformationAboutBaseTypes(iface, methodsToShadow.Length, ct).ToImmutableArray()); - } - else + if (MethodInfo.IsComInterface(iface, member, out MethodInfo? methodInfo)) { - cache.Add(iface, GetRelevantInformationAboutBaseTypes(iface, 3, ct).ToImmutableArray()); + comMethods.Add(methodInfo); } - } + return comMethods; + }); - static bool TryGetBaseComInterface(INamedTypeSymbol comIface, [NotNullWhen(true)] out INamedTypeSymbol? baseComIface) + // Using the methods, we can get the offsets and the shadowing methods in one stage. + var interfaceBaseCache = interfacesToGenerate.Zip(interfaceMethods).Collect().Select((data, ct) => + { + Dictionary> allMethods = new(SymbolEqualityComparer.Default); + Dictionary baseInterfaces = new(SymbolEqualityComparer.Default); + + var declaredMethods = data.ToImmutableDictionary( + static pair => (INamedTypeSymbol)pair.Left.Symbol, + static pair => pair.Right, + SymbolEqualityComparer.Default); + Dictionary> allMethodsCache = new(SymbolEqualityComparer.Default); + foreach (var ifaceMethodsPair in data) { - baseComIface = null; - foreach (var implemented in comIface.Interfaces) - { - if (implemented.GetAttributes().Any(static attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)) - { - // We'll filter out cases where there's multiple matching interfaces when determining - // if this is a valid candidate for generation. - Debug.Assert(baseComIface is null); - baseComIface = implemented; - } - } - return baseComIface is not null; - } + allMethods.Add(GetMethods(ifaceMethodsPair.Left.Symbol, declaredMethods, allMethodsCache); - ImmutableArray.Builder baseInterfaceInfo = ImmutableArray.CreateBuilder(data.Length); - // Track the calculated last offsets of the interfaces. - // If a type has invalid methods, we'll count them and issue an error when generating code for that - // interface. - Dictionary derivedNextOffsets = new(SymbolEqualityComparer.Default); - foreach (var iface in data) - { - var (starting, baseType, derivedStarting) = CalculateOffsetsForInterface(iface.Symbol, derivedNextOffsets); - var baseTypes = ImmutableArray.CreateBuilder(); - // TODO: We can cache the list of base types for each interface so we don't have to recalculate them. - while (baseType is not null) - { - (var baseStarting, var nextBaseType, derivedStarting) = CalculateOffsetsForInterface(baseType, derivedNextOffsets); - baseTypes.AddRange(GetRelevantInformationAboutBaseTypes(baseType, baseStarting, ct)); - baseType = nextBaseType; - } - baseInterfaceInfo.Add(new(starting, baseTypes.FirstOrDefault()?.BaseInterfaceType, baseTypes.ToImmutableArray(), starting == derivedStarting)); } - //return baseInterfaceInfo.MoveToImmutable(); - static IEnumerable GetRelevantInformationAboutBaseTypes(INamedTypeSymbol iface, int starting, CancellationToken ct) - { - foreach (var method in iface.GetMembers().Where(m => m is IMethodSymbol)) - { - var m = (IMethodSymbol)method; - yield return new ShadowingMethodContext( - starting++, - ManagedTypeInfo.CreateTypeInfoForTypeSymbol(iface), - iface.Name, - (MethodDeclarationSyntax)m.DeclaringSyntaxReferences[0].GetSyntax(ct), - ManagedTypeInfo.CreateTypeInfoForTypeSymbol(m.ReturnType), - m.Parameters.Select(p => (ManagedTypeInfo.CreateTypeInfoForTypeSymbol(p.Type), p.Name, (ParameterSyntax)p.DeclaringSyntaxReferences[0].GetSyntax(ct))).ToImmutableArray()); + return allMethods.ToImmutableArray(); - } - } - static (int Starting, INamedTypeSymbol BaseInterface, int DerivedStarting) CalculateOffsetsForInterface(INamedTypeSymbol iface, Dictionary baseInterfaceInfoCache) + static IEnumerable GetMethods(INamedTypeSymbol symbol, ImmutableDictionary>? declaredMethods, Dictionary> allMethodsCache) { - INamedTypeSymbol? baseInterface = null; - foreach (var implemented in iface.Interfaces) + int startingIndex = 3; + if (TryGetBaseComInterface(symbol, out var baseComIface)) { - if (implemented.GetAttributes().Any(static attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)) + if (!allMethodsCache.TryGetValue(baseComIface, out var baseMethods)) { - // We'll filter out cases where there's multiple matching interfaces when determining - // if this is a valid candidate for generation. - Debug.Assert(baseInterface is null); - baseInterface = implemented; + baseMethods = GetMethods(baseComIface, declaredMethods, allMethodsCache); + allMethodsCache[baseComIface] = baseMethods; } - } - // Cache the starting offsets for each base interface so we don't have to recalculate them. - (int StartingOffset, INamedTypeSymbol? BaseInterface) pair = (3, baseInterface); - if (baseInterface is not null) - { - if (!baseInterfaceInfoCache.TryGetValue(baseInterface, out pair)) + foreach (var method in BaseMethods) { - var baseInfo = CalculateOffsetsForInterface(baseInterface, baseInterfaceInfoCache); - pair = (baseInfo.DerivedStarting, baseInterface); + startingIndex++; + yield return method; } } - - // This calculation isn't strictly accurate. This will count methods that aren't in the same declaring syntax as the attribute on the interface, - // but we'll emit an error later if that's a problem. We also can't detect this error if the base type is in metadata. - int ifaceDerivedNextOffset = pair.StartingOffset + iface.GetMembers().Where(static m => m is IMethodSymbol { IsStatic: false }).Count(); - baseInterfaceInfoCache[iface] = pair; - - return (pair.StartingOffset, baseInterface, ifaceDerivedNextOffset); + foreach (var method in declaredMethods[symbol]) + { + yield return new ComInterfaceMethodContext(method.Item2, method.Item1, startingIndex++, null); + } } }); - var shadowingMethods = interfaceBaseInfo.Zip(interfacesToGenerate).Select((data, ct) => - { - return (data.Right, data.Left.Select(s => s.Generate())); - }); + // Zip the interface base information back with the symbols and syntax for the interface // to calculate the interface context. // The generator infrastructure preserves ordering of the tables once Select statements are in use, // so we can rely on the order matching here. var interfaceContexts = interfacesToGenerate - .Zip(interfaceBaseInfo.Select((data, ct) => data.Length)) + .Zip(interfaceBaseInfo.Select((data, ct) => data.StartingOffset)) .Select((data, ct) => { var (iface, startingOffset) = data; @@ -261,13 +210,6 @@ static IEnumerable GetRelevantInformationAboutBaseTypes( startingOffset, guid ?? Guid.Empty); }); - //var shadowingMethods = interfacesToGenerate.Zip(interfaceBaseInfo).Select((data, ct) => - //{ - // var interfaceInfo = data.Left; - // var interfaceBaseInfo = data.Right; - // interfaceBaseInfo.BaseInterface. - - //}) context.RegisterDiagnostics(invalidTypeDiagnostics.Select((data, ct) => data.Diagnostic)); @@ -771,7 +713,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(Im private static readonly MethodDeclarationSyntax CreateManagedVirtualFunctionTableMethodTemplate = MethodDeclaration(VoidStarStarSyntax, CreateManagedVirtualFunctionTableMethodName) .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword)); - private static InterfaceDeclarationSyntax GenerateImplementationVTable(ImmutableArray interfaceMethodStubs, DerivedComInterfacesContext baseInterfaceTypeInfo) + private static InterfaceDeclarationSyntax GenerateImplementationVTable(ImmutableArray interfaceMethodStubs, BaseInterfacesInfo baseInterfaceTypeInfo) { const string vtableLocalName = "vtable"; var interfaceType = interfaceMethodStubs[0].OriginalDefiningType; @@ -1009,11 +951,137 @@ static ExpressionSyntax CreateEmbeddedDataBlobCreationStatement(ReadOnlySpan(literals))))); } } - } - internal record struct ComInterfaceMethodContext(MethodDeclarationSyntax Syntax, IMethodSymbol Symbol, int Index, Diagnostic? Diagnostic) - { - public static implicit operator (MethodDeclarationSyntax Syntax, IMethodSymbol Symbol, int Index, Diagnostic? Diagnostic)(ComInterfaceMethodContext value) => (value.Syntax, value.Symbol, value.Index, value.Diagnostic); - public static implicit operator ComInterfaceMethodContext((MethodDeclarationSyntax Syntax, IMethodSymbol Symbol, int Index, Diagnostic? Diagnostic) value) => new ComInterfaceMethodContext(value.Syntax, value.Symbol, value.Index, value.Diagnostic); + private sealed record ComInterfaceContext( + ManagedTypeInfo InterfaceType, + ContainingSyntaxContext TypeDefinitionContext, + ContainingSyntax InterfaceTypeSyntax, + [property: Obsolete] int MethodStartIndex, + Guid InterfaceId) + { + public static ComInterfaceContext From(INamedTypeSymbol symbol, TypeDeclarationSyntax syntax) + { + Guid? guid = null; + var guidAttr = symbol.GetAttributes().Where(attr => attr.AttributeClass.ToDisplayString() == TypeNames.System_Runtime_InteropServices_GuidAttribute).SingleOrDefault(); + if (guidAttr is not null) + { + string? guidstr = guidAttr.ConstructorArguments.SingleOrDefault().Value as string; + if (guidstr is not null) + guid = new Guid(guidstr); + } + return new ComInterfaceContext( + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol), + new ContainingSyntaxContext(syntax), + new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList), + -1, + guid ?? Guid.Empty); + } + + } + + /// + /// Represents a method that has been determined to be a COM interface method. + /// + private record MethodInfo([property: Obsolete] IMethodSymbol Symbol, MethodDeclarationSyntax Syntax) + { + public static bool IsComInterface(INamedTypeSymbol iface, ISymbol member, [NotNullWhen(true)] out MethodInfo? comMethodInfo) + { + comMethodInfo = null; + Location interfaceLocation = iface.Locations[0]; + if (member.Kind == SymbolKind.Method && !member.IsStatic) + { + // We only support methods that are defined in the same partial interface definition as the + // [GeneratedComInterface] attribute. + // This restriction not only makes finding the syntax for a given method cheaper, + // but it also enables us to ensure that we can determine vtable method order easily. + Location? methodLocationInAttributedInterfaceDeclaration = null; + foreach (var methodLocation in member.Locations) + { + if (methodLocation.SourceTree == interfaceLocation.SourceTree + && interfaceLocation.SourceSpan.Contains(methodLocation.SourceSpan)) + { + methodLocationInAttributedInterfaceDeclaration = methodLocation; + break; + } + } + + MethodDeclarationSyntax? comMethodDeclaringSyntax = null; + + // TODO: this should cause a diagnostic + if (methodLocationInAttributedInterfaceDeclaration is null) + { + return false; + } + + // Find the matching declaration syntax + foreach (var declaringSyntaxReference in member.DeclaringSyntaxReferences) + { + var declaringSyntax = declaringSyntaxReference.GetSyntax(ct); + Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration)); + if (declaringSyntax.GetLocation() == methodLocationInAttributedInterfaceDeclaration) + { + comMethodDeclaringSyntax = (MethodDeclarationSyntax)declaringSyntax; + break; + } + } + if (comMethodDeclaringSyntax is null) + throw new NotImplementedException("Found a method that was declared in the attributed interface declaration, but couldn't find the syntax for it."); + comMethodInfo = new((IMethodSymbol)member, comMethodDeclaringSyntax); + return true; + } + return false; + } + } + + /// + /// Represents a method, its declaring interface, and its index in the interface's vtable. + /// + private record ComInterfaceMethodContext(ComInterfaceContext DeclaringInterface, MethodInfo MethodInfo, int Index, Diagnostic? Diagnostic) + { + public record Builder(ImmutableDictionary>? declaredMethods) + { + public Dictionary> allMethodsCache = new(); + public IEnumerable GetMethods(ComInterfaceContext symbol) + { + int startingIndex = 3; + if (TryGetBaseComInterface(symbol, out var baseComIface)) + { + if (!allMethodsCache.TryGetValue(baseComIface, out var baseMethods)) + { + baseMethods = GetMethods(baseComIface, declaredMethods, allMethodsCache); + allMethodsCache[baseComIface] = baseMethods; + } + + foreach (var method in baseMethods) + { + startingIndex++; + yield return method; + } + } + foreach (var method in declaredMethods[symbol]) + { + yield return new ComInterfaceMethodContext(symbol, method, startingIndex++, null); + } + + } + + } + } + + /// + /// Represents an interface and all of the methods that need to be generated for it. + /// + private record ComInterfaceAndMethods(ComInterfaceContext Interface, ImmutableArray Methods) + { + /// + /// COM methods that are declared on the attributed interface declaration. + /// + public IEnumerable DeclaredMethods => Methods.Where(m => m.DeclaringInterface == Interface); + + /// + /// COM methods that are declared on an interface the interface inherits from. + /// + public IEnumerable ShadowingMethods => Methods.Where(m => m.DeclaringInterface != Interface); + } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs index 8595023a7a30b..9f5ae862672b2 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs @@ -66,11 +66,32 @@ internal static class ComInterfaceGeneratorHelpers return MarshallingGeneratorFactoryKey.Create((env.TargetFramework, env.TargetFrameworkVersion), generatorFactory); } - //public static IsComMethod(IMethodSymbol method) - //{ - // method.ContainingSymbol is INamedTypeSymbol interfaceSymbol - // && method.DeclaringSyntaxReferences.First().GetSyntax().Parent is InterfaceDeclarationSyntax interfaceDeclaration&& interfaceDeclaration.AttributeLists.Any(attLiToString() == "GeneratedComInterfaceAttribute")) - // && interfaceSymbol.GetAttributes().Any(att=> att.AttributeClass.IsOfType(TypeNames.GeneratedComInterfaceAttribute)) - //} + enum ComInterfaceMethodValidity + { + Valid, + NotDefinedInAttributedSyntax, + + } + public static bool IsComMethod(ISymbol member, InterfaceDeclarationSyntax iface) + { + Location? locationInAttributeSyntax = null; + if (member.Kind == SymbolKind.Method && !member.IsStatic) + { + Location interfaceLocation = iface.GetLocation(); + // We only support methods that are defined in the same partial interface definition as the + // [GeneratedComInterface] attribute. + // This restriction not only makes finding the syntax for a given method cheaper, + // but it also enables us to ensure that we can determine vtable method order easily. + foreach (var location in member.Locations) + { + if (location.SourceTree == interfaceLocation.SourceTree + && interfaceLocation.SourceSpan.Contains(location.SourceSpan)) + { + locationInAttributeSyntax = location; + } + } + } + return locationInAttributeSyntax is not null; + } } } From 39869a78d1f0c6839d78bc533d841a2add4ee8a0 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Tue, 18 Apr 2023 10:18:48 -0700 Subject: [PATCH 03/31] wip --- .../ComInterfaceGenerator.cs | 88 ++++++++++++------- .../IncrementalValuesProviderExtensions.cs | 17 ++++ 2 files changed, 73 insertions(+), 32 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 8f18cd4cc9ecb..2a2322eae3901 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -127,6 +127,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var interfacesToGenerate = interfacesAndDiagnostics.Where(static data => data.Diagnostic is null); var invalidTypeDiagnostics = interfacesAndDiagnostics.Where(static data => data.Diagnostic is not null); + // Get the methods for each interface, without the indices var interfaceMethods = interfacesToGenerate.Select((data, ct) => { @@ -139,9 +140,38 @@ public void Initialize(IncrementalGeneratorInitializationContext context) comMethods.Add(methodInfo); } } - return comMethods; + return comMethods.ToImmutableArray(); + }); + + var interfaceAndBasePairs = interfacesToGenerate.Collect().Select((data, ct) => + { + Dictionary ifaceToBaseMap = new(); + Dictionary contexts = new(SymbolEqualityComparer.Default); + foreach (var iface in data) + { + contexts.Add(iface.Symbol, iface.Context); + } + foreach (var iface in data) + { + ifaceToBaseMap.Add(iface.Context, iface.BaseInterfaceSymbol is not null ? contexts[iface.BaseInterfaceSymbol] : null); + } + return ifaceToBaseMap; + }); + + var interfacesAndMethods = interfacesToGenerate.Select((iface, ct) => iface.Context).Zip(interfaceMethods); + var interfaceToMethodsMap = interfacesAndMethods.Collect().Select((data, ct) => + { + return data.ToImmutableDictionary<(ComInterfaceContext, ImmutableArray), ComInterfaceContext, ImmutableArray>( + static pair => pair.Item1, + static pair => pair.Item2); }); + var interfaceAndMethodsContexts = interfaceToMethodsMap.Combine(interfaceAndBasePairs).SelectMany((data, ct) => + { + return ComInterfaceMethodContext.GetMethods(data.Right, data.Left); + }); + + // Using the methods, we can get the offsets and the shadowing methods in one stage. var interfaceBaseCache = interfacesToGenerate.Zip(interfaceMethods).Collect().Select((data, ct) => { @@ -161,28 +191,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return allMethods.ToImmutableArray(); - static IEnumerable GetMethods(INamedTypeSymbol symbol, ImmutableDictionary>? declaredMethods, Dictionary> allMethodsCache) - { - int startingIndex = 3; - if (TryGetBaseComInterface(symbol, out var baseComIface)) - { - if (!allMethodsCache.TryGetValue(baseComIface, out var baseMethods)) - { - baseMethods = GetMethods(baseComIface, declaredMethods, allMethodsCache); - allMethodsCache[baseComIface] = baseMethods; - } - - foreach (var method in BaseMethods) - { - startingIndex++; - yield return method; - } - } - foreach (var method in declaredMethods[symbol]) - { - yield return new ComInterfaceMethodContext(method.Item2, method.Item1, startingIndex++, null); - } - } }); @@ -1016,7 +1024,7 @@ public static bool IsComInterface(INamedTypeSymbol iface, ISymbol member, [NotNu // Find the matching declaration syntax foreach (var declaringSyntaxReference in member.DeclaringSyntaxReferences) { - var declaringSyntax = declaringSyntaxReference.GetSyntax(ct); + var declaringSyntax = declaringSyntaxReference.GetSyntax(); Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration)); if (declaringSyntax.GetLocation() == methodLocationInAttributedInterfaceDeclaration) { @@ -1038,18 +1046,34 @@ public static bool IsComInterface(INamedTypeSymbol iface, ISymbol member, [NotNu /// private record ComInterfaceMethodContext(ComInterfaceContext DeclaringInterface, MethodInfo MethodInfo, int Index, Diagnostic? Diagnostic) { - public record Builder(ImmutableDictionary>? declaredMethods) + public static ImmutableArray GetMethods(Dictionary ifaceToBaseMap, ImmutableDictionary> ifaceToMethodsMap) { - public Dictionary> allMethodsCache = new(); - public IEnumerable GetMethods(ComInterfaceContext symbol) + + Dictionary> allMethodsCache = new(); + foreach(var kvp in ifaceToMethodsMap) + { + AddMethods(kvp.Key, kvp.Value); + } + + return allMethodsCache.Select(kvp => new ComInterfaceAndMethods(kvp.Key, kvp.Value.ToImmutableArray())).ToImmutableArray(); + + IEnumerable AddMethods(ComInterfaceContext iface, IEnumerable declaredMethods) { + if (allMethodsCache.TryGetValue(iface, out var cachedValue)) + { + foreach(var value in cachedValue) + { + yield return value; + } + } + + List metods = new(); int startingIndex = 3; - if (TryGetBaseComInterface(symbol, out var baseComIface)) + if (ifaceToBaseMap.TryGetValue(iface, out var baseComIface)) { if (!allMethodsCache.TryGetValue(baseComIface, out var baseMethods)) { - baseMethods = GetMethods(baseComIface, declaredMethods, allMethodsCache); - allMethodsCache[baseComIface] = baseMethods; + baseMethods = AddMethods(baseComIface, ifaceToMethodsMap[baseComIface]); } foreach (var method in baseMethods) @@ -1058,9 +1082,9 @@ public IEnumerable GetMethods(ComInterfaceContext sym yield return method; } } - foreach (var method in declaredMethods[symbol]) + foreach (var method in declaredMethods) { - yield return new ComInterfaceMethodContext(symbol, method, startingIndex++, null); + yield return new ComInterfaceMethodContext(iface, method, startingIndex++, null); } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs index fb0dd80ca7f1a..a6b6d5a5922ca 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs @@ -31,6 +31,23 @@ internal static class IncrementalValuesProviderExtensions }); } + public static IncrementalValuesProvider<(T Left, U Right)> ZipSingle(this IncrementalValuesProvider left, IncrementalValueProvider right) + { + return left.Collect().Combine(right).SelectMany((data, ct) => + { + ImmutableArray<(T, U)>.Builder builder = ImmutableArray.CreateBuilder<(T, U)>(data.Left.Length); + for (int i = 0; i < data.Left.Length; i++) + { + builder.Add((data.Left[i], data.Right)); + } + return builder.ToImmutable(); + }); + + + + } + + /// /// Format the syntax nodes in the given provider such that we will not re-normalize if the input nodes have not changed. /// From a5132ad8d98570a9f8a83c898158ebdbaf33431d Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Wed, 19 Apr 2023 09:49:18 -0700 Subject: [PATCH 04/31] Wip --- .../ComInterfaceGenerator.cs | 315 +++++++----------- .../ComInterfaceGeneratorHelpers.cs | 28 -- .../IncrementalValuesProviderExtensions.cs | 17 - .../SequenceEqualImmutableArray.cs | 32 +- 4 files changed, 143 insertions(+), 249 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 2a2322eae3901..7c855be8c7702 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -43,17 +43,12 @@ public static class StepNames internal struct BaseInterfacesInfo { - internal BaseInterfacesInfo(ImmutableArray methodsToShadow, ManagedTypeInfo? baseInterface, bool isMarkerInterface) + internal BaseInterfacesInfo(ManagedTypeInfo? baseInterface, bool isMarkerInterface) { - MethodsToShadow = methodsToShadow; BaseInterface = baseInterface; IsMarkerInterface = isMarkerInterface; } - public int StartingOffset => MethodsToShadow.Length; - - public ImmutableArray MethodsToShadow { get; } - public ManagedTypeInfo? BaseInterface { get; } public bool IsMarkerInterface { get; } @@ -94,8 +89,8 @@ private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, [NotNullWh { // We'll filter out cases where there's multiple matching interfaces when determining // if this is a valid candidate for generation. - Debug.Assert(baseComIface is null); baseComIface = implemented; + break; } } return baseComIface is not null; @@ -125,8 +120,10 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Split the types we want to generate and the ones we don't into two separate groups. var interfacesToGenerate = interfacesAndDiagnostics.Where(static data => data.Diagnostic is null); - var invalidTypeDiagnostics = interfacesAndDiagnostics.Where(static data => data.Diagnostic is not null); - + { + var invalidTypeDiagnostics = interfacesAndDiagnostics.Where(static data => data.Diagnostic is not null); + context.RegisterDiagnostics(invalidTypeDiagnostics.Select((data, ct) => data.Diagnostic)); + } // Get the methods for each interface, without the indices var interfaceMethods = interfacesToGenerate.Select((data, ct) => @@ -135,7 +132,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) List comMethods = new(); foreach (var member in iface.GetMembers()) { - if (MethodInfo.IsComInterface(iface, member, out MethodInfo? methodInfo)) + if (MethodInfo.IsComInterface(data.Context, member, out MethodInfo? methodInfo)) { comMethods.Add(methodInfo); } @@ -143,7 +140,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return comMethods.ToImmutableArray(); }); - var interfaceAndBasePairs = interfacesToGenerate.Collect().Select((data, ct) => + var ifaceToBaseMap = interfacesToGenerate.Collect().Select((data, ct) => { Dictionary ifaceToBaseMap = new(); Dictionary contexts = new(SymbolEqualityComparer.Default); @@ -158,133 +155,43 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return ifaceToBaseMap; }); - var interfacesAndMethods = interfacesToGenerate.Select((iface, ct) => iface.Context).Zip(interfaceMethods); - var interfaceToMethodsMap = interfacesAndMethods.Collect().Select((data, ct) => - { - return data.ToImmutableDictionary<(ComInterfaceContext, ImmutableArray), ComInterfaceContext, ImmutableArray>( - static pair => pair.Item1, - static pair => pair.Item2); - }); - - var interfaceAndMethodsContexts = interfaceToMethodsMap.Combine(interfaceAndBasePairs).SelectMany((data, ct) => - { - return ComInterfaceMethodContext.GetMethods(data.Right, data.Left); - }); - - - // Using the methods, we can get the offsets and the shadowing methods in one stage. - var interfaceBaseCache = interfacesToGenerate.Zip(interfaceMethods).Collect().Select((data, ct) => - { - Dictionary> allMethods = new(SymbolEqualityComparer.Default); - Dictionary baseInterfaces = new(SymbolEqualityComparer.Default); - - var declaredMethods = data.ToImmutableDictionary( - static pair => (INamedTypeSymbol)pair.Left.Symbol, - static pair => pair.Right, - SymbolEqualityComparer.Default); - Dictionary> allMethodsCache = new(SymbolEqualityComparer.Default); - foreach (var ifaceMethodsPair in data) - { - allMethods.Add(GetMethods(ifaceMethodsPair.Left.Symbol, declaredMethods, allMethodsCache); - - } - - return allMethods.ToImmutableArray(); - - }); - - - // Zip the interface base information back with the symbols and syntax for the interface - // to calculate the interface context. - // The generator infrastructure preserves ordering of the tables once Select statements are in use, - // so we can rely on the order matching here. - var interfaceContexts = interfacesToGenerate - .Zip(interfaceBaseInfo.Select((data, ct) => data.StartingOffset)) + var interfaceToMethodsMap = interfacesToGenerate + .Select((iface, ct) => iface.Context) + .Zip(interfaceMethods) + .Collect() .Select((data, ct) => - { - var (iface, startingOffset) = data; - Guid? guid = null; - var guidAttr = iface.Symbol.GetAttributes().Where(attr => attr.AttributeClass.ToDisplayString() == TypeNames.System_Runtime_InteropServices_GuidAttribute).SingleOrDefault(); - if (guidAttr is not null) { - string? guidstr = guidAttr.ConstructorArguments.SingleOrDefault().Value as string; - if (guidstr is not null) - guid = new Guid(guidstr); - } - return new ComInterfaceContext( - ManagedTypeInfo.CreateTypeInfoForTypeSymbol(iface.Symbol), - new ContainingSyntaxContext(iface.Syntax), - new ContainingSyntax(iface.Syntax.Modifiers, iface.Syntax.Kind(), iface.Syntax.Identifier, iface.Syntax.TypeParameterList), - startingOffset, - guid ?? Guid.Empty); - }); - - context.RegisterDiagnostics(invalidTypeDiagnostics.Select((data, ct) => data.Diagnostic)); + return data.ToImmutableDictionary<(ComInterfaceContext, ImmutableArray), ComInterfaceContext, ImmutableArray>( + static pair => pair.Item1, + static pair => pair.Item2); + }); - // Zip the incremental interface context back with the symbols and syntax for the interface - // to calculate the methods to generate. - var interfacesWithMethods = interfacesToGenerate - .Zip(interfaceContexts) - .Select(static (data, ct) => - { - var (interfaceData, interfaceContext) = data; - Location interfaceLocation = interfaceData.Syntax.GetLocation(); - var methods = ImmutableArray.CreateBuilder(); - int methodVtableOffset = interfaceContext.MethodStartIndex; - foreach (var member in interfaceData.Symbol.GetMembers()) + var interfaceAndMethodsContexts = interfaceToMethodsMap + .Combine(ifaceToBaseMap) + .Combine(context.CreateStubEnvironmentProvider()) + .SelectMany((data, ct) => { - if (member.Kind == SymbolKind.Method && !member.IsStatic) - { - // We only support methods that are defined in the same partial interface definition as the - // [GeneratedComInterface] attribute. - // This restriction not only makes finding the syntax for a given method cheaper, - // but it also enables us to ensure that we can determine vtable method order easily. - Location? locationInAttributeSyntax = null; - foreach (var location in member.Locations) - { - if (location.SourceTree == interfaceLocation.SourceTree - && interfaceLocation.SourceSpan.Contains(location.SourceSpan)) - { - locationInAttributeSyntax = location; - } - } + var ((ifaceToMethodsMap, ifaceToBaseMap), env) = data; + return ComInterfaceAndMethods.GetMethods(ifaceToBaseMap, ifaceToMethodsMap, env, ct); + }); - if (locationInAttributeSyntax is null) - { - methods.Add(( - null!, - (IMethodSymbol)member, - 0, - member.CreateDiagnostic( - GeneratorDiagnostics.MethodNotDeclaredInAttributedInterface, - member.ToDisplayString(), - interfaceData.Symbol.ToDisplayString()))); - } - else - { - var syntax = (MethodDeclarationSyntax)interfaceData.Syntax.FindNode(locationInAttributeSyntax.SourceSpan); - var method = (IMethodSymbol)member; - Diagnostic? diagnostic = GetDiagnosticIfInvalidMethodForGeneration(syntax, method); - methods.Add((syntax, method, diagnostic is null ? methodVtableOffset++ : 0, diagnostic)); - } - } - } - return (Interface: interfaceContext, Methods: methods.ToImmutable()); - }); - var interfaceWithMethodsContexts = interfacesWithMethods - .Where(data => data.Methods.Length > 0) + var interfacesWithMethodsAndItsMethods = interfaceAndMethodsContexts + .Where(data => data.DeclaredMethods.Any()); + + var interfacesWithMethods = interfacesWithMethodsAndItsMethods .Select(static (data, ct) => data.Interface); // Marker interfaces are COM interfaces that don't have any methods. // The lack of methods breaks the mechanism we use later to stitch back together interface-level data // and method-level data, but that's okay because marker interfaces are much simpler. // We'll handle them seperately because they are so simple. - var markerInterfaces = interfacesWithMethods - .Where(data => data.Methods.Length == 0) + var markerInterfaces = interfaceAndMethodsContexts + .Where(data => !data.DeclaredMethods.Any()) .Select(static (data, ct) => data.Interface); - var markerInterfaceIUnknownDerived = markerInterfaces.Select(static (context, ct) => GenerateIUnknownDerivedAttributeApplication(context)) + var markerInterfaceIUnknownDerived = markerInterfaces + .Select(static (context, ct) => GenerateIUnknownDerivedAttributeApplication(context)) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); @@ -296,32 +203,23 @@ public void Initialize(IncrementalGeneratorInitializationContext context) GenerateMarkerInterfaceSource(interfaceContext) + iUnknownDerivedAttributeApplication); }); - var methodsWithDiagnostics = interfacesWithMethods.SelectMany(static (data, ct) => data.Methods); + var allMethods = interfaceAndMethodsContexts.SelectMany(static (data, ct) => data.DeclaredMethods); // Split the methods we want to generate and the ones we don't into two separate groups. - var methodsToGenerate = methodsWithDiagnostics.Where(static data => data.Diagnostic is null); - var invalidMethodDiagnostics = methodsWithDiagnostics.Where(static data => data.Diagnostic is not null); - - context.RegisterSourceOutput(invalidMethodDiagnostics, static (context, invalidMethod) => + var methodsToGenerate = allMethods.Where(static data => { - context.ReportDiagnostic(invalidMethod.Diagnostic); + return data.Diagnostic is null; }); + { + var invalidMethods = allMethods.Where(static data => data.Diagnostic is not null); - // Calculate all of information to generate both managed-to-unmanaged and unmanaged-to-managed stubs - // for each method. - IncrementalValuesProvider generateStubInformation = methodsToGenerate - .Combine(context.CreateStubEnvironmentProvider()) - .Select(static (data, ct) => new + context.RegisterSourceOutput(invalidMethods, static (context, invalidMethod) => { - data.Left.Syntax, - data.Left.Symbol, - data.Left.Index, - Environment = data.Right - }) - .Select( - static (data, ct) => CalculateStubInformation(data.Syntax, data.Symbol, data.Index, data.Environment, ct) - ) - .WithTrackingName(StepNames.CalculateStubInformation); + context.ReportDiagnostic(invalidMethod.Diagnostic); + }); + } + + IncrementalValuesProvider generateStubInformation = methodsToGenerate.Select((data, ct) => data.GenerationContext); // Generate the code for the managed-to-unmanaged stubs and the diagnostics from code-generation. var generateManagedToNativeStub = generateStubInformation @@ -379,7 +277,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .SelectNormalized(); // Generate the native interface metadata for each [GeneratedComInterface]-attributed interface. - var nativeInterfaceInformation = interfaceWithMethodsContexts + var nativeInterfaceInformation = interfacesWithMethods .Select(static (context, ct) => GenerateInterfaceInformation(context)) .WithTrackingName(StepNames.GenerateInterfaceInformation) .WithComparer(SyntaxEquivalentComparer.Instance) @@ -387,23 +285,18 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Generate a method named CreateManagedVirtualFunctionTable on the native interface implementation // that allocates and fills in the memory for the vtable. - var nativeToManagedVtables = - generateStubInformation - .Collect() - .SelectMany(static (data, ct) => GroupContextsForInterfaceGeneration(data.CastArray())) - .Zip(interfaceBaseInfo.Where(static info => !info.IsMarkerInterface)) - .Select(static (data, ct) => GenerateImplementationVTable(ImmutableArray.CreateRange(data.Left.Array.Cast()), data.Right)) + var nativeToManagedVtables = interfacesWithMethodsAndItsMethods.Select((data, ct) => GenerateImplementationVTable(data)) .WithTrackingName(StepNames.GenerateNativeToManagedVTable) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); - var iUnknownDerivedAttributeApplication = interfaceWithMethodsContexts + var iUnknownDerivedAttributeApplication = interfacesWithMethods .Select(static (context, ct) => GenerateIUnknownDerivedAttributeApplication(context)) .WithTrackingName(StepNames.GenerateIUnknownDerivedAttribute) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); - var filesToGenerate = interfaceWithMethodsContexts + var filesToGenerate = interfacesWithMethods .Zip(nativeInterfaceInformation) .Zip(managedToNativeInterfaceImplementations) .Zip(nativeToManagedVtableMethods) @@ -474,9 +367,9 @@ public static void** ManagedVirtualMethodTable private static MemberDeclarationSyntax GenerateIUnknownDerivedAttributeApplication(ComInterfaceContext context) => context.TypeDefinitionContext.WrapMemberInContainingSyntaxWithUnsafeModifier( - TypeDeclaration(context.InterfaceTypeSyntax.TypeKind, context.InterfaceTypeSyntax.Identifier) - .WithModifiers(context.InterfaceTypeSyntax.Modifiers) - .WithTypeParameterList(context.InterfaceTypeSyntax.TypeParameters) + TypeDeclaration(context.ContainingSyntax.TypeKind, context.ContainingSyntax.Identifier) + .WithModifiers(context.ContainingSyntax.Modifiers) + .WithTypeParameterList(context.ContainingSyntax.TypeParameters) .AddAttributeLists(AttributeList(SingletonSeparatedList(s_iUnknownDerivedAttributeTemplate)))); private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, CancellationToken ct) @@ -721,10 +614,11 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(Im private static readonly MethodDeclarationSyntax CreateManagedVirtualFunctionTableMethodTemplate = MethodDeclaration(VoidStarStarSyntax, CreateManagedVirtualFunctionTableMethodName) .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword)); - private static InterfaceDeclarationSyntax GenerateImplementationVTable(ImmutableArray interfaceMethodStubs, BaseInterfacesInfo baseInterfaceTypeInfo) + private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterfaceAndMethods interfaceMethods) { const string vtableLocalName = "vtable"; - var interfaceType = interfaceMethodStubs[0].OriginalDefiningType; + var interfaceType = interfaceMethods.Interface.InterfaceType; + var interfaceMethodStubs = interfaceMethods.DeclaredMethods.Select(m => m.GenerationContext); ImmutableArray vtableExposedContexts = interfaceMethodStubs .Where(c => c.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) @@ -762,11 +656,12 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(Immutable BinaryExpression( SyntaxKind.MultiplyExpression, SizeOfExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))), - LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(1 + interfaceMethodStubs.Max(x => x.VtableIndexData.Index)))))))))))); + LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(3 + interfaceMethods.Methods.Length))))))))))); BlockSyntax fillBaseInterfaceSlots; - if (baseInterfaceTypeInfo.BaseInterface is null) + + if (interfaceMethods.BaseInterface is null) { // If we don't have a base interface, we need to manually fill in the base iUnknown slots. fillBaseInterfaceSlots = Block() @@ -868,7 +763,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(Immutable MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, TypeOfExpression( - ParseTypeName(baseInterfaceTypeInfo.BaseInterface.FullTypeName)), + ParseTypeName(interfaceMethods.BaseInterface.InterfaceType.FullTypeName)), //baseInterfaceTypeInfo.BaseInterface.FullTypeName)), IdentifierName("TypeHandle")))))), IdentifierName("ManagedVirtualMethodTable"))), Argument(IdentifierName(vtableLocalName)), @@ -876,7 +771,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(Immutable ParenthesizedExpression( BinaryExpression(SyntaxKind.MultiplyExpression, SizeOfExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))), - LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(baseInterfaceTypeInfo.StartingOffset)))))) + LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.ShadowingMethods.Count() + 3)))))) }))))); } @@ -962,12 +857,12 @@ static ExpressionSyntax CreateEmbeddedDataBlobCreationStatement(ReadOnlySpan attr.AttributeClass.ToDisplayString() == TypeNames.System_Runtime_InteropServices_GuidAttribute).SingleOrDefault(); @@ -979,23 +874,36 @@ public static ComInterfaceContext From(INamedTypeSymbol symbol, TypeDeclarationS } return new ComInterfaceContext( ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol), + syntax, new ContainingSyntaxContext(syntax), new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList), - -1, guid ?? Guid.Empty); } + public override int GetHashCode() + { + // ContainingSyntax and ContainingSyntaxContext do not implement GetHashCode + return HashCode.Combine(InterfaceType, TypeDefinitionContext, InterfaceId); + } + + public bool Equals(ComInterfaceContext other) + { + // ContainingSyntax and ContainingSyntaxContext are not used in the hash code + return InterfaceType == other.InterfaceType + && TypeDefinitionContext == other.TypeDefinitionContext + && InterfaceId == other.InterfaceId; + } } /// /// Represents a method that has been determined to be a COM interface method. /// - private record MethodInfo([property: Obsolete] IMethodSymbol Symbol, MethodDeclarationSyntax Syntax) + private sealed record MethodInfo([property: Obsolete] IMethodSymbol Symbol, MethodDeclarationSyntax Syntax, Diagnostic? Diagnostic) { - public static bool IsComInterface(INamedTypeSymbol iface, ISymbol member, [NotNullWhen(true)] out MethodInfo? comMethodInfo) + public static bool IsComInterface(ComInterfaceContext ifaceContext, ISymbol member, [NotNullWhen(true)] out MethodInfo? comMethodInfo) { comMethodInfo = null; - Location interfaceLocation = iface.Locations[0]; + Location interfaceLocation = ifaceContext.InterfaceDeclaration.GetLocation(); if (member.Kind == SymbolKind.Method && !member.IsStatic) { // We only support methods that are defined in the same partial interface definition as the @@ -1026,7 +934,7 @@ public static bool IsComInterface(INamedTypeSymbol iface, ISymbol member, [NotNu { var declaringSyntax = declaringSyntaxReference.GetSyntax(); Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration)); - if (declaringSyntax.GetLocation() == methodLocationInAttributedInterfaceDeclaration) + if (declaringSyntax.GetLocation().SourceSpan.Contains(methodLocationInAttributedInterfaceDeclaration.SourceSpan)) { comMethodDeclaringSyntax = (MethodDeclarationSyntax)declaringSyntax; break; @@ -1034,7 +942,8 @@ public static bool IsComInterface(INamedTypeSymbol iface, ISymbol member, [NotNu } if (comMethodDeclaringSyntax is null) throw new NotImplementedException("Found a method that was declared in the attributed interface declaration, but couldn't find the syntax for it."); - comMethodInfo = new((IMethodSymbol)member, comMethodDeclaringSyntax); + var diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, (IMethodSymbol)member); + comMethodInfo = new((IMethodSymbol)member, comMethodDeclaringSyntax, diag); return true; } return false; @@ -1044,32 +953,47 @@ public static bool IsComInterface(INamedTypeSymbol iface, ISymbol member, [NotNu /// /// Represents a method, its declaring interface, and its index in the interface's vtable. /// - private record ComInterfaceMethodContext(ComInterfaceContext DeclaringInterface, MethodInfo MethodInfo, int Index, Diagnostic? Diagnostic) + private sealed record ComInterfaceMethodContext(ComInterfaceContext DeclaringInterface, MethodInfo MethodInfo, int Index, IncrementalMethodStubGenerationContext GenerationContext) { - public static ImmutableArray GetMethods(Dictionary ifaceToBaseMap, ImmutableDictionary> ifaceToMethodsMap) - { + public Diagnostic? Diagnostic => MethodInfo.Diagnostic; + } + + /// + /// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces). + /// + private sealed record ComInterfaceAndMethods(ComInterfaceContext Interface, ImmutableArray Methods, ComInterfaceContext? BaseInterface) + { + /// + /// COM methods that are declared on the attributed interface declaration. + /// + public IEnumerable DeclaredMethods => Methods.Where(m => m.DeclaringInterface == Interface); + /// + /// COM methods that are declared on an interface the interface inherits from. + /// + public IEnumerable ShadowingMethods => Methods.Where(m => m.DeclaringInterface != Interface); + + public static IEnumerable GetMethods(Dictionary ifaceToBaseMap, ImmutableDictionary> ifaceToMethodsMap, StubEnvironment environment, CancellationToken ct) + { Dictionary> allMethodsCache = new(); - foreach(var kvp in ifaceToMethodsMap) + + foreach (var kvp in ifaceToMethodsMap) { - AddMethods(kvp.Key, kvp.Value); + IEnumerable asdf = AddMethods(kvp.Key, kvp.Value); } - return allMethodsCache.Select(kvp => new ComInterfaceAndMethods(kvp.Key, kvp.Value.ToImmutableArray())).ToImmutableArray(); + return allMethodsCache.Select(kvp => new ComInterfaceAndMethods(kvp.Key, kvp.Value.ToImmutableArray(), ifaceToBaseMap[kvp.Key])).ToImmutableArray(); IEnumerable AddMethods(ComInterfaceContext iface, IEnumerable declaredMethods) { if (allMethodsCache.TryGetValue(iface, out var cachedValue)) { - foreach(var value in cachedValue) - { - yield return value; - } + return cachedValue; } - List metods = new(); int startingIndex = 3; - if (ifaceToBaseMap.TryGetValue(iface, out var baseComIface)) + List methods = new(); + if (ifaceToBaseMap.TryGetValue(iface, out var baseComIface) && baseComIface is not null) { if (!allMethodsCache.TryGetValue(baseComIface, out var baseMethods)) { @@ -1079,33 +1003,18 @@ IEnumerable AddMethods(ComInterfaceContext iface, IEn foreach (var method in baseMethods) { startingIndex++; - yield return method; + methods.Add(method); } } foreach (var method in declaredMethods) { - yield return new ComInterfaceMethodContext(iface, method, startingIndex++, null); + var ctx = CalculateStubInformation(method.Syntax, method.Symbol, startingIndex, environment, ct); + methods.Add(new ComInterfaceMethodContext(iface, method, startingIndex++, ctx)); } - + allMethodsCache[iface] = methods; + return methods; } - } } - - /// - /// Represents an interface and all of the methods that need to be generated for it. - /// - private record ComInterfaceAndMethods(ComInterfaceContext Interface, ImmutableArray Methods) - { - /// - /// COM methods that are declared on the attributed interface declaration. - /// - public IEnumerable DeclaredMethods => Methods.Where(m => m.DeclaringInterface == Interface); - - /// - /// COM methods that are declared on an interface the interface inherits from. - /// - public IEnumerable ShadowingMethods => Methods.Where(m => m.DeclaringInterface != Interface); - } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs index 9f5ae862672b2..27961dec1b7fb 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs @@ -65,33 +65,5 @@ internal static class ComInterfaceGeneratorHelpers return MarshallingGeneratorFactoryKey.Create((env.TargetFramework, env.TargetFrameworkVersion), generatorFactory); } - - enum ComInterfaceMethodValidity - { - Valid, - NotDefinedInAttributedSyntax, - - } - public static bool IsComMethod(ISymbol member, InterfaceDeclarationSyntax iface) - { - Location? locationInAttributeSyntax = null; - if (member.Kind == SymbolKind.Method && !member.IsStatic) - { - Location interfaceLocation = iface.GetLocation(); - // We only support methods that are defined in the same partial interface definition as the - // [GeneratedComInterface] attribute. - // This restriction not only makes finding the syntax for a given method cheaper, - // but it also enables us to ensure that we can determine vtable method order easily. - foreach (var location in member.Locations) - { - if (location.SourceTree == interfaceLocation.SourceTree - && interfaceLocation.SourceSpan.Contains(location.SourceSpan)) - { - locationInAttributeSyntax = location; - } - } - } - return locationInAttributeSyntax is not null; - } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs index a6b6d5a5922ca..fb0dd80ca7f1a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs @@ -31,23 +31,6 @@ internal static class IncrementalValuesProviderExtensions }); } - public static IncrementalValuesProvider<(T Left, U Right)> ZipSingle(this IncrementalValuesProvider left, IncrementalValueProvider right) - { - return left.Collect().Combine(right).SelectMany((data, ct) => - { - ImmutableArray<(T, U)>.Builder builder = ImmutableArray.CreateBuilder<(T, U)>(data.Left.Length); - for (int i = 0; i < data.Left.Length; i++) - { - builder.Add((data.Left[i], data.Right)); - } - return builder.ToImmutable(); - }); - - - - } - - /// /// Format the syntax nodes in the given provider such that we will not re-normalize if the input nodes have not changed. /// diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs index f91d7be4ea2bc..adfa131473375 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; @@ -14,18 +15,47 @@ namespace Microsoft.Interop /// for many scenarios. This wrapper type allows us to use s in our other record types without having to write an Equals method /// that we may forget to update if we add new elements to the record. /// - public readonly record struct SequenceEqualImmutableArray(ImmutableArray Array, IEqualityComparer Comparer) + public readonly record struct SequenceEqualImmutableArray(ImmutableArray Array, IEqualityComparer Comparer) : IList { public SequenceEqualImmutableArray(ImmutableArray array) : this(array, EqualityComparer.Default) { } + public T this[int index] { get => ((IList)Array)[index]; set => ((IList)Array)[index] = value; } + + public int Count => ((ICollection)Array).Count; + + public bool IsReadOnly => ((ICollection)Array).IsReadOnly; + + public void Add(T item) => ((ICollection)Array).Add(item); + public void Clear() => ((ICollection)Array).Clear(); + public bool Contains(T item) => Array.Contains(item); + public void CopyTo(T[] array, int arrayIndex) => Array.CopyTo(array, arrayIndex); + public bool Equals(SequenceEqualImmutableArray other) { return Array.SequenceEqual(other.Array, Comparer); } + public IEnumerator GetEnumerator() => ((IEnumerable)Array).GetEnumerator(); public override int GetHashCode() => throw new UnreachableException(); + public int IndexOf(T item) => Array.IndexOf(item); + public void Insert(int index, T item) => ((IList)Array).Insert(index, item); + public bool Remove(T item) => ((ICollection)Array).Remove(item); + public void RemoveAt(int index) => ((IList)Array).RemoveAt(index); + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)Array).GetEnumerator(); + } + + public static class IEnumerableSequenceEqualImmutableArrayExtensions + { + public static SequenceEqualImmutableArray ToSequenceEqualImmutableArray(this IEnumerable source, IEqualityComparer comparer) + { + return new(source.ToImmutableArray(), comparer); + } + public static SequenceEqualImmutableArray ToSequenceEqualImmutableArray(this IEnumerable source) + { + return new(source.ToImmutableArray()); + } } } From b5b300449751b14126e0d776cca2fdae24e1196e Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Wed, 19 Apr 2023 09:49:37 -0700 Subject: [PATCH 05/31] wip --- .../HashCode.cs | 22 +++ .../ValueEqualityImmutableDictionary.cs | 58 ++++++++ ...terfaceGenerator.Tests.DerivedComObject.cs | 30 ++++ ...nerator.Tests.ManagedObjectExposedToCom.cs | 30 ++++ ...InterfaceGenerator.Tests.IComInterface1.cs | 118 +++++++++++++++ ...aceGenerator.Tests.IDerivedComInterface.cs | 139 ++++++++++++++++++ .../ManagedToNativeStubs.g.cs | 106 +++++++++++++ .../NativeInterfaces.g.cs | 33 +++++ .../NativeToManagedStubs.g.cs | 49 ++++++ .../PopulateVTable.g.cs | 21 +++ .../LibraryImports.g.cs | 110 ++++++++++++++ 11 files changed, 716 insertions(+) create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.DerivedComObject.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.ManagedObjectExposedToCom.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IComInterface1.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/ManagedToNativeStubs.g.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeInterfaces.g.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeToManagedStubs.g.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/PopulateVTable.g.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.LibraryImportGenerator/Microsoft.Interop.LibraryImportGenerator/LibraryImports.g.cs diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs new file mode 100644 index 0000000000000..5834bb92f6ab3 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.Interop +{ + public static class HashCode + { + public static int Combine(params object[] values) + { + int hash = 31; + foreach (object value in values) + { + hash = hash * 29 + value.GetHashCode(); + } + return hash; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs new file mode 100644 index 0000000000000..f8469ebeea90a --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs @@ -0,0 +1,58 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Text; + +namespace Microsoft.Interop +{ + internal record struct ValueEqualityImmutableDictionary(ImmutableDictionary Map) : IDictionary + { + public bool Equals(ValueEqualityImmutableDictionary other) + { + if (Count != other.Count) + { + return false; + } + + foreach(var kvp in this) + { + if (!other.TryGetValue(kvp.Key, out var value) || !kvp.Value.Equals(value)) + { + return false; + } + } + return true; + } + + public override int GetHashCode() + { + return HashCode.Combine(Map.Values); + } + + public U this[T key] { get => ((IDictionary)Map)[key]; set => ((IDictionary)Map)[key] = value; } + + public ICollection Keys => ((IDictionary)Map).Keys; + + public ICollection Values => ((IDictionary)Map).Values; + + public int Count => Map.Count; + + public bool IsReadOnly => ((ICollection>)Map).IsReadOnly; + + public void Add(T key, U value) => ((IDictionary)Map).Add(key, value); + public void Add(KeyValuePair item) => ((ICollection>)Map).Add(item); + public void Clear() => ((ICollection>)Map).Clear(); + public bool Contains(KeyValuePair item) => Map.Contains(item); + public bool ContainsKey(T key) => Map.ContainsKey(key); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => ((ICollection>)Map).CopyTo(array, arrayIndex); + public IEnumerator> GetEnumerator() => ((IEnumerable>)Map).GetEnumerator(); + public bool Remove(T key) => ((IDictionary)Map).Remove(key); + public bool Remove(KeyValuePair item) => ((ICollection>)Map).Remove(item); + public bool TryGetValue(T key, out U value) => Map.TryGetValue(key, out value); + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)Map).GetEnumerator(); + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.DerivedComObject.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.DerivedComObject.cs new file mode 100644 index 0000000000000..cb78fb469e93b --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.DerivedComObject.cs @@ -0,0 +1,30 @@ +file sealed unsafe class ComClassInformation : System.Runtime.InteropServices.Marshalling.IComExposedClass +{ + private static volatile System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* s_vtables; + public static System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) + { + count = 1; + if (s_vtables == null) + { + System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* vtables = (System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry*)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComClassInformation), sizeof(System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry) * 1); + System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails details; + details = System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(ComInterfaceGenerator.Tests.IComInterface1).TypeHandle); + vtables[0] = new() + { + IID = details.Iid, + Vtable = (nint)details.ManagedVirtualMethodTable + }; + s_vtables = vtables; + } + + return s_vtables; + } +} + +namespace ComInterfaceGenerator.Tests +{ + [System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute] + partial class DerivedComObject + { + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.ManagedObjectExposedToCom.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.ManagedObjectExposedToCom.cs new file mode 100644 index 0000000000000..a5c4b0aa2f52e --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.ManagedObjectExposedToCom.cs @@ -0,0 +1,30 @@ +file sealed unsafe class ComClassInformation : System.Runtime.InteropServices.Marshalling.IComExposedClass +{ + private static volatile System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* s_vtables; + public static System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) + { + count = 1; + if (s_vtables == null) + { + System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* vtables = (System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry*)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComClassInformation), sizeof(System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry) * 1); + System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails details; + details = System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(ComInterfaceGenerator.Tests.IComInterface1).TypeHandle); + vtables[0] = new() + { + IID = details.Iid, + Vtable = (nint)details.ManagedVirtualMethodTable + }; + s_vtables = vtables; + } + + return s_vtables; + } +} + +namespace ComInterfaceGenerator.Tests +{ + [System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute] + partial class ManagedObjectExposedToCom + { + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IComInterface1.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IComInterface1.cs new file mode 100644 index 0000000000000..a5838645227fa --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IComInterface1.cs @@ -0,0 +1,118 @@ +file unsafe class InterfaceInformation : System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType +{ + public static System.Guid Iid { get; } = new(new System.ReadOnlySpan(new byte[] { 3, 153, 63, 44, 134, 181, 177, 70, 136, 27, 173, 252, 233, 175, 71, 177 })); + + private static void** _vtable; + public static void** ManagedVirtualMethodTable => _vtable != null ? _vtable : (_vtable = InterfaceImplementation.CreateManagedVirtualFunctionTable()); +} + +[System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] +file unsafe partial interface InterfaceImplementation : global::ComInterfaceGenerator.Tests.IComInterface1 +{ + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + int global::ComInterfaceGenerator.Tests.IComInterface1.GetData() + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IComInterface1)); + int __retVal; + int __invokeRetVal; + { + __invokeRetVal = ((delegate* unmanaged )__vtable_native[3])(__this, &__retVal); + } + + // Unmarshal - Convert native data to managed data. + System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); + return __retVal; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + void global::ComInterfaceGenerator.Tests.IComInterface1.SetData(int n) + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IComInterface1)); + int __invokeRetVal; + { + __invokeRetVal = ((delegate* unmanaged )__vtable_native[4])(__this, n); + } + + // Unmarshal - Convert native data to managed data. + System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); + } +} + +file unsafe partial interface InterfaceImplementation +{ + [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] + internal static int ABI_GetData(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int* __invokeRetValUnmanaged__param) + { + global::ComInterfaceGenerator.Tests.IComInterface1 @this = default; + ref int __invokeRetValUnmanaged = ref *__invokeRetValUnmanaged__param; + int __invokeRetVal = default; + int __retVal = default; + try + { + // Unmarshal - Convert native data to managed data. + __retVal = 0; // S_OK + @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); + __invokeRetVal = @this.GetData(); + // Marshal - Convert managed data to native data. + __invokeRetValUnmanaged = __invokeRetVal; + } + catch (System.Exception __exception) + { + __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); + } + + return __retVal; + } + + [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] + internal static int ABI_SetData(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int n) + { + global::ComInterfaceGenerator.Tests.IComInterface1 @this = default; + int __retVal = default; + try + { + // Unmarshal - Convert native data to managed data. + __retVal = 0; // S_OK + @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); + @this.SetData(n); + } + catch (System.Exception __exception) + { + __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); + } + + return __retVal; + } +} + +file unsafe partial interface InterfaceImplementation +{ + internal static void** CreateManagedVirtualFunctionTable() + { + void** vtable = (void**)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(global::ComInterfaceGenerator.Tests.IComInterface1), sizeof(void*) * 5); + { + nint v0, v1, v2; + System.Runtime.InteropServices.ComWrappers.GetIUnknownImpl(out v0, out v1, out v2); + vtable[0] = (void*)v0; + vtable[1] = (void*)v1; + vtable[2] = (void*)v2; + } + + { + vtable[3] = (void*)(delegate* unmanaged )&ABI_GetData; + vtable[4] = (void*)(delegate* unmanaged )&ABI_SetData; + } + + return vtable; + } +} + +namespace ComInterfaceGenerator.Tests +{ + [System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute] + public partial interface IComInterface1 + { + } +} \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs new file mode 100644 index 0000000000000..0f90d2c1a4e5b --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs @@ -0,0 +1,139 @@ +file unsafe class InterfaceInformation : System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType +{ + public static System.Guid Iid { get; } = new(new System.ReadOnlySpan(new byte[] { 100, 179, 13, 127, 4, 60, 135, 68, 145, 147, 75, 176, 93, 199, 182, 84 })); + + private static void** _vtable; + public static void** ManagedVirtualMethodTable => _vtable != null ? _vtable : (_vtable = InterfaceImplementation.CreateManagedVirtualFunctionTable()); +} + +[System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] +file unsafe partial interface InterfaceImplementation : global::ComInterfaceGenerator.Tests.IDerivedComInterface +{ + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + void global::ComInterfaceGenerator.Tests.IDerivedComInterface.SetName(string name) + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IDerivedComInterface)); + int __invokeRetVal; + // Pin - Pin data in preparation for calling the P/Invoke. + fixed (void* __name_native = &global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.GetPinnableReference(name)) + { + __invokeRetVal = ((delegate* unmanaged )__vtable_native[5])(__this, (ushort*)__name_native); + } + + // Unmarshal - Convert native data to managed data. + System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + string global::ComInterfaceGenerator.Tests.IDerivedComInterface.GetName() + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IDerivedComInterface)); + string __retVal; + ushort* __retVal_native = default; + int __invokeRetVal; + try + { + { + __invokeRetVal = ((delegate* unmanaged )__vtable_native[6])(__this, &__retVal_native); + } + + // Unmarshal - Convert native data to managed data. + System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); + __retVal = global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.ConvertToManaged(__retVal_native); + } + finally + { + // Cleanup - Perform required cleanup. + global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.Free(__retVal_native); + } + + return __retVal; + } +} + +file unsafe partial interface InterfaceImplementation +{ + [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] + internal static int ABI_SetName(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, ushort* __name_native) + { + global::ComInterfaceGenerator.Tests.IDerivedComInterface @this = default; + string name = default; + int __retVal = default; + try + { + // Unmarshal - Convert native data to managed data. + __retVal = 0; // S_OK + name = global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.ConvertToManaged(__name_native); + @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); + @this.SetName(name); + } + catch (System.Exception __exception) + { + __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); + } + finally + { + // Cleanup - Perform required cleanup. + global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.Free(__name_native); + } + + return __retVal; + } + + [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] + internal static int ABI_GetName(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, ushort** __invokeRetValUnmanaged__param) + { + global::ComInterfaceGenerator.Tests.IDerivedComInterface @this = default; + ref ushort* __invokeRetValUnmanaged = ref *__invokeRetValUnmanaged__param; + string __invokeRetVal = default; + int __retVal = default; + try + { + // Unmarshal - Convert native data to managed data. + __retVal = 0; // S_OK + @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); + __invokeRetVal = @this.GetName(); + // Marshal - Convert managed data to native data. + __invokeRetValUnmanaged = (ushort*)global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.ConvertToUnmanaged(__invokeRetVal); + } + catch (System.Exception __exception) + { + __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); + } + finally + { + // Cleanup - Perform required cleanup. + global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.Free(__invokeRetValUnmanaged); + } + + return __retVal; + } +} + +file unsafe partial interface InterfaceImplementation +{ + internal static void** CreateManagedVirtualFunctionTable() + { + void** vtable = (void**)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(global::ComInterfaceGenerator.Tests.IDerivedComInterface), sizeof(void*) * 7); + { + System.Runtime.InteropServices.NativeMemory.Copy(System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(global::ComInterfaceGenerator.Tests.IComInterface1).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * 5)); + } + + { + vtable[5] = (void*)(delegate* unmanaged )&ABI_SetName; + vtable[6] = (void*)(delegate* unmanaged )&ABI_GetName; + } + + return vtable; + } +} + +namespace ComInterfaceGenerator.Tests +{ + [System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute] + public partial interface IDerivedComInterface + { + } +} \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/ManagedToNativeStubs.g.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/ManagedToNativeStubs.g.cs new file mode 100644 index 0000000000000..7eb9069150637 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/ManagedToNativeStubs.g.cs @@ -0,0 +1,106 @@ +// +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class ImplicitThis + { + internal unsafe partial interface INativeObject + { + internal unsafe partial interface Native + { + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + int global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject.GetData() + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject)); + int __retVal; + { + __retVal = ((delegate* unmanaged )__vtable_native[0])(__this); + } + + return __retVal; + } + } + } + } + } +} +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class ImplicitThis + { + internal unsafe partial interface INativeObject + { + internal unsafe partial interface Native + { + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + void global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject.SetData(int x) + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject)); + { + ((delegate* unmanaged )__vtable_native[1])(__this, x); + } + } + } + } + } + } +} +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class NoImplicitThis + { + internal unsafe partial interface IStaticMethodTable + { + internal unsafe partial interface Native + { + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + int global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.IStaticMethodTable.Add(int x, int y) + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.IStaticMethodTable)); + int __retVal; + { + __retVal = ((delegate* unmanaged )__vtable_native[0])(x, y); + } + + return __retVal; + } + } + } + } + } +} +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class NoImplicitThis + { + internal unsafe partial interface IStaticMethodTable + { + internal unsafe partial interface Native + { + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + int global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.IStaticMethodTable.Multiply(int x, int y) + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.IStaticMethodTable)); + int __retVal; + { + __retVal = ((delegate* unmanaged )__vtable_native[1])(x, y); + } + + return __retVal; + } + } + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeInterfaces.g.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeInterfaces.g.cs new file mode 100644 index 0000000000000..63e98564f3393 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeInterfaces.g.cs @@ -0,0 +1,33 @@ +// +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class ImplicitThis + { + internal unsafe partial interface INativeObject + { + [System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] + internal partial interface Native : INativeObject + { + } + } + } + } +} +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class NoImplicitThis + { + internal unsafe partial interface IStaticMethodTable + { + [System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] + internal partial interface Native : IStaticMethodTable + { + } + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeToManagedStubs.g.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeToManagedStubs.g.cs new file mode 100644 index 0000000000000..3fe35a8622fa7 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeToManagedStubs.g.cs @@ -0,0 +1,49 @@ +// +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class ImplicitThis + { + internal unsafe partial interface INativeObject + { + internal unsafe partial interface Native + { + [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] + internal static int ABI_GetData(void* __this_native) + { + global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject @this; + int __retVal = default; + // Unmarshal - Convert native data to managed data. + @this = (global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject)System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapper.GetObjectForUnmanagedWrapper>(__this_native); + __retVal = @this.GetData(); + return __retVal; + } + } + } + } + } +} +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class ImplicitThis + { + internal unsafe partial interface INativeObject + { + internal unsafe partial interface Native + { + [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] + internal static void ABI_SetData(void* __this_native, int x) + { + global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject @this; + // Unmarshal - Convert native data to managed data. + @this = (global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject)System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapper.GetObjectForUnmanagedWrapper>(__this_native); + @this.SetData(x); + } + } + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/PopulateVTable.g.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/PopulateVTable.g.cs new file mode 100644 index 0000000000000..8429a3ea8c269 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/PopulateVTable.g.cs @@ -0,0 +1,21 @@ +// +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class ImplicitThis + { + internal unsafe partial interface INativeObject + { + internal unsafe partial interface Native + { + internal static unsafe void PopulateUnmanagedVirtualMethodTable(void** vtable) + { + vtable[0] = (void*)(delegate* unmanaged )&ABI_GetData; + vtable[1] = (void*)(delegate* unmanaged )&ABI_SetData; + } + } + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.LibraryImportGenerator/Microsoft.Interop.LibraryImportGenerator/LibraryImports.g.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.LibraryImportGenerator/Microsoft.Interop.LibraryImportGenerator/LibraryImports.g.cs new file mode 100644 index 0000000000000..e6463e52e37a8 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.LibraryImportGenerator/Microsoft.Interop.LibraryImportGenerator/LibraryImports.g.cs @@ -0,0 +1,110 @@ +// +namespace ComInterfaceGenerator.Tests +{ + unsafe partial class NativeExportsNE + { + [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "set_com_object_data", ExactSpelling = true)] + public static extern partial void SetComObjectData(void* obj, int data); + } +} +namespace ComInterfaceGenerator.Tests +{ + unsafe partial class NativeExportsNE + { + [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "get_com_object_data", ExactSpelling = true)] + public static extern partial int GetComObjectData(void* obj); + } +} +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "get_com_object", ExactSpelling = true)] + public static extern partial void* NewNativeObject(); + } +} +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class ImplicitThis + { + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.LibraryImportGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + public static partial global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.NativeObject NewNativeObject() + { + global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.NativeObject __retVal; + void* __retVal_native; + { + __retVal_native = __PInvoke(); + } + + // Unmarshal - Convert native data to managed data. + __retVal = global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.NativeObjectMarshaller.ConvertToManaged(__retVal_native); + return __retVal; + // Local P/Invoke + [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "new_native_object", ExactSpelling = true)] + static extern unsafe void* __PInvoke(); + } + } + } +} +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class ImplicitThis + { + [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "delete_native_object", ExactSpelling = true)] + public static extern partial void DeleteNativeObject(void* obj); + } + } +} +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class ImplicitThis + { + [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "set_native_object_data", ExactSpelling = true)] + public static extern partial void SetNativeObjectData(void* obj, int data); + } + } +} +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class ImplicitThis + { + [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "get_native_object_data", ExactSpelling = true)] + public static extern partial int GetNativeObjectData(void* obj); + } + } +} +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal unsafe partial class NoImplicitThis + { + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.LibraryImportGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + public static partial global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.StaticMethodTable GetStaticFunctionTable() + { + global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.StaticMethodTable __retVal; + void* __retVal_native; + { + __retVal_native = __PInvoke(); + } + + // Unmarshal - Convert native data to managed data. + __retVal = global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.StaticMethodTableMarshaller.ConvertToManaged(__retVal_native); + return __retVal; + // Local P/Invoke + [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "get_static_function_table", ExactSpelling = true)] + static extern unsafe void* __PInvoke(); + } + } + } +} From b7a677c37c6abc9e58cd34dd4b9d98f450b18ab7 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Wed, 19 Apr 2023 17:19:57 -0700 Subject: [PATCH 06/31] wip --- .../ComInterfaceGenerator.cs | 180 ++++++++++-------- .../SequenceEqualImmutableArray.cs | 24 +-- .../ValueEqualityImmutableDictionary.cs | 31 ++- 3 files changed, 136 insertions(+), 99 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 7c855be8c7702..bee5ebb26542b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -15,6 +15,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using static Microsoft.Interop.CollectionExtensions; namespace Microsoft.Interop { @@ -53,32 +54,6 @@ internal BaseInterfacesInfo(ManagedTypeInfo? baseInterface, bool isMarkerInterfa public bool IsMarkerInterface { get; } } - public record ShadowingMethodContext( - int Offset, - ManagedTypeInfo BaseInterfaceType, - string MethodName, - MethodDeclarationSyntax Syntax, - ManagedTypeInfo ReturnType, - ImmutableArray<(ManagedTypeInfo Type, string Name, ParameterSyntax Syntax)> Parameters) - { - public MethodDeclarationSyntax Generate() - { - // DeclarationCopiedFromBaseDeclaration() - // { - // return (()this).(); - // } - return Syntax.WithBody( - Block( - ReturnStatement( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - CastExpression(BaseInterfaceType.Syntax, IdentifierName(Token(SyntaxKind.ThisKeyword))), - IdentifierName(MethodName)), - ArgumentList( - SeparatedList(Parameters.Select(p => Argument(IdentifierName(p.Name))))))))); - } - } private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, [NotNullWhen(true)] out INamedTypeSymbol? baseComIface) { @@ -125,7 +100,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.RegisterDiagnostics(invalidTypeDiagnostics.Select((data, ct) => data.Diagnostic)); } - // Get the methods for each interface, without the indices + // Get the information we need about methods themselves var interfaceMethods = interfacesToGenerate.Select((data, ct) => { INamedTypeSymbol iface = data.Symbol; @@ -137,9 +112,10 @@ public void Initialize(IncrementalGeneratorInitializationContext context) comMethods.Add(methodInfo); } } - return comMethods.ToImmutableArray(); + return comMethods.ToSequenceEqualImmutableArray(); }); + // Create a map of Com interface to its base for use later. var ifaceToBaseMap = interfacesToGenerate.Collect().Select((data, ct) => { Dictionary ifaceToBaseMap = new(); @@ -152,64 +128,65 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { ifaceToBaseMap.Add(iface.Context, iface.BaseInterfaceSymbol is not null ? contexts[iface.BaseInterfaceSymbol] : null); } - return ifaceToBaseMap; + return ifaceToBaseMap.ToValueEqualImmutable(); }); - var interfaceToMethodsMap = interfacesToGenerate + // Generate a map from Com interface to the methods it declares + var interfaceToDeclaredMethodsMap = interfacesToGenerate .Select((iface, ct) => iface.Context) .Zip(interfaceMethods) .Collect() .Select((data, ct) => { - return data.ToImmutableDictionary<(ComInterfaceContext, ImmutableArray), ComInterfaceContext, ImmutableArray>( + return data.ToValueEqualityImmutableDictionary<(ComInterfaceContext, SequenceEqualImmutableArray), ComInterfaceContext, SequenceEqualImmutableArray>( static pair => pair.Item1, static pair => pair.Item2); }); - var interfaceAndMethodsContexts = interfaceToMethodsMap + // Combine info about base methods and declared methods to get a list of interfaces, and all the methods they need to worry about (including both declared and inherited methods) + var interfaceAndMethodsContexts = interfaceToDeclaredMethodsMap .Combine(ifaceToBaseMap) .Combine(context.CreateStubEnvironmentProvider()) .SelectMany((data, ct) => { var ((ifaceToMethodsMap, ifaceToBaseMap), env) = data; - return ComInterfaceAndMethods.GetMethods(ifaceToBaseMap, ifaceToMethodsMap, env, ct); + return ComInterfaceAndMethods.GetAllMethods(ifaceToBaseMap, ifaceToMethodsMap, env, ct); }); + // Separate the methods which declare methods from those that don't declare methods var interfacesWithMethodsAndItsMethods = interfaceAndMethodsContexts .Where(data => data.DeclaredMethods.Any()); var interfacesWithMethods = interfacesWithMethodsAndItsMethods .Select(static (data, ct) => data.Interface); - // Marker interfaces are COM interfaces that don't have any methods. - // The lack of methods breaks the mechanism we use later to stitch back together interface-level data - // and method-level data, but that's okay because marker interfaces are much simpler. - // We'll handle them seperately because they are so simple. - var markerInterfaces = interfaceAndMethodsContexts - .Where(data => !data.DeclaredMethods.Any()) - .Select(static (data, ct) => data.Interface); - - var markerInterfaceIUnknownDerived = markerInterfaces - .Select(static (context, ct) => GenerateIUnknownDerivedAttributeApplication(context)) - .WithComparer(SyntaxEquivalentComparer.Instance) - .SelectNormalized(); - - context.RegisterSourceOutput(markerInterfaces.Zip(markerInterfaceIUnknownDerived), (context, data) => { - var (interfaceContext, iUnknownDerivedAttributeApplication) = data; - context.AddSource( - interfaceContext.InterfaceType.FullTypeName.Replace("global::", ""), - GenerateMarkerInterfaceSource(interfaceContext) + iUnknownDerivedAttributeApplication); - }); + // Marker interfaces are COM interfaces that don't have any methods. + // The lack of methods breaks the mechanism we use later to stitch back together interface-level data + // and method-level data, but that's okay because marker interfaces are much simpler. + // We'll handle them seperately because they are so simple. + var markerInterfaces = interfaceAndMethodsContexts + .Where(data => !data.DeclaredMethods.Any()) + .Select(static (data, ct) => data.Interface); + + var markerInterfaceIUnknownDerived = markerInterfaces + .Select(static (context, ct) => GenerateIUnknownDerivedAttributeApplication(context)) + .WithComparer(SyntaxEquivalentComparer.Instance) + .SelectNormalized(); + + context.RegisterSourceOutput(markerInterfaces.Zip(markerInterfaceIUnknownDerived), (context, data) => + { + var (interfaceContext, iUnknownDerivedAttributeApplication) = data; + context.AddSource( + interfaceContext.InterfaceType.FullTypeName.Replace("global::", ""), + GenerateMarkerInterfaceSource(interfaceContext) + iUnknownDerivedAttributeApplication); + }); + } var allMethods = interfaceAndMethodsContexts.SelectMany(static (data, ct) => data.DeclaredMethods); - // Split the methods we want to generate and the ones we don't into two separate groups. - var methodsToGenerate = allMethods.Where(static data => - { - return data.Diagnostic is null; - }); + // Split the methods we want to generate and the ones with warnings into different groups, and warn on the invalid methods { var invalidMethods = allMethods.Where(static data => data.Diagnostic is not null); @@ -218,7 +195,10 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.ReportDiagnostic(invalidMethod.Diagnostic); }); } - + var methodsToGenerate = allMethods.Where(static data => + { + return data.Diagnostic is null; + }); IncrementalValuesProvider generateStubInformation = methodsToGenerate.Select((data, ct) => data.GenerationContext); // Generate the code for the managed-to-unmanaged stubs and the diagnostics from code-generation. @@ -238,13 +218,19 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.RegisterDiagnostics(generateManagedToNativeStub.SelectMany((stubInfo, ct) => stubInfo.Diagnostics.Array)); - var managedToNativeInterfaceImplementations = generateManagedToNativeStub - .Collect() - .SelectMany(static (stubs, ct) => GroupContextsForInterfaceGeneration(stubs)) - .Select(static (interfaceGroup, ct) => GenerateImplementationInterface(interfaceGroup.Array)) - .WithTrackingName(StepNames.GenerateManagedToNativeInterfaceImplementation) + var managedToNativeInterfaceImplementations = interfacesWithMethodsAndItsMethods.Select((data, ct) => + { + return GenerateImplementationInterface(data); + }).WithTrackingName(StepNames.GenerateManagedToNativeInterfaceImplementation) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); + //var managedToNativeInterfaceImplementations = generateManagedToNativeStub + // .Collect() + // .SelectMany(static (stubs, ct) => GroupContextsForInterfaceGeneration(stubs)) + // .Select(static (interfaceGroup, ct) => GenerateImplementationInterface(interfaceGroup.Array)) + // .WithTrackingName(StepNames.GenerateManagedToNativeInterfaceImplementation) + // .WithComparer(SyntaxEquivalentComparer.Instance) + // .SelectNormalized(); // Filter the list of all stubs to only the stubs that requested unmanaged-to-managed stub generation. IncrementalValuesProvider nativeToManagedStubContexts = @@ -474,13 +460,13 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M containingSyntaxContext, methodSyntaxTemplate, new MethodSignatureDiagnosticLocations(syntax), - new SequenceEqualImmutableArray(callConv, SyntaxEquivalentComparer.Instance), + callConv.ToSequenceEqualImmutableArray(SyntaxEquivalentComparer.Instance), virtualMethodIndexData, new ComExceptionMarshalling(), ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.ManagedToUnmanaged), ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.UnmanagedToManaged), typeKeyOwner, - new SequenceEqualImmutableArray(generatorDiagnostics.Diagnostics.ToImmutableArray()), + generatorDiagnostics.Diagnostics.ToSequenceEqualImmutableArray(), ComInterfaceDispatchMarshallingInfo.Instance); } @@ -594,12 +580,12 @@ private static ImmutableArray interfaceGroup) + private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethods interfaceGroup) { - var definingType = interfaceGroup[0].OriginalDefiningType; + var definingType = interfaceGroup.Interface.InterfaceType; return ImplementationInterfaceTemplate .AddBaseListTypes(SimpleBaseType(definingType.Syntax)) - .WithMembers(List(interfaceGroup.OfType().Select(context => context.Stub.Node))) + .WithMembers(List(interfaceGroup.DeclaredMethods.Select(m => m.ManagedToUnmanagedStub).OfType().Select(ctx => ctx.Stub.Node))) .AddAttributeLists(AttributeList(SingletonSeparatedList(Attribute(ParseName(TypeNames.System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute))))); } private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(ImmutableArray interfaceGroup) @@ -898,7 +884,12 @@ public bool Equals(ComInterfaceContext other) /// /// Represents a method that has been determined to be a COM interface method. /// - private sealed record MethodInfo([property: Obsolete] IMethodSymbol Symbol, MethodDeclarationSyntax Syntax, Diagnostic? Diagnostic) + private sealed record MethodInfo( + [property: Obsolete] IMethodSymbol Symbol, + MethodDeclarationSyntax Syntax, + string MethodName, + SequenceEqualImmutableArray<(ManagedTypeInfo Type, string Name, RefKind RefKind)> Parameters, + Diagnostic? Diagnostic) { public static bool IsComInterface(ComInterfaceContext ifaceContext, ISymbol member, [NotNullWhen(true)] out MethodInfo? comMethodInfo) { @@ -921,8 +912,6 @@ public static bool IsComInterface(ComInterfaceContext ifaceContext, ISymbol memb } } - MethodDeclarationSyntax? comMethodDeclaringSyntax = null; - // TODO: this should cause a diagnostic if (methodLocationInAttributedInterfaceDeclaration is null) { @@ -930,6 +919,7 @@ public static bool IsComInterface(ComInterfaceContext ifaceContext, ISymbol memb } // Find the matching declaration syntax + MethodDeclarationSyntax? comMethodDeclaringSyntax = null; foreach (var declaringSyntaxReference in member.DeclaringSyntaxReferences) { var declaringSyntax = declaringSyntaxReference.GetSyntax(); @@ -942,8 +932,16 @@ public static bool IsComInterface(ComInterfaceContext ifaceContext, ISymbol memb } if (comMethodDeclaringSyntax is null) throw new NotImplementedException("Found a method that was declared in the attributed interface declaration, but couldn't find the syntax for it."); + + List<(ManagedTypeInfo ParameterType, string Name, RefKind RefKind)> parameters = new(); + foreach (var parameter in ((IMethodSymbol)member).Parameters) + { + parameters.Add((ManagedTypeInfo.CreateTypeInfoForTypeSymbol(parameter.Type), parameter.Name, parameter.RefKind)); + } + var diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, (IMethodSymbol)member); - comMethodInfo = new((IMethodSymbol)member, comMethodDeclaringSyntax, diag); + + comMethodInfo = new((IMethodSymbol)member, comMethodDeclaringSyntax, member.Name, parameters.ToSequenceEqualImmutableArray(), diag); return true; } return false; @@ -955,13 +953,45 @@ public static bool IsComInterface(ComInterfaceContext ifaceContext, ISymbol memb /// private sealed record ComInterfaceMethodContext(ComInterfaceContext DeclaringInterface, MethodInfo MethodInfo, int Index, IncrementalMethodStubGenerationContext GenerationContext) { + public GeneratedMethodContextBase ManagedToUnmanagedStub + { + get + { + if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)) + { + return (GeneratedMethodContextBase)new SkippedStubContext(DeclaringInterface.InterfaceType); + } + var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext); + return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); + } + } + public Diagnostic? Diagnostic => MethodInfo.Diagnostic; + + public MethodDeclarationSyntax GenerateShadow() + { + // DeclarationCopiedFromBaseDeclaration() + // { + // return (()this).(); + // } + return MethodInfo.Syntax.WithBody( + Block( + ReturnStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + CastExpression(DeclaringInterface.InterfaceType.Syntax, IdentifierName(Token(SyntaxKind.ThisKeyword))), + IdentifierName(MethodInfo.MethodName)), + ArgumentList( + // TODO: RefKind keywords + SeparatedList(MethodInfo.Parameters.Select(p => Argument(IdentifierName(p.Name))))))))); + } } /// /// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces). /// - private sealed record ComInterfaceAndMethods(ComInterfaceContext Interface, ImmutableArray Methods, ComInterfaceContext? BaseInterface) + private sealed record ComInterfaceAndMethods(ComInterfaceContext Interface, SequenceEqualImmutableArray Methods, ComInterfaceContext? BaseInterface) { /// /// COM methods that are declared on the attributed interface declaration. @@ -973,7 +1003,7 @@ private sealed record ComInterfaceAndMethods(ComInterfaceContext Interface, Immu /// public IEnumerable ShadowingMethods => Methods.Where(m => m.DeclaringInterface != Interface); - public static IEnumerable GetMethods(Dictionary ifaceToBaseMap, ImmutableDictionary> ifaceToMethodsMap, StubEnvironment environment, CancellationToken ct) + public static IEnumerable GetAllMethods(ValueEqualityImmutableDictionary ifaceToBaseMap, ValueEqualityImmutableDictionary> ifaceToMethodsMap, StubEnvironment environment, CancellationToken ct) { Dictionary> allMethodsCache = new(); @@ -982,7 +1012,7 @@ public static IEnumerable GetMethods(Dictionary asdf = AddMethods(kvp.Key, kvp.Value); } - return allMethodsCache.Select(kvp => new ComInterfaceAndMethods(kvp.Key, kvp.Value.ToImmutableArray(), ifaceToBaseMap[kvp.Key])).ToImmutableArray(); + return allMethodsCache.Select(kvp => new ComInterfaceAndMethods(kvp.Key, kvp.Value.ToSequenceEqualImmutableArray(), ifaceToBaseMap[kvp.Key])); IEnumerable AddMethods(ComInterfaceContext iface, IEnumerable declaredMethods) { diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs index adfa131473375..c2f6c887990a1 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs @@ -15,23 +15,18 @@ namespace Microsoft.Interop /// for many scenarios. This wrapper type allows us to use s in our other record types without having to write an Equals method /// that we may forget to update if we add new elements to the record. /// - public readonly record struct SequenceEqualImmutableArray(ImmutableArray Array, IEqualityComparer Comparer) : IList + public readonly record struct SequenceEqualImmutableArray(ImmutableArray Array, IEqualityComparer Comparer) : IEnumerable { public SequenceEqualImmutableArray(ImmutableArray array) : this(array, EqualityComparer.Default) { } - public T this[int index] { get => ((IList)Array)[index]; set => ((IList)Array)[index] = value; } + public T this[int i] { get => Array[i]; } - public int Count => ((ICollection)Array).Count; + public int Length => Array.Length; - public bool IsReadOnly => ((ICollection)Array).IsReadOnly; - - public void Add(T item) => ((ICollection)Array).Add(item); - public void Clear() => ((ICollection)Array).Clear(); - public bool Contains(T item) => Array.Contains(item); - public void CopyTo(T[] array, int arrayIndex) => Array.CopyTo(array, arrayIndex); + public override int GetHashCode() => throw new NotSupportedException(); public bool Equals(SequenceEqualImmutableArray other) { @@ -39,15 +34,10 @@ public bool Equals(SequenceEqualImmutableArray other) } public IEnumerator GetEnumerator() => ((IEnumerable)Array).GetEnumerator(); - public override int GetHashCode() => throw new UnreachableException(); - public int IndexOf(T item) => Array.IndexOf(item); - public void Insert(int index, T item) => ((IList)Array).Insert(index, item); - public bool Remove(T item) => ((ICollection)Array).Remove(item); - public void RemoveAt(int index) => ((IList)Array).RemoveAt(index); IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)Array).GetEnumerator(); } - public static class IEnumerableSequenceEqualImmutableArrayExtensions + public static partial class CollectionExtensions { public static SequenceEqualImmutableArray ToSequenceEqualImmutableArray(this IEnumerable source, IEqualityComparer comparer) { @@ -57,5 +47,9 @@ public static SequenceEqualImmutableArray ToSequenceEqualImmutableArray(th { return new(source.ToImmutableArray()); } + public static SequenceEqualImmutableArray ToSequenceEqual(this ImmutableArray source) + { + return new(source); + } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs index f8469ebeea90a..5ead994f6db93 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs @@ -9,7 +9,7 @@ namespace Microsoft.Interop { - internal record struct ValueEqualityImmutableDictionary(ImmutableDictionary Map) : IDictionary + public record struct ValueEqualityImmutableDictionary(ImmutableDictionary Map) : IDictionary { public bool Equals(ValueEqualityImmutableDictionary other) { @@ -34,18 +34,10 @@ public override int GetHashCode() } public U this[T key] { get => ((IDictionary)Map)[key]; set => ((IDictionary)Map)[key] = value; } - public ICollection Keys => ((IDictionary)Map).Keys; - public ICollection Values => ((IDictionary)Map).Values; - public int Count => Map.Count; - public bool IsReadOnly => ((ICollection>)Map).IsReadOnly; - - public void Add(T key, U value) => ((IDictionary)Map).Add(key, value); - public void Add(KeyValuePair item) => ((ICollection>)Map).Add(item); - public void Clear() => ((ICollection>)Map).Clear(); public bool Contains(KeyValuePair item) => Map.Contains(item); public bool ContainsKey(T key) => Map.ContainsKey(key); public void CopyTo(KeyValuePair[] array, int arrayIndex) => ((ICollection>)Map).CopyTo(array, arrayIndex); @@ -54,5 +46,26 @@ public override int GetHashCode() public bool Remove(KeyValuePair item) => ((ICollection>)Map).Remove(item); public bool TryGetValue(T key, out U value) => Map.TryGetValue(key, out value); IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)Map).GetEnumerator(); + public void Add(T key, U value) => ((IDictionary)Map).Add(key, value); + public void Add(KeyValuePair item) => ((ICollection>)Map).Add(item); + public void Clear() => ((ICollection>)Map).Clear(); + } + + public static partial class CollectionExtensions + { + public static ValueEqualityImmutableDictionary ToValueEqualityImmutableDictionary(this IEnumerable srcs, Func keyMap, Func valueMap) + { + return new(srcs.ToImmutableDictionary(keyMap, valueMap)); + } + + public static ValueEqualityImmutableDictionary ToValueEqual(this ImmutableDictionary dict) + { + return new(dict); + } + + public static ValueEqualityImmutableDictionary ToValueEqualImmutable(this Dictionary dict) + { + return new(dict.ToImmutableDictionary()); + } } } From 8b034019f198cd33e1476b56c8bda4fe314de47f Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Thu, 20 Apr 2023 11:10:13 -0700 Subject: [PATCH 07/31] delete unused code --- .../ComInterfaceGenerator.cs | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index bee5ebb26542b..30a2e0012628c 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -42,19 +42,6 @@ public static class StepNames public const string GenerateIUnknownDerivedAttribute = nameof(GenerateIUnknownDerivedAttribute); } - internal struct BaseInterfacesInfo - { - internal BaseInterfacesInfo(ManagedTypeInfo? baseInterface, bool isMarkerInterface) - { - BaseInterface = baseInterface; - IsMarkerInterface = isMarkerInterface; - } - - public ManagedTypeInfo? BaseInterface { get; } - - public bool IsMarkerInterface { get; } - } - private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, [NotNullWhen(true)] out INamedTypeSymbol? baseComIface) { baseComIface = null; @@ -224,13 +211,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }).WithTrackingName(StepNames.GenerateManagedToNativeInterfaceImplementation) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); - //var managedToNativeInterfaceImplementations = generateManagedToNativeStub - // .Collect() - // .SelectMany(static (stubs, ct) => GroupContextsForInterfaceGeneration(stubs)) - // .Select(static (interfaceGroup, ct) => GenerateImplementationInterface(interfaceGroup.Array)) - // .WithTrackingName(StepNames.GenerateManagedToNativeInterfaceImplementation) - // .WithComparer(SyntaxEquivalentComparer.Instance) - // .SelectNormalized(); // Filter the list of all stubs to only the stubs that requested unmanaged-to-managed stub generation. IncrementalValuesProvider nativeToManagedStubContexts = From d5b188c830b3db5585c0d9bfe77c24670b320ecb Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Thu, 20 Apr 2023 13:21:43 -0700 Subject: [PATCH 08/31] More cleaup and renaming --- .../ComInterfaceGenerator.cs | 185 ++++++++---------- 1 file changed, 83 insertions(+), 102 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 30a2e0012628c..18cfd272b74aa 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -76,7 +76,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { Diagnostic? Diagnostic = GetDiagnosticIfInvalidTypeForGeneration(data.Syntax, data.Symbol); INamedTypeSymbol? BaseInterfaceSymbol = TryGetBaseComInterface(data.Symbol, out var baseComInterface) ? baseComInterface : null; - ComInterfaceContext Context = ComInterfaceContext.From(data.Symbol, data.Syntax); + ComInterfaceInfo Context = ComInterfaceInfo.From(data.Symbol, data.Syntax); return new { data.Syntax, data.Symbol, Context, Diagnostic, BaseInterfaceSymbol }; }); @@ -84,11 +84,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var interfacesToGenerate = interfacesAndDiagnostics.Where(static data => data.Diagnostic is null); { var invalidTypeDiagnostics = interfacesAndDiagnostics.Where(static data => data.Diagnostic is not null); - context.RegisterDiagnostics(invalidTypeDiagnostics.Select((data, ct) => data.Diagnostic)); + context.RegisterDiagnostics(invalidTypeDiagnostics.Select(static (data, ct) => data.Diagnostic)); } // Get the information we need about methods themselves - var interfaceMethods = interfacesToGenerate.Select((data, ct) => + var interfaceMethods = interfacesToGenerate.Select(static (data, ct) => { INamedTypeSymbol iface = data.Symbol; List comMethods = new(); @@ -103,10 +103,10 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }); // Create a map of Com interface to its base for use later. - var ifaceToBaseMap = interfacesToGenerate.Collect().Select((data, ct) => + var ifaceToBaseMap = interfacesToGenerate.Collect().Select(static (data, ct) => { - Dictionary ifaceToBaseMap = new(); - Dictionary contexts = new(SymbolEqualityComparer.Default); + Dictionary ifaceToBaseMap = new(); + Dictionary contexts = new(SymbolEqualityComparer.Default); foreach (var iface in data) { contexts.Add(iface.Symbol, iface.Context); @@ -120,12 +120,12 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Generate a map from Com interface to the methods it declares var interfaceToDeclaredMethodsMap = interfacesToGenerate - .Select((iface, ct) => iface.Context) + .Select(static (iface, ct) => iface.Context) .Zip(interfaceMethods) .Collect() - .Select((data, ct) => + .Select(static (data, ct) => { - return data.ToValueEqualityImmutableDictionary<(ComInterfaceContext, SequenceEqualImmutableArray), ComInterfaceContext, SequenceEqualImmutableArray>( + return data.ToValueEqualityImmutableDictionary<(ComInterfaceInfo, SequenceEqualImmutableArray), ComInterfaceInfo, SequenceEqualImmutableArray>( static pair => pair.Item1, static pair => pair.Item2); }); @@ -134,16 +134,16 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var interfaceAndMethodsContexts = interfaceToDeclaredMethodsMap .Combine(ifaceToBaseMap) .Combine(context.CreateStubEnvironmentProvider()) - .SelectMany((data, ct) => + .SelectMany(static (data, ct) => { var ((ifaceToMethodsMap, ifaceToBaseMap), env) = data; - return ComInterfaceAndMethods.GetAllMethods(ifaceToBaseMap, ifaceToMethodsMap, env, ct); + return ComInterfaceAndMethodsContext.GetAllMethods(ifaceToBaseMap, ifaceToMethodsMap, env, ct); }); // Separate the methods which declare methods from those that don't declare methods var interfacesWithMethodsAndItsMethods = interfaceAndMethodsContexts - .Where(data => data.DeclaredMethods.Any()); + .Where(static data => data.DeclaredMethods.Any()); var interfacesWithMethods = interfacesWithMethodsAndItsMethods .Select(static (data, ct) => data.Interface); @@ -154,11 +154,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // and method-level data, but that's okay because marker interfaces are much simpler. // We'll handle them seperately because they are so simple. var markerInterfaces = interfaceAndMethodsContexts - .Where(data => !data.DeclaredMethods.Any()) + .Where(static data => !data.DeclaredMethods.Any()) .Select(static (data, ct) => data.Interface); var markerInterfaceIUnknownDerived = markerInterfaces - .Select(static (context, ct) => GenerateIUnknownDerivedAttributeApplication(context)) + .Select(GenerateIUnknownDerivedAttributeApplication) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); @@ -171,10 +171,10 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }); } - var allMethods = interfaceAndMethodsContexts.SelectMany(static (data, ct) => data.DeclaredMethods); // Split the methods we want to generate and the ones with warnings into different groups, and warn on the invalid methods { + var allMethods = interfaceAndMethodsContexts.SelectMany(static (data, ct) => data.DeclaredMethods); var invalidMethods = allMethods.Where(static data => data.Diagnostic is not null); context.RegisterSourceOutput(invalidMethods, static (context, invalidMethod) => @@ -182,82 +182,42 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.ReportDiagnostic(invalidMethod.Diagnostic); }); } - var methodsToGenerate = allMethods.Where(static data => - { - return data.Diagnostic is null; - }); - IncrementalValuesProvider generateStubInformation = methodsToGenerate.Select((data, ct) => data.GenerationContext); // Generate the code for the managed-to-unmanaged stubs and the diagnostics from code-generation. - var generateManagedToNativeStub = generateStubInformation - .Select( - static (data, ct) => - { - if (data.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)) - { - return (GeneratedMethodContextBase)new SkippedStubContext(data.OriginalDefiningType); - } - var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(data); - return new GeneratedStubCodeContext(data.TypeKeyOwner, data.ContainingSyntaxContext, new(methodStub), new(diagnostics)); - } - ) - .WithTrackingName(StepNames.GenerateManagedToNativeStub); - - context.RegisterDiagnostics(generateManagedToNativeStub.SelectMany((stubInfo, ct) => stubInfo.Diagnostics.Array)); - - var managedToNativeInterfaceImplementations = interfacesWithMethodsAndItsMethods.Select((data, ct) => - { - return GenerateImplementationInterface(data); - }).WithTrackingName(StepNames.GenerateManagedToNativeInterfaceImplementation) + context.RegisterDiagnostics(interfacesWithMethodsAndItsMethods + .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.ManagedToUnmanagedStub.Diagnostics))); + var managedToNativeInterfaceImplementations = interfacesWithMethodsAndItsMethods + .Select(GenerateImplementationInterface) + .WithTrackingName(StepNames.GenerateManagedToNativeInterfaceImplementation) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); - // Filter the list of all stubs to only the stubs that requested unmanaged-to-managed stub generation. - IncrementalValuesProvider nativeToManagedStubContexts = - generateStubInformation - .Where(static data => data.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional); - // Generate the code for the unmanaged-to-managed stubs and the diagnostics from code-generation. - var generateNativeToManagedStub = generateStubInformation - .Select( - static (data, ct) => - { - if (data.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)) - { - return (GeneratedMethodContextBase)new SkippedStubContext(data.OriginalDefiningType); - } - var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(data); - return new GeneratedStubCodeContext(data.OriginalDefiningType, data.ContainingSyntaxContext, new(methodStub), new(diagnostics)); - } - ) - .WithTrackingName(StepNames.GenerateNativeToManagedStub); - - context.RegisterDiagnostics(generateNativeToManagedStub.SelectMany((stubInfo, ct) => stubInfo.Diagnostics.Array)); - - var nativeToManagedVtableMethods = generateNativeToManagedStub - .Collect() - .SelectMany(static (stubs, ct) => GroupContextsForInterfaceGeneration(stubs)) - .Select(static (interfaceGroup, ct) => GenerateImplementationVTableMethods(interfaceGroup.Array)) + context.RegisterDiagnostics(interfacesWithMethodsAndItsMethods + .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.NativeToManagedStub.Diagnostics))); + var nativeToManagedVtableMethods = interfacesWithMethodsAndItsMethods + .Select(GenerateImplementationVTableMethods) .WithTrackingName(StepNames.GenerateNativeToManagedVTableMethods) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); // Generate the native interface metadata for each [GeneratedComInterface]-attributed interface. var nativeInterfaceInformation = interfacesWithMethods - .Select(static (context, ct) => GenerateInterfaceInformation(context)) + .Select(GenerateInterfaceInformation) .WithTrackingName(StepNames.GenerateInterfaceInformation) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); // Generate a method named CreateManagedVirtualFunctionTable on the native interface implementation // that allocates and fills in the memory for the vtable. - var nativeToManagedVtables = interfacesWithMethodsAndItsMethods.Select((data, ct) => GenerateImplementationVTable(data)) + var nativeToManagedVtables = interfacesWithMethodsAndItsMethods + .Select(GenerateImplementationVTable) .WithTrackingName(StepNames.GenerateNativeToManagedVTable) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); var iUnknownDerivedAttributeApplication = interfacesWithMethods - .Select(static (context, ct) => GenerateIUnknownDerivedAttributeApplication(context)) + .Select(GenerateIUnknownDerivedAttributeApplication) .WithTrackingName(StepNames.GenerateIUnknownDerivedAttribute) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); @@ -297,7 +257,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }); } - private static string GenerateMarkerInterfaceSource(ComInterfaceContext iface) => $$""" + private static string GenerateMarkerInterfaceSource(ComInterfaceInfo iface) => $$""" file unsafe class InterfaceInformation : global::System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType { public static global::System.Guid Iid => new(new global::System.ReadOnlySpan(new byte[] { {{string.Join(",", iface.InterfaceId.ToByteArray())}} })); @@ -331,13 +291,14 @@ public static void** ManagedVirtualMethodTable IdentifierName("InterfaceInformation"), IdentifierName("InterfaceImplementation"))); - private static MemberDeclarationSyntax GenerateIUnknownDerivedAttributeApplication(ComInterfaceContext context) + private static MemberDeclarationSyntax GenerateIUnknownDerivedAttributeApplication(ComInterfaceInfo context, CancellationToken _) => context.TypeDefinitionContext.WrapMemberInContainingSyntaxWithUnsafeModifier( TypeDeclaration(context.ContainingSyntax.TypeKind, context.ContainingSyntax.Identifier) .WithModifiers(context.ContainingSyntax.Modifiers) .WithTypeParameterList(context.ContainingSyntax.TypeParameters) .AddAttributeLists(AttributeList(SingletonSeparatedList(s_iUnknownDerivedAttributeTemplate)))); + // Todo: extract info needed from the IMethodSymbol into MethodInfo and only pass a MethodInfo here private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, CancellationToken ct) { ct.ThrowIfCancellationRequested(); @@ -560,7 +521,7 @@ private static ImmutableArray(interfaceGroup.DeclaredMethods.Select(m => m.ManagedToUnmanagedStub).OfType().Select(ctx => ctx.Stub.Node))) .AddAttributeLists(AttributeList(SingletonSeparatedList(Attribute(ParseName(TypeNames.System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute))))); } - private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(ImmutableArray interfaceGroup) + private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(ComInterfaceAndMethodsContext comInterfaceAndMethods, CancellationToken _) { return ImplementationInterfaceTemplate - .WithMembers(List(interfaceGroup.OfType().Select(context => context.Stub.Node))); + .WithMembers(List(comInterfaceAndMethods.DeclaredMethods.Select(m => m.NativeToManagedStub).OfType().Select(context => context.Stub.Node))); } private static readonly TypeSyntax VoidStarStarSyntax = PointerType(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))); @@ -580,7 +541,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(Im private static readonly MethodDeclarationSyntax CreateManagedVirtualFunctionTableMethodTemplate = MethodDeclaration(VoidStarStarSyntax, CreateManagedVirtualFunctionTableMethodName) .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword)); - private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterfaceAndMethods interfaceMethods) + private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterfaceAndMethodsContext interfaceMethods, CancellationToken _) { const string vtableLocalName = "vtable"; var interfaceType = interfaceMethods.Interface.InterfaceType; @@ -759,7 +720,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf .AddModifiers(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.UnsafeKeyword)) .AddBaseListTypes(SimpleBaseType(ParseTypeName(TypeNames.IIUnknownInterfaceType))); - private static ClassDeclarationSyntax GenerateInterfaceInformation(ComInterfaceContext context) + private static ClassDeclarationSyntax GenerateInterfaceInformation(ComInterfaceInfo context, CancellationToken _) { const string vtableFieldName = "_vtable"; return InterfaceInformationTypeTemplate @@ -821,14 +782,21 @@ static ExpressionSyntax CreateEmbeddedDataBlobCreationStatement(ReadOnlySpan + /// Information about a Com interface, but not it's methods + /// + private sealed record ComInterfaceInfo( ManagedTypeInfo InterfaceType, InterfaceDeclarationSyntax InterfaceDeclaration, ContainingSyntaxContext TypeDefinitionContext, ContainingSyntax ContainingSyntax, Guid InterfaceId) { - public static ComInterfaceContext From(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax) + public static ComInterfaceInfo From(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax) { Guid? guid = null; var guidAttr = symbol.GetAttributes().Where(attr => attr.AttributeClass.ToDisplayString() == TypeNames.System_Runtime_InteropServices_GuidAttribute).SingleOrDefault(); @@ -838,7 +806,7 @@ public static ComInterfaceContext From(INamedTypeSymbol symbol, InterfaceDeclara if (guidstr is not null) guid = new Guid(guidstr); } - return new ComInterfaceContext( + return new ComInterfaceInfo( ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol), syntax, new ContainingSyntaxContext(syntax), @@ -852,7 +820,7 @@ public override int GetHashCode() return HashCode.Combine(InterfaceType, TypeDefinitionContext, InterfaceId); } - public bool Equals(ComInterfaceContext other) + public bool Equals(ComInterfaceInfo other) { // ContainingSyntax and ContainingSyntaxContext are not used in the hash code return InterfaceType == other.InterfaceType @@ -862,7 +830,7 @@ public bool Equals(ComInterfaceContext other) } /// - /// Represents a method that has been determined to be a COM interface method. + /// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax. /// private sealed record MethodInfo( [property: Obsolete] IMethodSymbol Symbol, @@ -871,7 +839,7 @@ private sealed record MethodInfo( SequenceEqualImmutableArray<(ManagedTypeInfo Type, string Name, RefKind RefKind)> Parameters, Diagnostic? Diagnostic) { - public static bool IsComInterface(ComInterfaceContext ifaceContext, ISymbol member, [NotNullWhen(true)] out MethodInfo? comMethodInfo) + public static bool IsComInterface(ComInterfaceInfo ifaceContext, ISymbol member, [NotNullWhen(true)] out MethodInfo? comMethodInfo) { comMethodInfo = null; Location interfaceLocation = ifaceContext.InterfaceDeclaration.GetLocation(); @@ -931,7 +899,7 @@ public static bool IsComInterface(ComInterfaceContext ifaceContext, ISymbol memb /// /// Represents a method, its declaring interface, and its index in the interface's vtable. /// - private sealed record ComInterfaceMethodContext(ComInterfaceContext DeclaringInterface, MethodInfo MethodInfo, int Index, IncrementalMethodStubGenerationContext GenerationContext) + private sealed record ComMethodContext(ComInterfaceInfo DeclaringInterface, MethodInfo MethodInfo, int Index, IncrementalMethodStubGenerationContext GenerationContext) { public GeneratedMethodContextBase ManagedToUnmanagedStub { @@ -948,6 +916,19 @@ public GeneratedMethodContextBase ManagedToUnmanagedStub public Diagnostic? Diagnostic => MethodInfo.Diagnostic; + public GeneratedMethodContextBase NativeToManagedStub + { + get + { + if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)) + { + return (GeneratedMethodContextBase)new SkippedStubContext(GenerationContext.OriginalDefiningType); + } + var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext); + return new GeneratedStubCodeContext(GenerationContext.OriginalDefiningType, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); + } + } + public MethodDeclarationSyntax GenerateShadow() { // DeclarationCopiedFromBaseDeclaration() @@ -971,30 +952,30 @@ public MethodDeclarationSyntax GenerateShadow() /// /// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces). /// - private sealed record ComInterfaceAndMethods(ComInterfaceContext Interface, SequenceEqualImmutableArray Methods, ComInterfaceContext? BaseInterface) + private sealed record ComInterfaceAndMethodsContext(ComInterfaceInfo Interface, SequenceEqualImmutableArray Methods, ComInterfaceInfo? BaseInterface) { /// /// COM methods that are declared on the attributed interface declaration. /// - public IEnumerable DeclaredMethods => Methods.Where(m => m.DeclaringInterface == Interface); + public IEnumerable DeclaredMethods => Methods.Where(m => m.DeclaringInterface == Interface); /// /// COM methods that are declared on an interface the interface inherits from. /// - public IEnumerable ShadowingMethods => Methods.Where(m => m.DeclaringInterface != Interface); + public IEnumerable ShadowingMethods => Methods.Where(m => m.DeclaringInterface != Interface); - public static IEnumerable GetAllMethods(ValueEqualityImmutableDictionary ifaceToBaseMap, ValueEqualityImmutableDictionary> ifaceToMethodsMap, StubEnvironment environment, CancellationToken ct) + public static IEnumerable GetAllMethods(ValueEqualityImmutableDictionary ifaceToBaseMap, ValueEqualityImmutableDictionary> ifaceToDeclaredMethodsMap, StubEnvironment environment, CancellationToken ct) { - Dictionary> allMethodsCache = new(); + Dictionary> allMethodsCache = new(); - foreach (var kvp in ifaceToMethodsMap) + foreach (var kvp in ifaceToDeclaredMethodsMap) { - IEnumerable asdf = AddMethods(kvp.Key, kvp.Value); + AddMethods(kvp.Key, kvp.Value); } - return allMethodsCache.Select(kvp => new ComInterfaceAndMethods(kvp.Key, kvp.Value.ToSequenceEqualImmutableArray(), ifaceToBaseMap[kvp.Key])); + return allMethodsCache.Select(kvp => new ComInterfaceAndMethodsContext(kvp.Key, kvp.Value.ToSequenceEqualImmutableArray(), ifaceToBaseMap[kvp.Key])); - IEnumerable AddMethods(ComInterfaceContext iface, IEnumerable declaredMethods) + ImmutableArray AddMethods(ComInterfaceInfo iface, IEnumerable declaredMethods) { if (allMethodsCache.TryGetValue(iface, out var cachedValue)) { @@ -1002,27 +983,27 @@ IEnumerable AddMethods(ComInterfaceContext iface, IEn } int startingIndex = 3; - List methods = new(); + List methods = new(); + // If we have a base interface, we should add the inherited methods to our list in vtable order if (ifaceToBaseMap.TryGetValue(iface, out var baseComIface) && baseComIface is not null) { if (!allMethodsCache.TryGetValue(baseComIface, out var baseMethods)) { - baseMethods = AddMethods(baseComIface, ifaceToMethodsMap[baseComIface]); - } - - foreach (var method in baseMethods) - { - startingIndex++; - methods.Add(method); + baseMethods = AddMethods(baseComIface, ifaceToDeclaredMethodsMap[baseComIface]); } + methods.AddRange(baseMethods); + startingIndex += baseMethods.Length; } + // Then we append the declared methods in vtable order foreach (var method in declaredMethods) { var ctx = CalculateStubInformation(method.Syntax, method.Symbol, startingIndex, environment, ct); - methods.Add(new ComInterfaceMethodContext(iface, method, startingIndex++, ctx)); + methods.Add(new ComMethodContext(iface, method, startingIndex++, ctx)); } - allMethodsCache[iface] = methods; - return methods; + // Cache so we don't recalculate if many intherfaces inherit from the same o + var immutableMethods = methods.ToImmutableArray(); + allMethodsCache[iface] = immutableMethods; + return immutableMethods; } } } From 93a9033b3f13405993de77015acda79e0201a714 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Fri, 28 Apr 2023 14:32:55 -0500 Subject: [PATCH 09/31] Add ComInterfaceContext, create diagnostics in Record constructor --- .../ComInterfaceAndMethodsContext.cs | 77 +++ .../ComInterfaceContext.cs | 57 +++ .../ComInterfaceGenerator.cs | 455 ++---------------- .../ComInterfaceGenerator/ComInterfaceInfo.cs | 156 ++++++ .../ComInterfaceGenerator/ComMethodContext.cs | 67 +++ .../ComInterfaceGenerator/ComMethodInfo.cs | 105 ++++ .../IncrementalValuesProviderExtensions.cs | 26 + ...eneratorInitializationContextExtensions.cs | 2 +- ...terfaceGenerator.Tests.DerivedComObject.cs | 2 +- ...nerator.Tests.ManagedObjectExposedToCom.cs | 2 +- ...SharedTypes.ComInterfaces.IGetAndSetInt.cs | 118 +++++ .../SharedTypes.ComInterfaces.IGetIntArray.cs | 108 +++++ .../LibraryImports.g.cs | 16 + 13 files changed, 774 insertions(+), 417 deletions(-) create mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetAndSetInt.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs new file mode 100644 index 0000000000000..0bb07134c8dd2 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs @@ -0,0 +1,77 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Threading; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + public sealed partial class ComInterfaceGenerator + { + /// + /// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces). + /// + private sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interface, SequenceEqualImmutableArray Methods) + { + /// + /// COM methods that are declared on the attributed interface declaration. + /// + public IEnumerable DeclaredMethods => Methods.Where((m => m.DeclaringInterface == Interface)); + + /// + /// COM methods that are declared on an interface the interface inherits from. + /// + public IEnumerable ShadowingMethods => Methods.Where(m => m.DeclaringInterface != Interface); + + internal static ComInterfaceAndMethodsContext From((ComInterfaceContext, SequenceEqualImmutableArray) data, CancellationToken _) + => new ComInterfaceAndMethodsContext(data.Item1, data.Item2); + + public static IEnumerable CalculateAllMethods(ValueEqualityImmutableDictionary> ifaceToDeclaredMethodsMap, StubEnvironment environment, CancellationToken ct) + { + Dictionary> allMethodsCache = new(); + + foreach (var kvp in ifaceToDeclaredMethodsMap) + { + AddMethods(kvp.Key, kvp.Value); + } + + return allMethodsCache.Select(kvp => new ComInterfaceAndMethodsContext(kvp.Key, kvp.Value.ToSequenceEqual())); + + ImmutableArray AddMethods(ComInterfaceContext iface, IEnumerable declaredMethods) + { + if (allMethodsCache.TryGetValue(iface, out var cachedValue)) + { + return cachedValue; + } + + int startingIndex = 3; + List methods = new(); + // If we have a base interface, we should add the inherited methods to our list in vtable order + if (iface.Base is not null) + { + var baseComIface = iface.Base; + if (!allMethodsCache.TryGetValue(baseComIface, out var baseMethods)) + { + baseMethods = AddMethods(baseComIface, ifaceToDeclaredMethodsMap[baseComIface]); + } + methods.AddRange(baseMethods); + startingIndex += baseMethods.Length; + } + // Then we append the declared methods in vtable order + foreach (var method in declaredMethods) + { + var ctx = CalculateStubInformation(method.Syntax, method.Symbol, startingIndex, environment, ct); + methods.Add(new ComMethodContext(iface, method, startingIndex++, ctx)); + } + // Cache so we don't recalculate if many interfaces inherit from the same one + var immutableMethods = methods.ToImmutableArray(); + allMethodsCache[iface] = immutableMethods; + return immutableMethods; + } + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs new file mode 100644 index 0000000000000..aa0c48e5f04ff --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Threading; + +namespace Microsoft.Interop +{ + public sealed partial class ComInterfaceGenerator + { + private sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceContext? Base) + { + /// + /// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext. Does not guarantee the ordering of the output. + /// + public static Dictionary.ValueCollection GetContexts(ImmutableArray data, CancellationToken _) + { + Dictionary symbolToInterfaceInfoMap = new(); + foreach (var iface in data) + { + symbolToInterfaceInfoMap.Add(iface.ThisInterfaceKey, iface); + } + Dictionary symbolToContextMap = new(); + + foreach (var iface in data) + { + AddContext(iface); + } + return symbolToContextMap.Values; + + ComInterfaceContext AddContext(ComInterfaceInfo iface) + { + if (symbolToContextMap.TryGetValue(iface.ThisInterfaceKey, out var cachedValue)) + { + return cachedValue; + } + + if (iface.BaseInterfaceKey is null) + { + var baselessCtx = new ComInterfaceContext(iface, null); + symbolToContextMap[iface.ThisInterfaceKey] = baselessCtx; + return baselessCtx; + } + + if (!symbolToContextMap.TryGetValue(iface.BaseInterfaceKey, out var baseContext)) + { + baseContext = AddContext(symbolToInterfaceInfoMap[iface.BaseInterfaceKey]); + } + var ctx = new ComInterfaceContext(iface, baseContext); + symbolToContextMap[iface.ThisInterfaceKey] = ctx; + return ctx; + } + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 18cfd272b74aa..809e56b5e4896 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -4,13 +4,9 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; -using System.ComponentModel.Design; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; using System.Threading; -using System.Xml.Schema; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -20,7 +16,7 @@ namespace Microsoft.Interop { [Generator] - public sealed class ComInterfaceGenerator : IIncrementalGenerator + public sealed partial class ComInterfaceGenerator : IIncrementalGenerator { private sealed record class GeneratedStubCodeContext( ManagedTypeInfo OriginalDefiningType, @@ -42,21 +38,6 @@ public static class StepNames public const string GenerateIUnknownDerivedAttribute = nameof(GenerateIUnknownDerivedAttribute); } - private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, [NotNullWhen(true)] out INamedTypeSymbol? baseComIface) - { - baseComIface = null; - foreach (var implemented in comIface.Interfaces) - { - if (implemented.GetAttributes().Any(static attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)) - { - // We'll filter out cases where there's multiple matching interfaces when determining - // if this is a valid candidate for generation. - baseComIface = implemented; - break; - } - } - return baseComIface is not null; - } public void Initialize(IncrementalGeneratorInitializationContext context) { @@ -71,80 +52,64 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .Where( static modelData => modelData is not null); - - var interfacesAndDiagnostics = attributedInterfaces.Select(static (data, ct) => + var interfaceSymbolAndDiagnostic = attributedInterfaces.Select(static (data, ct) => { - Diagnostic? Diagnostic = GetDiagnosticIfInvalidTypeForGeneration(data.Syntax, data.Symbol); - INamedTypeSymbol? BaseInterfaceSymbol = TryGetBaseComInterface(data.Symbol, out var baseComInterface) ? baseComInterface : null; - ComInterfaceInfo Context = ComInterfaceInfo.From(data.Symbol, data.Syntax); - return new { data.Syntax, data.Symbol, Context, Diagnostic, BaseInterfaceSymbol }; + var (info, diagnostic) = ComInterfaceInfo.From(data.Symbol, data.Syntax); + return (InterfaceInfo: info, Diagnostic: diagnostic, Symbol: data.Symbol); }); + context.RegisterDiagnostics(interfaceSymbolAndDiagnostic.Select((data, ct) => data.Diagnostic)); - // Split the types we want to generate and the ones we don't into two separate groups. - var interfacesToGenerate = interfacesAndDiagnostics.Where(static data => data.Diagnostic is null); - { - var invalidTypeDiagnostics = interfacesAndDiagnostics.Where(static data => data.Diagnostic is not null); - context.RegisterDiagnostics(invalidTypeDiagnostics.Select(static (data, ct) => data.Diagnostic)); - } + var interfaceSymbolsWithoutDiagnostics = interfaceSymbolAndDiagnostic + .Where(data => data.Diagnostic is null); + + var interfacesToGenerate = interfaceSymbolsWithoutDiagnostics + .Select((data, ct) => data.InterfaceInfo!); + + var interfaceContexts = interfacesToGenerate.Collect().SelectMany(ComInterfaceContext.GetContexts); // Get the information we need about methods themselves - var interfaceMethods = interfacesToGenerate.Select(static (data, ct) => + var interfaceMethods = interfaceSymbolsWithoutDiagnostics.Select(static (pair, ct) => { - INamedTypeSymbol iface = data.Symbol; - List comMethods = new(); - foreach (var member in iface.GetMembers()) + var symbol = pair.Symbol; + var info = pair.InterfaceInfo; + List comMethods = new(); + foreach (var member in symbol.GetMembers()) { - if (MethodInfo.IsComInterface(data.Context, member, out MethodInfo? methodInfo)) + if (ComMethodInfo.IsComMethod(info, member, out ComMethodInfo? methodInfo)) { comMethods.Add(methodInfo); } } return comMethods.ToSequenceEqualImmutableArray(); }); - - // Create a map of Com interface to its base for use later. - var ifaceToBaseMap = interfacesToGenerate.Collect().Select(static (data, ct) => - { - Dictionary ifaceToBaseMap = new(); - Dictionary contexts = new(SymbolEqualityComparer.Default); - foreach (var iface in data) - { - contexts.Add(iface.Symbol, iface.Context); - } - foreach (var iface in data) - { - ifaceToBaseMap.Add(iface.Context, iface.BaseInterfaceSymbol is not null ? contexts[iface.BaseInterfaceSymbol] : null); - } - return ifaceToBaseMap.ToValueEqualImmutable(); - }); + context.RegisterDiagnostics(interfaceMethods.SelectMany(static (methodList, ct) => methodList.Select(m => m.Diagnostic))); // Generate a map from Com interface to the methods it declares - var interfaceToDeclaredMethodsMap = interfacesToGenerate - .Select(static (iface, ct) => iface.Context) + var interfaceToDeclaredMethodsMap = interfaceContexts .Zip(interfaceMethods) .Collect() .Select(static (data, ct) => { - return data.ToValueEqualityImmutableDictionary<(ComInterfaceInfo, SequenceEqualImmutableArray), ComInterfaceInfo, SequenceEqualImmutableArray>( + return data.ToValueEqualityImmutableDictionary<(ComInterfaceContext, SequenceEqualImmutableArray), ComInterfaceContext, SequenceEqualImmutableArray>( static pair => pair.Item1, static pair => pair.Item2); }); // Combine info about base methods and declared methods to get a list of interfaces, and all the methods they need to worry about (including both declared and inherited methods) var interfaceAndMethodsContexts = interfaceToDeclaredMethodsMap - .Combine(ifaceToBaseMap) + .Combine(interfaceContexts.Collect()) .Combine(context.CreateStubEnvironmentProvider()) .SelectMany(static (data, ct) => { var ((ifaceToMethodsMap, ifaceToBaseMap), env) = data; - return ComInterfaceAndMethodsContext.GetAllMethods(ifaceToBaseMap, ifaceToMethodsMap, env, ct); + return ComInterfaceAndMethodsContext.CalculateAllMethods(ifaceToMethodsMap, env, ct); }); - - // Separate the methods which declare methods from those that don't declare methods + // Separate the methods which have methods from those that don't var interfacesWithMethodsAndItsMethods = interfaceAndMethodsContexts - .Where(static data => data.DeclaredMethods.Any()); + .Where(static data => data.Methods.Length != 0); + // Separate out the interface for generation that doesn't depend on the methods var interfacesWithMethods = interfacesWithMethodsAndItsMethods .Select(static (data, ct) => data.Interface); @@ -158,6 +123,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .Select(static (data, ct) => data.Interface); var markerInterfaceIUnknownDerived = markerInterfaces + .Select(static (data, ct) => data.Info) .Select(GenerateIUnknownDerivedAttributeApplication) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); @@ -166,20 +132,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { var (interfaceContext, iUnknownDerivedAttributeApplication) = data; context.AddSource( - interfaceContext.InterfaceType.FullTypeName.Replace("global::", ""), - GenerateMarkerInterfaceSource(interfaceContext) + iUnknownDerivedAttributeApplication); - }); - } - - - // Split the methods we want to generate and the ones with warnings into different groups, and warn on the invalid methods - { - var allMethods = interfaceAndMethodsContexts.SelectMany(static (data, ct) => data.DeclaredMethods); - var invalidMethods = allMethods.Where(static data => data.Diagnostic is not null); - - context.RegisterSourceOutput(invalidMethods, static (context, invalidMethod) => - { - context.ReportDiagnostic(invalidMethod.Diagnostic); + interfaceContext.Info.Type.FullTypeName.Replace("global::", ""), + GenerateMarkerInterfaceSource(interfaceContext.Info) + iUnknownDerivedAttributeApplication); }); } @@ -203,6 +157,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Generate the native interface metadata for each [GeneratedComInterface]-attributed interface. var nativeInterfaceInformation = interfacesWithMethods + .Select(static (data, ct) => data.Info) .Select(GenerateInterfaceInformation) .WithTrackingName(StepNames.GenerateInterfaceInformation) .WithComparer(SyntaxEquivalentComparer.Instance) @@ -217,6 +172,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .SelectNormalized(); var iUnknownDerivedAttributeApplication = interfacesWithMethods + .Select(static (data, ct) => data.Info) .Select(GenerateIUnknownDerivedAttributeApplication) .WithTrackingName(StepNames.GenerateIUnknownDerivedAttribute) .WithComparer(SyntaxEquivalentComparer.Instance) @@ -248,7 +204,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) source.WriteLine(); source.WriteLine(); iUnknownDerivedAttribute.WriteTo(source); - return new { TypeName = interfaceContext.InterfaceType.FullTypeName, Source = source.ToString() }; + return new { TypeName = interfaceContext.Info.Type.FullTypeName, Source = source.ToString() }; }); context.RegisterSourceOutput(filesToGenerate, (context, data) => @@ -270,7 +226,7 @@ public static void** ManagedVirtualMethodTable { if (m_vtable == null) { - nint* vtable = (nint*)global::System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof({{iface.InterfaceType.FullTypeName}}), sizeof(nint) * 3); + nint* vtable = (nint*)global::System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof({{iface.Type.FullTypeName}}), sizeof(nint) * 3); global::System.Runtime.InteropServices.ComWrappers.GetIUnknownImpl(out vtable[0], out vtable[1], out vtable[2]); m_vtable = (void**)vtable; } @@ -280,7 +236,7 @@ public static void** ManagedVirtualMethodTable } [global::System.Runtime.InteropServices.DynamicInterfaceCastableImplementation] - file interface InterfaceImplementation : {{iface.InterfaceType.FullTypeName}} + file interface InterfaceImplementation : {{iface.Type.FullTypeName}} {} """; @@ -411,119 +367,14 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M ComInterfaceDispatchMarshallingInfo.Instance); } - private static Diagnostic? GetDiagnosticIfInvalidTypeForGeneration(InterfaceDeclarationSyntax syntax, INamedTypeSymbol type) - { - // Verify the method has no generic types or defined implementation - // and is not marked static or sealed - if (syntax.TypeParameterList is not null) - { - return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodSignature, syntax.Identifier.GetLocation(), type.Name); - } - - // Verify that the types the method is declared in are marked partial. - for (SyntaxNode? parentNode = syntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent) - { - if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword)) - { - return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers, syntax.Identifier.GetLocation(), type.Name, typeDecl.Identifier); - } - } - var guidAttr = type.GetAttributes().Where(attr => attr.AttributeClass.ToDisplayString() == TypeNames.System_Runtime_InteropServices_GuidAttribute).SingleOrDefault(); - var interfaceTypeAttr = type.GetAttributes().Where(attr => attr.AttributeClass.ToDisplayString() == TypeNames.InterfaceTypeAttribute).SingleOrDefault(); - // Assume interfaceType is IUnknown for now - if (interfaceTypeAttr is not null - && (guidAttr is null - || guidAttr.ConstructorArguments.SingleOrDefault().Value as string is null)) - { - return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedInterfaceMissingGuidAttribute, syntax.Identifier.GetLocation(), type.ToDisplayString()); - // Missing Guid - } - - // Error if more than one GeneratedComInterface base interface type. - INamedTypeSymbol? baseInterface = null; - foreach (var implemented in type.Interfaces) - { - if (implemented.GetAttributes().Any(static attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)) - { - if (baseInterface is not null) - { - return Diagnostic.Create(GeneratorDiagnostics.MultipleComInterfaceBaseTypesAttribute, syntax.Identifier.GetLocation(), type.ToDisplayString()); - } - baseInterface = implemented; - } - } - - return null; - } - - private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax syntax, IMethodSymbol method) - { - // Verify the method has no generic types or defined implementation - // and is not marked static or sealed - if (syntax.TypeParameterList is not null - || syntax.Body is not null - || syntax.Modifiers.Any(SyntaxKind.SealedKeyword)) - { - return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodSignature, syntax.Identifier.GetLocation(), method.Name); - } - - // Verify the method does not have a ref return - if (method.ReturnsByRef || method.ReturnsByRefReadonly) - { - return Diagnostic.Create(GeneratorDiagnostics.ReturnConfigurationNotSupported, syntax.Identifier.GetLocation(), "ref return", method.ToDisplayString()); - } - - return null; - } - - private static ImmutableArray> GroupContextsForInterfaceGeneration(ImmutableArray contexts) - { - // We can end up with an empty set of contexts here as the compiler will call a SelectMany - // after a Collect with no input entries - if (contexts.IsEmpty) - { - return ImmutableArray>.Empty; - } - - ImmutableArray>.Builder allGroupsBuilder = ImmutableArray.CreateBuilder>(); - - // Due to how the source generator driver processes the input item tables and our limitation that methods on COM interfaces can only be defined in a single partial definition of the type, - // we can guarantee that the method contexts are ordered as follows: - // - I1.M1 - // - I1.M2 - // - I1.M3 - // - I2.M1 - // - I2.M2 - // - I2.M3 - // - I3.M1 - // - etc... - // This enable us to group our contexts by their containing syntax rather simply. - ManagedTypeInfo? lastSeenDefiningType = null; - ImmutableArray.Builder groupBuilder = ImmutableArray.CreateBuilder(); - foreach (var context in contexts) - { - if (lastSeenDefiningType is null || lastSeenDefiningType == context.OriginalDefiningType) - { - groupBuilder.Add(context); - } - else - { - allGroupsBuilder.Add(new(groupBuilder.ToImmutable())); - groupBuilder.Clear(); - groupBuilder.Add(context); - } - lastSeenDefiningType = context.OriginalDefiningType; - } - allGroupsBuilder.Add(new(groupBuilder.ToImmutable())); - return allGroupsBuilder.ToImmutable(); - } private static readonly InterfaceDeclarationSyntax ImplementationInterfaceTemplate = InterfaceDeclaration("InterfaceImplementation") .WithModifiers(TokenList(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.UnsafeKeyword), Token(SyntaxKind.PartialKeyword))); - private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken ct) + + private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _) { - var definingType = interfaceGroup.Interface.InterfaceType; + var definingType = interfaceGroup.Interface.Info.Type; return ImplementationInterfaceTemplate .AddBaseListTypes(SimpleBaseType(definingType.Syntax)) .WithMembers(List(interfaceGroup.DeclaredMethods.Select(m => m.ManagedToUnmanagedStub).OfType().Select(ctx => ctx.Stub.Node))) @@ -544,7 +395,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(Co private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterfaceAndMethodsContext interfaceMethods, CancellationToken _) { const string vtableLocalName = "vtable"; - var interfaceType = interfaceMethods.Interface.InterfaceType; + var interfaceType = interfaceMethods.Interface.Info.Type; var interfaceMethodStubs = interfaceMethods.DeclaredMethods.Select(m => m.GenerationContext); ImmutableArray vtableExposedContexts = interfaceMethodStubs @@ -588,7 +439,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf BlockSyntax fillBaseInterfaceSlots; - if (interfaceMethods.BaseInterface is null) + if (interfaceMethods.Interface.Base is null) { // If we don't have a base interface, we need to manually fill in the base iUnknown slots. fillBaseInterfaceSlots = Block() @@ -690,7 +541,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, TypeOfExpression( - ParseTypeName(interfaceMethods.BaseInterface.InterfaceType.FullTypeName)), //baseInterfaceTypeInfo.BaseInterface.FullTypeName)), + ParseTypeName(interfaceMethods.Interface.Base.Info.Type.FullTypeName)), //baseInterfaceTypeInfo.BaseInterface.FullTypeName)), IdentifierName("TypeHandle")))))), IdentifierName("ManagedVirtualMethodTable"))), Argument(IdentifierName(vtableLocalName)), @@ -782,230 +633,6 @@ static ExpressionSyntax CreateEmbeddedDataBlobCreationStatement(ReadOnlySpan - /// Information about a Com interface, but not it's methods - /// - private sealed record ComInterfaceInfo( - ManagedTypeInfo InterfaceType, - InterfaceDeclarationSyntax InterfaceDeclaration, - ContainingSyntaxContext TypeDefinitionContext, - ContainingSyntax ContainingSyntax, - Guid InterfaceId) - { - public static ComInterfaceInfo From(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax) - { - Guid? guid = null; - var guidAttr = symbol.GetAttributes().Where(attr => attr.AttributeClass.ToDisplayString() == TypeNames.System_Runtime_InteropServices_GuidAttribute).SingleOrDefault(); - if (guidAttr is not null) - { - string? guidstr = guidAttr.ConstructorArguments.SingleOrDefault().Value as string; - if (guidstr is not null) - guid = new Guid(guidstr); - } - return new ComInterfaceInfo( - ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol), - syntax, - new ContainingSyntaxContext(syntax), - new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList), - guid ?? Guid.Empty); - } - - public override int GetHashCode() - { - // ContainingSyntax and ContainingSyntaxContext do not implement GetHashCode - return HashCode.Combine(InterfaceType, TypeDefinitionContext, InterfaceId); - } - - public bool Equals(ComInterfaceInfo other) - { - // ContainingSyntax and ContainingSyntaxContext are not used in the hash code - return InterfaceType == other.InterfaceType - && TypeDefinitionContext == other.TypeDefinitionContext - && InterfaceId == other.InterfaceId; - } - } - - /// - /// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax. - /// - private sealed record MethodInfo( - [property: Obsolete] IMethodSymbol Symbol, - MethodDeclarationSyntax Syntax, - string MethodName, - SequenceEqualImmutableArray<(ManagedTypeInfo Type, string Name, RefKind RefKind)> Parameters, - Diagnostic? Diagnostic) - { - public static bool IsComInterface(ComInterfaceInfo ifaceContext, ISymbol member, [NotNullWhen(true)] out MethodInfo? comMethodInfo) - { - comMethodInfo = null; - Location interfaceLocation = ifaceContext.InterfaceDeclaration.GetLocation(); - if (member.Kind == SymbolKind.Method && !member.IsStatic) - { - // We only support methods that are defined in the same partial interface definition as the - // [GeneratedComInterface] attribute. - // This restriction not only makes finding the syntax for a given method cheaper, - // but it also enables us to ensure that we can determine vtable method order easily. - Location? methodLocationInAttributedInterfaceDeclaration = null; - foreach (var methodLocation in member.Locations) - { - if (methodLocation.SourceTree == interfaceLocation.SourceTree - && interfaceLocation.SourceSpan.Contains(methodLocation.SourceSpan)) - { - methodLocationInAttributedInterfaceDeclaration = methodLocation; - break; - } - } - - // TODO: this should cause a diagnostic - if (methodLocationInAttributedInterfaceDeclaration is null) - { - return false; - } - - // Find the matching declaration syntax - MethodDeclarationSyntax? comMethodDeclaringSyntax = null; - foreach (var declaringSyntaxReference in member.DeclaringSyntaxReferences) - { - var declaringSyntax = declaringSyntaxReference.GetSyntax(); - Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration)); - if (declaringSyntax.GetLocation().SourceSpan.Contains(methodLocationInAttributedInterfaceDeclaration.SourceSpan)) - { - comMethodDeclaringSyntax = (MethodDeclarationSyntax)declaringSyntax; - break; - } - } - if (comMethodDeclaringSyntax is null) - throw new NotImplementedException("Found a method that was declared in the attributed interface declaration, but couldn't find the syntax for it."); - - List<(ManagedTypeInfo ParameterType, string Name, RefKind RefKind)> parameters = new(); - foreach (var parameter in ((IMethodSymbol)member).Parameters) - { - parameters.Add((ManagedTypeInfo.CreateTypeInfoForTypeSymbol(parameter.Type), parameter.Name, parameter.RefKind)); - } - - var diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, (IMethodSymbol)member); - - comMethodInfo = new((IMethodSymbol)member, comMethodDeclaringSyntax, member.Name, parameters.ToSequenceEqualImmutableArray(), diag); - return true; - } - return false; - } - } - - /// - /// Represents a method, its declaring interface, and its index in the interface's vtable. - /// - private sealed record ComMethodContext(ComInterfaceInfo DeclaringInterface, MethodInfo MethodInfo, int Index, IncrementalMethodStubGenerationContext GenerationContext) - { - public GeneratedMethodContextBase ManagedToUnmanagedStub - { - get - { - if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)) - { - return (GeneratedMethodContextBase)new SkippedStubContext(DeclaringInterface.InterfaceType); - } - var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext); - return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); - } - } - - public Diagnostic? Diagnostic => MethodInfo.Diagnostic; - - public GeneratedMethodContextBase NativeToManagedStub - { - get - { - if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)) - { - return (GeneratedMethodContextBase)new SkippedStubContext(GenerationContext.OriginalDefiningType); - } - var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext); - return new GeneratedStubCodeContext(GenerationContext.OriginalDefiningType, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); - } - } - - public MethodDeclarationSyntax GenerateShadow() - { - // DeclarationCopiedFromBaseDeclaration() - // { - // return (()this).(); - // } - return MethodInfo.Syntax.WithBody( - Block( - ReturnStatement( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - CastExpression(DeclaringInterface.InterfaceType.Syntax, IdentifierName(Token(SyntaxKind.ThisKeyword))), - IdentifierName(MethodInfo.MethodName)), - ArgumentList( - // TODO: RefKind keywords - SeparatedList(MethodInfo.Parameters.Select(p => Argument(IdentifierName(p.Name))))))))); - } - } - - /// - /// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces). - /// - private sealed record ComInterfaceAndMethodsContext(ComInterfaceInfo Interface, SequenceEqualImmutableArray Methods, ComInterfaceInfo? BaseInterface) - { - /// - /// COM methods that are declared on the attributed interface declaration. - /// - public IEnumerable DeclaredMethods => Methods.Where(m => m.DeclaringInterface == Interface); - - /// - /// COM methods that are declared on an interface the interface inherits from. - /// - public IEnumerable ShadowingMethods => Methods.Where(m => m.DeclaringInterface != Interface); - - public static IEnumerable GetAllMethods(ValueEqualityImmutableDictionary ifaceToBaseMap, ValueEqualityImmutableDictionary> ifaceToDeclaredMethodsMap, StubEnvironment environment, CancellationToken ct) - { - Dictionary> allMethodsCache = new(); - - foreach (var kvp in ifaceToDeclaredMethodsMap) - { - AddMethods(kvp.Key, kvp.Value); - } - - return allMethodsCache.Select(kvp => new ComInterfaceAndMethodsContext(kvp.Key, kvp.Value.ToSequenceEqualImmutableArray(), ifaceToBaseMap[kvp.Key])); - - ImmutableArray AddMethods(ComInterfaceInfo iface, IEnumerable declaredMethods) - { - if (allMethodsCache.TryGetValue(iface, out var cachedValue)) - { - return cachedValue; - } - - int startingIndex = 3; - List methods = new(); - // If we have a base interface, we should add the inherited methods to our list in vtable order - if (ifaceToBaseMap.TryGetValue(iface, out var baseComIface) && baseComIface is not null) - { - if (!allMethodsCache.TryGetValue(baseComIface, out var baseMethods)) - { - baseMethods = AddMethods(baseComIface, ifaceToDeclaredMethodsMap[baseComIface]); - } - methods.AddRange(baseMethods); - startingIndex += baseMethods.Length; - } - // Then we append the declared methods in vtable order - foreach (var method in declaredMethods) - { - var ctx = CalculateStubInformation(method.Syntax, method.Symbol, startingIndex, environment, ct); - methods.Add(new ComMethodContext(iface, method, startingIndex++, ctx)); - } - // Cache so we don't recalculate if many intherfaces inherit from the same o - var immutableMethods = methods.ToImmutableArray(); - allMethodsCache[iface] = immutableMethods; - return immutableMethods; - } - } - } + private sealed record InterfaceSymbolInfo(ComInterfaceInfo Info, Diagnostic? Diagnostic, TBaseInterfaceKey ThisInterfaceKey, TBaseInterfaceKey? BaseInterfaceKey); } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs new file mode 100644 index 0000000000000..a15c3b59974b9 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs @@ -0,0 +1,156 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.Interop +{ + public sealed partial class ComInterfaceGenerator + { + /// + /// Information about a Com interface, but not it's methods + /// + private sealed record ComInterfaceInfo( + ManagedTypeInfo Type, + string ThisInterfaceKey, + string? BaseInterfaceKey, + InterfaceDeclarationSyntax Declaration, + ContainingSyntaxContext TypeDefinitionContext, + ContainingSyntax ContainingSyntax, + Guid InterfaceId) + { + public static (ComInterfaceInfo? Info, Diagnostic? Diagnostic) From(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax) + { + // Verify the method has no generic types or defined implementation + // and is not marked static or sealed + if (syntax.TypeParameterList is not null) + { + return (null, Diagnostic.Create( + GeneratorDiagnostics.InvalidAttributedMethodSignature, + syntax.Identifier.GetLocation(), + symbol.Name)); + } + + // Verify that the types the method is declared in are marked partial. + for (SyntaxNode? parentNode = syntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent) + { + if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword)) + { + return (null, Diagnostic.Create( + GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers, + syntax.Identifier.GetLocation(), + symbol.Name, + typeDecl.Identifier)); + } + } + + if (!TryGetGuid(symbol, syntax, out var guid, out var guidDiagnostic)) + return (null, guidDiagnostic); + + if (!TryGetBaseComInterface(symbol, syntax, out var baseSymbol, out var baseDiagnostic)) + return (null, baseDiagnostic); + + return (new ComInterfaceInfo( + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol), + symbol.ToDisplayString(), + baseSymbol?.ToDisplayString(), + syntax, + new ContainingSyntaxContext(syntax), + new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList), + guid ?? Guid.Empty), null); + } + + /// + /// Returns true if there is 0 or 1 base Com interfaces (i.e. the inheritance is valid), and returns false when there are 2 or more base Com interfaces and sets . + /// + private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, InterfaceDeclarationSyntax syntax, [NotNullWhen(true)] out INamedTypeSymbol? baseComIface, [NotNullWhen(false)] out Diagnostic? diagnostic) + { + baseComIface = null; + foreach (var implemented in comIface.Interfaces) + { + foreach (var attr in implemented.GetAttributes()) + { + if (attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute) + { + // We'll filter out cases where there's multiple matching interfaces when determining + // if this is a valid candidate for generation. + if (baseComIface is not null) + { + diagnostic = Diagnostic.Create( + GeneratorDiagnostics.MultipleComInterfaceBaseTypesAttribute, + syntax.Identifier.GetLocation(), + comIface.ToDisplayString()); + return false; + } + baseComIface = implemented; + } + } + } + diagnostic = null; + return true; + } + + /// + /// Returns true and sets if the guid is present. Returns false and sets diagnostic if the guid is not present or is invalid. + /// + private static bool TryGetGuid(INamedTypeSymbol interfaceSymbol, InterfaceDeclarationSyntax syntax, [NotNullWhen(true)] out Guid? guid, [NotNullWhen(false)] out Diagnostic? diagnostic) + { + guid = null; + AttributeData? guidAttr = null; + AttributeData? interfaceTypeAttr = null; + foreach (var attr in interfaceSymbol.GetAttributes()) + { + var attrDisplayString = attr.AttributeClass?.ToDisplayString(); + if (attrDisplayString is TypeNames.System_Runtime_InteropServices_GuidAttribute) + guidAttr = attr; + else if (attrDisplayString is TypeNames.InterfaceTypeAttribute) + interfaceTypeAttr = attr; + } + + if (guidAttr is not null) + { + string? guidstr = guidAttr.ConstructorArguments.SingleOrDefault().Value as string; + if (guidstr is not null) + { + try + { + guid = new Guid(guidstr); + } + // Diagnostic will be raised if guid is null + catch (FormatException) { } + catch (OverflowException) { } + } + } + + // Assume interfaceType is IUnknown for now + if (interfaceTypeAttr is not null + && guid is null) + { + diagnostic = Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedInterfaceMissingGuidAttribute, syntax.Identifier.GetLocation(), interfaceSymbol.ToDisplayString()); + return false; + } + diagnostic = null; + return true; + } + + public override int GetHashCode() + { + // ContainingSyntax and ContainingSyntaxContext do not implement GetHashCode + return HashCode.Combine(Type, TypeDefinitionContext, InterfaceId); + } + + public bool Equals(ComInterfaceInfo other) + { + // ContainingSyntax and ContainingSyntaxContext are not used in the hash code + return Type == other.Type + && TypeDefinitionContext == other.TypeDefinitionContext + && InterfaceId == other.InterfaceId; + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs new file mode 100644 index 0000000000000..dba274da462b4 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs @@ -0,0 +1,67 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Interop +{ + public sealed partial class ComInterfaceGenerator + { + /// + /// Represents a method, its declaring interface, and its index in the interface's vtable. + /// + private sealed record ComMethodContext(ComInterfaceContext DeclaringInterface, ComMethodInfo MethodInfo, int Index, IncrementalMethodStubGenerationContext GenerationContext) + { + public GeneratedMethodContextBase ManagedToUnmanagedStub + { + get + { + if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)) + { + return (GeneratedMethodContextBase)new SkippedStubContext(DeclaringInterface.Info.Type); + } + var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext); + return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); + } + } + + public Diagnostic? Diagnostic => MethodInfo.Diagnostic; + + public GeneratedMethodContextBase NativeToManagedStub + { + get + { + if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)) + { + return (GeneratedMethodContextBase)new SkippedStubContext(GenerationContext.OriginalDefiningType); + } + var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext); + return new GeneratedStubCodeContext(GenerationContext.OriginalDefiningType, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); + } + } + + public MethodDeclarationSyntax GenerateShadow() + { + // DeclarationCopiedFromBaseDeclaration() + // { + // return (()this).(); + // } + return MethodInfo.Syntax.WithBody( + Block( + ReturnStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + CastExpression(DeclaringInterface.Info.Type.Syntax, IdentifierName(Token(SyntaxKind.ThisKeyword))), + IdentifierName(MethodInfo.MethodName)), + ArgumentList( + // TODO: RefKind keywords + SeparatedList(MethodInfo.Parameters.Select(p => Argument(IdentifierName(p.Name))))))))); + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs new file mode 100644 index 0000000000000..f19892e14a738 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs @@ -0,0 +1,105 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.Interop +{ + public sealed partial class ComInterfaceGenerator + { + /// + /// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax. + /// + private sealed record ComMethodInfo( + [property: Obsolete] IMethodSymbol Symbol, + MethodDeclarationSyntax Syntax, + string MethodName, + SequenceEqualImmutableArray<(ManagedTypeInfo Type, string Name, RefKind RefKind)> Parameters, + Diagnostic? Diagnostic) + { + private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax comMethodDeclaringSyntax, IMethodSymbol method) + { + // Verify the method has no generic types or defined implementation + // and is not marked static or sealed + if (comMethodDeclaringSyntax.TypeParameterList is not null + || comMethodDeclaringSyntax.Body is not null + || comMethodDeclaringSyntax.Modifiers.Any(SyntaxKind.SealedKeyword)) + { + return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodSignature, comMethodDeclaringSyntax.Identifier.GetLocation(), method.Name); + } + + // Verify the method does not have a ref return + if (method.ReturnsByRef || method.ReturnsByRefReadonly) + { + return Diagnostic.Create(GeneratorDiagnostics.ReturnConfigurationNotSupported, comMethodDeclaringSyntax.Identifier.GetLocation(), "ref return", method.ToDisplayString()); + } + + return null; + } + + public static bool IsComMethod(ComInterfaceInfo ifaceContext, ISymbol member, [NotNullWhen(true)] out ComMethodInfo? comMethodInfo) + { + Diagnostic diag; + comMethodInfo = null; + Location interfaceLocation = ifaceContext.Declaration.GetLocation(); + if (member.Kind == SymbolKind.Method && !member.IsStatic) + { + // We only support methods that are defined in the same partial interface definition as the + // [GeneratedComInterface] attribute. + // This restriction not only makes finding the syntax for a given method cheaper, + // but it also enables us to ensure that we can determine vtable method order easily. + Location? methodLocationInAttributedInterfaceDeclaration = null; + foreach (var methodLocation in member.Locations) + { + if (methodLocation.SourceTree == interfaceLocation.SourceTree + && interfaceLocation.SourceSpan.Contains(methodLocation.SourceSpan)) + { + methodLocationInAttributedInterfaceDeclaration = methodLocation; + break; + } + } + + // TODO: this should cause a diagnostic + if (methodLocationInAttributedInterfaceDeclaration is null) + { + return false; + } + + // Find the matching declaration syntax + MethodDeclarationSyntax? comMethodDeclaringSyntax = null; + foreach (var declaringSyntaxReference in member.DeclaringSyntaxReferences) + { + var declaringSyntax = declaringSyntaxReference.GetSyntax(); + Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration)); + if (declaringSyntax.GetLocation().SourceSpan.Contains(methodLocationInAttributedInterfaceDeclaration.SourceSpan)) + { + comMethodDeclaringSyntax = (MethodDeclarationSyntax)declaringSyntax; + break; + } + } + if (comMethodDeclaringSyntax is null) + throw new NotImplementedException("Found a method that was declared in the attributed interface declaration, but couldn't find the syntax for it."); + + List<(ManagedTypeInfo ParameterType, string Name, RefKind RefKind)> parameters = new(); + foreach (var parameter in ((IMethodSymbol)member).Parameters) + { + parameters.Add((ManagedTypeInfo.CreateTypeInfoForTypeSymbol(parameter.Type), parameter.Name, parameter.RefKind)); + } + + diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, (IMethodSymbol)member); + + comMethodInfo = new((IMethodSymbol)member, comMethodDeclaringSyntax, member.Name, parameters.ToSequenceEqualImmutableArray(), diag); + return true; + } + return false; + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs index fb0dd80ca7f1a..e578c24f00f57 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Linq; using System.Text; using Microsoft.CodeAnalysis; @@ -31,6 +32,31 @@ internal static class IncrementalValuesProviderExtensions }); } + public static IncrementalValuesProvider<(TGrouper, SequenceEqualImmutableArray)> GroupTuples(this IncrementalValuesProvider<(TGrouper Key, TGroupee Value)> values) + { + return values.Collect().SelectMany(static (values, ct) => + { + var valueMap = new Dictionary>(); + foreach (var value in values) + { + if (!valueMap.TryGetValue(value.Key, out var list)) + { + list = new(); + } + list.Add(value.Value); + valueMap[value.Key] = list; + } + + var builder = ImmutableArray.CreateBuilder<(TGrouper, SequenceEqualImmutableArray)>(valueMap.Count); + foreach (var kvp in valueMap) + { + builder.Add((kvp.Key, kvp.Value.ToSequenceEqualImmutableArray())); + } + + return builder.MoveToImmutable(); + }); + } + /// /// Format the syntax nodes in the given provider such that we will not re-normalize if the input nodes have not changed. /// diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalGeneratorInitializationContextExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalGeneratorInitializationContextExtensions.cs index 069daeb130fff..d0adc407312d1 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalGeneratorInitializationContextExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalGeneratorInitializationContextExtensions.cs @@ -39,7 +39,7 @@ public static IncrementalValueProvider CreateStubEnvironmentPro public static void RegisterDiagnostics(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider diagnostics) { - context.RegisterSourceOutput(diagnostics, (context, diagnostic) => + context.RegisterSourceOutput(diagnostics.Where(diag => diag is not null), (context, diagnostic) => { context.ReportDiagnostic(diagnostic); }); diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.DerivedComObject.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.DerivedComObject.cs index cb78fb469e93b..a8a8351c5360d 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.DerivedComObject.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.DerivedComObject.cs @@ -8,7 +8,7 @@ { System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* vtables = (System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry*)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComClassInformation), sizeof(System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry) * 1); System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails details; - details = System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(ComInterfaceGenerator.Tests.IComInterface1).TypeHandle); + details = System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(SharedTypes.ComInterfaces.IGetAndSetInt).TypeHandle); vtables[0] = new() { IID = details.Iid, diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.ManagedObjectExposedToCom.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.ManagedObjectExposedToCom.cs index a5c4b0aa2f52e..84fe3a175d798 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.ManagedObjectExposedToCom.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.ManagedObjectExposedToCom.cs @@ -8,7 +8,7 @@ { System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* vtables = (System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry*)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComClassInformation), sizeof(System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry) * 1); System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails details; - details = System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(ComInterfaceGenerator.Tests.IComInterface1).TypeHandle); + details = System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(SharedTypes.ComInterfaces.IGetAndSetInt).TypeHandle); vtables[0] = new() { IID = details.Iid, diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetAndSetInt.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetAndSetInt.cs new file mode 100644 index 0000000000000..b7ec44242a3b2 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetAndSetInt.cs @@ -0,0 +1,118 @@ +file unsafe class InterfaceInformation : System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType +{ + public static System.Guid Iid { get; } = new(new System.ReadOnlySpan(new byte[] { 3, 153, 63, 44, 134, 181, 177, 70, 136, 27, 173, 252, 233, 175, 71, 177 })); + + private static void** _vtable; + public static void** ManagedVirtualMethodTable => _vtable != null ? _vtable : (_vtable = InterfaceImplementation.CreateManagedVirtualFunctionTable()); +} + +[System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] +file unsafe partial interface InterfaceImplementation : global::SharedTypes.ComInterfaces.IGetAndSetInt +{ + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + int global::SharedTypes.ComInterfaces.IGetAndSetInt.GetInt() + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::SharedTypes.ComInterfaces.IGetAndSetInt)); + int __retVal; + int __invokeRetVal; + { + __invokeRetVal = ((delegate* unmanaged )__vtable_native[3])(__this, &__retVal); + } + + // Unmarshal - Convert native data to managed data. + System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); + return __retVal; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + void global::SharedTypes.ComInterfaces.IGetAndSetInt.SetInt(int x) + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::SharedTypes.ComInterfaces.IGetAndSetInt)); + int __invokeRetVal; + { + __invokeRetVal = ((delegate* unmanaged )__vtable_native[4])(__this, x); + } + + // Unmarshal - Convert native data to managed data. + System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); + } +} + +file unsafe partial interface InterfaceImplementation +{ + [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] + internal static int ABI_GetInt(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int* __invokeRetValUnmanaged__param) + { + global::SharedTypes.ComInterfaces.IGetAndSetInt @this = default; + ref int __invokeRetValUnmanaged = ref *__invokeRetValUnmanaged__param; + int __invokeRetVal = default; + int __retVal = default; + try + { + // Unmarshal - Convert native data to managed data. + __retVal = 0; // S_OK + @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); + __invokeRetVal = @this.GetInt(); + // Marshal - Convert managed data to native data. + __invokeRetValUnmanaged = __invokeRetVal; + } + catch (System.Exception __exception) + { + __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); + } + + return __retVal; + } + + [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] + internal static int ABI_SetInt(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int x) + { + global::SharedTypes.ComInterfaces.IGetAndSetInt @this = default; + int __retVal = default; + try + { + // Unmarshal - Convert native data to managed data. + __retVal = 0; // S_OK + @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); + @this.SetInt(x); + } + catch (System.Exception __exception) + { + __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); + } + + return __retVal; + } +} + +file unsafe partial interface InterfaceImplementation +{ + internal static void** CreateManagedVirtualFunctionTable() + { + void** vtable = (void**)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(global::SharedTypes.ComInterfaces.IGetAndSetInt), sizeof(void*) * 5); + { + nint v0, v1, v2; + System.Runtime.InteropServices.ComWrappers.GetIUnknownImpl(out v0, out v1, out v2); + vtable[0] = (void*)v0; + vtable[1] = (void*)v1; + vtable[2] = (void*)v2; + } + + { + vtable[3] = (void*)(delegate* unmanaged )&ABI_GetInt; + vtable[4] = (void*)(delegate* unmanaged )&ABI_SetInt; + } + + return vtable; + } +} + +namespace SharedTypes.ComInterfaces +{ + [System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute] + partial interface IGetAndSetInt + { + } +} \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs new file mode 100644 index 0000000000000..f17fa80401881 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs @@ -0,0 +1,108 @@ +file unsafe class InterfaceInformation : System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType +{ + public static System.Guid Iid { get; } = new(new System.ReadOnlySpan(new byte[] { 10, 42, 128, 125, 10, 99, 142, 76, 162, 31, 119, 28, 201, 3, 31, 185 })); + + private static void** _vtable; + public static void** ManagedVirtualMethodTable => _vtable != null ? _vtable : (_vtable = InterfaceImplementation.CreateManagedVirtualFunctionTable()); +} + +[System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] +file unsafe partial interface InterfaceImplementation : global::SharedTypes.ComInterfaces.IGetIntArray +{ + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + int[] global::SharedTypes.ComInterfaces.IGetIntArray.GetInts() + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::SharedTypes.ComInterfaces.IGetIntArray)); + int[] __retVal; + int* __retVal_native = default; + int __invokeRetVal; + // Setup - Perform required setup. + int __retVal_native__numElements; + System.Runtime.CompilerServices.Unsafe.SkipInit(out __retVal_native__numElements); + try + { + { + __invokeRetVal = ((delegate* unmanaged )__vtable_native[3])(__this, &__retVal_native); + } + + // Unmarshal - Convert native data to managed data. + System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); + __retVal_native__numElements = 10; + __retVal = global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.AllocateContainerForManagedElements(__retVal_native, __retVal_native__numElements); + global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.GetUnmanagedValuesSource(__retVal_native, __retVal_native__numElements).CopyTo(global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.GetManagedValuesDestination(__retVal)); + } + finally + { + // Cleanup - Perform required cleanup. + global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.Free(__retVal_native); + } + + return __retVal; + } +} + +file unsafe partial interface InterfaceImplementation +{ + [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] + internal static int ABI_GetInts(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int** __invokeRetValUnmanaged__param) + { + global::SharedTypes.ComInterfaces.IGetIntArray @this = default; + ref int* __invokeRetValUnmanaged = ref *__invokeRetValUnmanaged__param; + int[] __invokeRetVal = default; + int __retVal = default; + // Setup - Perform required setup. + int __invokeRetValUnmanaged__numElements; + System.Runtime.CompilerServices.Unsafe.SkipInit(out __invokeRetValUnmanaged__numElements); + try + { + // Unmarshal - Convert native data to managed data. + __retVal = 0; // S_OK + @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); + __invokeRetVal = @this.GetInts(); + // Marshal - Convert managed data to native data. + __invokeRetValUnmanaged = global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.AllocateContainerForUnmanagedElements(__invokeRetVal, out __invokeRetValUnmanaged__numElements); + global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.GetManagedValuesSource(__invokeRetVal).CopyTo(global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.GetUnmanagedValuesDestination(__invokeRetValUnmanaged, __invokeRetValUnmanaged__numElements)); + } + catch (System.Exception __exception) + { + __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); + } + finally + { + // Cleanup - Perform required cleanup. + global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.Free(__invokeRetValUnmanaged); + } + + return __retVal; + } +} + +file unsafe partial interface InterfaceImplementation +{ + internal static void** CreateManagedVirtualFunctionTable() + { + void** vtable = (void**)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(global::SharedTypes.ComInterfaces.IGetIntArray), sizeof(void*) * 4); + { + nint v0, v1, v2; + System.Runtime.InteropServices.ComWrappers.GetIUnknownImpl(out v0, out v1, out v2); + vtable[0] = (void*)v0; + vtable[1] = (void*)v1; + vtable[2] = (void*)v2; + } + + { + vtable[3] = (void*)(delegate* unmanaged )&ABI_GetInts; + } + + return vtable; + } +} + +namespace SharedTypes.ComInterfaces +{ + [System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute] + partial interface IGetIntArray + { + } +} \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.LibraryImportGenerator/Microsoft.Interop.LibraryImportGenerator/LibraryImports.g.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.LibraryImportGenerator/Microsoft.Interop.LibraryImportGenerator/LibraryImports.g.cs index e6463e52e37a8..8e1783fce02a9 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.LibraryImportGenerator/Microsoft.Interop.LibraryImportGenerator/LibraryImports.g.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.LibraryImportGenerator/Microsoft.Interop.LibraryImportGenerator/LibraryImports.g.cs @@ -24,6 +24,22 @@ internal unsafe partial class NativeExportsNE } } namespace ComInterfaceGenerator.Tests +{ + public unsafe partial class IGetAndSetIntTests + { + [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "new_get_and_set_int", ExactSpelling = true)] + public static extern partial void* NewNativeObject(); + } +} +namespace ComInterfaceGenerator.Tests +{ + public unsafe partial class IGetIntArrayTests + { + [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "new_get_and_set_int_array", ExactSpelling = true)] + public static extern partial void* NewNativeObject(); + } +} +namespace ComInterfaceGenerator.Tests { internal unsafe partial class NativeExportsNE { From b3ad04c346e4a77fce3e0352eabfdca382b90d3c Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Mon, 1 May 2023 13:13:28 -0500 Subject: [PATCH 10/31] wip --- .../ComInterfaceGenerator.cs | 23 ++++- .../ComInterfaceGenerator/ComMethodContext.cs | 9 +- .../ComInterfaceGenerator/ComMethodInfo.cs | 20 ++++- ...InterfaceGenerator.Tests.IComInterface1.cs | 7 ++ ...aceGenerator.Tests.IDerivedComInterface.cs | 89 +++++++++++++++++++ ...SharedTypes.ComInterfaces.IGetAndSetInt.cs | 7 ++ .../SharedTypes.ComInterfaces.IGetIntArray.cs | 7 ++ 7 files changed, 153 insertions(+), 9 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 809e56b5e4896..480363baa6f99 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -163,6 +163,19 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); + var shadowingMethods = interfacesWithMethodsAndItsMethods + .Select((data, ct) => + { + var context = data.Interface.Info; + var methods = data.ShadowingMethods.Select(m => (MemberDeclarationSyntax)m.GenerateShadow()); + var asdf = TypeDeclaration(context.ContainingSyntax.TypeKind, context.ContainingSyntax.Identifier) + .WithModifiers(context.ContainingSyntax.Modifiers) + .WithTypeParameterList(context.ContainingSyntax.TypeParameters) + .WithMembers(List(methods)); + return data.Interface.Info.TypeDefinitionContext.WrapMemberInContainingSyntaxWithUnsafeModifier(asdf); + }) + .SelectNormalized(); + // Generate a method named CreateManagedVirtualFunctionTable on the native interface implementation // that allocates and fills in the memory for the vtable. var nativeToManagedVtables = interfacesWithMethodsAndItsMethods @@ -184,9 +197,10 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .Zip(nativeToManagedVtableMethods) .Zip(nativeToManagedVtables) .Zip(iUnknownDerivedAttributeApplication) + .Zip(shadowingMethods) .Select(static (data, ct) => { - var (((((interfaceContext, interfaceInfo), managedToNativeStubs), nativeToManagedStubs), nativeToManagedVtable), iUnknownDerivedAttribute) = data; + var ((((((interfaceContext, interfaceInfo), managedToNativeStubs), nativeToManagedStubs), nativeToManagedVtable), iUnknownDerivedAttribute), shadowingMethod) = data; using StringWriter source = new(); interfaceInfo.WriteTo(source); @@ -204,6 +218,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) source.WriteLine(); source.WriteLine(); iUnknownDerivedAttribute.WriteTo(source); + source.WriteLine(); + source.WriteLine(); + shadowingMethod.WriteTo(source); return new { TypeName = interfaceContext.Info.Type.FullTypeName, Source = source.ToString() }; }); @@ -377,13 +394,13 @@ private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInt var definingType = interfaceGroup.Interface.Info.Type; return ImplementationInterfaceTemplate .AddBaseListTypes(SimpleBaseType(definingType.Syntax)) - .WithMembers(List(interfaceGroup.DeclaredMethods.Select(m => m.ManagedToUnmanagedStub).OfType().Select(ctx => ctx.Stub.Node))) + .WithMembers(List(interfaceGroup.Methods.Select(m => m.ManagedToUnmanagedStub).OfType().Select(ctx => ctx.Stub.Node))) .AddAttributeLists(AttributeList(SingletonSeparatedList(Attribute(ParseName(TypeNames.System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute))))); } private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(ComInterfaceAndMethodsContext comInterfaceAndMethods, CancellationToken _) { return ImplementationInterfaceTemplate - .WithMembers(List(comInterfaceAndMethods.DeclaredMethods.Select(m => m.NativeToManagedStub).OfType().Select(context => context.Stub.Node))); + .WithMembers(List(comInterfaceAndMethods.Methods.Select(m => m.NativeToManagedStub).OfType().Select(context => context.Stub.Node))); } private static readonly TypeSyntax VoidStarStarSyntax = PointerType(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))); diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs index dba274da462b4..ffbf78ab68083 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs @@ -50,17 +50,20 @@ public MethodDeclarationSyntax GenerateShadow() // { // return (()this).(); // } - return MethodInfo.Syntax.WithBody( + return MethodInfo.Syntax + .WithModifiers(TokenList(Token(SyntaxKind.NewKeyword))) + .WithBody( Block( ReturnStatement( InvocationExpression( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, - CastExpression(DeclaringInterface.Info.Type.Syntax, IdentifierName(Token(SyntaxKind.ThisKeyword))), + CastExpression(DeclaringInterface.Info.Type.Syntax, IdentifierName("this")), // Token(SyntaxKind.ThisKeyword))), IdentifierName(MethodInfo.MethodName)), ArgumentList( // TODO: RefKind keywords - SeparatedList(MethodInfo.Parameters.Select(p => Argument(IdentifierName(p.Name))))))))); + SeparatedList(MethodInfo.Parameters.Select(p => + Argument(IdentifierName(p.Name))))))))); } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs index f19892e14a738..1da459cdbba48 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs @@ -14,6 +14,7 @@ namespace Microsoft.Interop { public sealed partial class ComInterfaceGenerator { + /// /// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax. /// @@ -21,7 +22,7 @@ private sealed record ComMethodInfo( [property: Obsolete] IMethodSymbol Symbol, MethodDeclarationSyntax Syntax, string MethodName, - SequenceEqualImmutableArray<(ManagedTypeInfo Type, string Name, RefKind RefKind)> Parameters, + SequenceEqualImmutableArray Parameters, Diagnostic? Diagnostic) { private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax comMethodDeclaringSyntax, IMethodSymbol method) @@ -87,10 +88,10 @@ public static bool IsComMethod(ComInterfaceInfo ifaceContext, ISymbol member, [N if (comMethodDeclaringSyntax is null) throw new NotImplementedException("Found a method that was declared in the attributed interface declaration, but couldn't find the syntax for it."); - List<(ManagedTypeInfo ParameterType, string Name, RefKind RefKind)> parameters = new(); + List parameters = new(); foreach (var parameter in ((IMethodSymbol)member).Parameters) { - parameters.Add((ManagedTypeInfo.CreateTypeInfoForTypeSymbol(parameter.Type), parameter.Name, parameter.RefKind)); + parameters.Add(ParameterInfo.From(parameter)); } diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, (IMethodSymbol)member); @@ -102,4 +103,17 @@ public static bool IsComMethod(ComInterfaceInfo ifaceContext, ISymbol member, [N } } } + + internal record struct ParameterInfo(ManagedTypeInfo Type, string Name, RefKind RefKind, SequenceEqualImmutableArray Attributes) + { + public static ParameterInfo From(IParameterSymbol parameter) + { + var attributes = new List(); + foreach (var attribute in parameter.GetAttributes()) + { + attributes.Add(AttributeInfo.From(attribute)); + } + return new(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(parameter.Type), parameter.Name, parameter.RefKind, attributes.ToSequenceEqualImmutableArray()); + } + } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IComInterface1.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IComInterface1.cs index a5838645227fa..869a49747e313 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IComInterface1.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IComInterface1.cs @@ -115,4 +115,11 @@ namespace ComInterfaceGenerator.Tests public partial interface IComInterface1 { } +} + +namespace ComInterfaceGenerator.Tests +{ + public partial interface IComInterface1 + { + } } \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs index 0f90d2c1a4e5b..5617867b92132 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs @@ -9,6 +9,36 @@ [System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] file unsafe partial interface InterfaceImplementation : global::ComInterfaceGenerator.Tests.IDerivedComInterface { + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + int global::ComInterfaceGenerator.Tests.IComInterface1.GetData() + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IComInterface1)); + int __retVal; + int __invokeRetVal; + { + __invokeRetVal = ((delegate* unmanaged )__vtable_native[3])(__this, &__retVal); + } + + // Unmarshal - Convert native data to managed data. + System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); + return __retVal; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + void global::ComInterfaceGenerator.Tests.IComInterface1.SetData(int n) + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IComInterface1)); + int __invokeRetVal; + { + __invokeRetVal = ((delegate* unmanaged )__vtable_native[4])(__this, n); + } + + // Unmarshal - Convert native data to managed data. + System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); + } + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] [System.Runtime.CompilerServices.SkipLocalsInitAttribute] void global::ComInterfaceGenerator.Tests.IDerivedComInterface.SetName(string name) @@ -55,6 +85,50 @@ file unsafe partial interface InterfaceImplementation { + [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] + internal static int ABI_GetData(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int* __invokeRetValUnmanaged__param) + { + global::ComInterfaceGenerator.Tests.IComInterface1 @this = default; + ref int __invokeRetValUnmanaged = ref *__invokeRetValUnmanaged__param; + int __invokeRetVal = default; + int __retVal = default; + try + { + // Unmarshal - Convert native data to managed data. + __retVal = 0; // S_OK + @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); + __invokeRetVal = @this.GetData(); + // Marshal - Convert managed data to native data. + __invokeRetValUnmanaged = __invokeRetVal; + } + catch (System.Exception __exception) + { + __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); + } + + return __retVal; + } + + [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] + internal static int ABI_SetData(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int n) + { + global::ComInterfaceGenerator.Tests.IComInterface1 @this = default; + int __retVal = default; + try + { + // Unmarshal - Convert native data to managed data. + __retVal = 0; // S_OK + @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); + @this.SetData(n); + } + catch (System.Exception __exception) + { + __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); + } + + return __retVal; + } + [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] internal static int ABI_SetName(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, ushort* __name_native) { @@ -136,4 +210,19 @@ namespace ComInterfaceGenerator.Tests public partial interface IDerivedComInterface { } +} + +namespace ComInterfaceGenerator.Tests +{ + public partial interface IDerivedComInterface + { + new int GetData() + { + return (global::ComInterfaceGenerator.Tests.IComInterface1)this.GetData(); + }; + new void SetData(int n) + { + return (global::ComInterfaceGenerator.Tests.IComInterface1)this.SetData(n); + }; + } } \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetAndSetInt.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetAndSetInt.cs index b7ec44242a3b2..987bdab087a89 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetAndSetInt.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetAndSetInt.cs @@ -115,4 +115,11 @@ namespace SharedTypes.ComInterfaces partial interface IGetAndSetInt { } +} + +namespace SharedTypes.ComInterfaces +{ + partial interface IGetAndSetInt + { + } } \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs index f17fa80401881..5a2bf629d7625 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs @@ -105,4 +105,11 @@ namespace SharedTypes.ComInterfaces partial interface IGetIntArray { } +} + +namespace SharedTypes.ComInterfaces +{ + partial interface IGetIntArray + { + } } \ No newline at end of file From d0394109ff54810901c99d3d5e3804e1f9235147 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Thu, 4 May 2023 17:04:38 -0500 Subject: [PATCH 11/31] Generate shadowing definitions and implementations --- .../ComInterfaceGenerator/AttributeInfo.cs | 19 ++++++ .../ComInterfaceAndMethodsContext.cs | 38 +++++++----- .../ComInterfaceContext.cs | 5 +- .../ComInterfaceGenerator.cs | 28 ++++++--- .../ComInterfaceGenerator/ComMethodContext.cs | 11 ++-- .../CustomTypeMarshallingGenerator.cs | 6 +- ....GeneratedComInterfaceTests.DerivedImpl.cs | 39 +++++++++++++ ...aceGenerator.Tests.IDerivedComInterface.cs | 45 +++++++++----- .../SharedTypes.ComInterfaces.IGetIntArray.cs | 5 -- .../GeneratedComInterfaceTests.cs | 58 ++++++++++++++++++- 10 files changed, 205 insertions(+), 49 deletions(-) create mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.GeneratedComInterfaceTests.DerivedImpl.cs diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs new file mode 100644 index 0000000000000..997bfb467d501 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; + +namespace Microsoft.Interop +{ + internal sealed record AttributeInfo(ManagedTypeInfo Type, SequenceEqualImmutableArray Arguments) + { + internal static AttributeInfo From(AttributeData attribute) + { + var type = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(attribute.AttributeClass); + var args = attribute.ConstructorArguments.Select(ca => ca.ToCSharpString()); + return new(type, args.ToSequenceEqualImmutableArray()); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs index 0bb07134c8dd2..622f98136796b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs @@ -14,7 +14,7 @@ public sealed partial class ComInterfaceGenerator /// /// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces). /// - private sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interface, SequenceEqualImmutableArray Methods) + private sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interface, SequenceEqualImmutableArray Methods, SequenceEqualImmutableArray ShadowingMethods) { /// /// COM methods that are declared on the attributed interface declaration. @@ -24,23 +24,23 @@ private sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interfac /// /// COM methods that are declared on an interface the interface inherits from. /// - public IEnumerable ShadowingMethods => Methods.Where(m => m.DeclaringInterface != Interface); + public IEnumerable InheritedMethods => Methods.Where(m => m.DeclaringInterface != Interface); - internal static ComInterfaceAndMethodsContext From((ComInterfaceContext, SequenceEqualImmutableArray) data, CancellationToken _) - => new ComInterfaceAndMethodsContext(data.Item1, data.Item2); + //internal static ComInterfaceAndMethodsContext From((ComInterfaceContext, SequenceEqualImmutableArray) data, CancellationToken _) + // => new ComInterfaceAndMethodsContext(data.Item1, data.Item2); public static IEnumerable CalculateAllMethods(ValueEqualityImmutableDictionary> ifaceToDeclaredMethodsMap, StubEnvironment environment, CancellationToken ct) { - Dictionary> allMethodsCache = new(); + Dictionary Methods, ImmutableArray ShadowingMethods)> allMethodsCache = new(); foreach (var kvp in ifaceToDeclaredMethodsMap) { AddMethods(kvp.Key, kvp.Value); } - return allMethodsCache.Select(kvp => new ComInterfaceAndMethodsContext(kvp.Key, kvp.Value.ToSequenceEqual())); + return allMethodsCache.Select(kvp => new ComInterfaceAndMethodsContext(kvp.Key, kvp.Value.Methods.ToSequenceEqual(), kvp.Value.ShadowingMethods.ToSequenceEqualImmutableArray())); - ImmutableArray AddMethods(ComInterfaceContext iface, IEnumerable declaredMethods) + (ImmutableArray Methods, ImmutableArray ShadowingMethods) AddMethods(ComInterfaceContext iface, IEnumerable declaredMethods) { if (allMethodsCache.TryGetValue(iface, out var cachedValue)) { @@ -53,23 +53,33 @@ ImmutableArray AddMethods(ComInterfaceContext iface, IEnumerab if (iface.Base is not null) { var baseComIface = iface.Base; - if (!allMethodsCache.TryGetValue(baseComIface, out var baseMethods)) + ImmutableArray baseMethods; + if (!allMethodsCache.TryGetValue(baseComIface, out var pair)) { - baseMethods = AddMethods(baseComIface, ifaceToDeclaredMethodsMap[baseComIface]); + baseMethods = AddMethods(baseComIface, ifaceToDeclaredMethodsMap[baseComIface]).Methods; + } + else + { + baseMethods = pair.Methods; } methods.AddRange(baseMethods); - startingIndex += baseMethods.Length; } + var shadowingMethods = methods.Select(method => + { + var info = method.MethodInfo; + var ctx = CalculateStubInformation(info.Syntax, info.Symbol, startingIndex, environment, iface.Info.Type, ct); + return new ComMethodContext(iface, info, startingIndex++, ctx); + }).ToImmutableArray(); // Then we append the declared methods in vtable order foreach (var method in declaredMethods) { - var ctx = CalculateStubInformation(method.Syntax, method.Symbol, startingIndex, environment, ct); + var ctx = CalculateStubInformation(method.Syntax, method.Symbol, startingIndex, environment, iface.Info.Type, ct); methods.Add(new ComMethodContext(iface, method, startingIndex++, ctx)); } // Cache so we don't recalculate if many interfaces inherit from the same one - var immutableMethods = methods.ToImmutableArray(); - allMethodsCache[iface] = immutableMethods; - return immutableMethods; + var finalPair = (methods.ToImmutableArray(), shadowingMethods); + allMethodsCache[iface] = finalPair; + return finalPair; } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs index aa0c48e5f04ff..35b7f39de3dd6 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs @@ -14,7 +14,7 @@ private sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceCon /// /// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext. Does not guarantee the ordering of the output. /// - public static Dictionary.ValueCollection GetContexts(ImmutableArray data, CancellationToken _) + public static IEnumerable GetContexts(ImmutableArray data, CancellationToken _) { Dictionary symbolToInterfaceInfoMap = new(); foreach (var iface in data) @@ -25,9 +25,8 @@ public static Dictionary.ValueCollection GetContext foreach (var iface in data) { - AddContext(iface); + yield return AddContext(iface); } - return symbolToContextMap.Values; ComInterfaceContext AddContext(ComInterfaceInfo iface) { diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 480363baa6f99..ec48d76e9bd07 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -167,7 +167,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .Select((data, ct) => { var context = data.Interface.Info; - var methods = data.ShadowingMethods.Select(m => (MemberDeclarationSyntax)m.GenerateShadow()); + var methods = data.InheritedMethods.Select(m => (MemberDeclarationSyntax)m.GenerateShadow()); var asdf = TypeDeclaration(context.ContainingSyntax.TypeKind, context.ContainingSyntax.Identifier) .WithModifiers(context.ContainingSyntax.Modifiers) .WithTypeParameterList(context.ContainingSyntax.TypeParameters) @@ -272,7 +272,7 @@ private static MemberDeclarationSyntax GenerateIUnknownDerivedAttributeApplicati .AddAttributeLists(AttributeList(SingletonSeparatedList(s_iUnknownDerivedAttributeTemplate)))); // Todo: extract info needed from the IMethodSymbol into MethodInfo and only pass a MethodInfo here - private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, CancellationToken ct) + private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ManagedTypeInfo typeKeyOwner, CancellationToken ct) { ct.ThrowIfCancellationRequested(); INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute); @@ -365,8 +365,6 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M ImmutableArray callConv = VtableIndexStubGenerator.GenerateCallConvSyntaxFromAttributes(suppressGCTransitionAttribute, unmanagedCallConvAttribute); - var typeKeyOwner = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType); - var virtualMethodIndexData = new VirtualMethodIndexData(index, ImplicitThisParameter: true, MarshalDirection.Bidirectional, true, ExceptionMarshalling.Com); return new IncrementalMethodStubGenerationContext( @@ -392,15 +390,31 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _) { var definingType = interfaceGroup.Interface.Info.Type; + var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (m, m.ManagedToUnmanagedStub)) + .Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext) + .Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node + .WithExplicitInterfaceSpecifier( + ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName)))); return ImplementationInterfaceTemplate .AddBaseListTypes(SimpleBaseType(definingType.Syntax)) - .WithMembers(List(interfaceGroup.Methods.Select(m => m.ManagedToUnmanagedStub).OfType().Select(ctx => ctx.Stub.Node))) + .WithMembers( + List( + interfaceGroup.Methods + .Select(m => m.ManagedToUnmanagedStub) + .OfType() + .Select(ctx => ctx.Stub.Node) + .Concat(shadowImplementations))) .AddAttributeLists(AttributeList(SingletonSeparatedList(Attribute(ParseName(TypeNames.System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute))))); } private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(ComInterfaceAndMethodsContext comInterfaceAndMethods, CancellationToken _) { return ImplementationInterfaceTemplate - .WithMembers(List(comInterfaceAndMethods.Methods.Select(m => m.NativeToManagedStub).OfType().Select(context => context.Stub.Node))); + .WithMembers( + List( + comInterfaceAndMethods.Methods + .Select(m => m.NativeToManagedStub) + .OfType() + .Select(context => context.Stub.Node) )); } private static readonly TypeSyntax VoidStarStarSyntax = PointerType(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))); @@ -566,7 +580,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf ParenthesizedExpression( BinaryExpression(SyntaxKind.MultiplyExpression, SizeOfExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))), - LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.ShadowingMethods.Count() + 3)))))) + LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.InheritedMethods.Count() + 3)))))) }))))); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs index ffbf78ab68083..1776e7a0a5d93 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs @@ -50,20 +50,21 @@ public MethodDeclarationSyntax GenerateShadow() // { // return (()this).(); // } + // TODO: Copy full name of parameter types and attributes / attribute arguments for parameters return MethodInfo.Syntax .WithModifiers(TokenList(Token(SyntaxKind.NewKeyword))) - .WithBody( - Block( - ReturnStatement( + .WithExpressionBody( + ArrowExpressionClause( InvocationExpression( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, - CastExpression(DeclaringInterface.Info.Type.Syntax, IdentifierName("this")), // Token(SyntaxKind.ThisKeyword))), + ParenthesizedExpression( + CastExpression(DeclaringInterface.Info.Type.Syntax, IdentifierName("this"))), IdentifierName(MethodInfo.MethodName)), ArgumentList( // TODO: RefKind keywords SeparatedList(MethodInfo.Parameters.Select(p => - Argument(IdentifierName(p.Name))))))))); + Argument(IdentifierName(p.Name)))))))); } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs index d5b2da8213490..70d7cd6c86913 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs @@ -97,7 +97,11 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont } break; case StubCodeContext.Stage.Cleanup: - return _nativeTypeMarshaller.GenerateCleanupStatements(info, context); + if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged) + { + return _nativeTypeMarshaller.GenerateCleanupStatements(info, context); + } + break; default: break; } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.GeneratedComInterfaceTests.DerivedImpl.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.GeneratedComInterfaceTests.DerivedImpl.cs new file mode 100644 index 0000000000000..02edb1b548f10 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.GeneratedComInterfaceTests.DerivedImpl.cs @@ -0,0 +1,39 @@ +file sealed unsafe class ComClassInformation : System.Runtime.InteropServices.Marshalling.IComExposedClass +{ + private static volatile System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* s_vtables; + public static System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) + { + count = 2; + if (s_vtables == null) + { + System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* vtables = (System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry*)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComClassInformation), sizeof(System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry) * 2); + System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails details; + details = System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(ComInterfaceGenerator.Tests.IDerivedComInterface).TypeHandle); + vtables[0] = new() + { + IID = details.Iid, + Vtable = (nint)details.ManagedVirtualMethodTable + }; + details = System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(ComInterfaceGenerator.Tests.IComInterface1).TypeHandle); + vtables[1] = new() + { + IID = details.Iid, + Vtable = (nint)details.ManagedVirtualMethodTable + }; + s_vtables = vtables; + } + + return s_vtables; + } +} + +namespace ComInterfaceGenerator.Tests +{ + public unsafe partial class GeneratedComInterfaceTests + { + [System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute] + partial class DerivedImpl + { + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs index 5617867b92132..86d4417dfc78a 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs @@ -81,6 +81,36 @@ return __retVal; } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + int global::ComInterfaceGenerator.Tests.IDerivedComInterface.GetData() + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IDerivedComInterface)); + int __retVal; + int __invokeRetVal; + { + __invokeRetVal = ((delegate* unmanaged )__vtable_native[3])(__this, &__retVal); + } + + // Unmarshal - Convert native data to managed data. + System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); + return __retVal; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] + [System.Runtime.CompilerServices.SkipLocalsInitAttribute] + void global::ComInterfaceGenerator.Tests.IDerivedComInterface.SetData(int n) + { + var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IDerivedComInterface)); + int __invokeRetVal; + { + __invokeRetVal = ((delegate* unmanaged )__vtable_native[4])(__this, n); + } + + // Unmarshal - Convert native data to managed data. + System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); + } } file unsafe partial interface InterfaceImplementation @@ -176,11 +206,6 @@ internal static int ABI_GetName(System.Runtime.InteropServices.ComWrappers.ComIn { __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); } - finally - { - // Cleanup - Perform required cleanup. - global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.Free(__invokeRetValUnmanaged); - } return __retVal; } @@ -216,13 +241,7 @@ namespace ComInterfaceGenerator.Tests { public partial interface IDerivedComInterface { - new int GetData() - { - return (global::ComInterfaceGenerator.Tests.IComInterface1)this.GetData(); - }; - new void SetData(int n) - { - return (global::ComInterfaceGenerator.Tests.IComInterface1)this.SetData(n); - }; + new int GetData() => ((global::ComInterfaceGenerator.Tests.IComInterface1)this).GetData(); + new void SetData(int n) => ((global::ComInterfaceGenerator.Tests.IComInterface1)this).SetData(n); } } \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs index 5a2bf629d7625..0dff938dd3f99 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs @@ -68,11 +68,6 @@ internal static int ABI_GetInts(System.Runtime.InteropServices.ComWrappers.ComIn { __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); } - finally - { - // Cleanup - Perform required cleanup. - global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.Free(__invokeRetValUnmanaged); - } return __retVal; } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs index 494f0baf88a66..7a12364735374 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs @@ -9,6 +9,7 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; +using System.Runtime.Remoting; using Xunit; using Xunit.Sdk; @@ -21,7 +22,7 @@ internal unsafe partial class NativeExportsNE } -public class GeneratedComInterfaceTests +public partial class GeneratedComInterfaceTests { [Fact] public unsafe void CallNativeComObjectThroughGeneratedStub() @@ -54,4 +55,59 @@ public unsafe void DerivedInterfaceTypeProvidesBaseInterfaceUnmanagedToManagedMe Assert.True(expected.SequenceEqual(actual)); } + + [Fact] + public unsafe void CallBaseInterfaceMethod_EnsureQiCalledOnce() + { + var cw = new SingleQIComWrapper(); + var asdf = new DerivedImpl(); + var nativeObj = cw.GetOrCreateComInterfaceForObject(asdf, CreateComInterfaceFlags.None); + var obj = cw.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None); + IDerivedComInterface iface = (IDerivedComInterface)obj; + + Assert.Equal(3, iface.GetData()); + iface.SetData(5); + Assert.Equal(5, iface.GetData()); + + Assert.Equal("myName", iface.GetName()); + // https://github.com/dotnet/runtime/issues/85795 + //iface.SetName("updated"); + //Assert.Equal("updated", iface.GetName()); + + var qiCallCountObj = obj.GetType().GetRuntimeProperties().Where(p => p.Name == "IUnknownStrategy").Single().GetValue(obj); + var countQi = (SingleQIComWrapper.CountQI)qiCallCountObj; + Assert.Equal(1, countQi.QiCallCount); + } + + [GeneratedComClass] + partial class DerivedImpl : IDerivedComInterface + { + int data = 3; + string myName = "myName"; + public int GetData() => data; + [return: MarshalUsing(typeof(Utf16StringMarshaller))] + public string GetName() => myName; + public void SetData(int n) => data = n; + public void SetName([MarshalUsing(typeof(Utf16StringMarshaller))] string name) => myName = name; + } + + class SingleQIComWrapper : StrategyBasedComWrappers + { + public class CountQI : IIUnknownStrategy + { + public CountQI(IIUnknownStrategy iUnknown) => _iUnknownStrategy = iUnknown; + private IIUnknownStrategy _iUnknownStrategy; + public int QiCallCount = 0; + public unsafe void* CreateInstancePointer(void* unknown) => _iUnknownStrategy.CreateInstancePointer(unknown); + public unsafe int QueryInterface(void* instancePtr, in Guid iid, out void* ppObj) + { + QiCallCount++; + return _iUnknownStrategy.QueryInterface(instancePtr, in iid, out ppObj); + } + public unsafe int Release(void* instancePtr) => _iUnknownStrategy.Release(instancePtr); + } + + protected override IIUnknownStrategy GetOrCreateIUnknownStrategy() + => new CountQI(base.GetOrCreateIUnknownStrategy()); + } } From 3915490481bfedc59fb55785573b020544de2f9e Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Thu, 4 May 2023 17:10:30 -0500 Subject: [PATCH 12/31] rename --- .../gen/ComInterfaceGenerator/ComInterfaceGenerator.cs | 4 ++-- .../ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index ec48d76e9bd07..7cb28a5f5eb84 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -168,11 +168,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { var context = data.Interface.Info; var methods = data.InheritedMethods.Select(m => (MemberDeclarationSyntax)m.GenerateShadow()); - var asdf = TypeDeclaration(context.ContainingSyntax.TypeKind, context.ContainingSyntax.Identifier) + var typeDecl = TypeDeclaration(context.ContainingSyntax.TypeKind, context.ContainingSyntax.Identifier) .WithModifiers(context.ContainingSyntax.Modifiers) .WithTypeParameterList(context.ContainingSyntax.TypeParameters) .WithMembers(List(methods)); - return data.Interface.Info.TypeDefinitionContext.WrapMemberInContainingSyntaxWithUnsafeModifier(asdf); + return data.Interface.Info.TypeDefinitionContext.WrapMemberInContainingSyntaxWithUnsafeModifier(typeDecl); }) .SelectNormalized(); diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs index 7a12364735374..654ef09490ecf 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs @@ -60,8 +60,8 @@ public unsafe void DerivedInterfaceTypeProvidesBaseInterfaceUnmanagedToManagedMe public unsafe void CallBaseInterfaceMethod_EnsureQiCalledOnce() { var cw = new SingleQIComWrapper(); - var asdf = new DerivedImpl(); - var nativeObj = cw.GetOrCreateComInterfaceForObject(asdf, CreateComInterfaceFlags.None); + var derivedImpl = new DerivedImpl(); + var nativeObj = cw.GetOrCreateComInterfaceForObject(derivedImpl, CreateComInterfaceFlags.None); var obj = cw.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None); IDerivedComInterface iface = (IDerivedComInterface)obj; From 3eba497f5e07b8ce296feb679d87b79a1adc7b9e Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Thu, 4 May 2023 17:13:18 -0500 Subject: [PATCH 13/31] Don't generate Vtable methods for inherited methods --- .../ComInterfaceGenerator.cs | 2 +- ...aceGenerator.Tests.IDerivedComInterface.cs | 44 ------------------- 2 files changed, 1 insertion(+), 45 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 7cb28a5f5eb84..0a949b9ed8b72 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -411,7 +411,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(Co return ImplementationInterfaceTemplate .WithMembers( List( - comInterfaceAndMethods.Methods + comInterfaceAndMethods.DeclaredMethods .Select(m => m.NativeToManagedStub) .OfType() .Select(context => context.Stub.Node) )); diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs index 86d4417dfc78a..bca77fe757d09 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs @@ -115,50 +115,6 @@ file unsafe partial interface InterfaceImplementation { - [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] - internal static int ABI_GetData(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int* __invokeRetValUnmanaged__param) - { - global::ComInterfaceGenerator.Tests.IComInterface1 @this = default; - ref int __invokeRetValUnmanaged = ref *__invokeRetValUnmanaged__param; - int __invokeRetVal = default; - int __retVal = default; - try - { - // Unmarshal - Convert native data to managed data. - __retVal = 0; // S_OK - @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); - __invokeRetVal = @this.GetData(); - // Marshal - Convert managed data to native data. - __invokeRetValUnmanaged = __invokeRetVal; - } - catch (System.Exception __exception) - { - __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); - } - - return __retVal; - } - - [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] - internal static int ABI_SetData(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int n) - { - global::ComInterfaceGenerator.Tests.IComInterface1 @this = default; - int __retVal = default; - try - { - // Unmarshal - Convert native data to managed data. - __retVal = 0; // S_OK - @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); - @this.SetData(n); - } - catch (System.Exception __exception) - { - __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); - } - - return __retVal; - } - [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] internal static int ABI_SetName(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, ushort* __name_native) { From 04b8309b146f97b3e6012e5f5e5a0c7dc169b56a Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Fri, 5 May 2023 14:58:28 -0500 Subject: [PATCH 14/31] Stub out inherited methods on InterfaceImplementation --- .../ComInterfaceGenerator.cs | 6 ++-- .../ComInterfaceGenerator/ComMethodContext.cs | 14 ++++++++ .../CustomTypeMarshallingGenerator.cs | 6 +--- .../TypeNames.cs | 1 + ...aceGenerator.Tests.IDerivedComInterface.cs | 33 ++----------------- .../GeneratedComInterfaceTests.cs | 5 +-- 6 files changed, 26 insertions(+), 39 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 0a949b9ed8b72..eb57713b67a43 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -395,15 +395,17 @@ private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInt .Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node .WithExplicitInterfaceSpecifier( ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName)))); + var inheritedStubs = interfaceGroup.InheritedMethods.Select(m => m.GenerateUnreachableExceptionStub()); return ImplementationInterfaceTemplate .AddBaseListTypes(SimpleBaseType(definingType.Syntax)) .WithMembers( List( - interfaceGroup.Methods + interfaceGroup.DeclaredMethods .Select(m => m.ManagedToUnmanagedStub) .OfType() .Select(ctx => ctx.Stub.Node) - .Concat(shadowImplementations))) + .Concat(shadowImplementations) + .Concat(inheritedStubs))) .AddAttributeLists(AttributeList(SingletonSeparatedList(Attribute(ParseName(TypeNames.System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute))))); } private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(ComInterfaceAndMethodsContext comInterfaceAndMethods, CancellationToken _) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs index 1776e7a0a5d93..ddf357f768ea1 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs @@ -44,6 +44,20 @@ public GeneratedMethodContextBase NativeToManagedStub } } + public MethodDeclarationSyntax GenerateUnreachableExceptionStub() + { + // DeclarationCopiedFromBaseDeclaration() => throw new UnreachableException("This method should not be reached"); + return MethodInfo.Syntax + .WithAttributeLists(List()) + .WithExplicitInterfaceSpecifier(ExplicitInterfaceSpecifier( + ParseName(DeclaringInterface.Info.Type.FullTypeName))) + .WithExpressionBody(ArrowExpressionClause( + ThrowExpression( + ObjectCreationExpression( + ParseTypeName(TypeNames.UnreachableException)) + .WithArgumentList(ArgumentList())))); + } + public MethodDeclarationSyntax GenerateShadow() { // DeclarationCopiedFromBaseDeclaration() diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs index 70d7cd6c86913..d5b2da8213490 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs @@ -97,11 +97,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont } break; case StubCodeContext.Stage.Cleanup: - if (elementMarshalDirection is MarshalDirection.UnmanagedToManaged) - { - return _nativeTypeMarshaller.GenerateCleanupStatements(info, context); - } - break; + return _nativeTypeMarshaller.GenerateCleanupStatements(info, context); default: break; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs index 4db27ba1e0bde..274f6cae75e86 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs @@ -136,5 +136,6 @@ public static string MarshalEx(InteropGenerationOptions options) public const string GeneratedComClassAttribute = "System.Runtime.InteropServices.Marshalling.GeneratedComClassAttribute"; public const string ComExposedClassAttribute = "System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute"; public const string IComExposedClass = "System.Runtime.InteropServices.Marshalling.IComExposedClass"; + public const string UnreachableException = "System.Diagnostics.UnreachableException"; } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs index bca77fe757d09..fcbda10c8adc5 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs @@ -9,36 +9,6 @@ [System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] file unsafe partial interface InterfaceImplementation : global::ComInterfaceGenerator.Tests.IDerivedComInterface { - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - int global::ComInterfaceGenerator.Tests.IComInterface1.GetData() - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IComInterface1)); - int __retVal; - int __invokeRetVal; - { - __invokeRetVal = ((delegate* unmanaged )__vtable_native[3])(__this, &__retVal); - } - - // Unmarshal - Convert native data to managed data. - System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); - return __retVal; - } - - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - void global::ComInterfaceGenerator.Tests.IComInterface1.SetData(int n) - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IComInterface1)); - int __invokeRetVal; - { - __invokeRetVal = ((delegate* unmanaged )__vtable_native[4])(__this, n); - } - - // Unmarshal - Convert native data to managed data. - System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); - } - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] [System.Runtime.CompilerServices.SkipLocalsInitAttribute] void global::ComInterfaceGenerator.Tests.IDerivedComInterface.SetName(string name) @@ -111,6 +81,9 @@ // Unmarshal - Convert native data to managed data. System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); } + + int global::ComInterfaceGenerator.Tests.IComInterface1.GetData() => throw new System.Diagnostics.UnreachableException(); + void global::ComInterfaceGenerator.Tests.IComInterface1.SetData(int n) => throw new System.Diagnostics.UnreachableException(); } file unsafe partial interface InterfaceImplementation diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs index 654ef09490ecf..21bb6b28e54dc 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs @@ -66,8 +66,9 @@ public unsafe void CallBaseInterfaceMethod_EnsureQiCalledOnce() IDerivedComInterface iface = (IDerivedComInterface)obj; Assert.Equal(3, iface.GetData()); - iface.SetData(5); - Assert.Equal(5, iface.GetData()); + // https://github.com/dotnet/runtime/issues/85795 + //iface.SetData(5); + //Assert.Equal(5, iface.GetData()); Assert.Equal("myName", iface.GetName()); // https://github.com/dotnet/runtime/issues/85795 From 6d78a2dff541a0806c39ca4cac3a460987854e23 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Fri, 5 May 2023 16:01:19 -0500 Subject: [PATCH 15/31] Don't emit compiler generated files --- .../ComInterfaceGenerator.Tests.csproj | 5 - ...terfaceGenerator.Tests.DerivedComObject.cs | 30 --- ....GeneratedComInterfaceTests.DerivedImpl.cs | 39 ---- ...nerator.Tests.ManagedObjectExposedToCom.cs | 30 --- ...InterfaceGenerator.Tests.IComInterface1.cs | 125 ------------- ...aceGenerator.Tests.IDerivedComInterface.cs | 176 ------------------ ...SharedTypes.ComInterfaces.IGetAndSetInt.cs | 125 ------------- .../SharedTypes.ComInterfaces.IGetIntArray.cs | 110 ----------- .../ManagedToNativeStubs.g.cs | 106 ----------- .../NativeInterfaces.g.cs | 33 ---- .../NativeToManagedStubs.g.cs | 49 ----- .../PopulateVTable.g.cs | 21 --- .../LibraryImports.g.cs | 126 ------------- 13 files changed, 975 deletions(-) delete mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.DerivedComObject.cs delete mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.GeneratedComInterfaceTests.DerivedImpl.cs delete mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.ManagedObjectExposedToCom.cs delete mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IComInterface1.cs delete mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs delete mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetAndSetInt.cs delete mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs delete mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/ManagedToNativeStubs.g.cs delete mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeInterfaces.g.cs delete mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeToManagedStubs.g.cs delete mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/PopulateVTable.g.cs delete mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.LibraryImportGenerator/Microsoft.Interop.LibraryImportGenerator/LibraryImports.g.cs diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ComInterfaceGenerator.Tests.csproj b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ComInterfaceGenerator.Tests.csproj index a81409aafba1c..88114bc4d0650 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ComInterfaceGenerator.Tests.csproj +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ComInterfaceGenerator.Tests.csproj @@ -7,16 +7,11 @@ true false - - None - true - $(MSBuildThisFileDirectory)Generated true - diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.DerivedComObject.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.DerivedComObject.cs deleted file mode 100644 index a8a8351c5360d..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.DerivedComObject.cs +++ /dev/null @@ -1,30 +0,0 @@ -file sealed unsafe class ComClassInformation : System.Runtime.InteropServices.Marshalling.IComExposedClass -{ - private static volatile System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* s_vtables; - public static System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) - { - count = 1; - if (s_vtables == null) - { - System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* vtables = (System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry*)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComClassInformation), sizeof(System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry) * 1); - System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails details; - details = System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(SharedTypes.ComInterfaces.IGetAndSetInt).TypeHandle); - vtables[0] = new() - { - IID = details.Iid, - Vtable = (nint)details.ManagedVirtualMethodTable - }; - s_vtables = vtables; - } - - return s_vtables; - } -} - -namespace ComInterfaceGenerator.Tests -{ - [System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute] - partial class DerivedComObject - { - } -} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.GeneratedComInterfaceTests.DerivedImpl.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.GeneratedComInterfaceTests.DerivedImpl.cs deleted file mode 100644 index 02edb1b548f10..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.GeneratedComInterfaceTests.DerivedImpl.cs +++ /dev/null @@ -1,39 +0,0 @@ -file sealed unsafe class ComClassInformation : System.Runtime.InteropServices.Marshalling.IComExposedClass -{ - private static volatile System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* s_vtables; - public static System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) - { - count = 2; - if (s_vtables == null) - { - System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* vtables = (System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry*)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComClassInformation), sizeof(System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry) * 2); - System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails details; - details = System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(ComInterfaceGenerator.Tests.IDerivedComInterface).TypeHandle); - vtables[0] = new() - { - IID = details.Iid, - Vtable = (nint)details.ManagedVirtualMethodTable - }; - details = System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(ComInterfaceGenerator.Tests.IComInterface1).TypeHandle); - vtables[1] = new() - { - IID = details.Iid, - Vtable = (nint)details.ManagedVirtualMethodTable - }; - s_vtables = vtables; - } - - return s_vtables; - } -} - -namespace ComInterfaceGenerator.Tests -{ - public unsafe partial class GeneratedComInterfaceTests - { - [System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute] - partial class DerivedImpl - { - } - } -} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.ManagedObjectExposedToCom.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.ManagedObjectExposedToCom.cs deleted file mode 100644 index 84fe3a175d798..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComClassGenerator/ComInterfaceGenerator.Tests.ManagedObjectExposedToCom.cs +++ /dev/null @@ -1,30 +0,0 @@ -file sealed unsafe class ComClassInformation : System.Runtime.InteropServices.Marshalling.IComExposedClass -{ - private static volatile System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* s_vtables; - public static System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) - { - count = 1; - if (s_vtables == null) - { - System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* vtables = (System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry*)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComClassInformation), sizeof(System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry) * 1); - System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails details; - details = System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(SharedTypes.ComInterfaces.IGetAndSetInt).TypeHandle); - vtables[0] = new() - { - IID = details.Iid, - Vtable = (nint)details.ManagedVirtualMethodTable - }; - s_vtables = vtables; - } - - return s_vtables; - } -} - -namespace ComInterfaceGenerator.Tests -{ - [System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute] - partial class ManagedObjectExposedToCom - { - } -} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IComInterface1.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IComInterface1.cs deleted file mode 100644 index 869a49747e313..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IComInterface1.cs +++ /dev/null @@ -1,125 +0,0 @@ -file unsafe class InterfaceInformation : System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType -{ - public static System.Guid Iid { get; } = new(new System.ReadOnlySpan(new byte[] { 3, 153, 63, 44, 134, 181, 177, 70, 136, 27, 173, 252, 233, 175, 71, 177 })); - - private static void** _vtable; - public static void** ManagedVirtualMethodTable => _vtable != null ? _vtable : (_vtable = InterfaceImplementation.CreateManagedVirtualFunctionTable()); -} - -[System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] -file unsafe partial interface InterfaceImplementation : global::ComInterfaceGenerator.Tests.IComInterface1 -{ - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - int global::ComInterfaceGenerator.Tests.IComInterface1.GetData() - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IComInterface1)); - int __retVal; - int __invokeRetVal; - { - __invokeRetVal = ((delegate* unmanaged )__vtable_native[3])(__this, &__retVal); - } - - // Unmarshal - Convert native data to managed data. - System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); - return __retVal; - } - - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - void global::ComInterfaceGenerator.Tests.IComInterface1.SetData(int n) - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IComInterface1)); - int __invokeRetVal; - { - __invokeRetVal = ((delegate* unmanaged )__vtable_native[4])(__this, n); - } - - // Unmarshal - Convert native data to managed data. - System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); - } -} - -file unsafe partial interface InterfaceImplementation -{ - [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] - internal static int ABI_GetData(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int* __invokeRetValUnmanaged__param) - { - global::ComInterfaceGenerator.Tests.IComInterface1 @this = default; - ref int __invokeRetValUnmanaged = ref *__invokeRetValUnmanaged__param; - int __invokeRetVal = default; - int __retVal = default; - try - { - // Unmarshal - Convert native data to managed data. - __retVal = 0; // S_OK - @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); - __invokeRetVal = @this.GetData(); - // Marshal - Convert managed data to native data. - __invokeRetValUnmanaged = __invokeRetVal; - } - catch (System.Exception __exception) - { - __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); - } - - return __retVal; - } - - [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] - internal static int ABI_SetData(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int n) - { - global::ComInterfaceGenerator.Tests.IComInterface1 @this = default; - int __retVal = default; - try - { - // Unmarshal - Convert native data to managed data. - __retVal = 0; // S_OK - @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); - @this.SetData(n); - } - catch (System.Exception __exception) - { - __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); - } - - return __retVal; - } -} - -file unsafe partial interface InterfaceImplementation -{ - internal static void** CreateManagedVirtualFunctionTable() - { - void** vtable = (void**)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(global::ComInterfaceGenerator.Tests.IComInterface1), sizeof(void*) * 5); - { - nint v0, v1, v2; - System.Runtime.InteropServices.ComWrappers.GetIUnknownImpl(out v0, out v1, out v2); - vtable[0] = (void*)v0; - vtable[1] = (void*)v1; - vtable[2] = (void*)v2; - } - - { - vtable[3] = (void*)(delegate* unmanaged )&ABI_GetData; - vtable[4] = (void*)(delegate* unmanaged )&ABI_SetData; - } - - return vtable; - } -} - -namespace ComInterfaceGenerator.Tests -{ - [System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute] - public partial interface IComInterface1 - { - } -} - -namespace ComInterfaceGenerator.Tests -{ - public partial interface IComInterface1 - { - } -} \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs deleted file mode 100644 index fcbda10c8adc5..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/ComInterfaceGenerator.Tests.IDerivedComInterface.cs +++ /dev/null @@ -1,176 +0,0 @@ -file unsafe class InterfaceInformation : System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType -{ - public static System.Guid Iid { get; } = new(new System.ReadOnlySpan(new byte[] { 100, 179, 13, 127, 4, 60, 135, 68, 145, 147, 75, 176, 93, 199, 182, 84 })); - - private static void** _vtable; - public static void** ManagedVirtualMethodTable => _vtable != null ? _vtable : (_vtable = InterfaceImplementation.CreateManagedVirtualFunctionTable()); -} - -[System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] -file unsafe partial interface InterfaceImplementation : global::ComInterfaceGenerator.Tests.IDerivedComInterface -{ - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - void global::ComInterfaceGenerator.Tests.IDerivedComInterface.SetName(string name) - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IDerivedComInterface)); - int __invokeRetVal; - // Pin - Pin data in preparation for calling the P/Invoke. - fixed (void* __name_native = &global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.GetPinnableReference(name)) - { - __invokeRetVal = ((delegate* unmanaged )__vtable_native[5])(__this, (ushort*)__name_native); - } - - // Unmarshal - Convert native data to managed data. - System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); - } - - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - string global::ComInterfaceGenerator.Tests.IDerivedComInterface.GetName() - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IDerivedComInterface)); - string __retVal; - ushort* __retVal_native = default; - int __invokeRetVal; - try - { - { - __invokeRetVal = ((delegate* unmanaged )__vtable_native[6])(__this, &__retVal_native); - } - - // Unmarshal - Convert native data to managed data. - System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); - __retVal = global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.ConvertToManaged(__retVal_native); - } - finally - { - // Cleanup - Perform required cleanup. - global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.Free(__retVal_native); - } - - return __retVal; - } - - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - int global::ComInterfaceGenerator.Tests.IDerivedComInterface.GetData() - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IDerivedComInterface)); - int __retVal; - int __invokeRetVal; - { - __invokeRetVal = ((delegate* unmanaged )__vtable_native[3])(__this, &__retVal); - } - - // Unmarshal - Convert native data to managed data. - System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); - return __retVal; - } - - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - void global::ComInterfaceGenerator.Tests.IDerivedComInterface.SetData(int n) - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.IDerivedComInterface)); - int __invokeRetVal; - { - __invokeRetVal = ((delegate* unmanaged )__vtable_native[4])(__this, n); - } - - // Unmarshal - Convert native data to managed data. - System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); - } - - int global::ComInterfaceGenerator.Tests.IComInterface1.GetData() => throw new System.Diagnostics.UnreachableException(); - void global::ComInterfaceGenerator.Tests.IComInterface1.SetData(int n) => throw new System.Diagnostics.UnreachableException(); -} - -file unsafe partial interface InterfaceImplementation -{ - [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] - internal static int ABI_SetName(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, ushort* __name_native) - { - global::ComInterfaceGenerator.Tests.IDerivedComInterface @this = default; - string name = default; - int __retVal = default; - try - { - // Unmarshal - Convert native data to managed data. - __retVal = 0; // S_OK - name = global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.ConvertToManaged(__name_native); - @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); - @this.SetName(name); - } - catch (System.Exception __exception) - { - __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); - } - finally - { - // Cleanup - Perform required cleanup. - global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.Free(__name_native); - } - - return __retVal; - } - - [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] - internal static int ABI_GetName(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, ushort** __invokeRetValUnmanaged__param) - { - global::ComInterfaceGenerator.Tests.IDerivedComInterface @this = default; - ref ushort* __invokeRetValUnmanaged = ref *__invokeRetValUnmanaged__param; - string __invokeRetVal = default; - int __retVal = default; - try - { - // Unmarshal - Convert native data to managed data. - __retVal = 0; // S_OK - @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); - __invokeRetVal = @this.GetName(); - // Marshal - Convert managed data to native data. - __invokeRetValUnmanaged = (ushort*)global::System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller.ConvertToUnmanaged(__invokeRetVal); - } - catch (System.Exception __exception) - { - __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); - } - - return __retVal; - } -} - -file unsafe partial interface InterfaceImplementation -{ - internal static void** CreateManagedVirtualFunctionTable() - { - void** vtable = (void**)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(global::ComInterfaceGenerator.Tests.IDerivedComInterface), sizeof(void*) * 7); - { - System.Runtime.InteropServices.NativeMemory.Copy(System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(global::ComInterfaceGenerator.Tests.IComInterface1).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * 5)); - } - - { - vtable[5] = (void*)(delegate* unmanaged )&ABI_SetName; - vtable[6] = (void*)(delegate* unmanaged )&ABI_GetName; - } - - return vtable; - } -} - -namespace ComInterfaceGenerator.Tests -{ - [System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute] - public partial interface IDerivedComInterface - { - } -} - -namespace ComInterfaceGenerator.Tests -{ - public partial interface IDerivedComInterface - { - new int GetData() => ((global::ComInterfaceGenerator.Tests.IComInterface1)this).GetData(); - new void SetData(int n) => ((global::ComInterfaceGenerator.Tests.IComInterface1)this).SetData(n); - } -} \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetAndSetInt.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetAndSetInt.cs deleted file mode 100644 index 987bdab087a89..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetAndSetInt.cs +++ /dev/null @@ -1,125 +0,0 @@ -file unsafe class InterfaceInformation : System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType -{ - public static System.Guid Iid { get; } = new(new System.ReadOnlySpan(new byte[] { 3, 153, 63, 44, 134, 181, 177, 70, 136, 27, 173, 252, 233, 175, 71, 177 })); - - private static void** _vtable; - public static void** ManagedVirtualMethodTable => _vtable != null ? _vtable : (_vtable = InterfaceImplementation.CreateManagedVirtualFunctionTable()); -} - -[System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] -file unsafe partial interface InterfaceImplementation : global::SharedTypes.ComInterfaces.IGetAndSetInt -{ - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - int global::SharedTypes.ComInterfaces.IGetAndSetInt.GetInt() - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::SharedTypes.ComInterfaces.IGetAndSetInt)); - int __retVal; - int __invokeRetVal; - { - __invokeRetVal = ((delegate* unmanaged )__vtable_native[3])(__this, &__retVal); - } - - // Unmarshal - Convert native data to managed data. - System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); - return __retVal; - } - - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - void global::SharedTypes.ComInterfaces.IGetAndSetInt.SetInt(int x) - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::SharedTypes.ComInterfaces.IGetAndSetInt)); - int __invokeRetVal; - { - __invokeRetVal = ((delegate* unmanaged )__vtable_native[4])(__this, x); - } - - // Unmarshal - Convert native data to managed data. - System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); - } -} - -file unsafe partial interface InterfaceImplementation -{ - [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] - internal static int ABI_GetInt(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int* __invokeRetValUnmanaged__param) - { - global::SharedTypes.ComInterfaces.IGetAndSetInt @this = default; - ref int __invokeRetValUnmanaged = ref *__invokeRetValUnmanaged__param; - int __invokeRetVal = default; - int __retVal = default; - try - { - // Unmarshal - Convert native data to managed data. - __retVal = 0; // S_OK - @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); - __invokeRetVal = @this.GetInt(); - // Marshal - Convert managed data to native data. - __invokeRetValUnmanaged = __invokeRetVal; - } - catch (System.Exception __exception) - { - __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); - } - - return __retVal; - } - - [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] - internal static int ABI_SetInt(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int x) - { - global::SharedTypes.ComInterfaces.IGetAndSetInt @this = default; - int __retVal = default; - try - { - // Unmarshal - Convert native data to managed data. - __retVal = 0; // S_OK - @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); - @this.SetInt(x); - } - catch (System.Exception __exception) - { - __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); - } - - return __retVal; - } -} - -file unsafe partial interface InterfaceImplementation -{ - internal static void** CreateManagedVirtualFunctionTable() - { - void** vtable = (void**)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(global::SharedTypes.ComInterfaces.IGetAndSetInt), sizeof(void*) * 5); - { - nint v0, v1, v2; - System.Runtime.InteropServices.ComWrappers.GetIUnknownImpl(out v0, out v1, out v2); - vtable[0] = (void*)v0; - vtable[1] = (void*)v1; - vtable[2] = (void*)v2; - } - - { - vtable[3] = (void*)(delegate* unmanaged )&ABI_GetInt; - vtable[4] = (void*)(delegate* unmanaged )&ABI_SetInt; - } - - return vtable; - } -} - -namespace SharedTypes.ComInterfaces -{ - [System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute] - partial interface IGetAndSetInt - { - } -} - -namespace SharedTypes.ComInterfaces -{ - partial interface IGetAndSetInt - { - } -} \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs deleted file mode 100644 index 0dff938dd3f99..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.ComInterfaceGenerator/SharedTypes.ComInterfaces.IGetIntArray.cs +++ /dev/null @@ -1,110 +0,0 @@ -file unsafe class InterfaceInformation : System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType -{ - public static System.Guid Iid { get; } = new(new System.ReadOnlySpan(new byte[] { 10, 42, 128, 125, 10, 99, 142, 76, 162, 31, 119, 28, 201, 3, 31, 185 })); - - private static void** _vtable; - public static void** ManagedVirtualMethodTable => _vtable != null ? _vtable : (_vtable = InterfaceImplementation.CreateManagedVirtualFunctionTable()); -} - -[System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] -file unsafe partial interface InterfaceImplementation : global::SharedTypes.ComInterfaces.IGetIntArray -{ - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - int[] global::SharedTypes.ComInterfaces.IGetIntArray.GetInts() - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::SharedTypes.ComInterfaces.IGetIntArray)); - int[] __retVal; - int* __retVal_native = default; - int __invokeRetVal; - // Setup - Perform required setup. - int __retVal_native__numElements; - System.Runtime.CompilerServices.Unsafe.SkipInit(out __retVal_native__numElements); - try - { - { - __invokeRetVal = ((delegate* unmanaged )__vtable_native[3])(__this, &__retVal_native); - } - - // Unmarshal - Convert native data to managed data. - System.Runtime.InteropServices.Marshal.ThrowExceptionForHR(__invokeRetVal); - __retVal_native__numElements = 10; - __retVal = global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.AllocateContainerForManagedElements(__retVal_native, __retVal_native__numElements); - global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.GetUnmanagedValuesSource(__retVal_native, __retVal_native__numElements).CopyTo(global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.GetManagedValuesDestination(__retVal)); - } - finally - { - // Cleanup - Perform required cleanup. - global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.Free(__retVal_native); - } - - return __retVal; - } -} - -file unsafe partial interface InterfaceImplementation -{ - [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] - internal static int ABI_GetInts(System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch* __this_native, int** __invokeRetValUnmanaged__param) - { - global::SharedTypes.ComInterfaces.IGetIntArray @this = default; - ref int* __invokeRetValUnmanaged = ref *__invokeRetValUnmanaged__param; - int[] __invokeRetVal = default; - int __retVal = default; - // Setup - Perform required setup. - int __invokeRetValUnmanaged__numElements; - System.Runtime.CompilerServices.Unsafe.SkipInit(out __invokeRetValUnmanaged__numElements); - try - { - // Unmarshal - Convert native data to managed data. - __retVal = 0; // S_OK - @this = System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch.GetInstance(__this_native); - __invokeRetVal = @this.GetInts(); - // Marshal - Convert managed data to native data. - __invokeRetValUnmanaged = global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.AllocateContainerForUnmanagedElements(__invokeRetVal, out __invokeRetValUnmanaged__numElements); - global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.GetManagedValuesSource(__invokeRetVal).CopyTo(global::System.Runtime.InteropServices.Marshalling.ArrayMarshaller.GetUnmanagedValuesDestination(__invokeRetValUnmanaged, __invokeRetValUnmanaged__numElements)); - } - catch (System.Exception __exception) - { - __retVal = System.Runtime.InteropServices.Marshalling.ExceptionAsHResultMarshaller.ConvertToUnmanaged(__exception); - } - - return __retVal; - } -} - -file unsafe partial interface InterfaceImplementation -{ - internal static void** CreateManagedVirtualFunctionTable() - { - void** vtable = (void**)System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(global::SharedTypes.ComInterfaces.IGetIntArray), sizeof(void*) * 4); - { - nint v0, v1, v2; - System.Runtime.InteropServices.ComWrappers.GetIUnknownImpl(out v0, out v1, out v2); - vtable[0] = (void*)v0; - vtable[1] = (void*)v1; - vtable[2] = (void*)v2; - } - - { - vtable[3] = (void*)(delegate* unmanaged )&ABI_GetInts; - } - - return vtable; - } -} - -namespace SharedTypes.ComInterfaces -{ - [System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute] - partial interface IGetIntArray - { - } -} - -namespace SharedTypes.ComInterfaces -{ - partial interface IGetIntArray - { - } -} \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/ManagedToNativeStubs.g.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/ManagedToNativeStubs.g.cs deleted file mode 100644 index 7eb9069150637..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/ManagedToNativeStubs.g.cs +++ /dev/null @@ -1,106 +0,0 @@ -// -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class ImplicitThis - { - internal unsafe partial interface INativeObject - { - internal unsafe partial interface Native - { - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - int global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject.GetData() - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject)); - int __retVal; - { - __retVal = ((delegate* unmanaged )__vtable_native[0])(__this); - } - - return __retVal; - } - } - } - } - } -} -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class ImplicitThis - { - internal unsafe partial interface INativeObject - { - internal unsafe partial interface Native - { - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - void global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject.SetData(int x) - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject)); - { - ((delegate* unmanaged )__vtable_native[1])(__this, x); - } - } - } - } - } - } -} -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class NoImplicitThis - { - internal unsafe partial interface IStaticMethodTable - { - internal unsafe partial interface Native - { - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - int global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.IStaticMethodTable.Add(int x, int y) - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.IStaticMethodTable)); - int __retVal; - { - __retVal = ((delegate* unmanaged )__vtable_native[0])(x, y); - } - - return __retVal; - } - } - } - } - } -} -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class NoImplicitThis - { - internal unsafe partial interface IStaticMethodTable - { - internal unsafe partial interface Native - { - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.ComInterfaceGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - int global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.IStaticMethodTable.Multiply(int x, int y) - { - var(__this, __vtable_native) = ((System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey(typeof(global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.IStaticMethodTable)); - int __retVal; - { - __retVal = ((delegate* unmanaged )__vtable_native[1])(x, y); - } - - return __retVal; - } - } - } - } - } -} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeInterfaces.g.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeInterfaces.g.cs deleted file mode 100644 index 63e98564f3393..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeInterfaces.g.cs +++ /dev/null @@ -1,33 +0,0 @@ -// -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class ImplicitThis - { - internal unsafe partial interface INativeObject - { - [System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] - internal partial interface Native : INativeObject - { - } - } - } - } -} -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class NoImplicitThis - { - internal unsafe partial interface IStaticMethodTable - { - [System.Runtime.InteropServices.DynamicInterfaceCastableImplementationAttribute] - internal partial interface Native : IStaticMethodTable - { - } - } - } - } -} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeToManagedStubs.g.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeToManagedStubs.g.cs deleted file mode 100644 index 3fe35a8622fa7..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/NativeToManagedStubs.g.cs +++ /dev/null @@ -1,49 +0,0 @@ -// -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class ImplicitThis - { - internal unsafe partial interface INativeObject - { - internal unsafe partial interface Native - { - [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] - internal static int ABI_GetData(void* __this_native) - { - global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject @this; - int __retVal = default; - // Unmarshal - Convert native data to managed data. - @this = (global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject)System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapper.GetObjectForUnmanagedWrapper>(__this_native); - __retVal = @this.GetData(); - return __retVal; - } - } - } - } - } -} -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class ImplicitThis - { - internal unsafe partial interface INativeObject - { - internal unsafe partial interface Native - { - [System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute] - internal static void ABI_SetData(void* __this_native, int x) - { - global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject @this; - // Unmarshal - Convert native data to managed data. - @this = (global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.INativeObject)System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapper.GetObjectForUnmanagedWrapper>(__this_native); - @this.SetData(x); - } - } - } - } - } -} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/PopulateVTable.g.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/PopulateVTable.g.cs deleted file mode 100644 index 8429a3ea8c269..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.ComInterfaceGenerator/Microsoft.Interop.VtableIndexStubGenerator/PopulateVTable.g.cs +++ /dev/null @@ -1,21 +0,0 @@ -// -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class ImplicitThis - { - internal unsafe partial interface INativeObject - { - internal unsafe partial interface Native - { - internal static unsafe void PopulateUnmanagedVirtualMethodTable(void** vtable) - { - vtable[0] = (void*)(delegate* unmanaged )&ABI_GetData; - vtable[1] = (void*)(delegate* unmanaged )&ABI_SetData; - } - } - } - } - } -} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.LibraryImportGenerator/Microsoft.Interop.LibraryImportGenerator/LibraryImports.g.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.LibraryImportGenerator/Microsoft.Interop.LibraryImportGenerator/LibraryImports.g.cs deleted file mode 100644 index 8e1783fce02a9..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/Generated/Microsoft.Interop.LibraryImportGenerator/Microsoft.Interop.LibraryImportGenerator/LibraryImports.g.cs +++ /dev/null @@ -1,126 +0,0 @@ -// -namespace ComInterfaceGenerator.Tests -{ - unsafe partial class NativeExportsNE - { - [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "set_com_object_data", ExactSpelling = true)] - public static extern partial void SetComObjectData(void* obj, int data); - } -} -namespace ComInterfaceGenerator.Tests -{ - unsafe partial class NativeExportsNE - { - [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "get_com_object_data", ExactSpelling = true)] - public static extern partial int GetComObjectData(void* obj); - } -} -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "get_com_object", ExactSpelling = true)] - public static extern partial void* NewNativeObject(); - } -} -namespace ComInterfaceGenerator.Tests -{ - public unsafe partial class IGetAndSetIntTests - { - [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "new_get_and_set_int", ExactSpelling = true)] - public static extern partial void* NewNativeObject(); - } -} -namespace ComInterfaceGenerator.Tests -{ - public unsafe partial class IGetIntArrayTests - { - [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "new_get_and_set_int_array", ExactSpelling = true)] - public static extern partial void* NewNativeObject(); - } -} -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class ImplicitThis - { - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.LibraryImportGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - public static partial global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.NativeObject NewNativeObject() - { - global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.NativeObject __retVal; - void* __retVal_native; - { - __retVal_native = __PInvoke(); - } - - // Unmarshal - Convert native data to managed data. - __retVal = global::ComInterfaceGenerator.Tests.NativeExportsNE.ImplicitThis.NativeObjectMarshaller.ConvertToManaged(__retVal_native); - return __retVal; - // Local P/Invoke - [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "new_native_object", ExactSpelling = true)] - static extern unsafe void* __PInvoke(); - } - } - } -} -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class ImplicitThis - { - [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "delete_native_object", ExactSpelling = true)] - public static extern partial void DeleteNativeObject(void* obj); - } - } -} -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class ImplicitThis - { - [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "set_native_object_data", ExactSpelling = true)] - public static extern partial void SetNativeObjectData(void* obj, int data); - } - } -} -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class ImplicitThis - { - [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "get_native_object_data", ExactSpelling = true)] - public static extern partial int GetNativeObjectData(void* obj); - } - } -} -namespace ComInterfaceGenerator.Tests -{ - internal unsafe partial class NativeExportsNE - { - internal unsafe partial class NoImplicitThis - { - [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Interop.LibraryImportGenerator", "42.42.42.42")] - [System.Runtime.CompilerServices.SkipLocalsInitAttribute] - public static partial global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.StaticMethodTable GetStaticFunctionTable() - { - global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.StaticMethodTable __retVal; - void* __retVal_native; - { - __retVal_native = __PInvoke(); - } - - // Unmarshal - Convert native data to managed data. - __retVal = global::ComInterfaceGenerator.Tests.NativeExportsNE.NoImplicitThis.StaticMethodTableMarshaller.ConvertToManaged(__retVal_native); - return __retVal; - // Local P/Invoke - [System.Runtime.InteropServices.DllImportAttribute("Microsoft.Interop.Tests.NativeExportsNE", EntryPoint = "get_static_function_table", ExactSpelling = true)] - static extern unsafe void* __PInvoke(); - } - } - } -} From 230899868425c0122265d3bb905f8db98ad5c322 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Mon, 8 May 2023 12:44:25 -0500 Subject: [PATCH 16/31] Fix test and add additional doc comments --- .../ComInterfaceGenerator/AttributeInfo.cs | 3 +++ .../ComInterfaceGenerator/ComInterfaceInfo.cs | 2 +- .../ComInterfaceGenerator/ComMethodContext.cs | 1 + .../IncrementalValuesProviderExtensions.cs | 25 ------------------- .../GeneratedComInterfaceTests.cs | 2 +- 5 files changed, 6 insertions(+), 27 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs index 997bfb467d501..8443100fa4682 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs @@ -7,6 +7,9 @@ namespace Microsoft.Interop { + /// + /// Provides the info necessary for copying an attribute from user code to generated code. + /// internal sealed record AttributeInfo(ManagedTypeInfo Type, SequenceEqualImmutableArray Arguments) { internal static AttributeInfo From(AttributeData attribute) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs index a15c3b59974b9..cfa65e9a2b9ae 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs @@ -13,7 +13,7 @@ namespace Microsoft.Interop public sealed partial class ComInterfaceGenerator { /// - /// Information about a Com interface, but not it's methods + /// Information about a Com interface, but not it's methods. /// private sealed record ComInterfaceInfo( ManagedTypeInfo Type, diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs index ddf357f768ea1..df2bbf69f442e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs @@ -13,6 +13,7 @@ public sealed partial class ComInterfaceGenerator { /// /// Represents a method, its declaring interface, and its index in the interface's vtable. + /// This type contains all information necessary to generate the corresponding methods in the ComInterfaceGenerator /// private sealed record ComMethodContext(ComInterfaceContext DeclaringInterface, ComMethodInfo MethodInfo, int Index, IncrementalMethodStubGenerationContext GenerationContext) { diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs index e578c24f00f57..8586f4635c21c 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs @@ -32,31 +32,6 @@ internal static class IncrementalValuesProviderExtensions }); } - public static IncrementalValuesProvider<(TGrouper, SequenceEqualImmutableArray)> GroupTuples(this IncrementalValuesProvider<(TGrouper Key, TGroupee Value)> values) - { - return values.Collect().SelectMany(static (values, ct) => - { - var valueMap = new Dictionary>(); - foreach (var value in values) - { - if (!valueMap.TryGetValue(value.Key, out var list)) - { - list = new(); - } - list.Add(value.Value); - valueMap[value.Key] = list; - } - - var builder = ImmutableArray.CreateBuilder<(TGrouper, SequenceEqualImmutableArray)>(valueMap.Count); - foreach (var kvp in valueMap) - { - builder.Add((kvp.Key, kvp.Value.ToSequenceEqualImmutableArray())); - } - - return builder.MoveToImmutable(); - }); - } - /// /// Format the syntax nodes in the given provider such that we will not re-normalize if the input nodes have not changed. /// diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs index 21bb6b28e54dc..134e488b830a7 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs @@ -70,8 +70,8 @@ public unsafe void CallBaseInterfaceMethod_EnsureQiCalledOnce() //iface.SetData(5); //Assert.Equal(5, iface.GetData()); - Assert.Equal("myName", iface.GetName()); // https://github.com/dotnet/runtime/issues/85795 + //Assert.Equal("myName", iface.GetName()); //iface.SetName("updated"); //Assert.Equal("updated", iface.GetName()); From 8883b2830db4f18791eb729c221f88963ab2a436 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Tue, 9 May 2023 11:06:06 -0700 Subject: [PATCH 17/31] Format and comment update --- .../gen/ComInterfaceGenerator/ComInterfaceInfo.cs | 2 +- .../gen/ComInterfaceGenerator/ComMethodInfo.cs | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs index cfa65e9a2b9ae..d50fd914466fe 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs @@ -121,7 +121,7 @@ private static bool TryGetGuid(INamedTypeSymbol interfaceSymbol, InterfaceDeclar { guid = new Guid(guidstr); } - // Diagnostic will be raised if guid is null + // Catch any issues with the Guid string -- Diagnostic will be raised if guid is null catch (FormatException) { } catch (OverflowException) { } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs index 1da459cdbba48..3fe7e8c518897 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs @@ -14,7 +14,6 @@ namespace Microsoft.Interop { public sealed partial class ComInterfaceGenerator { - /// /// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax. /// From 130faae6f68a1a885cedde7fae44c807d84c99f7 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Tue, 9 May 2023 11:08:34 -0700 Subject: [PATCH 18/31] Use the right diagnostic property --- .../gen/ComInterfaceGenerator/ComInterfaceInfo.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs index d50fd914466fe..3246597179d3e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs @@ -82,7 +82,7 @@ private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, InterfaceD if (baseComIface is not null) { diagnostic = Diagnostic.Create( - GeneratorDiagnostics.MultipleComInterfaceBaseTypesAttribute, + GeneratorDiagnostics.MultipleComInterfaceBaseTypes, syntax.Identifier.GetLocation(), comIface.ToDisplayString()); return false; From 033d817cd951be9e9661577ab5e70c2e89d55d45 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Tue, 9 May 2023 11:20:27 -0700 Subject: [PATCH 19/31] Remove extra usings --- .../gen/ComInterfaceGenerator/ComInterfaceGenerator.cs | 1 - .../ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs | 1 - .../IncrementalValuesProviderExtensions.cs | 1 - .../GeneratedComInterfaceTests.cs | 5 ----- 4 files changed, 8 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index a1a91df016461..9bd738545fdfe 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -11,7 +11,6 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -using static Microsoft.Interop.CollectionExtensions; namespace Microsoft.Interop { diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs index 27961dec1b7fb..e192fcd02d8fe 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs @@ -6,7 +6,6 @@ using System.Linq; using System.Text; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; namespace Microsoft.Interop { diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs index 8586f4635c21c..fb0dd80ca7f1a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; -using System.Linq; using System.Text; using Microsoft.CodeAnalysis; diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs index 134e488b830a7..ba077b2e33a45 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs @@ -2,16 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections; -using System.Diagnostics; using System.Linq; using System.Reflection; -using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; -using System.Runtime.Remoting; using Xunit; -using Xunit.Sdk; namespace ComInterfaceGenerator.Tests; From ae2fe395cbb5dbdf01338a50e7ee58040fcf3420 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Wed, 10 May 2023 10:46:14 -0700 Subject: [PATCH 20/31] Remove commented code, update doc comment --- .../gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs | 3 --- .../gen/ComInterfaceGenerator/ComInterfaceContext.cs | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs index 622f98136796b..20ec29b98c19f 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs @@ -26,9 +26,6 @@ private sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interfac /// public IEnumerable InheritedMethods => Methods.Where(m => m.DeclaringInterface != Interface); - //internal static ComInterfaceAndMethodsContext From((ComInterfaceContext, SequenceEqualImmutableArray) data, CancellationToken _) - // => new ComInterfaceAndMethodsContext(data.Item1, data.Item2); - public static IEnumerable CalculateAllMethods(ValueEqualityImmutableDictionary> ifaceToDeclaredMethodsMap, StubEnvironment environment, CancellationToken ct) { Dictionary Methods, ImmutableArray ShadowingMethods)> allMethodsCache = new(); diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs index 35b7f39de3dd6..4565e1f176400 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs @@ -12,7 +12,7 @@ public sealed partial class ComInterfaceGenerator private sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceContext? Base) { /// - /// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext. Does not guarantee the ordering of the output. + /// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext. /// public static IEnumerable GetContexts(ImmutableArray data, CancellationToken _) { From 299b3f39a140392622485a1b2af88359af922c51 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Wed, 10 May 2023 11:56:14 -0700 Subject: [PATCH 21/31] Make methodInfo creation follow the same pattern as InterfaceInfo --- .../ComInterfaceGenerator.cs | 28 ++--- .../ComInterfaceGenerator/ComMethodContext.cs | 2 - .../ComInterfaceGenerator/ComMethodInfo.cs | 109 +++++++++++------- 3 files changed, 77 insertions(+), 62 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 9bd738545fdfe..a10cc775454d2 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -58,7 +58,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.RegisterDiagnostics(interfaceSymbolAndDiagnostic.Select((data, ct) => data.Diagnostic)); var interfaceSymbolsWithoutDiagnostics = interfaceSymbolAndDiagnostic - .Where(data => data.Diagnostic is null); + .Where(data => data.Diagnostic is null) + .Select((data, ct) => (data.InterfaceInfo, data.Symbol)); var interfacesToGenerate = interfaceSymbolsWithoutDiagnostics .Select((data, ct) => data.InterfaceInfo!); @@ -66,21 +67,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var interfaceContexts = interfacesToGenerate.Collect().SelectMany(ComInterfaceContext.GetContexts); // Get the information we need about methods themselves - var interfaceMethods = interfaceSymbolsWithoutDiagnostics.Select(static (pair, ct) => - { - var symbol = pair.Symbol; - var info = pair.InterfaceInfo; - List comMethods = new(); - foreach (var member in symbol.GetMembers()) - { - if (ComMethodInfo.IsComMethod(info, member, out ComMethodInfo? methodInfo)) - { - comMethods.Add(methodInfo); - } - } - return comMethods.ToSequenceEqualImmutableArray(); - }); - context.RegisterDiagnostics(interfaceMethods.SelectMany(static (methodList, ct) => methodList.Select(m => m.Diagnostic))); + var interfaceMethodsAndDiagnostics = interfaceSymbolsWithoutDiagnostics.Select(ComMethodInfo.GetMethodsFromInterface); + context.RegisterDiagnostics(interfaceMethodsAndDiagnostics.SelectMany(static (methodList, ct) => methodList.Select(m => m.Diagnostic))); + var interfaceMethods = interfaceMethodsAndDiagnostics + .Select(static (methods, ct) => + methods + .Where(pair => pair.Diagnostic is null) + .Select(pair => pair.ComMethod) + .ToSequenceEqualImmutableArray()); // Generate a map from Com interface to the methods it declares var interfaceToDeclaredMethodsMap = interfaceContexts @@ -455,7 +449,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(Co comInterfaceAndMethods.DeclaredMethods .Select(m => m.NativeToManagedStub) .OfType() - .Select(context => context.Stub.Node) )); + .Select(context => context.Stub.Node))); } private static readonly TypeSyntax VoidStarStarSyntax = PointerType(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))); diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs index df2bbf69f442e..3154dcf447338 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs @@ -30,8 +30,6 @@ public GeneratedMethodContextBase ManagedToUnmanagedStub } } - public Diagnostic? Diagnostic => MethodInfo.Diagnostic; - public GeneratedMethodContextBase NativeToManagedStub { get diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs index 3fe7e8c518897..73938f2471fcb 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Threading; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -18,12 +19,26 @@ public sealed partial class ComInterfaceGenerator /// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax. /// private sealed record ComMethodInfo( + // Symbols cannot be compared across incremental runs or they'll keep alive the compilation. We should remove a reference to the symbol [property: Obsolete] IMethodSymbol Symbol, MethodDeclarationSyntax Syntax, string MethodName, SequenceEqualImmutableArray Parameters, Diagnostic? Diagnostic) { + public static SequenceEqualImmutableArray<(ComMethodInfo? ComMethod, Diagnostic? Diagnostic)> GetMethodsFromInterface((ComInterfaceInfo ifaceContext, INamedTypeSymbol ifaceSymbol) data, CancellationToken _) + { + List<(ComMethodInfo, Diagnostic?)> methods = new(); + foreach (var member in data.ifaceSymbol.GetMembers()) + { + if (IsComMethodCandidate(member)) + { + methods.Add(From(data.ifaceContext, (IMethodSymbol)member)); + } + } + return methods.ToSequenceEqualImmutableArray(); + + } private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax comMethodDeclaringSyntax, IMethodSymbol method) { // Verify the method has no generic types or defined implementation @@ -44,61 +59,69 @@ private sealed record ComMethodInfo( return null; } - public static bool IsComMethod(ComInterfaceInfo ifaceContext, ISymbol member, [NotNullWhen(true)] out ComMethodInfo? comMethodInfo) + private static bool IsComMethodCandidate(ISymbol member) { - Diagnostic diag; - comMethodInfo = null; + return member.Kind == SymbolKind.Method && !member.IsStatic; + } + + public static (ComMethodInfo?, Diagnostic?) From(ComInterfaceInfo ifaceContext, IMethodSymbol member) + { + Debug.Assert(IsComMethodCandidate(member)); + + // We only support methods that are defined in the same partial interface definition as the + // [GeneratedComInterface] attribute. + // This restriction not only makes finding the syntax for a given method cheaper, + // but it also enables us to ensure that we can determine vtable method order easily. Location interfaceLocation = ifaceContext.Declaration.GetLocation(); - if (member.Kind == SymbolKind.Method && !member.IsStatic) + Location? methodLocationInAttributedInterfaceDeclaration = null; + foreach (var methodLocation in member.Locations) { - // We only support methods that are defined in the same partial interface definition as the - // [GeneratedComInterface] attribute. - // This restriction not only makes finding the syntax for a given method cheaper, - // but it also enables us to ensure that we can determine vtable method order easily. - Location? methodLocationInAttributedInterfaceDeclaration = null; - foreach (var methodLocation in member.Locations) - { - if (methodLocation.SourceTree == interfaceLocation.SourceTree - && interfaceLocation.SourceSpan.Contains(methodLocation.SourceSpan)) - { - methodLocationInAttributedInterfaceDeclaration = methodLocation; - break; - } - } - - // TODO: this should cause a diagnostic - if (methodLocationInAttributedInterfaceDeclaration is null) + if (methodLocation.SourceTree == interfaceLocation.SourceTree + && interfaceLocation.SourceSpan.Contains(methodLocation.SourceSpan)) { - return false; + methodLocationInAttributedInterfaceDeclaration = methodLocation; + break; } + } + // TODO: this should cause a diagnostic + if (methodLocationInAttributedInterfaceDeclaration is null) + { + throw new NotImplementedException($"Could not find location for method {member.ToDisplayString()} within the attributed declaration"); + } - // Find the matching declaration syntax - MethodDeclarationSyntax? comMethodDeclaringSyntax = null; - foreach (var declaringSyntaxReference in member.DeclaringSyntaxReferences) - { - var declaringSyntax = declaringSyntaxReference.GetSyntax(); - Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration)); - if (declaringSyntax.GetLocation().SourceSpan.Contains(methodLocationInAttributedInterfaceDeclaration.SourceSpan)) - { - comMethodDeclaringSyntax = (MethodDeclarationSyntax)declaringSyntax; - break; - } - } - if (comMethodDeclaringSyntax is null) - throw new NotImplementedException("Found a method that was declared in the attributed interface declaration, but couldn't find the syntax for it."); - List parameters = new(); - foreach (var parameter in ((IMethodSymbol)member).Parameters) + // Find the matching declaration syntax + MethodDeclarationSyntax? comMethodDeclaringSyntax = null; + foreach (var declaringSyntaxReference in member.DeclaringSyntaxReferences) + { + var declaringSyntax = declaringSyntaxReference.GetSyntax(); + Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration)); + if (declaringSyntax.GetLocation().SourceSpan.Contains(methodLocationInAttributedInterfaceDeclaration.SourceSpan)) { - parameters.Add(ParameterInfo.From(parameter)); + comMethodDeclaringSyntax = (MethodDeclarationSyntax)declaringSyntax; + break; } + } + if (comMethodDeclaringSyntax is null) + { + throw new NotImplementedException("Found a method that was declared in the attributed interface declaration, but couldn't find the syntax for it."); + } - diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, (IMethodSymbol)member); + var diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, member); + if (diag is not null) + { + return (null, diag); + } - comMethodInfo = new((IMethodSymbol)member, comMethodDeclaringSyntax, member.Name, parameters.ToSequenceEqualImmutableArray(), diag); - return true; + List parameters = new(); + foreach (var parameter in member.Parameters) + { + parameters.Add(ParameterInfo.From(parameter)); } - return false; + + + var comMethodInfo = new ComMethodInfo(member, comMethodDeclaringSyntax, member.Name, parameters.ToSequenceEqualImmutableArray(), diag); + return (comMethodInfo, null); } } } From 1730faa0ce003457200cb981577f7ed654ca9885 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Wed, 10 May 2023 11:56:14 -0700 Subject: [PATCH 22/31] Make methodInfo creation follow the same pattern as InterfaceInfo --- .../ComInterfaceGenerator.cs | 28 ++--- .../ComInterfaceGenerator/ComMethodContext.cs | 2 - .../ComInterfaceGenerator/ComMethodInfo.cs | 109 +++++++++++------- 3 files changed, 77 insertions(+), 62 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 9bd738545fdfe..a10cc775454d2 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -58,7 +58,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.RegisterDiagnostics(interfaceSymbolAndDiagnostic.Select((data, ct) => data.Diagnostic)); var interfaceSymbolsWithoutDiagnostics = interfaceSymbolAndDiagnostic - .Where(data => data.Diagnostic is null); + .Where(data => data.Diagnostic is null) + .Select((data, ct) => (data.InterfaceInfo, data.Symbol)); var interfacesToGenerate = interfaceSymbolsWithoutDiagnostics .Select((data, ct) => data.InterfaceInfo!); @@ -66,21 +67,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var interfaceContexts = interfacesToGenerate.Collect().SelectMany(ComInterfaceContext.GetContexts); // Get the information we need about methods themselves - var interfaceMethods = interfaceSymbolsWithoutDiagnostics.Select(static (pair, ct) => - { - var symbol = pair.Symbol; - var info = pair.InterfaceInfo; - List comMethods = new(); - foreach (var member in symbol.GetMembers()) - { - if (ComMethodInfo.IsComMethod(info, member, out ComMethodInfo? methodInfo)) - { - comMethods.Add(methodInfo); - } - } - return comMethods.ToSequenceEqualImmutableArray(); - }); - context.RegisterDiagnostics(interfaceMethods.SelectMany(static (methodList, ct) => methodList.Select(m => m.Diagnostic))); + var interfaceMethodsAndDiagnostics = interfaceSymbolsWithoutDiagnostics.Select(ComMethodInfo.GetMethodsFromInterface); + context.RegisterDiagnostics(interfaceMethodsAndDiagnostics.SelectMany(static (methodList, ct) => methodList.Select(m => m.Diagnostic))); + var interfaceMethods = interfaceMethodsAndDiagnostics + .Select(static (methods, ct) => + methods + .Where(pair => pair.Diagnostic is null) + .Select(pair => pair.ComMethod) + .ToSequenceEqualImmutableArray()); // Generate a map from Com interface to the methods it declares var interfaceToDeclaredMethodsMap = interfaceContexts @@ -455,7 +449,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(Co comInterfaceAndMethods.DeclaredMethods .Select(m => m.NativeToManagedStub) .OfType() - .Select(context => context.Stub.Node) )); + .Select(context => context.Stub.Node))); } private static readonly TypeSyntax VoidStarStarSyntax = PointerType(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))); diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs index df2bbf69f442e..3154dcf447338 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs @@ -30,8 +30,6 @@ public GeneratedMethodContextBase ManagedToUnmanagedStub } } - public Diagnostic? Diagnostic => MethodInfo.Diagnostic; - public GeneratedMethodContextBase NativeToManagedStub { get diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs index 3fe7e8c518897..0e749962dd199 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Threading; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -18,12 +19,26 @@ public sealed partial class ComInterfaceGenerator /// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax. /// private sealed record ComMethodInfo( + // Symbols cannot be compared across incremental runs or they'll keep alive the compilation. We should remove a reference to the symbol [property: Obsolete] IMethodSymbol Symbol, MethodDeclarationSyntax Syntax, string MethodName, SequenceEqualImmutableArray Parameters, Diagnostic? Diagnostic) { + public static SequenceEqualImmutableArray<(ComMethodInfo? ComMethod, Diagnostic? Diagnostic)> GetMethodsFromInterface((ComInterfaceInfo ifaceContext, INamedTypeSymbol ifaceSymbol) data, CancellationToken _) + { + List<(ComMethodInfo, Diagnostic?)> methods = new(); + foreach (var member in data.ifaceSymbol.GetMembers()) + { + if (IsComMethodCandidate(member)) + { + methods.Add(From(data.ifaceContext, (IMethodSymbol)member)); + } + } + return methods.ToSequenceEqualImmutableArray(); + } + private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax comMethodDeclaringSyntax, IMethodSymbol method) { // Verify the method has no generic types or defined implementation @@ -44,61 +59,69 @@ private sealed record ComMethodInfo( return null; } - public static bool IsComMethod(ComInterfaceInfo ifaceContext, ISymbol member, [NotNullWhen(true)] out ComMethodInfo? comMethodInfo) + private static bool IsComMethodCandidate(ISymbol member) { - Diagnostic diag; - comMethodInfo = null; + return member.Kind == SymbolKind.Method && !member.IsStatic; + } + + private static (ComMethodInfo?, Diagnostic?) From(ComInterfaceInfo ifaceContext, IMethodSymbol member) + { + Debug.Assert(IsComMethodCandidate(member)); + + // We only support methods that are defined in the same partial interface definition as the + // [GeneratedComInterface] attribute. + // This restriction not only makes finding the syntax for a given method cheaper, + // but it also enables us to ensure that we can determine vtable method order easily. Location interfaceLocation = ifaceContext.Declaration.GetLocation(); - if (member.Kind == SymbolKind.Method && !member.IsStatic) + Location? methodLocationInAttributedInterfaceDeclaration = null; + foreach (var methodLocation in member.Locations) { - // We only support methods that are defined in the same partial interface definition as the - // [GeneratedComInterface] attribute. - // This restriction not only makes finding the syntax for a given method cheaper, - // but it also enables us to ensure that we can determine vtable method order easily. - Location? methodLocationInAttributedInterfaceDeclaration = null; - foreach (var methodLocation in member.Locations) - { - if (methodLocation.SourceTree == interfaceLocation.SourceTree - && interfaceLocation.SourceSpan.Contains(methodLocation.SourceSpan)) - { - methodLocationInAttributedInterfaceDeclaration = methodLocation; - break; - } - } - - // TODO: this should cause a diagnostic - if (methodLocationInAttributedInterfaceDeclaration is null) + if (methodLocation.SourceTree == interfaceLocation.SourceTree + && interfaceLocation.SourceSpan.Contains(methodLocation.SourceSpan)) { - return false; + methodLocationInAttributedInterfaceDeclaration = methodLocation; + break; } + } + // TODO: this should cause a diagnostic + if (methodLocationInAttributedInterfaceDeclaration is null) + { + throw new NotImplementedException($"Could not find location for method {member.ToDisplayString()} within the attributed declaration"); + } - // Find the matching declaration syntax - MethodDeclarationSyntax? comMethodDeclaringSyntax = null; - foreach (var declaringSyntaxReference in member.DeclaringSyntaxReferences) - { - var declaringSyntax = declaringSyntaxReference.GetSyntax(); - Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration)); - if (declaringSyntax.GetLocation().SourceSpan.Contains(methodLocationInAttributedInterfaceDeclaration.SourceSpan)) - { - comMethodDeclaringSyntax = (MethodDeclarationSyntax)declaringSyntax; - break; - } - } - if (comMethodDeclaringSyntax is null) - throw new NotImplementedException("Found a method that was declared in the attributed interface declaration, but couldn't find the syntax for it."); - List parameters = new(); - foreach (var parameter in ((IMethodSymbol)member).Parameters) + // Find the matching declaration syntax + MethodDeclarationSyntax? comMethodDeclaringSyntax = null; + foreach (var declaringSyntaxReference in member.DeclaringSyntaxReferences) + { + var declaringSyntax = declaringSyntaxReference.GetSyntax(); + Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration)); + if (declaringSyntax.GetLocation().SourceSpan.Contains(methodLocationInAttributedInterfaceDeclaration.SourceSpan)) { - parameters.Add(ParameterInfo.From(parameter)); + comMethodDeclaringSyntax = (MethodDeclarationSyntax)declaringSyntax; + break; } + } + if (comMethodDeclaringSyntax is null) + { + throw new NotImplementedException("Found a method that was declared in the attributed interface declaration, but couldn't find the syntax for it."); + } - diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, (IMethodSymbol)member); + var diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, member); + if (diag is not null) + { + return (null, diag); + } - comMethodInfo = new((IMethodSymbol)member, comMethodDeclaringSyntax, member.Name, parameters.ToSequenceEqualImmutableArray(), diag); - return true; + List parameters = new(); + foreach (var parameter in member.Parameters) + { + parameters.Add(ParameterInfo.From(parameter)); } - return false; + + + var comMethodInfo = new ComMethodInfo(member, comMethodDeclaringSyntax, member.Name, parameters.ToSequenceEqualImmutableArray(), diag); + return (comMethodInfo, null); } } } From f18c21c899c3031dafa6a667ac2527c4e336dd4a Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Wed, 10 May 2023 14:25:50 -0700 Subject: [PATCH 23/31] Remove Diagnostic from ComMethodInfo --- .../gen/ComInterfaceGenerator/ComMethodInfo.cs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs index 0e749962dd199..70f74a97c8ba6 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs @@ -23,8 +23,7 @@ private sealed record ComMethodInfo( [property: Obsolete] IMethodSymbol Symbol, MethodDeclarationSyntax Syntax, string MethodName, - SequenceEqualImmutableArray Parameters, - Diagnostic? Diagnostic) + SequenceEqualImmutableArray Parameters) { public static SequenceEqualImmutableArray<(ComMethodInfo? ComMethod, Diagnostic? Diagnostic)> GetMethodsFromInterface((ComInterfaceInfo ifaceContext, INamedTypeSymbol ifaceSymbol) data, CancellationToken _) { @@ -120,7 +119,7 @@ private static (ComMethodInfo?, Diagnostic?) From(ComInterfaceInfo ifaceContext, } - var comMethodInfo = new ComMethodInfo(member, comMethodDeclaringSyntax, member.Name, parameters.ToSequenceEqualImmutableArray(), diag); + var comMethodInfo = new ComMethodInfo(member, comMethodDeclaringSyntax, member.Name, parameters.ToSequenceEqualImmutableArray()); return (comMethodInfo, null); } } From 38e6267eef89206ed7a0178f18ce01e91f0247ea Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Thu, 11 May 2023 21:55:13 -0700 Subject: [PATCH 24/31] Create flat list of methods, group with interfaceContext so we don't need 'marker interfaces' --- .../ComInterfaceAndMethodsContext.cs | 66 +-------- .../ComInterfaceGenerator.cs | 136 +++++++----------- .../ComInterfaceGenerator/ComInterfaceInfo.cs | 20 +-- .../ComInterfaceGenerator/ComMethodContext.cs | 78 +++++++++- .../ComInterfaceGenerator/ComMethodInfo.cs | 38 ++--- .../HashCode.cs | 84 ++++++++++- .../Microsoft.Interop.SourceGeneration.csproj | 1 + .../SequenceEqualImmutableArray.cs | 2 +- .../ValueEqualityImmutableDictionary.cs | 12 +- .../IComInterface1.cs | 17 +++ 10 files changed, 269 insertions(+), 185 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs index 20ec29b98c19f..82f1e2a598706 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; -using System.Collections.Immutable; using System.Linq; -using System.Threading; using Microsoft.CodeAnalysis; namespace Microsoft.Interop @@ -14,71 +12,21 @@ public sealed partial class ComInterfaceGenerator /// /// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces). /// - private sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interface, SequenceEqualImmutableArray Methods, SequenceEqualImmutableArray ShadowingMethods) + private sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interface, SequenceEqualImmutableArray Methods) { + // Change Calc all methods to return an ordered list of all the methods and the data in comInterfaceandMethodsContext + // Have a step that runs CalculateMethodStub on each of them. + // Call GroupMethodsByInterfaceForGeneration + /// /// COM methods that are declared on the attributed interface declaration. /// - public IEnumerable DeclaredMethods => Methods.Where((m => m.DeclaringInterface == Interface)); + public IEnumerable DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod); /// /// COM methods that are declared on an interface the interface inherits from. /// - public IEnumerable InheritedMethods => Methods.Where(m => m.DeclaringInterface != Interface); - - public static IEnumerable CalculateAllMethods(ValueEqualityImmutableDictionary> ifaceToDeclaredMethodsMap, StubEnvironment environment, CancellationToken ct) - { - Dictionary Methods, ImmutableArray ShadowingMethods)> allMethodsCache = new(); - - foreach (var kvp in ifaceToDeclaredMethodsMap) - { - AddMethods(kvp.Key, kvp.Value); - } - - return allMethodsCache.Select(kvp => new ComInterfaceAndMethodsContext(kvp.Key, kvp.Value.Methods.ToSequenceEqual(), kvp.Value.ShadowingMethods.ToSequenceEqualImmutableArray())); - - (ImmutableArray Methods, ImmutableArray ShadowingMethods) AddMethods(ComInterfaceContext iface, IEnumerable declaredMethods) - { - if (allMethodsCache.TryGetValue(iface, out var cachedValue)) - { - return cachedValue; - } - - int startingIndex = 3; - List methods = new(); - // If we have a base interface, we should add the inherited methods to our list in vtable order - if (iface.Base is not null) - { - var baseComIface = iface.Base; - ImmutableArray baseMethods; - if (!allMethodsCache.TryGetValue(baseComIface, out var pair)) - { - baseMethods = AddMethods(baseComIface, ifaceToDeclaredMethodsMap[baseComIface]).Methods; - } - else - { - baseMethods = pair.Methods; - } - methods.AddRange(baseMethods); - } - var shadowingMethods = methods.Select(method => - { - var info = method.MethodInfo; - var ctx = CalculateStubInformation(info.Syntax, info.Symbol, startingIndex, environment, iface.Info.Type, ct); - return new ComMethodContext(iface, info, startingIndex++, ctx); - }).ToImmutableArray(); - // Then we append the declared methods in vtable order - foreach (var method in declaredMethods) - { - var ctx = CalculateStubInformation(method.Syntax, method.Symbol, startingIndex, environment, iface.Info.Type, ct); - methods.Add(new ComMethodContext(iface, method, startingIndex++, ctx)); - } - // Cache so we don't recalculate if many interfaces inherit from the same one - var finalPair = (methods.ToImmutableArray(), shadowingMethods); - allMethodsCache[iface] = finalPair; - return finalPair; - } - } + public IEnumerable ShadowingMethods => Methods.Where(m => m.IsInheritedMethod); } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index a10cc775454d2..9f09ecda42e01 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Collections.Specialized; using System.IO; using System.Linq; using System.Threading; @@ -66,100 +67,75 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var interfaceContexts = interfacesToGenerate.Collect().SelectMany(ComInterfaceContext.GetContexts); - // Get the information we need about methods themselves - var interfaceMethodsAndDiagnostics = interfaceSymbolsWithoutDiagnostics.Select(ComMethodInfo.GetMethodsFromInterface); - context.RegisterDiagnostics(interfaceMethodsAndDiagnostics.SelectMany(static (methodList, ct) => methodList.Select(m => m.Diagnostic))); - var interfaceMethods = interfaceMethodsAndDiagnostics + var interfaceMethodsAndSymbolsAndDiagnostics = interfaceSymbolsWithoutDiagnostics.Select(ComMethodInfo.GetMethodsFromInterface); + context.RegisterDiagnostics(interfaceMethodsAndSymbolsAndDiagnostics.SelectMany(static (methodList, ct) => methodList.Select(m => m.Diagnostic))); + var interfaceMethodSymbols = interfaceMethodsAndSymbolsAndDiagnostics .Select(static (methods, ct) => methods .Where(pair => pair.Diagnostic is null) - .Select(pair => pair.ComMethod) + .Select(pair => (pair.Symbol, pair.ComMethod)) .ToSequenceEqualImmutableArray()); - // Generate a map from Com interface to the methods it declares - var interfaceToDeclaredMethodsMap = interfaceContexts - .Zip(interfaceMethods) + var methodInfoGroups = interfaceMethodSymbols + .Select(static (methods, ct) => + methods.Select(pair => pair.ComMethod).ToSequenceEqualImmutableArray()); + + var methodInfoToSymbolMap = interfaceMethodSymbols + .SelectMany((data, ct) => data) .Collect() - .Select(static (data, ct) => - { - return data.ToValueEqualityImmutableDictionary<(ComInterfaceContext, SequenceEqualImmutableArray), ComInterfaceContext, SequenceEqualImmutableArray>( - static pair => pair.Item1, - static pair => pair.Item2); - }); + .Select((data, ct) => data.ToDictionary(static x => x.ComMethod, static x => x.Symbol)); - // Combine info about base methods and declared methods to get a list of interfaces, and all the methods they need to worry about (including both declared and inherited methods) - var interfaceAndMethodsContexts = interfaceToDeclaredMethodsMap - .Combine(interfaceContexts.Collect()) - .Combine(context.CreateStubEnvironmentProvider()) + // Determine which methods each interface declares and inherits + var interfaceMethodNoContexts = interfaceContexts + .Zip(methodInfoGroups) + .Collect() .SelectMany(static (data, ct) => { - var ((ifaceToMethodsMap, ifaceToBaseMap), env) = data; - return ComInterfaceAndMethodsContext.CalculateAllMethods(ifaceToMethodsMap, env, ct); + return ComMethodContext.CalculateAllMethods(data, ct); }); - // Separate the methods which have methods from those that don't - var interfacesWithMethodsAndItsMethods = interfaceAndMethodsContexts - .Where(static data => data.Methods.Length != 0); - - // Separate out the interface for generation that doesn't depend on the methods - var interfacesWithMethods = interfacesWithMethodsAndItsMethods - .Select(static (data, ct) => data.Interface); - + var interfaceMethodContexts = interfaceMethodNoContexts + .Combine(methodInfoToSymbolMap).Combine(context.CreateStubEnvironmentProvider()).Select((param, ct) => { - // Marker interfaces are COM interfaces that don't have any methods. - // The lack of methods breaks the mechanism we use later to stitch back together interface-level data - // and method-level data, but that's okay because marker interfaces are much simpler. - // We'll handle them seperately because they are so simple. - var markerInterfaces = interfaceAndMethodsContexts - .Where(static data => !data.DeclaredMethods.Any()) - .Select(static (data, ct) => data.Interface); - - var markerInterfaceIUnknownDerived = markerInterfaces - .Select(static (data, ct) => data.Info) - .Select(GenerateIUnknownDerivedAttributeApplication) - .WithComparer(SyntaxEquivalentComparer.Instance) - .SelectNormalized(); - - context.RegisterSourceOutput(markerInterfaces.Zip(markerInterfaceIUnknownDerived), (context, data) => - { - var (interfaceContext, iUnknownDerivedAttributeApplication) = data; - context.AddSource( - interfaceContext.Info.Type.FullTypeName.Replace("global::", ""), - GenerateMarkerInterfaceSource(interfaceContext.Info) + iUnknownDerivedAttributeApplication); - }); - } + var ((data, symbolMap), env) = param; + return new ComMethodContext(data.Method.DeclaringInterface, data.TypeKeyOwner, data.Method.MethodInfo, data.Method.Index, CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.TypeKeyOwner.Info.Type, ct)); + }).WithTrackingName(StepNames.CalculateStubInformation); + + var interfaceAndMethodsContexts = interfaceMethodContexts.Collect() + .Combine(interfaceContexts.Collect()) + .SelectMany((data, ct) => GroupComContextsForInterfaceGeneration(data.Left, data.Right, ct)); // Generate the code for the managed-to-unmanaged stubs and the diagnostics from code-generation. - context.RegisterDiagnostics(interfacesWithMethodsAndItsMethods + context.RegisterDiagnostics(interfaceAndMethodsContexts .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.ManagedToUnmanagedStub.Diagnostics))); - var managedToNativeInterfaceImplementations = interfacesWithMethodsAndItsMethods + var managedToNativeInterfaceImplementations = interfaceAndMethodsContexts .Select(GenerateImplementationInterface) .WithTrackingName(StepNames.GenerateManagedToNativeInterfaceImplementation) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); // Generate the code for the unmanaged-to-managed stubs and the diagnostics from code-generation. - context.RegisterDiagnostics(interfacesWithMethodsAndItsMethods + context.RegisterDiagnostics(interfaceAndMethodsContexts .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.NativeToManagedStub.Diagnostics))); - var nativeToManagedVtableMethods = interfacesWithMethodsAndItsMethods + var nativeToManagedVtableMethods = interfaceAndMethodsContexts .Select(GenerateImplementationVTableMethods) .WithTrackingName(StepNames.GenerateNativeToManagedVTableMethods) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); // Generate the native interface metadata for each [GeneratedComInterface]-attributed interface. - var nativeInterfaceInformation = interfacesWithMethods + var nativeInterfaceInformation = interfaceContexts .Select(static (data, ct) => data.Info) .Select(GenerateInterfaceInformation) .WithTrackingName(StepNames.GenerateInterfaceInformation) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); - var shadowingMethods = interfacesWithMethodsAndItsMethods + var shadowingMethods = interfaceAndMethodsContexts .Select((data, ct) => { var context = data.Interface.Info; - var methods = data.InheritedMethods.Select(m => (MemberDeclarationSyntax)m.GenerateShadow()); + var methods = data.ShadowingMethods.Select(m => (MemberDeclarationSyntax)m.GenerateShadow()); var typeDecl = TypeDeclaration(context.ContainingSyntax.TypeKind, context.ContainingSyntax.Identifier) .WithModifiers(context.ContainingSyntax.Modifiers) .WithTypeParameterList(context.ContainingSyntax.TypeParameters) @@ -170,20 +146,20 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Generate a method named CreateManagedVirtualFunctionTable on the native interface implementation // that allocates and fills in the memory for the vtable. - var nativeToManagedVtables = interfacesWithMethodsAndItsMethods + var nativeToManagedVtables = interfaceAndMethodsContexts .Select(GenerateImplementationVTable) .WithTrackingName(StepNames.GenerateNativeToManagedVTable) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); - var iUnknownDerivedAttributeApplication = interfacesWithMethods + var iUnknownDerivedAttributeApplication = interfaceContexts .Select(static (data, ct) => data.Info) .Select(GenerateIUnknownDerivedAttributeApplication) .WithTrackingName(StepNames.GenerateIUnknownDerivedAttribute) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); - var filesToGenerate = interfacesWithMethods + var filesToGenerate = interfaceContexts .Zip(nativeInterfaceInformation) .Zip(managedToNativeInterfaceImplementations) .Zip(nativeToManagedVtableMethods) @@ -374,47 +350,39 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M ComInterfaceDispatchMarshallingInfo.Instance); } - private static ImmutableArray> GroupContextsForInterfaceGeneration(ImmutableArray contexts) + private static ImmutableArray GroupComContextsForInterfaceGeneration(ImmutableArray methods, ImmutableArray interfaces, CancellationToken ct) { + ct.ThrowIfCancellationRequested(); // We can end up with an empty set of contexts here as the compiler will call a SelectMany // after a Collect with no input entries - if (contexts.IsEmpty) + if (interfaces.IsEmpty) { - return ImmutableArray>.Empty; + return ImmutableArray.Empty; } - ImmutableArray>.Builder allGroupsBuilder = ImmutableArray.CreateBuilder>(); - // Due to how the source generator driver processes the input item tables and our limitation that methods on COM interfaces can only be defined in a single partial definition of the type, - // we can guarantee that the method contexts are ordered as follows: + // we can guarantee that, if the interface contexts are in order of I1, I2, I3, I4..., then then method contexts are ordered as follows: // - I1.M1 // - I1.M2 // - I1.M3 // - I2.M1 // - I2.M2 // - I2.M3 - // - I3.M1 + // - I4.M1 (I3 had no methods) // - etc... // This enable us to group our contexts by their containing syntax rather simply. - ManagedTypeInfo? lastSeenDefiningType = null; - ImmutableArray.Builder groupBuilder = ImmutableArray.CreateBuilder(); - foreach (var context in contexts) + var contextList = ImmutableArray.CreateBuilder(); + int methodIndex = 0; + foreach(var iface in interfaces) { - if (lastSeenDefiningType is null || lastSeenDefiningType == context.OriginalDefiningType) - { - groupBuilder.Add(context); - } - else + var methodList = ImmutableArray.CreateBuilder(); + while (methodIndex < methods.Length && methods[methodIndex].TypeKeyOwner == iface) { - allGroupsBuilder.Add(new(groupBuilder.ToImmutable())); - groupBuilder.Clear(); - groupBuilder.Add(context); + methodList.Add(methods[methodIndex++]); } - lastSeenDefiningType = context.OriginalDefiningType; + contextList.Add(new(iface, methodList.ToImmutable().ToSequenceEqual())); } - - allGroupsBuilder.Add(new(groupBuilder.ToImmutable())); - return allGroupsBuilder.ToImmutable(); + return contextList.ToImmutable(); } private static readonly InterfaceDeclarationSyntax ImplementationInterfaceTemplate = InterfaceDeclaration("InterfaceImplementation") @@ -428,7 +396,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInt .Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node .WithExplicitInterfaceSpecifier( ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName)))); - var inheritedStubs = interfaceGroup.InheritedMethods.Select(m => m.GenerateUnreachableExceptionStub()); + var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.GenerateUnreachableExceptionStub()); return ImplementationInterfaceTemplate .AddBaseListTypes(SimpleBaseType(definingType.Syntax)) .WithMembers( @@ -615,7 +583,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf ParenthesizedExpression( BinaryExpression(SyntaxKind.MultiplyExpression, SizeOfExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))), - LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.InheritedMethods.Count() + 3)))))) + LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.ShadowingMethods.Count() + 3)))))) }))))); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs index 3246597179d3e..0338dc5c81dd3 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs @@ -7,13 +7,14 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Roslyn.Utilities; namespace Microsoft.Interop { public sealed partial class ComInterfaceGenerator { /// - /// Information about a Com interface, but not it's methods. + /// Information about a Com interface, but not its methods. /// private sealed record ComInterfaceInfo( ManagedTypeInfo Type, @@ -112,19 +113,12 @@ private static bool TryGetGuid(INamedTypeSymbol interfaceSymbol, InterfaceDeclar interfaceTypeAttr = attr; } - if (guidAttr is not null) + if (guidAttr is not null + && guidAttr.ConstructorArguments.Length == 1 + && guidAttr.ConstructorArguments[0].Value is string guidStr + && Guid.TryParse(guidStr, out var result)) { - string? guidstr = guidAttr.ConstructorArguments.SingleOrDefault().Value as string; - if (guidstr is not null) - { - try - { - guid = new Guid(guidstr); - } - // Catch any issues with the Guid string -- Diagnostic will be raised if guid is null - catch (FormatException) { } - catch (OverflowException) { } - } + guid = result; } // Assume interfaceType is IUnknown for now diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs index 3154dcf447338..e3556f5a6c41e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs @@ -1,7 +1,10 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; +using System.Collections.Immutable; using System.Linq; +using System.Threading; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -15,15 +18,25 @@ public sealed partial class ComInterfaceGenerator /// Represents a method, its declaring interface, and its index in the interface's vtable. /// This type contains all information necessary to generate the corresponding methods in the ComInterfaceGenerator /// - private sealed record ComMethodContext(ComInterfaceContext DeclaringInterface, ComMethodInfo MethodInfo, int Index, IncrementalMethodStubGenerationContext GenerationContext) + private sealed record ComMethodContext( + ComInterfaceContext DeclaringInterface, + // TypeKeyOwner is also the interface that the code is being generated for. + ComInterfaceContext TypeKeyOwner, + ComMethodInfo MethodInfo, + int Index, + IncrementalMethodStubGenerationContext GenerationContext) { + public bool IsInheritedMethod => DeclaringInterface != TypeKeyOwner; + + public sealed record Builder(ComInterfaceContext DeclaringInterface, ComMethodInfo MethodInfo, int Index); + public GeneratedMethodContextBase ManagedToUnmanagedStub { get { if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)) { - return (GeneratedMethodContextBase)new SkippedStubContext(DeclaringInterface.Info.Type); + return new SkippedStubContext(DeclaringInterface.Info.Type); } var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext); return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); @@ -36,7 +49,7 @@ public GeneratedMethodContextBase NativeToManagedStub { if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)) { - return (GeneratedMethodContextBase)new SkippedStubContext(GenerationContext.OriginalDefiningType); + return new SkippedStubContext(GenerationContext.OriginalDefiningType); } var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext); return new GeneratedStubCodeContext(GenerationContext.OriginalDefiningType, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); @@ -79,6 +92,65 @@ public MethodDeclarationSyntax GenerateShadow() SeparatedList(MethodInfo.Parameters.Select(p => Argument(IdentifierName(p.Name)))))))); } + + /// + /// Returns a flat list of and it's type key owner that represents all declared methods, and inherited methods. + /// Guarantees the output will be sorted by order of interface input order, then by vtable order. + /// + public static List<(ComInterfaceContext TypeKeyOwner, Builder Method)> CalculateAllMethods(IEnumerable<(ComInterfaceContext, SequenceEqualImmutableArray)> ifaceAndDeclaredMethods, CancellationToken _) + { + // opt : change this to only take in a hierarchy of interfaces. we calc that before and select ober that + var ifaceToDeclaredMethodsMap = ifaceAndDeclaredMethods.ToDictionary(static pair => pair.Item1, static pair => pair.Item2); + // Track insertion order + var allMethodsCache = new Dictionary>(); + + List<(ComInterfaceContext TypeKeyOwner, Builder Method)> accumulator = new(); + foreach (var kvp in ifaceAndDeclaredMethods) + { + var methods = AddMethods(kvp.Item1, kvp.Item2); + foreach (var method in methods) + { + accumulator.Add((kvp.Item1, method)); + } + } + return accumulator; + + ImmutableArray AddMethods(ComInterfaceContext iface, IEnumerable declaredMethods) + { + if (allMethodsCache.TryGetValue(iface, out var cachedValue)) + { + return cachedValue; + } + + int startingIndex = 3; + List methods = new(); + // If we have a base interface, we should add the inherited methods to our list in vtable order + if (iface.Base is not null) + { + var baseComIface = iface.Base; + ImmutableArray baseMethods; + if (!allMethodsCache.TryGetValue(baseComIface, out var pair)) + { + baseMethods = AddMethods(baseComIface, ifaceToDeclaredMethodsMap[baseComIface]); + } + else + { + baseMethods = pair; + } + methods.AddRange(baseMethods); + startingIndex += baseMethods.Length; + } + // Then we append the declared methods in vtable order + foreach (var method in declaredMethods) + { + methods.Add(new Builder(iface, method, startingIndex++)); + } + // Cache so we don't recalculate if many interfaces inherit from the same one + var imm = methods.ToImmutableArray(); + allMethodsCache[iface] = imm; + return imm; + } + } } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs index 70f74a97c8ba6..d13ccbd6929cd 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; @@ -19,23 +20,24 @@ public sealed partial class ComInterfaceGenerator /// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax. /// private sealed record ComMethodInfo( - // Symbols cannot be compared across incremental runs or they'll keep alive the compilation. We should remove a reference to the symbol - [property: Obsolete] IMethodSymbol Symbol, MethodDeclarationSyntax Syntax, string MethodName, SequenceEqualImmutableArray Parameters) { - public static SequenceEqualImmutableArray<(ComMethodInfo? ComMethod, Diagnostic? Diagnostic)> GetMethodsFromInterface((ComInterfaceInfo ifaceContext, INamedTypeSymbol ifaceSymbol) data, CancellationToken _) + /// + /// Returns a list of tuples of ComMethodInfo, IMethodSymbol, and Diagnostic. If ComMethodInfo is null, Diagnostic will not be null, and vice versa. + /// + public static SequenceEqualImmutableArray<(ComMethodInfo? ComMethod, IMethodSymbol Symbol, Diagnostic? Diagnostic)> GetMethodsFromInterface((ComInterfaceInfo ifaceContext, INamedTypeSymbol ifaceSymbol) data, CancellationToken ct) { - List<(ComMethodInfo, Diagnostic?)> methods = new(); + var methods = ImmutableArray.CreateBuilder<(ComMethodInfo, IMethodSymbol, Diagnostic?)>(); foreach (var member in data.ifaceSymbol.GetMembers()) { if (IsComMethodCandidate(member)) { - methods.Add(From(data.ifaceContext, (IMethodSymbol)member)); + methods.Add(CalculateMethodInfo(data.ifaceContext, (IMethodSymbol)member, ct)); } } - return methods.ToSequenceEqualImmutableArray(); + return methods.ToImmutable().ToSequenceEqual(); } private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax comMethodDeclaringSyntax, IMethodSymbol method) @@ -63,9 +65,10 @@ private static bool IsComMethodCandidate(ISymbol member) return member.Kind == SymbolKind.Method && !member.IsStatic; } - private static (ComMethodInfo?, Diagnostic?) From(ComInterfaceInfo ifaceContext, IMethodSymbol member) + private static (ComMethodInfo?, IMethodSymbol, Diagnostic?) CalculateMethodInfo(ComInterfaceInfo ifaceContext, IMethodSymbol method, CancellationToken ct) { - Debug.Assert(IsComMethodCandidate(member)); + ct.ThrowIfCancellationRequested(); + Debug.Assert(IsComMethodCandidate(method)); // We only support methods that are defined in the same partial interface definition as the // [GeneratedComInterface] attribute. @@ -73,7 +76,7 @@ private static (ComMethodInfo?, Diagnostic?) From(ComInterfaceInfo ifaceContext, // but it also enables us to ensure that we can determine vtable method order easily. Location interfaceLocation = ifaceContext.Declaration.GetLocation(); Location? methodLocationInAttributedInterfaceDeclaration = null; - foreach (var methodLocation in member.Locations) + foreach (var methodLocation in method.Locations) { if (methodLocation.SourceTree == interfaceLocation.SourceTree && interfaceLocation.SourceSpan.Contains(methodLocation.SourceSpan)) @@ -85,15 +88,15 @@ private static (ComMethodInfo?, Diagnostic?) From(ComInterfaceInfo ifaceContext, // TODO: this should cause a diagnostic if (methodLocationInAttributedInterfaceDeclaration is null) { - throw new NotImplementedException($"Could not find location for method {member.ToDisplayString()} within the attributed declaration"); + throw new NotImplementedException($"Could not find location for method {method.ToDisplayString()} within the attributed declaration"); } // Find the matching declaration syntax MethodDeclarationSyntax? comMethodDeclaringSyntax = null; - foreach (var declaringSyntaxReference in member.DeclaringSyntaxReferences) + foreach (var declaringSyntaxReference in method.DeclaringSyntaxReferences) { - var declaringSyntax = declaringSyntaxReference.GetSyntax(); + var declaringSyntax = declaringSyntaxReference.GetSyntax(ct); Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration)); if (declaringSyntax.GetLocation().SourceSpan.Contains(methodLocationInAttributedInterfaceDeclaration.SourceSpan)) { @@ -106,21 +109,20 @@ private static (ComMethodInfo?, Diagnostic?) From(ComInterfaceInfo ifaceContext, throw new NotImplementedException("Found a method that was declared in the attributed interface declaration, but couldn't find the syntax for it."); } - var diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, member); + var diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, method); if (diag is not null) { - return (null, diag); + return (null, method, diag); } List parameters = new(); - foreach (var parameter in member.Parameters) + foreach (var parameter in method.Parameters) { parameters.Add(ParameterInfo.From(parameter)); } - - var comMethodInfo = new ComMethodInfo(member, comMethodDeclaringSyntax, member.Name, parameters.ToSequenceEqualImmutableArray()); - return (comMethodInfo, null); + var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name, parameters.ToSequenceEqualImmutableArray()); + return (comMethodInfo, method, null); } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs index 5834bb92f6ab3..adb441ab899d0 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs @@ -4,17 +4,91 @@ using System; using System.Collections.Generic; using System.Text; +using Roslyn.Utilities; namespace Microsoft.Interop { - public static class HashCode + /// + /// Exposes the hashing utilities from Roslyn + /// + public class HashCode { - public static int Combine(params object[] values) + public static int Combine(T1 t1, T2 t2) { - int hash = 31; - foreach (object value in values) + int hashCode1 = t1 != null ? t1.GetHashCode() : 0; + int hashCode2 = t2 != null ? t2.GetHashCode() : 0; + return Hash.Combine(hashCode1, hashCode2); + } + + public static int Combine(T1 t1, T2 t2, T3 t3) + { + int hashCode1 = t1 != null ? t1.GetHashCode() : 0; + int hashCode2 = t2 != null ? t2.GetHashCode() : 0; + int hashCode3 = t3 != null ? t3.GetHashCode() : 0; + return Hash.Combine(Hash.Combine(hashCode1, hashCode2), hashCode3); + } + + public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4) + { + int hashCode1 = t1 != null ? t1.GetHashCode() : 0; + int hashCode2 = t2 != null ? t2.GetHashCode() : 0; + int hashCode3 = t3 != null ? t3.GetHashCode() : 0; + int hashCode4 = t4 != null ? t4.GetHashCode() : 0; + return Hash.Combine(Hash.Combine(Hash.Combine(hashCode1, hashCode2), hashCode3), hashCode4); + } + + public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) + { + int hashCode1 = t1 != null ? t1.GetHashCode() : 0; + int hashCode2 = t2 != null ? t2.GetHashCode() : 0; + int hashCode3 = t3 != null ? t3.GetHashCode() : 0; + int hashCode4 = t4 != null ? t4.GetHashCode() : 0; + int hashCode5 = t5 != null ? t5.GetHashCode() : 0; + return Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(hashCode1, hashCode2), hashCode3), hashCode4), hashCode5); + } + + public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) + { + int hashCode1 = t1 != null ? t1.GetHashCode() : 0; + int hashCode2 = t2 != null ? t2.GetHashCode() : 0; + int hashCode3 = t3 != null ? t3.GetHashCode() : 0; + int hashCode4 = t4 != null ? t4.GetHashCode() : 0; + int hashCode5 = t5 != null ? t5.GetHashCode() : 0; + int hashCode6 = t6 != null ? t6.GetHashCode() : 0; + return Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(hashCode1, hashCode2), hashCode3), hashCode4), hashCode5), hashCode6); + } + + public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) + { + int hashCode1 = t1 != null ? t1.GetHashCode() : 0; + int hashCode2 = t2 != null ? t2.GetHashCode() : 0; + int hashCode3 = t3 != null ? t3.GetHashCode() : 0; + int hashCode4 = t4 != null ? t4.GetHashCode() : 0; + int hashCode5 = t5 != null ? t5.GetHashCode() : 0; + int hashCode6 = t6 != null ? t6.GetHashCode() : 0; + int hashCode7 = t7 != null ? t7.GetHashCode() : 0; + return Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(hashCode1, hashCode2), hashCode3), hashCode4), hashCode5), hashCode6), hashCode7); + } + + public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) + { + int hashCode1 = t1 != null ? t1.GetHashCode() : 0; + int hashCode2 = t2 != null ? t2.GetHashCode() : 0; + int hashCode3 = t3 != null ? t3.GetHashCode() : 0; + int hashCode4 = t4 != null ? t4.GetHashCode() : 0; + int hashCode5 = t5 != null ? t5.GetHashCode() : 0; + int hashCode6 = t6 != null ? t6.GetHashCode() : 0; + int hashCode7 = t7 != null ? t7.GetHashCode() : 0; + int hashCode8 = t8 != null ? t8.GetHashCode() : 0; + return Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(hashCode1, hashCode2), hashCode3), hashCode4), hashCode5), hashCode6), hashCode7), hashCode8); + } + + public static int SequentialValuesHash(IEnumerable values) + { + int hash = 0; + foreach (var value in values) { - hash = hash * 29 + value.GetHashCode(); + hash = Hash.Combine(hash, value.GetHashCode()); } return hash; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Microsoft.Interop.SourceGeneration.csproj b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Microsoft.Interop.SourceGeneration.csproj index a85ea9e94678f..825fb6c60b8e3 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Microsoft.Interop.SourceGeneration.csproj +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Microsoft.Interop.SourceGeneration.csproj @@ -16,6 +16,7 @@ + diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs index c2f6c887990a1..e1833b3119524 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs @@ -26,7 +26,7 @@ public SequenceEqualImmutableArray(ImmutableArray array) public int Length => Array.Length; - public override int GetHashCode() => throw new NotSupportedException(); + public override int GetHashCode() => HashCode.SequentialValuesHash(Array); public bool Equals(SequenceEqualImmutableArray other) { diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs index 5ead994f6db93..e957df5a57ac8 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs @@ -5,7 +5,9 @@ using System.Collections; using System.Collections.Generic; using System.Collections.Immutable; -using System.Text; +using System.Linq; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Roslyn.Utilities; namespace Microsoft.Interop { @@ -30,7 +32,13 @@ public bool Equals(ValueEqualityImmutableDictionary other) public override int GetHashCode() { - return HashCode.Combine(Map.Values); + int hash = 17; + foreach(var value in Map) + { + hash = Hash.Combine(hash, value.Key.GetHashCode()); + hash = Hash.Combine(hash, value.Value.GetHashCode()); + } + return hash; } public U this[T key] { get => ((IDictionary)Map)[key]; set => ((IDictionary)Map)[key] = value; } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs index bb464c0147c95..efe0237d83b45 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs @@ -15,4 +15,21 @@ public partial interface IComInterface1 int GetData(); void SetData(int n); } + + [GeneratedComInterface] + partial interface I + { + void Method(); + void Method2(); + } + [GeneratedComInterface] + partial interface Empty + { + } + [GeneratedComInterface] + partial interface J + { + void Method(); + void Method2(); + } } From 7c1d275cfa230523ccefb58a037b3f0bd312c70e Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Sun, 14 May 2023 20:39:55 -0700 Subject: [PATCH 25/31] Clean up, rename variables, and add comments --- .../ComInterfaceContext.cs | 6 +- .../ComInterfaceGenerator.cs | 69 ++++++----- .../ComInterfaceGenerator/ComInterfaceInfo.cs | 25 ++-- .../ComInterfaceGenerator/ComMethodContext.cs | 41 +++++-- .../ComInterfaceGenerator/ComMethodInfo.cs | 1 - .../HashCode.cs | 112 +++++++++++------- .../SequenceEqualImmutableArray.cs | 2 + .../ValueEqualityImmutableDictionary.cs | 79 ------------ .../GeneratedComInterfaceTests.cs | 30 +++-- .../IComInterface1.cs | 35 ------ .../ComInterfaceGeneratorOutputShape.cs | 8 ++ .../ComInterfaces}/IDerivedComInterface.cs | 8 +- .../SharedTypes/ComInterfaces/IEmpty.cs | 16 +++ .../ComInterfaces/IGetAndSetInt.cs | 6 +- .../SharedTypes/ComInterfaces/IGetIntArray.cs | 6 +- 15 files changed, 203 insertions(+), 241 deletions(-) delete mode 100644 src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs delete mode 100644 src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs rename src/libraries/System.Runtime.InteropServices/tests/{ComInterfaceGenerator.Tests => TestAssets/SharedTypes/ComInterfaces}/IDerivedComInterface.cs (69%) create mode 100644 src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IEmpty.cs diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs index 4565e1f176400..5b570c1f0455a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs @@ -14,9 +14,10 @@ private sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceCon /// /// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext. /// - public static IEnumerable GetContexts(ImmutableArray data, CancellationToken _) + public static ImmutableArray GetContexts(ImmutableArray data, CancellationToken _) { Dictionary symbolToInterfaceInfoMap = new(); + var accumulator = ImmutableArray.CreateBuilder(data.Length); foreach (var iface in data) { symbolToInterfaceInfoMap.Add(iface.ThisInterfaceKey, iface); @@ -25,8 +26,9 @@ public static IEnumerable GetContexts(ImmutableArray modelData is not null); - var interfaceSymbolAndDiagnostic = attributedInterfaces.Select(static (data, ct) => + var interfaceSymbolAndDiagnostics = attributedInterfaces.Select(static (data, ct) => { var (info, diagnostic) = ComInterfaceInfo.From(data.Symbol, data.Syntax); return (InterfaceInfo: info, Diagnostic: diagnostic, Symbol: data.Symbol); }); - context.RegisterDiagnostics(interfaceSymbolAndDiagnostic.Select((data, ct) => data.Diagnostic)); + context.RegisterDiagnostics(interfaceSymbolAndDiagnostics.Select((data, ct) => data.Diagnostic)); - var interfaceSymbolsWithoutDiagnostics = interfaceSymbolAndDiagnostic + var interfaceSymbolsWithoutDiagnostics = interfaceSymbolAndDiagnostics .Where(data => data.Diagnostic is null) - .Select((data, ct) => (data.InterfaceInfo, data.Symbol)); - - var interfacesToGenerate = interfaceSymbolsWithoutDiagnostics - .Select((data, ct) => data.InterfaceInfo!); + .Select((data, ct) => + (data.InterfaceInfo, data.Symbol)); - var interfaceContexts = interfacesToGenerate.Collect().SelectMany(ComInterfaceContext.GetContexts); + var interfaceContexts = interfaceSymbolsWithoutDiagnostics + .Select((data, ct) => data.InterfaceInfo!) + .Collect() + .SelectMany(ComInterfaceContext.GetContexts); - var interfaceMethodsAndSymbolsAndDiagnostics = interfaceSymbolsWithoutDiagnostics.Select(ComMethodInfo.GetMethodsFromInterface); - context.RegisterDiagnostics(interfaceMethodsAndSymbolsAndDiagnostics.SelectMany(static (methodList, ct) => methodList.Select(m => m.Diagnostic))); - var interfaceMethodSymbols = interfaceMethodsAndSymbolsAndDiagnostics + var comMethodsAndSymbolsAndDiagnostics = interfaceSymbolsWithoutDiagnostics.Select(ComMethodInfo.GetMethodsFromInterface); + context.RegisterDiagnostics(comMethodsAndSymbolsAndDiagnostics.SelectMany(static (methodList, ct) => methodList.Select(m => m.Diagnostic))); + var methodInfoAndSymbolGroupedByInterface = comMethodsAndSymbolsAndDiagnostics .Select(static (methods, ct) => methods .Where(pair => pair.Diagnostic is null) .Select(pair => (pair.Symbol, pair.ComMethod)) .ToSequenceEqualImmutableArray()); - var methodInfoGroups = interfaceMethodSymbols + var methodInfosGroupedByInterface = methodInfoAndSymbolGroupedByInterface .Select(static (methods, ct) => methods.Select(pair => pair.ComMethod).ToSequenceEqualImmutableArray()); - - var methodInfoToSymbolMap = interfaceMethodSymbols - .SelectMany((data, ct) => data) - .Collect() - .Select((data, ct) => data.ToDictionary(static x => x.ComMethod, static x => x.Symbol)); - - // Determine which methods each interface declares and inherits - var interfaceMethodNoContexts = interfaceContexts - .Zip(methodInfoGroups) + // Create list of methods (inherited and declared) and their owning interface + var comMethodContextBuilders = interfaceContexts + .Zip(methodInfosGroupedByInterface) .Collect() .SelectMany(static (data, ct) => { return ComMethodContext.CalculateAllMethods(data, ct); }); - var interfaceMethodContexts = interfaceMethodNoContexts - .Combine(methodInfoToSymbolMap).Combine(context.CreateStubEnvironmentProvider()).Select((param, ct) => - { - var ((data, symbolMap), env) = param; - return new ComMethodContext(data.Method.DeclaringInterface, data.TypeKeyOwner, data.Method.MethodInfo, data.Method.Index, CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.TypeKeyOwner.Info.Type, ct)); - }).WithTrackingName(StepNames.CalculateStubInformation); - - var interfaceAndMethodsContexts = interfaceMethodContexts.Collect() + // A dictionary isn't incremental, but it will have symbols, so it will never be incremental anyway. + var methodInfoToSymbolMap = methodInfoAndSymbolGroupedByInterface + .SelectMany((data, ct) => data) + .Collect() + .Select((data, ct) => data.ToDictionary(static x => x.ComMethod, static x => x.Symbol)); + var comMethodContexts = comMethodContextBuilders + .Combine(methodInfoToSymbolMap) + .Combine(context.CreateStubEnvironmentProvider()) + .Select((param, ct) => + { + var ((data, symbolMap), env) = param; + return new ComMethodContext( + data.Method.OriginalDeclaringInterface, + data.TypeKeyOwner, + data.Method.MethodInfo, + data.Method.Index, + CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.TypeKeyOwner.Info.Type, ct)); + }).WithTrackingName(StepNames.CalculateStubInformation); + + var interfaceAndMethodsContexts = comMethodContexts.Collect() .Combine(interfaceContexts.Collect()) .SelectMany((data, ct) => GroupComContextsForInterfaceGeneration(data.Left, data.Right, ct)); @@ -376,7 +383,7 @@ private static ImmutableArray GroupComContextsFor foreach(var iface in interfaces) { var methodList = ImmutableArray.CreateBuilder(); - while (methodIndex < methods.Length && methods[methodIndex].TypeKeyOwner == iface) + while (methodIndex < methods.Length && methods[methodIndex].OwningInterface == iface) { methodList.Add(methods[methodIndex++]); } @@ -666,7 +673,5 @@ static ExpressionSyntax CreateEmbeddedDataBlobCreationStatement(ReadOnlySpan(literals))))); } } - - private sealed record InterfaceSymbolInfo(ComInterfaceInfo Info, Diagnostic? Diagnostic, TBaseInterfaceKey ThisInterfaceKey, TBaseInterfaceKey? BaseInterfaceKey); } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs index 0338dc5c81dd3..cd747bc6dc1c2 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; using Microsoft.CodeAnalysis; @@ -18,8 +19,8 @@ public sealed partial class ComInterfaceGenerator /// private sealed record ComInterfaceInfo( ManagedTypeInfo Type, - string ThisInterfaceKey, - string? BaseInterfaceKey, + string ThisInterfaceKey, // For associating interfaces to its base + string? BaseInterfaceKey, // For associating interfaces to its base InterfaceDeclarationSyntax Declaration, ContainingSyntaxContext TypeDefinitionContext, ContainingSyntax ContainingSyntax, @@ -50,10 +51,10 @@ public static (ComInterfaceInfo? Info, Diagnostic? Diagnostic) From(INamedTypeSy } } - if (!TryGetGuid(symbol, syntax, out var guid, out var guidDiagnostic)) + if (!TryGetGuid(symbol, syntax, out Guid? guid, out Diagnostic? guidDiagnostic)) return (null, guidDiagnostic); - if (!TryGetBaseComInterface(symbol, syntax, out var baseSymbol, out var baseDiagnostic)) + if (!TryGetBaseComInterface(symbol, syntax, out INamedTypeSymbol? baseSymbol, out Diagnostic? baseDiagnostic)) return (null, baseDiagnostic); return (new ComInterfaceInfo( @@ -69,7 +70,7 @@ public static (ComInterfaceInfo? Info, Diagnostic? Diagnostic) From(INamedTypeSy /// /// Returns true if there is 0 or 1 base Com interfaces (i.e. the inheritance is valid), and returns false when there are 2 or more base Com interfaces and sets . /// - private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, InterfaceDeclarationSyntax syntax, [NotNullWhen(true)] out INamedTypeSymbol? baseComIface, [NotNullWhen(false)] out Diagnostic? diagnostic) + private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, InterfaceDeclarationSyntax syntax, out INamedTypeSymbol? baseComIface, [NotNullWhen(false)] out Diagnostic? diagnostic) { baseComIface = null; foreach (var implemented in comIface.Interfaces) @@ -78,8 +79,6 @@ private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, InterfaceD { if (attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute) { - // We'll filter out cases where there's multiple matching interfaces when determining - // if this is a valid candidate for generation. if (baseComIface is not null) { diagnostic = Diagnostic.Create( @@ -103,14 +102,14 @@ private static bool TryGetGuid(INamedTypeSymbol interfaceSymbol, InterfaceDeclar { guid = null; AttributeData? guidAttr = null; - AttributeData? interfaceTypeAttr = null; + AttributeData? _ = null; // Interface Attribute Type. We'll always assume IUnkown for now. foreach (var attr in interfaceSymbol.GetAttributes()) { var attrDisplayString = attr.AttributeClass?.ToDisplayString(); if (attrDisplayString is TypeNames.System_Runtime_InteropServices_GuidAttribute) guidAttr = attr; else if (attrDisplayString is TypeNames.InterfaceTypeAttribute) - interfaceTypeAttr = attr; + _ = attr; } if (guidAttr is not null @@ -122,10 +121,12 @@ private static bool TryGetGuid(INamedTypeSymbol interfaceSymbol, InterfaceDeclar } // Assume interfaceType is IUnknown for now - if (interfaceTypeAttr is not null - && guid is null) + if (guid is null) { - diagnostic = Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedInterfaceMissingGuidAttribute, syntax.Identifier.GetLocation(), interfaceSymbol.ToDisplayString()); + diagnostic = Diagnostic.Create( + GeneratorDiagnostics.InvalidAttributedInterfaceMissingGuidAttribute, + syntax.Identifier.GetLocation(), + interfaceSymbol.ToDisplayString()); return false; } diagnostic = null; diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs index e3556f5a6c41e..41915f99d0ed1 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs @@ -18,17 +18,29 @@ public sealed partial class ComInterfaceGenerator /// Represents a method, its declaring interface, and its index in the interface's vtable. /// This type contains all information necessary to generate the corresponding methods in the ComInterfaceGenerator /// + /// + /// The interface that originally declared the method in user code + /// + /// + /// The interface that this methods is being generated for (may be different that OriginalDeclaringInterface if it is an inherited method) + /// + /// The basic information about the method. + /// The index on the interface vtable that points to this method + /// private sealed record ComMethodContext( - ComInterfaceContext DeclaringInterface, - // TypeKeyOwner is also the interface that the code is being generated for. - ComInterfaceContext TypeKeyOwner, + ComInterfaceContext OriginalDeclaringInterface, + ComInterfaceContext OwningInterface, ComMethodInfo MethodInfo, int Index, IncrementalMethodStubGenerationContext GenerationContext) { - public bool IsInheritedMethod => DeclaringInterface != TypeKeyOwner; + /// + /// A partially constructed that does not have a generated for it yet. + /// can be constructed without a reference to an ISymbol, whereas the requires an ISymbol + /// + public sealed record Builder(ComInterfaceContext OriginalDeclaringInterface, ComMethodInfo MethodInfo, int Index); - public sealed record Builder(ComInterfaceContext DeclaringInterface, ComMethodInfo MethodInfo, int Index); + public bool IsInheritedMethod => OriginalDeclaringInterface != OwningInterface; public GeneratedMethodContextBase ManagedToUnmanagedStub { @@ -36,7 +48,7 @@ public GeneratedMethodContextBase ManagedToUnmanagedStub { if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)) { - return new SkippedStubContext(DeclaringInterface.Info.Type); + return new SkippedStubContext(OriginalDeclaringInterface.Info.Type); } var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext); return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); @@ -60,9 +72,10 @@ public MethodDeclarationSyntax GenerateUnreachableExceptionStub() { // DeclarationCopiedFromBaseDeclaration() => throw new UnreachableException("This method should not be reached"); return MethodInfo.Syntax + .WithModifiers(TokenList()) .WithAttributeLists(List()) .WithExplicitInterfaceSpecifier(ExplicitInterfaceSpecifier( - ParseName(DeclaringInterface.Info.Type.FullTypeName))) + ParseName(OriginalDeclaringInterface.Info.Type.FullTypeName))) .WithExpressionBody(ArrowExpressionClause( ThrowExpression( ObjectCreationExpression( @@ -85,7 +98,7 @@ public MethodDeclarationSyntax GenerateShadow() MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, ParenthesizedExpression( - CastExpression(DeclaringInterface.Info.Type.Syntax, IdentifierName("this"))), + CastExpression(OriginalDeclaringInterface.Info.Type.Syntax, IdentifierName("this"))), IdentifierName(MethodInfo.MethodName)), ArgumentList( // TODO: RefKind keywords @@ -99,12 +112,13 @@ public MethodDeclarationSyntax GenerateShadow() /// public static List<(ComInterfaceContext TypeKeyOwner, Builder Method)> CalculateAllMethods(IEnumerable<(ComInterfaceContext, SequenceEqualImmutableArray)> ifaceAndDeclaredMethods, CancellationToken _) { - // opt : change this to only take in a hierarchy of interfaces. we calc that before and select ober that + // Optimization : This step technically only needs a single interface inheritance hierarchy. + // We can calculate all inheritance chains in a previous step and only pass a single inheritance chain to this method. + // This way, when a single method changes, we would only need to recalculate this for the inheritance chain in which that method exists. + var ifaceToDeclaredMethodsMap = ifaceAndDeclaredMethods.ToDictionary(static pair => pair.Item1, static pair => pair.Item2); - // Track insertion order var allMethodsCache = new Dictionary>(); - - List<(ComInterfaceContext TypeKeyOwner, Builder Method)> accumulator = new(); + var accumulator = new List<(ComInterfaceContext TypeKeyOwner, Builder Method)>(); foreach (var kvp in ifaceAndDeclaredMethods) { var methods = AddMethods(kvp.Item1, kvp.Item2); @@ -115,6 +129,9 @@ public MethodDeclarationSyntax GenerateShadow() } return accumulator; + /// + /// Adds methods to a cache and returns inherited and declared methods for the interface in vtable order + /// ImmutableArray AddMethods(ComInterfaceContext iface, IEnumerable declaredMethods) { if (allMethodsCache.TryGetValue(iface, out var cachedValue)) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs index d13ccbd6929cd..bbb8ff2eba531 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs @@ -5,7 +5,6 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; using Microsoft.CodeAnalysis; diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs index adb441ab899d0..7b38f16d4b72d 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs @@ -15,72 +15,100 @@ public class HashCode { public static int Combine(T1 t1, T2 t2) { - int hashCode1 = t1 != null ? t1.GetHashCode() : 0; - int hashCode2 = t2 != null ? t2.GetHashCode() : 0; - return Hash.Combine(hashCode1, hashCode2); + int hash1 = t1 != null ? t1.GetHashCode() : 0; + int hash2 = t2 != null ? t2.GetHashCode() : 0; + int combinedHash = Hash.Combine(hash1, hash2); + return combinedHash; } public static int Combine(T1 t1, T2 t2, T3 t3) { - int hashCode1 = t1 != null ? t1.GetHashCode() : 0; - int hashCode2 = t2 != null ? t2.GetHashCode() : 0; - int hashCode3 = t3 != null ? t3.GetHashCode() : 0; - return Hash.Combine(Hash.Combine(hashCode1, hashCode2), hashCode3); + int hash1 = t1 != null ? t1.GetHashCode() : 0; + int hash2 = t2 != null ? t2.GetHashCode() : 0; + int hash3 = t3 != null ? t3.GetHashCode() : 0; + int combinedHash = Hash.Combine(hash1, hash2); + combinedHash = Hash.Combine(combinedHash, hash3); + return combinedHash; } public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4) { - int hashCode1 = t1 != null ? t1.GetHashCode() : 0; - int hashCode2 = t2 != null ? t2.GetHashCode() : 0; - int hashCode3 = t3 != null ? t3.GetHashCode() : 0; - int hashCode4 = t4 != null ? t4.GetHashCode() : 0; - return Hash.Combine(Hash.Combine(Hash.Combine(hashCode1, hashCode2), hashCode3), hashCode4); + int hash1 = t1 != null ? t1.GetHashCode() : 0; + int hash2 = t2 != null ? t2.GetHashCode() : 0; + int hash3 = t3 != null ? t3.GetHashCode() : 0; + int hash4 = t4 != null ? t4.GetHashCode() : 0; + int combinedHash = Hash.Combine(hash1, hash2); + combinedHash = Hash.Combine(combinedHash, hash3); + combinedHash = Hash.Combine(combinedHash, hash4); + return combinedHash; } public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) { - int hashCode1 = t1 != null ? t1.GetHashCode() : 0; - int hashCode2 = t2 != null ? t2.GetHashCode() : 0; - int hashCode3 = t3 != null ? t3.GetHashCode() : 0; - int hashCode4 = t4 != null ? t4.GetHashCode() : 0; - int hashCode5 = t5 != null ? t5.GetHashCode() : 0; - return Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(hashCode1, hashCode2), hashCode3), hashCode4), hashCode5); + int hash1 = t1 != null ? t1.GetHashCode() : 0; + int hash2 = t2 != null ? t2.GetHashCode() : 0; + int hash3 = t3 != null ? t3.GetHashCode() : 0; + int hash4 = t4 != null ? t4.GetHashCode() : 0; + int hash5 = t5 != null ? t5.GetHashCode() : 0; + int combinedHash = Hash.Combine(hash1, hash2); + combinedHash = Hash.Combine(combinedHash, hash3); + combinedHash = Hash.Combine(combinedHash, hash4); + combinedHash = Hash.Combine(combinedHash, hash5); + return combinedHash; } public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) { - int hashCode1 = t1 != null ? t1.GetHashCode() : 0; - int hashCode2 = t2 != null ? t2.GetHashCode() : 0; - int hashCode3 = t3 != null ? t3.GetHashCode() : 0; - int hashCode4 = t4 != null ? t4.GetHashCode() : 0; - int hashCode5 = t5 != null ? t5.GetHashCode() : 0; - int hashCode6 = t6 != null ? t6.GetHashCode() : 0; - return Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(hashCode1, hashCode2), hashCode3), hashCode4), hashCode5), hashCode6); + int hash1 = t1 != null ? t1.GetHashCode() : 0; + int hash2 = t2 != null ? t2.GetHashCode() : 0; + int hash3 = t3 != null ? t3.GetHashCode() : 0; + int hash4 = t4 != null ? t4.GetHashCode() : 0; + int hash5 = t5 != null ? t5.GetHashCode() : 0; + int hash6 = t6 != null ? t6.GetHashCode() : 0; + int combinedHash = Hash.Combine(hash1, hash2); + combinedHash = Hash.Combine(combinedHash, hash3); + combinedHash = Hash.Combine(combinedHash, hash4); + combinedHash = Hash.Combine(combinedHash, hash5); + combinedHash = Hash.Combine(combinedHash, hash6); + return combinedHash; } public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) { - int hashCode1 = t1 != null ? t1.GetHashCode() : 0; - int hashCode2 = t2 != null ? t2.GetHashCode() : 0; - int hashCode3 = t3 != null ? t3.GetHashCode() : 0; - int hashCode4 = t4 != null ? t4.GetHashCode() : 0; - int hashCode5 = t5 != null ? t5.GetHashCode() : 0; - int hashCode6 = t6 != null ? t6.GetHashCode() : 0; - int hashCode7 = t7 != null ? t7.GetHashCode() : 0; - return Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(hashCode1, hashCode2), hashCode3), hashCode4), hashCode5), hashCode6), hashCode7); + int hash1 = t1 != null ? t1.GetHashCode() : 0; + int hash2 = t2 != null ? t2.GetHashCode() : 0; + int hash3 = t3 != null ? t3.GetHashCode() : 0; + int hash4 = t4 != null ? t4.GetHashCode() : 0; + int hash5 = t5 != null ? t5.GetHashCode() : 0; + int hash6 = t6 != null ? t6.GetHashCode() : 0; + int hash7 = t7 != null ? t7.GetHashCode() : 0; + int combinedHash = Hash.Combine(hash1, hash2); + combinedHash = Hash.Combine(combinedHash, hash3); + combinedHash = Hash.Combine(combinedHash, hash4); + combinedHash = Hash.Combine(combinedHash, hash5); + combinedHash = Hash.Combine(combinedHash, hash6); + combinedHash = Hash.Combine(combinedHash, hash7); + return combinedHash; } public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) { - int hashCode1 = t1 != null ? t1.GetHashCode() : 0; - int hashCode2 = t2 != null ? t2.GetHashCode() : 0; - int hashCode3 = t3 != null ? t3.GetHashCode() : 0; - int hashCode4 = t4 != null ? t4.GetHashCode() : 0; - int hashCode5 = t5 != null ? t5.GetHashCode() : 0; - int hashCode6 = t6 != null ? t6.GetHashCode() : 0; - int hashCode7 = t7 != null ? t7.GetHashCode() : 0; - int hashCode8 = t8 != null ? t8.GetHashCode() : 0; - return Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(Hash.Combine(hashCode1, hashCode2), hashCode3), hashCode4), hashCode5), hashCode6), hashCode7), hashCode8); + int hash1 = t1 != null ? t1.GetHashCode() : 0; + int hash2 = t2 != null ? t2.GetHashCode() : 0; + int hash3 = t3 != null ? t3.GetHashCode() : 0; + int hash4 = t4 != null ? t4.GetHashCode() : 0; + int hash5 = t5 != null ? t5.GetHashCode() : 0; + int hash6 = t6 != null ? t6.GetHashCode() : 0; + int hash7 = t7 != null ? t7.GetHashCode() : 0; + int hash8 = t8 != null ? t8.GetHashCode() : 0; + int combinedHash = Hash.Combine(hash1, hash2); + combinedHash = Hash.Combine(combinedHash, hash3); + combinedHash = Hash.Combine(combinedHash, hash4); + combinedHash = Hash.Combine(combinedHash, hash5); + combinedHash = Hash.Combine(combinedHash, hash6); + combinedHash = Hash.Combine(combinedHash, hash7); + combinedHash = Hash.Combine(combinedHash, hash8); + return combinedHash; } public static int SequentialValuesHash(IEnumerable values) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs index e1833b3119524..0c49cddb4cc15 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SequenceEqualImmutableArray.cs @@ -25,6 +25,8 @@ public SequenceEqualImmutableArray(ImmutableArray array) public T this[int i] { get => Array[i]; } public int Length => Array.Length; + public SequenceEqualImmutableArray Insert(int index, T item) + => new SequenceEqualImmutableArray(Array.Insert(index, item), Comparer); public override int GetHashCode() => HashCode.SequentialValuesHash(Array); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs deleted file mode 100644 index e957df5a57ac8..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ValueEqualityImmutableDictionary.cs +++ /dev/null @@ -1,79 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.Collections; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Linq; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Roslyn.Utilities; - -namespace Microsoft.Interop -{ - public record struct ValueEqualityImmutableDictionary(ImmutableDictionary Map) : IDictionary - { - public bool Equals(ValueEqualityImmutableDictionary other) - { - if (Count != other.Count) - { - return false; - } - - foreach(var kvp in this) - { - if (!other.TryGetValue(kvp.Key, out var value) || !kvp.Value.Equals(value)) - { - return false; - } - } - return true; - } - - public override int GetHashCode() - { - int hash = 17; - foreach(var value in Map) - { - hash = Hash.Combine(hash, value.Key.GetHashCode()); - hash = Hash.Combine(hash, value.Value.GetHashCode()); - } - return hash; - } - - public U this[T key] { get => ((IDictionary)Map)[key]; set => ((IDictionary)Map)[key] = value; } - public ICollection Keys => ((IDictionary)Map).Keys; - public ICollection Values => ((IDictionary)Map).Values; - public int Count => Map.Count; - public bool IsReadOnly => ((ICollection>)Map).IsReadOnly; - public bool Contains(KeyValuePair item) => Map.Contains(item); - public bool ContainsKey(T key) => Map.ContainsKey(key); - public void CopyTo(KeyValuePair[] array, int arrayIndex) => ((ICollection>)Map).CopyTo(array, arrayIndex); - public IEnumerator> GetEnumerator() => ((IEnumerable>)Map).GetEnumerator(); - public bool Remove(T key) => ((IDictionary)Map).Remove(key); - public bool Remove(KeyValuePair item) => ((ICollection>)Map).Remove(item); - public bool TryGetValue(T key, out U value) => Map.TryGetValue(key, out value); - IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)Map).GetEnumerator(); - public void Add(T key, U value) => ((IDictionary)Map).Add(key, value); - public void Add(KeyValuePair item) => ((ICollection>)Map).Add(item); - public void Clear() => ((ICollection>)Map).Clear(); - } - - public static partial class CollectionExtensions - { - public static ValueEqualityImmutableDictionary ToValueEqualityImmutableDictionary(this IEnumerable srcs, Func keyMap, Func valueMap) - { - return new(srcs.ToImmutableDictionary(keyMap, valueMap)); - } - - public static ValueEqualityImmutableDictionary ToValueEqual(this ImmutableDictionary dict) - { - return new(dict); - } - - public static ValueEqualityImmutableDictionary ToValueEqualImmutable(this Dictionary dict) - { - return new(dict.ToImmutableDictionary()); - } - } -} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs index ba077b2e33a45..feb08ecbe297c 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs @@ -6,6 +6,7 @@ using System.Reflection; using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; +using SharedTypes.ComInterfaces; using Xunit; namespace ComInterfaceGenerator.Tests; @@ -26,22 +27,22 @@ public unsafe void CallNativeComObjectThroughGeneratedStub() var cw = new StrategyBasedComWrappers(); var obj = cw.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None); - var intObj = (IComInterface1)obj; - Assert.Equal(0, intObj.GetData()); - intObj.SetData(2); - Assert.Equal(2, intObj.GetData()); + var intObj = (IGetAndSetInt)obj; + Assert.Equal(0, intObj.GetInt()); + intObj.SetInt(2); + Assert.Equal(2, intObj.GetInt()); } [Fact] public unsafe void DerivedInterfaceTypeProvidesBaseInterfaceUnmanagedToManagedMembers() { // Make sure that we have the correct derived and base types here. - Assert.Contains(typeof(IComInterface1), typeof(IDerivedComInterface).GetInterfaces()); + Assert.Contains(typeof(IGetAndSetInt), typeof(IDerivedComInterface).GetInterfaces()); - IIUnknownDerivedDetails baseInterfaceDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(IComInterface1).TypeHandle); + IIUnknownDerivedDetails baseInterfaceDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(IGetAndSetInt).TypeHandle); IIUnknownDerivedDetails derivedInterfaceDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(IDerivedComInterface).TypeHandle); - var numBaseMethods = typeof(IComInterface1).GetMethods().Length; + var numBaseMethods = typeof(IGetAndSetInt).GetMethods().Length; var numPointersToCompare = 3 + numBaseMethods; @@ -60,10 +61,9 @@ public unsafe void CallBaseInterfaceMethod_EnsureQiCalledOnce() var obj = cw.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None); IDerivedComInterface iface = (IDerivedComInterface)obj; - Assert.Equal(3, iface.GetData()); - // https://github.com/dotnet/runtime/issues/85795 - //iface.SetData(5); - //Assert.Equal(5, iface.GetData()); + Assert.Equal(3, iface.GetInt()); + iface.SetInt(5); + Assert.Equal(5, iface.GetInt()); // https://github.com/dotnet/runtime/issues/85795 //Assert.Equal("myName", iface.GetName()); @@ -80,10 +80,14 @@ partial class DerivedImpl : IDerivedComInterface { int data = 3; string myName = "myName"; - public int GetData() => data; + + public int GetInt() => data; + [return: MarshalUsing(typeof(Utf16StringMarshaller))] public string GetName() => myName; - public void SetData(int n) => data = n; + + public void SetInt(int n) => data = n; + public void SetName([MarshalUsing(typeof(Utf16StringMarshaller))] string name) => myName = name; } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs deleted file mode 100644 index efe0237d83b45..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.Runtime.InteropServices; -using System.Runtime.InteropServices.Marshalling; - -namespace ComInterfaceGenerator.Tests -{ - [GeneratedComInterface] - [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] - [Guid("2c3f9903-b586-46b1-881b-adfce9af47b1")] - public partial interface IComInterface1 - { - int GetData(); - void SetData(int n); - } - - [GeneratedComInterface] - partial interface I - { - void Method(); - void Method2(); - } - [GeneratedComInterface] - partial interface Empty - { - } - [GeneratedComInterface] - partial interface J - { - void Method(); - void Method2(); - } -} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGeneratorOutputShape.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGeneratorOutputShape.cs index d9063c07a7860..d98fc95e6437d 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGeneratorOutputShape.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGeneratorOutputShape.cs @@ -30,6 +30,7 @@ public async Task SingleComInterface() using System.Runtime.InteropServices.Marshalling; [GeneratedComInterface] + [Guid("9D3FD745-3C90-4C10-B140-FAFB01E3541D")] partial interface INativeAPI { void Method(); @@ -48,12 +49,14 @@ public async Task MultipleComInterfaces() using System.Runtime.InteropServices.Marshalling; [GeneratedComInterface] + [Guid("9D3FD745-3C90-4C10-B140-FAFB01E3541D")] partial interface I { void Method(); void Method2(); } [GeneratedComInterface] + [Guid("734AFCEC-8862-43CB-AB29-5A7954929E23")] partial interface J { void Method(); @@ -72,16 +75,19 @@ public async Task EmptyComInterface() using System.Runtime.InteropServices.Marshalling; [GeneratedComInterface] + [Guid("9D3FD745-3C90-4C10-B140-FAFB01E3541D")] partial interface I { void Method(); void Method2(); } [GeneratedComInterface] + [Guid("734AFCEC-8862-43CB-AB29-5A7954929E23")] partial interface Empty { } [GeneratedComInterface] + [Guid("734AFCEC-8862-43CB-AB29-5A7954929E23")] partial interface J { void Method(); @@ -100,12 +106,14 @@ public async Task InheritingComInterfaces() using System.Runtime.InteropServices.Marshalling; [GeneratedComInterface] + [Guid("9D3FD745-3C90-4C10-B140-FAFB01E3541D")] partial interface I { void Method(); void Method2(); } [GeneratedComInterface] + [Guid("734AFCEC-8862-43CB-AB29-5A7954929E23")] partial interface J : I { void MethodA(); diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedComInterface.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IDerivedComInterface.cs similarity index 69% rename from src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedComInterface.cs rename to src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IDerivedComInterface.cs index 7df314587e077..ff32a3f632fdf 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedComInterface.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IDerivedComInterface.cs @@ -5,15 +5,17 @@ using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; -namespace ComInterfaceGenerator.Tests +namespace SharedTypes.ComInterfaces { [GeneratedComInterface] - [Guid("7F0DB364-3C04-4487-9193-4BB05DC7B654")] - public partial interface IDerivedComInterface : IComInterface1 + [Guid(_guid)] + internal partial interface IDerivedComInterface : IGetAndSetInt { void SetName([MarshalUsing(typeof(Utf16StringMarshaller))] string name); [return:MarshalUsing(typeof(Utf16StringMarshaller))] string GetName(); + + internal new const string _guid = "7F0DB364-3C04-4487-9193-4BB05DC7B654"; } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IEmpty.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IEmpty.cs new file mode 100644 index 0000000000000..cf99ae8a47530 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IEmpty.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; + +namespace SharedTypes.ComInterfaces +{ + [GeneratedComInterface] + [Guid(_guid)] + internal partial interface Empty + { + public const string _guid = "95D19F50-F2D8-4E61-884B-0A9162EA4646"; + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetAndSetInt.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetAndSetInt.cs index 2d76ef4c4f5a5..f1b46c81b7966 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetAndSetInt.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetAndSetInt.cs @@ -2,18 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; -using System.Linq; using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; -using System.Text; -using System.Threading.Tasks; namespace SharedTypes.ComInterfaces { [GeneratedComInterface] [Guid(_guid)] - partial interface IGetAndSetInt + internal partial interface IGetAndSetInt { int GetInt(); diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetIntArray.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetIntArray.cs index 6b99c1d4bcac2..b70de07d1596b 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetIntArray.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetIntArray.cs @@ -2,18 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; -using System.Linq; using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; -using System.Text; -using System.Threading.Tasks; namespace SharedTypes.ComInterfaces { [GeneratedComInterface] [Guid(_guid)] - partial interface IGetIntArray + internal partial interface IGetIntArray { [return: MarshalUsing(ConstantElementCount = 10)] int[] GetInts(); From d78aeaf991c1c4f6a8bc5306c1a7349b66fb60cc Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Mon, 15 May 2023 13:02:46 -0700 Subject: [PATCH 26/31] PR Feedback --- .../ComInterfaceGenerator.cs | 21 +++++++---- .../ComInterfaceGenerator/ComMethodContext.cs | 36 ++++++++----------- .../ComInterfaceGenerator/ComMethodInfo.cs | 25 ++----------- .../IncrementalMethodStubGenerationContext.cs | 3 +- .../VtableIndexStubGenerator.cs | 1 + .../SignatureContext.cs | 2 ++ .../GeneratedComInterfaceTests.cs | 3 +- 7 files changed, 37 insertions(+), 54 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 47bbf9d08262c..457fc4af2ac76 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -37,6 +37,7 @@ public static class StepNames public const string GenerateNativeToManagedVTable = nameof(GenerateNativeToManagedVTable); public const string GenerateInterfaceInformation = nameof(GenerateInterfaceInformation); public const string GenerateIUnknownDerivedAttribute = nameof(GenerateIUnknownDerivedAttribute); + public const string GenerateShadowingMethods = nameof(GenerateShadowingMethods); } public void Initialize(IncrementalGeneratorInitializationContext context) @@ -109,13 +110,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context) CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.TypeKeyOwner.Info.Type, ct)); }).WithTrackingName(StepNames.CalculateStubInformation); - var interfaceAndMethodsContexts = comMethodContexts.Collect() + var interfaceAndMethodsContexts = comMethodContexts + .Collect() .Combine(interfaceContexts.Collect()) .SelectMany((data, ct) => GroupComContextsForInterfaceGeneration(data.Left, data.Right, ct)); // Generate the code for the managed-to-unmanaged stubs and the diagnostics from code-generation. context.RegisterDiagnostics(interfaceAndMethodsContexts - .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.ManagedToUnmanagedStub.Diagnostics))); + .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.GetManagedToUnmanagedStub().Diagnostics))); var managedToNativeInterfaceImplementations = interfaceAndMethodsContexts .Select(GenerateImplementationInterface) .WithTrackingName(StepNames.GenerateManagedToNativeInterfaceImplementation) @@ -124,7 +126,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Generate the code for the unmanaged-to-managed stubs and the diagnostics from code-generation. context.RegisterDiagnostics(interfaceAndMethodsContexts - .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.NativeToManagedStub.Diagnostics))); + .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.GetNativeToManagedStub().Diagnostics))); var nativeToManagedVtableMethods = interfaceAndMethodsContexts .Select(GenerateImplementationVTableMethods) .WithTrackingName(StepNames.GenerateNativeToManagedVTableMethods) @@ -150,6 +152,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .WithMembers(List(methods)); return data.Interface.Info.TypeDefinitionContext.WrapMemberInContainingSyntaxWithUnsafeModifier(typeDecl); }) + .WithTrackingName(StepNames.GenerateShadowingMethods) + .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); // Generate a method named CreateManagedVirtualFunctionTable on the native interface implementation @@ -347,6 +351,8 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M unmanagedCallConvAttribute, ImmutableArray.Create(FunctionPointerUnmanagedCallingConvention(Identifier("MemberFunction")))); + var declaringType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType); + var virtualMethodIndexData = new VirtualMethodIndexData(index, ImplicitThisParameter: true, MarshalDirection.Bidirectional, true, ExceptionMarshalling.Com); return new IncrementalMethodStubGenerationContext( @@ -360,6 +366,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.ManagedToUnmanaged), ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.UnmanagedToManaged), typeKeyOwner, + declaringType, generatorDiagnostics.Diagnostics.ToSequenceEqualImmutableArray(), ComInterfaceDispatchMarshallingInfo.Instance); } @@ -387,7 +394,7 @@ private static ImmutableArray GroupComContextsFor // This enable us to group our contexts by their containing syntax rather simply. var contextList = ImmutableArray.CreateBuilder(); int methodIndex = 0; - foreach(var iface in interfaces) + foreach (var iface in interfaces) { var methodList = ImmutableArray.CreateBuilder(); while (methodIndex < methods.Length && methods[methodIndex].OwningInterface == iface) @@ -405,7 +412,7 @@ private static ImmutableArray GroupComContextsFor private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _) { var definingType = interfaceGroup.Interface.Info.Type; - var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (m, m.ManagedToUnmanagedStub)) + var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.GetManagedToUnmanagedStub())) .Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext) .Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node .WithExplicitInterfaceSpecifier( @@ -416,7 +423,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInt .WithMembers( List( interfaceGroup.DeclaredMethods - .Select(m => m.ManagedToUnmanagedStub) + .Select(m => m.GetManagedToUnmanagedStub()) .OfType() .Select(ctx => ctx.Stub.Node) .Concat(shadowImplementations) @@ -429,7 +436,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(Co .WithMembers( List( comInterfaceAndMethods.DeclaredMethods - .Select(m => m.NativeToManagedStub) + .Select(m => m.GetNativeToManagedStub()) .OfType() .Select(context => context.Stub.Node))); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs index 41915f99d0ed1..31708cb9b8a1a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs @@ -42,30 +42,24 @@ public sealed record Builder(ComInterfaceContext OriginalDeclaringInterface, Com public bool IsInheritedMethod => OriginalDeclaringInterface != OwningInterface; - public GeneratedMethodContextBase ManagedToUnmanagedStub + public GeneratedMethodContextBase GetManagedToUnmanagedStub() { - get + if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)) { - if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)) - { - return new SkippedStubContext(OriginalDeclaringInterface.Info.Type); - } - var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext); - return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); + return new SkippedStubContext(OriginalDeclaringInterface.Info.Type); } + var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext); + return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); } - public GeneratedMethodContextBase NativeToManagedStub + public GeneratedMethodContextBase GetNativeToManagedStub() { - get + if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)) { - if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)) - { - return new SkippedStubContext(GenerationContext.OriginalDefiningType); - } - var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext); - return new GeneratedStubCodeContext(GenerationContext.OriginalDefiningType, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); + return new SkippedStubContext(GenerationContext.OriginalDefiningType); } + var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext); + return new GeneratedStubCodeContext(GenerationContext.OriginalDefiningType, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); } public MethodDeclarationSyntax GenerateUnreachableExceptionStub() @@ -89,9 +83,10 @@ public MethodDeclarationSyntax GenerateShadow() // { // return (()this).(); // } - // TODO: Copy full name of parameter types and attributes / attribute arguments for parameters - return MethodInfo.Syntax + var forwarder = new Forwarder(); + return MethodDeclaration(GenerationContext.SignatureContext.StubReturnType, MethodInfo.MethodName) .WithModifiers(TokenList(Token(SyntaxKind.NewKeyword))) + .WithParameterList(ParameterList(SeparatedList(GenerationContext.SignatureContext.StubParameters))) .WithExpressionBody( ArrowExpressionClause( InvocationExpression( @@ -101,9 +96,8 @@ public MethodDeclarationSyntax GenerateShadow() CastExpression(OriginalDeclaringInterface.Info.Type.Syntax, IdentifierName("this"))), IdentifierName(MethodInfo.MethodName)), ArgumentList( - // TODO: RefKind keywords - SeparatedList(MethodInfo.Parameters.Select(p => - Argument(IdentifierName(p.Name)))))))); + SeparatedList(GenerationContext.SignatureContext.ManagedParameters.Select(p => forwarder.AsArgument(p, new ManagedStubCodeContext()))))))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); } /// diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs index bbb8ff2eba531..f440c2c4d5dc1 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs @@ -20,8 +20,7 @@ public sealed partial class ComInterfaceGenerator /// private sealed record ComMethodInfo( MethodDeclarationSyntax Syntax, - string MethodName, - SequenceEqualImmutableArray Parameters) + string MethodName) { /// /// Returns a list of tuples of ComMethodInfo, IMethodSymbol, and Diagnostic. If ComMethodInfo is null, Diagnostic will not be null, and vice versa. @@ -113,29 +112,9 @@ private static (ComMethodInfo?, IMethodSymbol, Diagnostic?) CalculateMethodInfo( { return (null, method, diag); } - - List parameters = new(); - foreach (var parameter in method.Parameters) - { - parameters.Add(ParameterInfo.From(parameter)); - } - - var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name, parameters.ToSequenceEqualImmutableArray()); + var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name); return (comMethodInfo, method, null); } } } - - internal record struct ParameterInfo(ManagedTypeInfo Type, string Name, RefKind RefKind, SequenceEqualImmutableArray Attributes) - { - public static ParameterInfo From(IParameterSymbol parameter) - { - var attributes = new List(); - foreach (var attribute in parameter.GetAttributes()) - { - attributes.Add(AttributeInfo.From(attribute)); - } - return new(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(parameter.Type), parameter.Name, parameter.RefKind, attributes.ToSequenceEqualImmutableArray()); - } - } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalMethodStubGenerationContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalMethodStubGenerationContext.cs index 7cd5842ba47c7..ff3d6f32d5f23 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalMethodStubGenerationContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalMethodStubGenerationContext.cs @@ -21,6 +21,7 @@ internal sealed record IncrementalMethodStubGenerationContext( MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> ManagedToUnmanagedGeneratorFactory, MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> UnmanagedToManagedGeneratorFactory, ManagedTypeInfo TypeKeyOwner, + ManagedTypeInfo DeclaringType, SequenceEqualImmutableArray Diagnostics, - MarshallingInfo ManagedThisMarshallingInfo) : GeneratedMethodContextBase(TypeKeyOwner, Diagnostics); + MarshallingInfo ManagedThisMarshallingInfo) : GeneratedMethodContextBase(DeclaringType, Diagnostics); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs index 13ee527043de0..1d161ff1d5f2b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs @@ -312,6 +312,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M VtableIndexStubGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.ManagedToUnmanaged), VtableIndexStubGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.UnmanagedToManaged), interfaceType, + interfaceType, new SequenceEqualImmutableArray(generatorDiagnostics.Diagnostics.ToImmutableArray()), new ObjectUnwrapperInfo(unwrapperSyntax)); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs index d5d7526c3b45b..44821703e16d9 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs @@ -30,6 +30,8 @@ private SignatureContext() public ImmutableArray ElementTypeInformation { get; init; } + public IEnumerable ManagedParameters => ElementTypeInformation.Where(tpi => tpi.ManagedIndex >= 0); + public TypeSyntax StubReturnType { get; init; } public IEnumerable StubParameters diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs index feb08ecbe297c..9ef4284427eb6 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComInterfaceTests.cs @@ -83,12 +83,11 @@ partial class DerivedImpl : IDerivedComInterface public int GetInt() => data; - [return: MarshalUsing(typeof(Utf16StringMarshaller))] public string GetName() => myName; public void SetInt(int n) => data = n; - public void SetName([MarshalUsing(typeof(Utf16StringMarshaller))] string name) => myName = name; + public void SetName(string name) => myName = name; } class SingleQIComWrapper : StrategyBasedComWrappers From 5659559811d43d2b4fe69854351fd620aae6507b Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Mon, 15 May 2023 13:04:40 -0700 Subject: [PATCH 27/31] Forgot to add new file --- .../ManagedStubCodeContext.cs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedStubCodeContext.cs diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedStubCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedStubCodeContext.cs new file mode 100644 index 0000000000000..d4a83ca851f0c --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedStubCodeContext.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Interop +{ + /// + /// Stub code context for generating code that does not cross a native/managed boundary + /// + internal sealed record ManagedStubCodeContext : StubCodeContext + { + public override bool SingleFrameSpansNativeContext => throw new NotImplementedException(); + + public override bool AdditionalTemporaryStateLivesAcrossStages => throw new NotImplementedException(); + + public override (TargetFramework framework, Version version) GetTargetFramework() => throw new NotImplementedException(); + } +} From 2d201d19655ae3464b9a0ef68f1cfc948e9edbc1 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Mon, 15 May 2023 14:00:10 -0700 Subject: [PATCH 28/31] Update src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs Co-authored-by: Jeremy Koritzinsky --- .../gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs index 44821703e16d9..f39f460eaef9f 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/SignatureContext.cs @@ -30,7 +30,7 @@ private SignatureContext() public ImmutableArray ElementTypeInformation { get; init; } - public IEnumerable ManagedParameters => ElementTypeInformation.Where(tpi => tpi.ManagedIndex >= 0); + public IEnumerable ManagedParameters => ElementTypeInformation.Where(tpi => !TypePositionInfo.IsSpecialIndex(tpi.ManagedIndex)); public TypeSyntax StubReturnType { get; init; } From 68ea987bf65fb37a88c1d81efb43b6aa3ad5649c Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Mon, 15 May 2023 15:28:37 -0700 Subject: [PATCH 29/31] Report diagnostic when unable to analyze code --- .../ComInterfaceGenerator/ComMethodInfo.cs | 4 ++-- .../GeneratorDiagnostics.cs | 21 +++++++++++++++++++ .../Resources/Strings.resx | 12 +++++++++++ .../Resources/xlf/Strings.cs.xlf | 20 ++++++++++++++++++ .../Resources/xlf/Strings.de.xlf | 20 ++++++++++++++++++ .../Resources/xlf/Strings.es.xlf | 20 ++++++++++++++++++ .../Resources/xlf/Strings.fr.xlf | 20 ++++++++++++++++++ .../Resources/xlf/Strings.it.xlf | 20 ++++++++++++++++++ .../Resources/xlf/Strings.ja.xlf | 20 ++++++++++++++++++ .../Resources/xlf/Strings.ko.xlf | 20 ++++++++++++++++++ .../Resources/xlf/Strings.pl.xlf | 20 ++++++++++++++++++ .../Resources/xlf/Strings.pt-BR.xlf | 20 ++++++++++++++++++ .../Resources/xlf/Strings.ru.xlf | 20 ++++++++++++++++++ .../Resources/xlf/Strings.tr.xlf | 20 ++++++++++++++++++ .../Resources/xlf/Strings.zh-Hans.xlf | 20 ++++++++++++++++++ .../Resources/xlf/Strings.zh-Hant.xlf | 20 ++++++++++++++++++ 16 files changed, 295 insertions(+), 2 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs index f440c2c4d5dc1..38bf19a462f6a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs @@ -86,7 +86,7 @@ private static (ComMethodInfo?, IMethodSymbol, Diagnostic?) CalculateMethodInfo( // TODO: this should cause a diagnostic if (methodLocationInAttributedInterfaceDeclaration is null) { - throw new NotImplementedException($"Could not find location for method {method.ToDisplayString()} within the attributed declaration"); + return (null, method, Diagnostic.Create(GeneratorDiagnostics.CannotAnalyzeMethodPattern, method.Locations.FirstOrDefault(), method.ToDisplayString())); } @@ -104,7 +104,7 @@ private static (ComMethodInfo?, IMethodSymbol, Diagnostic?) CalculateMethodInfo( } if (comMethodDeclaringSyntax is null) { - throw new NotImplementedException("Found a method that was declared in the attributed interface declaration, but couldn't find the syntax for it."); + return (null, method, Diagnostic.Create(GeneratorDiagnostics.CannotAnalyzeMethodPattern, method.Locations.FirstOrDefault(), method.ToDisplayString())); } var diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, method); diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratorDiagnostics.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratorDiagnostics.cs index 9df719d19d51f..36fc49267fea0 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratorDiagnostics.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratorDiagnostics.cs @@ -23,6 +23,7 @@ public class Ids public const string MethodNotDeclaredInAttributedInterface = Prefix + "1091"; public const string InvalidGeneratedComInterfaceAttributeUsage = Prefix + "1092"; public const string MultipleComInterfaceBaseTypes = Prefix + "1093"; + public const string AnalysisFailed = Prefix + "1094"; } private const string Category = "ComInterfaceGenerator"; @@ -197,6 +198,26 @@ public class Ids isEnabledByDefault: true, description: GetResourceString(nameof(SR.MultipleComInterfaceBaseTypesDescription))); + public static readonly DiagnosticDescriptor CannotAnalyzeMethodPattern = + new DiagnosticDescriptor( + Ids.AnalysisFailed, + GetResourceString(nameof(SR.AnalysisFailedTitle)), + GetResourceString(nameof(SR.AnalysisFailedMethodMessage)), + Category, + DiagnosticSeverity.Warning, + isEnabledByDefault: true, + description: GetResourceString(nameof(SR.AnalysisFailedDescription))); + + public static readonly DiagnosticDescriptor CannotAnalyzeInterfacePattern = + new DiagnosticDescriptor( + Ids.AnalysisFailed, + GetResourceString(nameof(SR.AnalysisFailedTitle)), + GetResourceString(nameof(SR.AnalysisFailedInterfaceMessage)), + Category, + DiagnosticSeverity.Warning, + isEnabledByDefault: true, + description: GetResourceString(nameof(SR.AnalysisFailedDescription))); + private readonly List _diagnostics = new List(); public IEnumerable Diagnostics => _diagnostics; diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/Strings.resx b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/Strings.resx index 0d9b0b6631e08..9c0cc44426487 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/Strings.resx +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/Strings.resx @@ -237,4 +237,16 @@ Specified interface derives from two or more 'GeneratedComInterfaceAttribute'-attributed interfaces. + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + Analysis for generation has failed. + \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.cs.xlf b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.cs.xlf index 04697d456efdf..241afaa3a4af3 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.cs.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.cs.xlf @@ -2,6 +2,26 @@ + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + + Analysis for generation has failed. + Analysis for generation has failed. + + Source-generated P/Invokes will ignore any configuration that is not supported. Zdrojem generovaná volání P/Invokes budou ignorovat všechny nepodporované konfigurace. diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.de.xlf b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.de.xlf index 9d06e7ad05992..37cda7bfd599b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.de.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.de.xlf @@ -2,6 +2,26 @@ + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + + Analysis for generation has failed. + Analysis for generation has failed. + + Source-generated P/Invokes will ignore any configuration that is not supported. Quellgenerierte P/Invokes ignorieren alle Konfigurationen, die nicht unterstützt werden. diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.es.xlf b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.es.xlf index a169ab705ea4d..71b79c45dd520 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.es.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.es.xlf @@ -2,6 +2,26 @@ + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + + Analysis for generation has failed. + Analysis for generation has failed. + + Source-generated P/Invokes will ignore any configuration that is not supported. Los P/Invoke de un generador de código fuente omitirán cualquier configuración que no esté admitida. diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.fr.xlf b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.fr.xlf index f8fd5f71c42c4..ed463398efa79 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.fr.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.fr.xlf @@ -2,6 +2,26 @@ + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + + Analysis for generation has failed. + Analysis for generation has failed. + + Source-generated P/Invokes will ignore any configuration that is not supported. Les P/Invokes générés par la source ignorent toute configuration qui n’est pas prise en charge. diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.it.xlf b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.it.xlf index 45db75981a336..bf2b7f5ef0065 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.it.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.it.xlf @@ -2,6 +2,26 @@ + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + + Analysis for generation has failed. + Analysis for generation has failed. + + Source-generated P/Invokes will ignore any configuration that is not supported. I P/Invoke generati dall'origine ignoreranno qualsiasi configurazione non supportata. diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.ja.xlf b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.ja.xlf index 9ae0a4d6af416..c787e252ca99d 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.ja.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.ja.xlf @@ -2,6 +2,26 @@ + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + + Analysis for generation has failed. + Analysis for generation has failed. + + Source-generated P/Invokes will ignore any configuration that is not supported. ソース生成済みの P/Invoke は、サポートされていない構成を無視します。 diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.ko.xlf b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.ko.xlf index 65ae4ad12efce..082dafc2a19f6 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.ko.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.ko.xlf @@ -2,6 +2,26 @@ + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + + Analysis for generation has failed. + Analysis for generation has failed. + + Source-generated P/Invokes will ignore any configuration that is not supported. 소스 생성 P/Invoke는 지원되지 않는 구성을 무시합니다. diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.pl.xlf b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.pl.xlf index 6f340f9776c2c..ec95453294f1e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.pl.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.pl.xlf @@ -2,6 +2,26 @@ + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + + Analysis for generation has failed. + Analysis for generation has failed. + + Source-generated P/Invokes will ignore any configuration that is not supported. Funkcja P/Invokes generowana przez źródło zignoruje każdą nieobsługiwaną konfigurację. diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.pt-BR.xlf b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.pt-BR.xlf index c12ce4422ba38..03693164c2aab 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.pt-BR.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.pt-BR.xlf @@ -2,6 +2,26 @@ + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + + Analysis for generation has failed. + Analysis for generation has failed. + + Source-generated P/Invokes will ignore any configuration that is not supported. P/Invokes gerados pela origem ignorarão qualquer configuração sem suporte. diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.ru.xlf b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.ru.xlf index 21eefbe895001..83fef8b4d6caa 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.ru.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.ru.xlf @@ -2,6 +2,26 @@ + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + + Analysis for generation has failed. + Analysis for generation has failed. + + Source-generated P/Invokes will ignore any configuration that is not supported. P/Invoke с созданием источника будут игнорировать все неподдерживаемые конфигурации. diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.tr.xlf b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.tr.xlf index 70597a22a7bf5..cefac4ef490f4 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.tr.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.tr.xlf @@ -2,6 +2,26 @@ + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + + Analysis for generation has failed. + Analysis for generation has failed. + + Source-generated P/Invokes will ignore any configuration that is not supported. Kaynak tarafından oluşturulan P/Invokes desteklenmeyen yapılandırmaları yok sayar. diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.zh-Hans.xlf b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.zh-Hans.xlf index 73f74d04f4175..b8b3951366925 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.zh-Hans.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.zh-Hans.xlf @@ -2,6 +2,26 @@ + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + + Analysis for generation has failed. + Analysis for generation has failed. + + Source-generated P/Invokes will ignore any configuration that is not supported. 源生成的 P/Invoke 将忽略任何不受支持的配置。 diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.zh-Hant.xlf b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.zh-Hant.xlf index 6ca68e2bc3414..83558aa554498 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.zh-Hant.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Resources/xlf/Strings.zh-Hant.xlf @@ -2,6 +2,26 @@ + + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + The analysis required to generate code for this interface or method has failed due to an unexpected code pattern. If you are using new or unconventional syntax, consider using other syntax. + + + + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + Analysis of interface '{0}' has failed. ComInterfaceGenerator will not generate code for this interface. + + + + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + Analysis of method '{0}' has failed. ComInterfaceGenerator will not generate code for this method. + + + + Analysis for generation has failed. + Analysis for generation has failed. + + Source-generated P/Invokes will ignore any configuration that is not supported. 来源產生的 P/Invokes 將會忽略任何不支援的設定。 From ec978c6160dae3c7c6992a0201421644d740a4c1 Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Mon, 15 May 2023 16:11:18 -0700 Subject: [PATCH 30/31] Inline local variables in HashCode --- .../HashCode.cs | 104 ++++++------------ 1 file changed, 34 insertions(+), 70 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs index 7b38f16d4b72d..1fce3fc765a53 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/HashCode.cs @@ -15,100 +15,64 @@ public class HashCode { public static int Combine(T1 t1, T2 t2) { - int hash1 = t1 != null ? t1.GetHashCode() : 0; - int hash2 = t2 != null ? t2.GetHashCode() : 0; - int combinedHash = Hash.Combine(hash1, hash2); - return combinedHash; + return Hash.Combine(t1 != null ? t1.GetHashCode() : 0, t2 != null ? t2.GetHashCode() : 0); } public static int Combine(T1 t1, T2 t2, T3 t3) { - int hash1 = t1 != null ? t1.GetHashCode() : 0; - int hash2 = t2 != null ? t2.GetHashCode() : 0; - int hash3 = t3 != null ? t3.GetHashCode() : 0; - int combinedHash = Hash.Combine(hash1, hash2); - combinedHash = Hash.Combine(combinedHash, hash3); - return combinedHash; + int combinedHash = t1 != null ? t1.GetHashCode() : 0; + combinedHash = Hash.Combine(combinedHash, t2 != null ? t2.GetHashCode() : 0); + return Hash.Combine(combinedHash, t3 != null ? t3.GetHashCode() : 0); } public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4) { - int hash1 = t1 != null ? t1.GetHashCode() : 0; - int hash2 = t2 != null ? t2.GetHashCode() : 0; - int hash3 = t3 != null ? t3.GetHashCode() : 0; - int hash4 = t4 != null ? t4.GetHashCode() : 0; - int combinedHash = Hash.Combine(hash1, hash2); - combinedHash = Hash.Combine(combinedHash, hash3); - combinedHash = Hash.Combine(combinedHash, hash4); - return combinedHash; + int combinedHash = t1 != null ? t1.GetHashCode() : 0; + combinedHash = Hash.Combine(combinedHash, t2 != null ? t2.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t3 != null ? t3.GetHashCode() : 0); + return Hash.Combine(combinedHash, t4 != null ? t4.GetHashCode() : 0); } public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) { - int hash1 = t1 != null ? t1.GetHashCode() : 0; - int hash2 = t2 != null ? t2.GetHashCode() : 0; - int hash3 = t3 != null ? t3.GetHashCode() : 0; - int hash4 = t4 != null ? t4.GetHashCode() : 0; - int hash5 = t5 != null ? t5.GetHashCode() : 0; - int combinedHash = Hash.Combine(hash1, hash2); - combinedHash = Hash.Combine(combinedHash, hash3); - combinedHash = Hash.Combine(combinedHash, hash4); - combinedHash = Hash.Combine(combinedHash, hash5); - return combinedHash; + int combinedHash = t1 != null ? t1.GetHashCode() : 0; + combinedHash = Hash.Combine(combinedHash, t2 != null ? t2.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t3 != null ? t3.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t4 != null ? t4.GetHashCode() : 0); + return Hash.Combine(combinedHash, t5 != null ? t5.GetHashCode() : 0); } public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) { - int hash1 = t1 != null ? t1.GetHashCode() : 0; - int hash2 = t2 != null ? t2.GetHashCode() : 0; - int hash3 = t3 != null ? t3.GetHashCode() : 0; - int hash4 = t4 != null ? t4.GetHashCode() : 0; - int hash5 = t5 != null ? t5.GetHashCode() : 0; - int hash6 = t6 != null ? t6.GetHashCode() : 0; - int combinedHash = Hash.Combine(hash1, hash2); - combinedHash = Hash.Combine(combinedHash, hash3); - combinedHash = Hash.Combine(combinedHash, hash4); - combinedHash = Hash.Combine(combinedHash, hash5); - combinedHash = Hash.Combine(combinedHash, hash6); - return combinedHash; + int combinedHash = t1 != null ? t1.GetHashCode() : 0; + combinedHash = Hash.Combine(combinedHash, t2 != null ? t2.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t3 != null ? t3.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t4 != null ? t4.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t5 != null ? t5.GetHashCode() : 0); + return Hash.Combine(combinedHash, t6 != null ? t6.GetHashCode() : 0); } public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) { - int hash1 = t1 != null ? t1.GetHashCode() : 0; - int hash2 = t2 != null ? t2.GetHashCode() : 0; - int hash3 = t3 != null ? t3.GetHashCode() : 0; - int hash4 = t4 != null ? t4.GetHashCode() : 0; - int hash5 = t5 != null ? t5.GetHashCode() : 0; - int hash6 = t6 != null ? t6.GetHashCode() : 0; - int hash7 = t7 != null ? t7.GetHashCode() : 0; - int combinedHash = Hash.Combine(hash1, hash2); - combinedHash = Hash.Combine(combinedHash, hash3); - combinedHash = Hash.Combine(combinedHash, hash4); - combinedHash = Hash.Combine(combinedHash, hash5); - combinedHash = Hash.Combine(combinedHash, hash6); - combinedHash = Hash.Combine(combinedHash, hash7); - return combinedHash; + int combinedHash = t1 != null ? t1.GetHashCode() : 0; + combinedHash = Hash.Combine(combinedHash, t2 != null ? t2.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t3 != null ? t3.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t4 != null ? t4.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t5 != null ? t5.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t6 != null ? t6.GetHashCode() : 0); + return Hash.Combine(combinedHash, t7 != null ? t7.GetHashCode() : 0); } public static int Combine(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) { - int hash1 = t1 != null ? t1.GetHashCode() : 0; - int hash2 = t2 != null ? t2.GetHashCode() : 0; - int hash3 = t3 != null ? t3.GetHashCode() : 0; - int hash4 = t4 != null ? t4.GetHashCode() : 0; - int hash5 = t5 != null ? t5.GetHashCode() : 0; - int hash6 = t6 != null ? t6.GetHashCode() : 0; - int hash7 = t7 != null ? t7.GetHashCode() : 0; - int hash8 = t8 != null ? t8.GetHashCode() : 0; - int combinedHash = Hash.Combine(hash1, hash2); - combinedHash = Hash.Combine(combinedHash, hash3); - combinedHash = Hash.Combine(combinedHash, hash4); - combinedHash = Hash.Combine(combinedHash, hash5); - combinedHash = Hash.Combine(combinedHash, hash6); - combinedHash = Hash.Combine(combinedHash, hash7); - combinedHash = Hash.Combine(combinedHash, hash8); - return combinedHash; + int combinedHash = t1 != null ? t1.GetHashCode() : 0; + combinedHash = Hash.Combine(combinedHash, t2 != null ? t2.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t3 != null ? t3.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t4 != null ? t4.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t5 != null ? t5.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t6 != null ? t6.GetHashCode() : 0); + combinedHash = Hash.Combine(combinedHash, t7 != null ? t7.GetHashCode() : 0); + return Hash.Combine(combinedHash, t8 != null ? t8.GetHashCode() : 0); } public static int SequentialValuesHash(IEnumerable values) From 01fa078305ebb3bdeeca282271b735e94f55471d Mon Sep 17 00:00:00 2001 From: Jackson Schuster Date: Mon, 15 May 2023 17:03:04 -0700 Subject: [PATCH 31/31] Return IUnknown methods in VTable for empty interface --- .../ComInterfaceGenerator.cs | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 457fc4af2ac76..303a20aa63849 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -453,22 +453,6 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf var interfaceType = interfaceMethods.Interface.Info.Type; var interfaceMethodStubs = interfaceMethods.DeclaredMethods.Select(m => m.GenerationContext); - ImmutableArray vtableExposedContexts = interfaceMethodStubs - .Where(c => c.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) - .ToImmutableArray(); - - // If none of the methods are exposed as part of the vtable, then don't emit - // a vtable (return null). - if (vtableExposedContexts.Length == 0) - { - return ImplementationInterfaceTemplate - .AddMembers( - CreateManagedVirtualFunctionTableMethodTemplate - .WithBody( - Block( - ReturnStatement(LiteralExpression(SyntaxKind.NullLiteralExpression))))); - } - // void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(, sizeof(void*) * ); var vtableDeclarationStatement = LocalDeclarationStatement(