diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 77b40b5dc3d2a..0690bc0da1986 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -7,6 +7,7 @@ using System.Diagnostics; using System.IO; using System.Linq; +using System.Reflection; using System.Threading; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; @@ -448,46 +449,49 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M // Create the stub. var signatureContext = SignatureContext.Create(symbol, DefaultMarshallingInfoParser.Create(environment, generatorDiagnostics, symbol, new InteropAttributeCompilationData(), generatedComAttribute), environment, typeof(VtableIndexStubGenerator).Assembly); - // Search for the element information for the managed return value. - // We need to transform it such that any return type is converted to an out parameter at the end of the parameter list. - ImmutableArray returnSwappedSignatureElements = signatureContext.ElementTypeInformation; - for (int i = 0; i < returnSwappedSignatureElements.Length; ++i) + if (!symbol.MethodImplementationFlags.HasFlag(MethodImplAttributes.PreserveSig)) { - if (returnSwappedSignatureElements[i].IsManagedReturnPosition) + // Search for the element information for the managed return value. + // We need to transform it such that any return type is converted to an out parameter at the end of the parameter list. + ImmutableArray returnSwappedSignatureElements = signatureContext.ElementTypeInformation; + for (int i = 0; i < returnSwappedSignatureElements.Length; ++i) { - if (returnSwappedSignatureElements[i].ManagedType == SpecialTypeInfo.Void) + if (returnSwappedSignatureElements[i].IsManagedReturnPosition) { - // Return type is void, just remove the element from the signature list. - // We don't introduce an out parameter. - returnSwappedSignatureElements = returnSwappedSignatureElements.RemoveAt(i); - } - else - { - // Convert the current element into an out parameter on the native signature - // while keeping it at the return position in the managed signature. - var managedSignatureAsNativeOut = returnSwappedSignatureElements[i] with + if (returnSwappedSignatureElements[i].ManagedType == SpecialTypeInfo.Void) { - RefKind = RefKind.Out, - RefKindSyntax = SyntaxKind.OutKeyword, - ManagedIndex = TypePositionInfo.ReturnIndex, - NativeIndex = symbol.Parameters.Length - }; - returnSwappedSignatureElements = returnSwappedSignatureElements.SetItem(i, managedSignatureAsNativeOut); + // Return type is void, just remove the element from the signature list. + // We don't introduce an out parameter. + returnSwappedSignatureElements = returnSwappedSignatureElements.RemoveAt(i); + } + else + { + // Convert the current element into an out parameter on the native signature + // while keeping it at the return position in the managed signature. + var managedSignatureAsNativeOut = returnSwappedSignatureElements[i] with + { + RefKind = RefKind.Out, + RefKindSyntax = SyntaxKind.OutKeyword, + ManagedIndex = TypePositionInfo.ReturnIndex, + NativeIndex = symbol.Parameters.Length + }; + returnSwappedSignatureElements = returnSwappedSignatureElements.SetItem(i, managedSignatureAsNativeOut); + } + break; } - break; } - } - signatureContext = signatureContext with - { - // Add the HRESULT return value in the native signature. - // This element does not have any influence on the managed signature, so don't assign a managed index. - ElementTypeInformation = returnSwappedSignatureElements.Add( - new TypePositionInfo(SpecialTypeInfo.Int32, new ManagedHResultExceptionMarshallingInfo()) - { - NativeIndex = TypePositionInfo.ReturnIndex - }) - }; + signatureContext = signatureContext with + { + // Add the HRESULT return value in the native signature. + // This element does not have any influence on the managed signature, so don't assign a managed index. + ElementTypeInformation = returnSwappedSignatureElements.Add( + new TypePositionInfo(SpecialTypeInfo.Int32, new ManagedHResultExceptionMarshallingInfo()) + { + NativeIndex = TypePositionInfo.ReturnIndex + }) + }; + } var containingSyntaxContext = new ContainingSyntaxContext(syntax); diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CallingConventionForwarding.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/TargetSignatureTests.cs similarity index 56% rename from src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CallingConventionForwarding.cs rename to src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/TargetSignatureTests.cs index 627d5f860eaf5..a418eaafa2ea4 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CallingConventionForwarding.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/TargetSignatureTests.cs @@ -2,19 +2,21 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections.Generic; using System.Linq; using System.Reflection.Metadata; using System.Threading.Tasks; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Operations; using Microsoft.CodeAnalysis.Testing; +using Microsoft.Interop; using Xunit; -using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; +using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; namespace ComInterfaceGenerator.Unit.Tests { - public class CallingConventionForwarding + public class TargetSignatureTests { [Fact] public async Task NoSpecifiedCallConvForwardsDefault() @@ -32,7 +34,7 @@ partial interface INativeAPI : IUnmanagedInterfaceType } """; - await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (compilation, signature) => + await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (compilation, signature) => { Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention); Assert.Empty(signature.UnmanagedCallingConventionTypes); @@ -56,7 +58,7 @@ partial interface INativeAPI : IUnmanagedInterfaceType } """; - await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) => + await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) => { Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention); Assert.Equal(newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvSuppressGCTransition"), Assert.Single(signature.UnmanagedCallingConventionTypes), SymbolEqualityComparer.Default); @@ -80,7 +82,7 @@ partial interface INativeAPI : IUnmanagedInterfaceType } """; - await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (_, signature) => + await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (_, signature) => { Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention); Assert.Empty(signature.UnmanagedCallingConventionTypes); @@ -105,7 +107,7 @@ partial interface INativeAPI : IUnmanagedInterfaceType } """; - await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (_, signature) => + await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (_, signature) => { Assert.Equal(SignatureCallingConvention.CDecl, signature.CallingConvention); Assert.Empty(signature.UnmanagedCallingConventionTypes); @@ -130,7 +132,7 @@ partial interface INativeAPI : IUnmanagedInterfaceType } """; - await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) => + await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) => { Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention); Assert.Equal(new[] @@ -162,7 +164,7 @@ partial interface INativeAPI : IUnmanagedInterfaceType } """; - await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) => + await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) => { Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention); Assert.Equal(new[] @@ -176,9 +178,80 @@ await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signa }); } - private static async Task VerifySourceGeneratorAsync(string source, string interfaceName, string methodName, Action signatureValidator) + [Fact] + public async Task ComInterfaceMethodFunctionPointerReturnsInt() { - CallingConventionForwardingTest test = new(interfaceName, methodName, signatureValidator) + string source = $$""" + using System.Runtime.CompilerServices; + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + [Guid("0A617667-4961-4F90-B74F-6DC368E98179")] + partial interface IComInterface + { + void Method(); + } + """; + + await VerifyComInterfaceGeneratorAsync(source, "IComInterface", "Method", (newComp, signature) => + { + Assert.Equal(SpecialType.System_Int32, signature.ReturnType.SpecialType); + }); + } + + [Fact] + public async Task ComInterfaceMethodFunctionPointerReturnTypeChangedToOutParameter() + { + string source = $$""" + using System.Runtime.CompilerServices; + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + [Guid("0A617667-4961-4F90-B74F-6DC368E98179")] + partial interface IComInterface + { + long Method(); + } + """; + + await VerifyComInterfaceGeneratorAsync(source, "IComInterface", "Method", (newComp, signature) => + { + Assert.Equal(SpecialType.System_Int32, signature.ReturnType.SpecialType); + Assert.Equal(2, signature.Parameters.Length); + Assert.Equal(newComp.CreatePointerTypeSymbol(newComp.GetSpecialType(SpecialType.System_Void)), signature.Parameters[0].Type, SymbolEqualityComparer.Default); + Assert.Equal(newComp.CreatePointerTypeSymbol(newComp.GetSpecialType(SpecialType.System_Int64)), signature.Parameters[^1].Type, SymbolEqualityComparer.Default); + }); + } + + [Fact] + public async Task ComInterfaceMethodPreserveSigFunctionPointerReturnTypePreserved() + { + string source = $$""" + using System.Runtime.CompilerServices; + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + [Guid("0A617667-4961-4F90-B74F-6DC368E98179")] + partial interface IComInterface + { + [PreserveSig] + long Method(); + } + """; + + await VerifyComInterfaceGeneratorAsync(source, "IComInterface", "Method", (newComp, signature) => + { + Assert.Equal(SpecialType.System_Int64, signature.ReturnType.SpecialType); + Assert.Equal(newComp.CreatePointerTypeSymbol(newComp.GetSpecialType(SpecialType.System_Void)), Assert.Single(signature.Parameters).Type, SymbolEqualityComparer.Default); + }); + } + + private static async Task VerifyVirtualMethodIndexGeneratorAsync(string source, string interfaceName, string methodName, Action signatureValidator) + { + VirtualMethodIndexTargetSignatureTest test = new(interfaceName, methodName, signatureValidator) { TestCode = source, TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck @@ -186,14 +259,24 @@ private static async Task VerifySourceGeneratorAsync(string source, string inter await test.RunAsync(); } + private static async Task VerifyComInterfaceGeneratorAsync(string source, string interfaceName, string methodName, Action signatureValidator) + { + ComInterfaceTargetSignatureTest test = new(interfaceName, methodName, signatureValidator) + { + TestCode = source, + TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck + }; - class CallingConventionForwardingTest : VerifyCS.Test + await test.RunAsync(); + } + + private abstract class TargetSignatureTestBase : VerifyCS.Test { private readonly Action _signatureValidator; private readonly string _interfaceName; private readonly string _methodName; - public CallingConventionForwardingTest(string interfaceName, string methodName, Action signatureValidator) + protected TargetSignatureTestBase(string interfaceName, string methodName, Action signatureValidator) : base(referenceAncillaryInterop: true) { _signatureValidator = signatureValidator; @@ -205,12 +288,14 @@ protected override void VerifyFinalCompilation(Compilation compilation) { _signatureValidator(compilation, FindFunctionPointerInvocationSignature(compilation)); } + + protected abstract INamedTypeSymbol FindImplementationInterface(Compilation compilation, INamedTypeSymbol userDefinedInterface); private IMethodSymbol FindFunctionPointerInvocationSignature(Compilation compilation) { INamedTypeSymbol? userDefinedInterface = compilation.Assembly.GetTypeByMetadataName(_interfaceName); Assert.NotNull(userDefinedInterface); - INamedTypeSymbol generatedInterfaceImplementation = Assert.Single(userDefinedInterface.GetTypeMembers("Native")); + INamedTypeSymbol generatedInterfaceImplementation = FindImplementationInterface(compilation, userDefinedInterface); IMethodSymbol methodImplementation = Assert.Single(generatedInterfaceImplementation.GetMembers($"global::{_interfaceName}.{_methodName}").OfType()); @@ -223,5 +308,38 @@ private IMethodSymbol FindFunctionPointerInvocationSignature(Compilation compila return Assert.Single(body.Descendants().OfType()).GetFunctionPointerSignature(); } } + + private sealed class VirtualMethodIndexTargetSignatureTest : TargetSignatureTestBase + { + public VirtualMethodIndexTargetSignatureTest(string interfaceName, string methodName, Action signatureValidator) + : base(interfaceName, methodName, signatureValidator) + { + } + + protected override IEnumerable GetSourceGenerators() => new[] { typeof(VtableIndexStubGenerator) }; + + protected override INamedTypeSymbol FindImplementationInterface(Compilation compilation, INamedTypeSymbol userDefinedInterface) => Assert.Single(userDefinedInterface.GetTypeMembers("Native")); + } + + private sealed class ComInterfaceTargetSignatureTest : TargetSignatureTestBase + { + public ComInterfaceTargetSignatureTest(string interfaceName, string methodName, Action signatureValidator) : base(interfaceName, methodName, signatureValidator) + { + } + protected override IEnumerable GetSourceGenerators() => new[] { typeof(Microsoft.Interop.ComInterfaceGenerator) }; + + protected override INamedTypeSymbol FindImplementationInterface(Compilation compilation, INamedTypeSymbol userDefinedInterface) + { + INamedTypeSymbol? iUnknownDerivedAttributeType = compilation.GetTypeByMetadataName("System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute`2"); + + Assert.NotNull(iUnknownDerivedAttributeType); + + AttributeData iUnknownDerivedAttribute = Assert.Single( + userDefinedInterface.GetAttributes(), + attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.OriginalDefinition, iUnknownDerivedAttributeType)); + + return (INamedTypeSymbol)iUnknownDerivedAttribute.AttributeClass!.TypeArguments[1]; + } + } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/Common/Verifiers/CSharpSourceGeneratorVerifier.cs b/src/libraries/System.Runtime.InteropServices/tests/Common/Verifiers/CSharpSourceGeneratorVerifier.cs index 167f00029ed45..1af721d9dfc5c 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/Common/Verifiers/CSharpSourceGeneratorVerifier.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/Common/Verifiers/CSharpSourceGeneratorVerifier.cs @@ -19,7 +19,7 @@ namespace Microsoft.Interop.UnitTests.Verifiers { public static class CSharpSourceGeneratorVerifier - where TSourceGenerator : IIncrementalGenerator, new() + where TSourceGenerator : new() { public static DiagnosticResult Diagnostic(string diagnosticId) => new DiagnosticResult(diagnosticId, DiagnosticSeverity.Error);