From ea2763950fcaf0946794aef6f0024c2e0ff4b95d Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 25 Sep 2020 17:33:45 -0700 Subject: [PATCH 1/6] Add marshaler for SafeHandle-derived types. Extend the MarshallingInfo discriminated union to not only represent attribute info but to also represent any other marshalling info that requires a Compilation object to determine. This avoids having to plumb passing down a compilation object through 4 API layers and makes the marshaller selection switch statement a little bit simpler. --- .../Ancillary.Interop.csproj | 1 + .../GeneratedDllImportAttribute.cs | 3 +- .../Ancillary.Interop/MarshalEx.cs | 29 ++++ .../DllImportGenerator.Test/Compiles.cs | 1 + .../Marshalling/MarshallingGenerator.cs | 6 + .../Marshalling/SafeHandleMarshaller.cs | 163 ++++++++++++++++++ .../MarshallingAttributeInfo.cs | 16 +- .../DllImportGenerator/StubCodeContext.cs | 9 +- .../DllImportGenerator/StubCodeGenerator.cs | 1 + .../DllImportGenerator/TypeNames.cs | 6 +- .../DllImportGenerator/TypePositionInfo.cs | 28 ++- 11 files changed, 249 insertions(+), 14 deletions(-) create mode 100644 DllImportGenerator/Ancillary.Interop/MarshalEx.cs create mode 100644 DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs diff --git a/DllImportGenerator/Ancillary.Interop/Ancillary.Interop.csproj b/DllImportGenerator/Ancillary.Interop/Ancillary.Interop.csproj index 536eec52ae8f..e9063db196be 100644 --- a/DllImportGenerator/Ancillary.Interop/Ancillary.Interop.csproj +++ b/DllImportGenerator/Ancillary.Interop/Ancillary.Interop.csproj @@ -4,6 +4,7 @@ net5.0 8.0 System.Runtime.InteropServices + enable diff --git a/DllImportGenerator/Ancillary.Interop/GeneratedDllImportAttribute.cs b/DllImportGenerator/Ancillary.Interop/GeneratedDllImportAttribute.cs index 0f0d322da4dd..42c15094915a 100644 --- a/DllImportGenerator/Ancillary.Interop/GeneratedDllImportAttribute.cs +++ b/DllImportGenerator/Ancillary.Interop/GeneratedDllImportAttribute.cs @@ -1,5 +1,4 @@ -#nullable enable - + namespace System.Runtime.InteropServices { // [TODO] Remove once the attribute has been added to the BCL diff --git a/DllImportGenerator/Ancillary.Interop/MarshalEx.cs b/DllImportGenerator/Ancillary.Interop/MarshalEx.cs new file mode 100644 index 000000000000..931227dbc8d1 --- /dev/null +++ b/DllImportGenerator/Ancillary.Interop/MarshalEx.cs @@ -0,0 +1,29 @@ + +using System.Reflection; + +namespace System.Runtime.InteropServices +{ + /// + /// Marshalling helper methods that will likely live in S.R.IS.Marshal + /// when we integrate our APIs with dotnet/runtime. + /// + public static class MarshalEx + { + public static TSafeHandle CreateSafeHandle() + where TSafeHandle : SafeHandle + { + if (typeof(TSafeHandle).IsAbstract || typeof(TSafeHandle).GetConstructor(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.CreateInstance, null, Type.EmptyTypes, null) == null) + { + throw new MissingMemberException($"The safe handle type '{typeof(TSafeHandle).FullName}' must be a non-abstract type with a parameterless constructor."); + } + + TSafeHandle safeHandle = (TSafeHandle)Activator.CreateInstance(typeof(TSafeHandle), nonPublic: true)!; + return safeHandle; + } + + public static void SetHandle(SafeHandle safeHandle, IntPtr handle) + { + typeof(SafeHandle).GetMethod("SetHandle", BindingFlags.NonPublic | BindingFlags.Instance)!.Invoke(safeHandle, new object[] { handle }); + } + } +} diff --git a/DllImportGenerator/DllImportGenerator.Test/Compiles.cs b/DllImportGenerator/DllImportGenerator.Test/Compiles.cs index 33b741e8ad25..3c66f36bbe6d 100644 --- a/DllImportGenerator/DllImportGenerator.Test/Compiles.cs +++ b/DllImportGenerator/DllImportGenerator.Test/Compiles.cs @@ -90,6 +90,7 @@ public static IEnumerable CodeSnippetsToCompile() yield return new[] { CodeSnippets.DelegateMarshalAsParametersAndModifiers }; yield return new[] { CodeSnippets.BlittableStructParametersAndModifiers }; yield return new[] { CodeSnippets.GenericBlittableStructParametersAndModifiers }; + yield return new[] { CodeSnippets.BasicParametersAndModifiers("Microsoft.Win32.SafeHandles.SafeFileHandle") }; } [Theory] diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs index 2e50dcc2df0d..82e65085ccbd 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs @@ -64,6 +64,7 @@ internal class MarshallingGenerators public static readonly Forwarder Forwarder = new Forwarder(); public static readonly BlittableMarshaller Blittable = new BlittableMarshaller(); public static readonly DelegateMarshaller Delegate = new DelegateMarshaller(); + public static readonly SafeHandleMarshaller SafeHandle = new SafeHandleMarshaller(); public static bool TryCreate(TypePositionInfo info, StubCodeContext context, out IMarshallingGenerator generator) { @@ -126,6 +127,11 @@ public static bool TryCreate(TypePositionInfo info, StubCodeContext context, out generator = Forwarder; return false; + case { MarshallingAttributeInfo: SafeHandleMarshallingInfo _}: + + generator = Forwarder; + return false; + default: generator = Forwarder; return false; diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs new file mode 100644 index 000000000000..ad6c88829c0f --- /dev/null +++ b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs @@ -0,0 +1,163 @@ +using System; +using System.Collections.Generic; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Interop +{ + class SafeHandleMarshaller : IMarshallingGenerator + { + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return ParseTypeName("global::System.IntPtr"); + } + + public ParameterSyntax AsParameter(TypePositionInfo info) + { + var type = info.IsByRef + ? PointerType(AsNativeType(info)) + : AsNativeType(info); + return Parameter(Identifier(info.InstanceIdentifier)) + .WithType(type); + } + + public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) + { + string identifier = context.GetIdentifiers(info).native; + if (info.IsByRef) + { + return Argument( + PrefixUnaryExpression( + SyntaxKind.AddressOfExpression, + IdentifierName(identifier))); + } + + return Argument(IdentifierName(identifier)); + } + + public IEnumerable Generate(TypePositionInfo info, StubCodeContext context) + { + // [TODO] Handle byrefs in a more common place? + // This pattern will become very common (arrays and strings will also use it) + (string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info); + string addRefdIdentifier = $"{managedIdentifier}__addRefd"; + string newHandleObjectIdentifier = info.IsManagedReturnPosition + ? managedIdentifier + : $"{managedIdentifier}__newHandle"; + string handleValueBackupIdentifier = $"{nativeIdentifier}__original"; + switch (context.CurrentStage) + { + case StubCodeContext.Stage.Setup: + yield return LocalDeclarationStatement( + VariableDeclaration( + AsNativeType(info), + SingletonSeparatedList( + VariableDeclarator(nativeIdentifier)))); + yield return LocalDeclarationStatement( + VariableDeclaration( + PredefinedType(Token(SyntaxKind.BoolKeyword)), + SingletonSeparatedList( + VariableDeclarator(addRefdIdentifier) + .WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.FalseLiteralExpression)))))); + if (info.IsByRef && info.RefKind != RefKind.In) + { + // We create the new handle in the Setup phase + // so we reduce the possible failure points during unmarshalling, where we would + // leak the handle if we failed to create the handle. + yield return LocalDeclarationStatement( + VariableDeclaration( + info.ManagedType.AsTypeSyntax(), + SingletonSeparatedList( + VariableDeclarator(newHandleObjectIdentifier) + .WithInitializer(EqualsValueClause( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx), + GenericName(Identifier("CreateSafeHandle"), + TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))), + ArgumentList())))))); + yield return LocalDeclarationStatement( + VariableDeclaration( + AsNativeType(info), + SingletonSeparatedList( + VariableDeclarator(handleValueBackupIdentifier) + .WithInitializer(EqualsValueClause( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(newHandleObjectIdentifier), + IdentifierName("DangerousGetHandle")), + ArgumentList())))))); + } + break; + case StubCodeContext.Stage.Marshal: + if (info.RefKind != RefKind.Out) + { + yield return ParseStatement($"{managedIdentifier}.DangerousAddRef(ref {addRefdIdentifier}"); + if (info.IsByRef && info.RefKind != RefKind.In) + { + yield return ParseStatement($"{handleValueBackupIdentifier} = {nativeIdentifier} = {managedIdentifier}.DangerousGetHandle();"); + } + } + break; + case StubCodeContext.Stage.LeakSafeUnmarshal: + StatementSyntax unmarshalStatement = + ParseStatement($"{TypeNames.System_Runtime_InteropServices_MarshalEx}.SetHandle({newHandleObjectIdentifier}, {nativeIdentifier});"); + + if(info.IsManagedReturnPosition) + { + yield return unmarshalStatement; + } + else if (info.RefKind == RefKind.Out) + { + yield return unmarshalStatement; + yield return ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(managedIdentifier), + IdentifierName(newHandleObjectIdentifier))); + } + else if (info.RefKind == RefKind.Ref) + { + // Decrement refcount on original SafeHandle if we addrefd + yield return IfStatement( + IdentifierName(addRefdIdentifier), + ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(managedIdentifier), + IdentifierName("DangerousRelease")), + ArgumentList()))); + + // Do not unmarshal the handle if the value didn't change. + yield return IfStatement( + ParseExpression($"{handleValueBackupIdentifier} != {nativeIdentifier}"), + Block( + unmarshalStatement, + ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(managedIdentifier), + IdentifierName(newHandleObjectIdentifier))))); + } + break; + case StubCodeContext.Stage.Cleanup: + if (!info.IsByRef || info.RefKind == RefKind.In) + { + yield return IfStatement( + IdentifierName(addRefdIdentifier), + ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(managedIdentifier), + IdentifierName("DangerousRelease")), + ArgumentList()))); + } + break; + default: + break; + } + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true; + } +} diff --git a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs index a4736aa155a8..0462c8cdd191 100644 --- a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs +++ b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs @@ -12,7 +12,7 @@ namespace Microsoft.Interop // for C# 10 discriminated unions. Once discriminated unions are released, // these should be updated to be implemented as a discriminated union. - internal abstract record MarshallingAttributeInfo {} + internal abstract record MarshallingInfo {} /// /// User-applied System.Runtime.InteropServices.MarshalAsAttribute @@ -23,14 +23,14 @@ internal sealed record MarshalAsInfo( string? CustomMarshallerCookie, UnmanagedType UnmanagedArraySubType, int ArraySizeConst, - short ArraySizeParamIndex) : MarshallingAttributeInfo; + short ArraySizeParamIndex) : MarshallingInfo; /// /// User-applied System.Runtime.InteropServices.BlittableTypeAttribute /// or System.Runtime.InteropServices.GeneratedMarshallingAttribute on a blittable type /// in source in this compilation. /// - internal sealed record BlittableTypeAttributeInfo : MarshallingAttributeInfo; + internal sealed record BlittableTypeAttributeInfo : MarshallingInfo; [Flags] internal enum SupportedMarshallingMethods @@ -47,12 +47,18 @@ internal enum SupportedMarshallingMethods internal sealed record NativeMarshallingAttributeInfo( ITypeSymbol NativeMarshallingType, ITypeSymbol? ValuePropertyType, - SupportedMarshallingMethods MarshallingMethods) : MarshallingAttributeInfo; + SupportedMarshallingMethods MarshallingMethods) : MarshallingInfo; /// /// User-applied System.Runtime.InteropServices.GeneratedMarshallingAttribute /// on a non-blittable type in source in this compilation. /// internal sealed record GeneratedNativeMarshallingAttributeInfo( - string NativeMarshallingFullyQualifiedTypeName) : MarshallingAttributeInfo; + string NativeMarshallingFullyQualifiedTypeName) : MarshallingInfo; + + /// + /// The type of the element is a SafeHandle-derived type with no marshalling attributes. + /// + internal sealed record SafeHandleMarshallingInfo : MarshallingInfo; + } diff --git a/DllImportGenerator/DllImportGenerator/StubCodeContext.cs b/DllImportGenerator/DllImportGenerator/StubCodeContext.cs index 10c254234fcf..f43c7f278ccc 100644 --- a/DllImportGenerator/DllImportGenerator/StubCodeContext.cs +++ b/DllImportGenerator/DllImportGenerator/StubCodeContext.cs @@ -46,7 +46,14 @@ public enum Stage /// /// Keep alive any managed objects that need to stay alive across the call. /// - KeepAlive + KeepAlive, + + /// + /// Convert native data to managed data + /// where native values will leak if we + /// fail to unmarshal, but do run cleanup. + /// + LeakSafeUnmarshal } public Stage CurrentStage { get; protected set; } diff --git a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs index e7e870bf888f..8d9e4701f01b 100644 --- a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs @@ -141,6 +141,7 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt Stage.Invoke, Stage.KeepAlive, Stage.Unmarshal, + Stage.LeakSafeUnmarshal, Stage.Cleanup }; diff --git a/DllImportGenerator/DllImportGenerator/TypeNames.cs b/DllImportGenerator/DllImportGenerator/TypeNames.cs index 27481eb846e8..284aa2a3694e 100644 --- a/DllImportGenerator/DllImportGenerator/TypeNames.cs +++ b/DllImportGenerator/DllImportGenerator/TypeNames.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.Text; -namespace DllImportGenerator +namespace Microsoft.Interop { static class TypeNames { @@ -19,5 +19,9 @@ static class TypeNames public const string System_Runtime_InteropServices_StructLayoutAttribute = "System.Runtime.InteropServices.StructLayoutAttribute"; public const string System_Runtime_InteropServices_MarshalAsAttribute = "System.Runtime.InteropServices.MarshalAsAttribute"; + + public const string System_Runtime_InteropServices_MarshalEx = "System.Runtime.InteropServices.MarshalEx"; + + public const string System_Runtime_InteropServices_SafeHandle = "System.Runtime.InteropServices.SafeHandle"; } } diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index 17badcfb3405..980c7555a741 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -39,11 +39,11 @@ private TypePositionInfo() public int NativeIndex { get; set; } public int UnmanagedLCIDConversionArgIndex { get; private set; } - public MarshallingAttributeInfo MarshallingAttributeInfo { get; private set; } + public MarshallingInfo MarshallingAttributeInfo { get; private set; } public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, Compilation compilation) { - var marshallingInfo = GetMarshallingAttributeInfo(paramSymbol.Type, paramSymbol.GetAttributes(), compilation); + var marshallingInfo = GetMarshallingInfo(paramSymbol.Type, paramSymbol.GetAttributes(), compilation); var typeInfo = new TypePositionInfo() { ManagedType = paramSymbol.Type, @@ -58,7 +58,7 @@ public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable attributes, Compilation compilation) { - var marshallingInfo = GetMarshallingAttributeInfo(type, attributes, compilation); + var marshallingInfo = GetMarshallingInfo(type, attributes, compilation); var typeInfo = new TypePositionInfo() { ManagedType = type, @@ -72,9 +72,9 @@ public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable attributes, Compilation compilation) + private static MarshallingInfo? GetMarshallingInfo(ITypeSymbol type, IEnumerable attributes, Compilation compilation) { - MarshallingAttributeInfo? marshallingInfo = null; + MarshallingInfo? marshallingInfo = null; // Look at attributes on the type. foreach (var attrData in attributes) { @@ -134,6 +134,11 @@ public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable Date: Tue, 29 Sep 2020 16:54:01 -0700 Subject: [PATCH 2/6] Enable safe handle marshaller and fix various bugs. Add runtime tests using xunit asserts that currently we just call directly in the Demo. --- .../Ancillary.Interop/MarshalEx.cs | 2 +- DllImportGenerator/Demo/Demo.csproj | 4 + DllImportGenerator/Demo/Program.cs | 7 ++ DllImportGenerator/Demo/SafeHandleTests.cs | 74 +++++++++++++++++++ DllImportGenerator/Directory.Build.props | 1 + .../DllImportGenerator.Test.csproj | 2 +- .../Marshalling/MarshallingGenerator.cs | 7 +- .../Marshalling/SafeHandleMarshaller.cs | 61 ++++++++++++--- .../DllImportGenerator/StubCodeGenerator.cs | 2 +- .../TestAssets/NativeExports/Handles.cs | 54 ++++++++++++++ 10 files changed, 198 insertions(+), 16 deletions(-) create mode 100644 DllImportGenerator/Demo/SafeHandleTests.cs create mode 100644 DllImportGenerator/TestAssets/NativeExports/Handles.cs diff --git a/DllImportGenerator/Ancillary.Interop/MarshalEx.cs b/DllImportGenerator/Ancillary.Interop/MarshalEx.cs index 931227dbc8d1..95be96d86858 100644 --- a/DllImportGenerator/Ancillary.Interop/MarshalEx.cs +++ b/DllImportGenerator/Ancillary.Interop/MarshalEx.cs @@ -12,7 +12,7 @@ public static class MarshalEx public static TSafeHandle CreateSafeHandle() where TSafeHandle : SafeHandle { - if (typeof(TSafeHandle).IsAbstract || typeof(TSafeHandle).GetConstructor(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.CreateInstance, null, Type.EmptyTypes, null) == null) + if (typeof(TSafeHandle).IsAbstract || typeof(TSafeHandle).GetConstructor(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.CreateInstance | BindingFlags.Instance, null, Type.EmptyTypes, null) == null) { throw new MissingMemberException($"The safe handle type '{typeof(TSafeHandle).FullName}' must be a non-abstract type with a parameterless constructor."); } diff --git a/DllImportGenerator/Demo/Demo.csproj b/DllImportGenerator/Demo/Demo.csproj index 4cee92efcd88..fdba3fbbb354 100644 --- a/DllImportGenerator/Demo/Demo.csproj +++ b/DllImportGenerator/Demo/Demo.csproj @@ -16,4 +16,8 @@ + + + + diff --git a/DllImportGenerator/Demo/Program.cs b/DllImportGenerator/Demo/Program.cs index ff7d931ccd3c..06715a76c7c0 100644 --- a/DllImportGenerator/Demo/Program.cs +++ b/DllImportGenerator/Demo/Program.cs @@ -31,6 +31,13 @@ static void Main(string[] args) c = b; NativeExportsNE.Sum(a, ref c); Console.WriteLine($"{a} + {b} = {c}"); + + SafeHandleTests tests = new SafeHandleTests(); + + tests.ReturnValue_CreatesSafeHandle(); + tests.ByValue_CorrectlyUnwrapsHandle(); + tests.ByRefSameValue_UsesSameHandleInstance(); + tests.ByRefDifferentValue_UsesNewHandleInstance(); } } } diff --git a/DllImportGenerator/Demo/SafeHandleTests.cs b/DllImportGenerator/Demo/SafeHandleTests.cs new file mode 100644 index 000000000000..bb21cfb565ca --- /dev/null +++ b/DllImportGenerator/Demo/SafeHandleTests.cs @@ -0,0 +1,74 @@ + +using System.Runtime.InteropServices; +using Microsoft.Win32.SafeHandles; +using Xunit; + +namespace Demo +{ + partial class NativeExportsNE + { + public class NativeExportsSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + private NativeExportsSafeHandle() : base(true) + { + } + + protected override bool ReleaseHandle() + { + Assert.True(NativeExportsNE.ReleaseHandle(handle)); + return true; + } + } + + [GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "alloc_handle")] + public static partial NativeExportsSafeHandle AllocateHandle(); + + [GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "release_handle")] + [return:MarshalAs(UnmanagedType.I1)] + private static partial bool ReleaseHandle(nint handle); + + [GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "is_handle_alive")] + [return:MarshalAs(UnmanagedType.I1)] + public static partial bool IsHandleAlive(NativeExportsSafeHandle handle); + + [GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "modify_handle")] + public static partial void ModifyHandle(ref NativeExportsSafeHandle handle, [MarshalAs(UnmanagedType.I1)] bool newHandle); + } + + public class SafeHandleTests + { + [Fact] + public void ReturnValue_CreatesSafeHandle() + { + using NativeExportsNE.NativeExportsSafeHandle handle = NativeExportsNE.AllocateHandle(); + Assert.False(handle.IsClosed); + Assert.False(handle.IsInvalid); + } + + [Fact] + public void ByValue_CorrectlyUnwrapsHandle() + { + using NativeExportsNE.NativeExportsSafeHandle handle = NativeExportsNE.AllocateHandle(); + Assert.True(NativeExportsNE.IsHandleAlive(handle)); + } + + [Fact] + public void ByRefSameValue_UsesSameHandleInstance() + { + using NativeExportsNE.NativeExportsSafeHandle handleToDispose = NativeExportsNE.AllocateHandle(); + NativeExportsNE.NativeExportsSafeHandle handle = handleToDispose; + NativeExportsNE.ModifyHandle(ref handle, false); + Assert.Same(handleToDispose, handle); + } + + [Fact] + public void ByRefDifferentValue_UsesNewHandleInstance() + { + using NativeExportsNE.NativeExportsSafeHandle handleToDispose = NativeExportsNE.AllocateHandle(); + NativeExportsNE.NativeExportsSafeHandle handle = handleToDispose; + NativeExportsNE.ModifyHandle(ref handle, true); + Assert.NotSame(handleToDispose, handle); + handle.Dispose(); + } + } +} \ No newline at end of file diff --git a/DllImportGenerator/Directory.Build.props b/DllImportGenerator/Directory.Build.props index 81194ffe87da..61ba7c1afc32 100644 --- a/DllImportGenerator/Directory.Build.props +++ b/DllImportGenerator/Directory.Build.props @@ -2,6 +2,7 @@ 3.8.0-3.final + 2.4.1 diff --git a/DllImportGenerator/DllImportGenerator.Test/DllImportGenerator.Test.csproj b/DllImportGenerator/DllImportGenerator.Test/DllImportGenerator.Test.csproj index 75fb2cb381c8..cd384172bce4 100644 --- a/DllImportGenerator/DllImportGenerator.Test/DllImportGenerator.Test.csproj +++ b/DllImportGenerator/DllImportGenerator.Test/DllImportGenerator.Test.csproj @@ -15,7 +15,7 @@ runtime; build; native; contentfiles; analyzers; buildtransitive - + runtime; build; native; contentfiles; analyzers; buildtransitive all diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs index 82e65085ccbd..60ef933582a8 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs @@ -127,10 +127,9 @@ public static bool TryCreate(TypePositionInfo info, StubCodeContext context, out generator = Forwarder; return false; - case { MarshallingAttributeInfo: SafeHandleMarshallingInfo _}: - - generator = Forwarder; - return false; + case { MarshallingAttributeInfo: SafeHandleMarshallingInfo _}: + generator = SafeHandle; + return true; default: generator = Forwarder; diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs index ad6c88829c0f..d6af8e508537 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs @@ -55,12 +55,16 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont AsNativeType(info), SingletonSeparatedList( VariableDeclarator(nativeIdentifier)))); - yield return LocalDeclarationStatement( - VariableDeclaration( - PredefinedType(Token(SyntaxKind.BoolKeyword)), - SingletonSeparatedList( - VariableDeclarator(addRefdIdentifier) - .WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.FalseLiteralExpression)))))); + if (!info.IsManagedReturnPosition && info.RefKind != RefKind.Out) + { + yield return LocalDeclarationStatement( + VariableDeclaration( + PredefinedType(Token(SyntaxKind.BoolKeyword)), + SingletonSeparatedList( + VariableDeclarator(addRefdIdentifier) + .WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.FalseLiteralExpression)))))); + + } if (info.IsByRef && info.RefKind != RefKind.In) { // We create the new handle in the Setup phase @@ -90,14 +94,51 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont IdentifierName("DangerousGetHandle")), ArgumentList())))))); } + else if (info.IsManagedReturnPosition) + { + yield return ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(managedIdentifier), + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx), + GenericName(Identifier("CreateSafeHandle"), + TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))), + ArgumentList()))); + } break; case StubCodeContext.Stage.Marshal: if (info.RefKind != RefKind.Out) { - yield return ParseStatement($"{managedIdentifier}.DangerousAddRef(ref {addRefdIdentifier}"); + // .DangerousAddRef(ref ); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(managedIdentifier), + IdentifierName("DangerousAddRef")), + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(addRefdIdentifier)) + .WithRefKindKeyword(Token(SyntaxKind.RefKeyword)))))); + + + ExpressionSyntax assignHandleToNativeExpression = + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(nativeIdentifier), + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(managedIdentifier), + IdentifierName("DangerousGetHandle")), + ArgumentList())); if (info.IsByRef && info.RefKind != RefKind.In) { - yield return ParseStatement($"{handleValueBackupIdentifier} = {nativeIdentifier} = {managedIdentifier}.DangerousGetHandle();"); + yield return ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(handleValueBackupIdentifier), + assignHandleToNativeExpression)); + } + else + { + yield return ExpressionStatement(assignHandleToNativeExpression); } } break; @@ -131,7 +172,9 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont // Do not unmarshal the handle if the value didn't change. yield return IfStatement( - ParseExpression($"{handleValueBackupIdentifier} != {nativeIdentifier}"), + BinaryExpression(SyntaxKind.NotEqualsExpression, + IdentifierName(handleValueBackupIdentifier), + IdentifierName(nativeIdentifier)), Block( unmarshalStatement, ExpressionStatement( diff --git a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs index 8d9e4701f01b..d6e87300a6b8 100644 --- a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs @@ -152,7 +152,7 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt int initialCount = statements.Count; context.CurrentStage = stage; - if (!invokeReturnsVoid && (stage == Stage.Setup || stage == Stage.Unmarshal)) + if (!invokeReturnsVoid && (stage == Stage.Setup || stage == Stage.Unmarshal || stage == Stage.LeakSafeUnmarshal)) { // Handle setup and unmarshalling for return var retStatements = retMarshaller.Generator.Generate(retMarshaller.TypeInfo, context); diff --git a/DllImportGenerator/TestAssets/NativeExports/Handles.cs b/DllImportGenerator/TestAssets/NativeExports/Handles.cs new file mode 100644 index 000000000000..fae57f242fe5 --- /dev/null +++ b/DllImportGenerator/TestAssets/NativeExports/Handles.cs @@ -0,0 +1,54 @@ +using System.Collections.Generic; +using System.Runtime.InteropServices; + + +namespace NativeExports +{ + public static unsafe class Handles + { + private const nint InvalidHandle = -1; + + private static nint LastHandle = 0; + + private static HashSet ActiveHandles = new HashSet(); + + [UnmanagedCallersOnly(EntryPoint = "alloc_handle")] + public static nint AllocateHandle() + { + return AllocateHandleCore(); + } + + private static nint AllocateHandleCore() + { + if (LastHandle == int.MaxValue) + { + return InvalidHandle; + } + + nint newHandle = ++LastHandle; + ActiveHandles.Add(newHandle); + return newHandle; + } + + [UnmanagedCallersOnly(EntryPoint = "release_handle")] + public static byte ReleaseHandle(nint handle) + { + return ActiveHandles.Remove(handle) ? 1 : 0; + } + + [UnmanagedCallersOnly(EntryPoint = "is_handle_alive")] + public static byte IsHandleAlive(nint handle) + { + return ActiveHandles.Contains(handle) ? 1 : 0; + } + + [UnmanagedCallersOnly(EntryPoint = "modify_handle")] + public static void ModifyHandle(nint* handle, byte newHandle) + { + if (newHandle != 0) + { + *handle = AllocateHandleCore(); + } + } + } +} \ No newline at end of file From 29b12f344d31d0fa258e7823126386ca35e86c8c Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Tue, 29 Sep 2020 17:00:03 -0700 Subject: [PATCH 3/6] Copy over the "high level logic" comment block from the runtime implementation to explain the overall logic of SafeHandle marshaller. Signed-off-by: Jeremy Koritzinsky --- .../Marshalling/SafeHandleMarshaller.cs | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs index d6af8e508537..4b6e9f6be1ac 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs @@ -39,8 +39,22 @@ public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) public IEnumerable Generate(TypePositionInfo info, StubCodeContext context) { - // [TODO] Handle byrefs in a more common place? - // This pattern will become very common (arrays and strings will also use it) + // The high level logic (note that the parameter may be in, out or both): + // 1) If this is an input parameter we need to AddRef the SafeHandle. + // 2) If this is an output parameter we need to preallocate a SafeHandle to wrap the new native handle value. We + // must allocate this before the native call to avoid a failure point when we already have a native resource + // allocated. We must allocate a new SafeHandle even if we have one on input since both input and output native + // handles need to be tracked and released by a SafeHandle. + // 3) Initialize a local IntPtr that will be passed to the native call. If we have an input SafeHandle the value + // comes from there otherwise we get it from the new SafeHandle (which is guaranteed to be initialized to an + // invalid handle value). + // 4) If this is a out parameter we also store the original handle value (that we just computed above) in a local + // variable. + // 5) If we successfully AddRef'd the incoming SafeHandle, we need to Release it before we return. + // 6) After the native call, if this is an output parameter and the handle value we passed to native differs from + // the local copy we made then the new handle value is written into the output SafeHandle and that SafeHandle + // is propagated back to the caller. + (string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info); string addRefdIdentifier = $"{managedIdentifier}__addRefd"; string newHandleObjectIdentifier = info.IsManagedReturnPosition From 17e94ef491388938dbf1ece43abe8ab429d3e13f Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Wed, 30 Sep 2020 10:57:47 -0700 Subject: [PATCH 4/6] Update DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs Co-authored-by: Jan Kotas --- .../DllImportGenerator/Marshalling/SafeHandleMarshaller.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs index 4b6e9f6be1ac..75a605441671 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs @@ -82,7 +82,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont if (info.IsByRef && info.RefKind != RefKind.In) { // We create the new handle in the Setup phase - // so we reduce the possible failure points during unmarshalling, where we would + // so we eliminate the possible failure points during unmarshalling, where we would // leak the handle if we failed to create the handle. yield return LocalDeclarationStatement( VariableDeclaration( From ea7b3560387c78ebfe2b1a1a1e4b95ea5d3e2595 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Thu, 1 Oct 2020 14:29:09 -0700 Subject: [PATCH 5/6] PR feedback. --- .../Marshalling/SafeHandleMarshaller.cs | 26 +++++++++++++------ .../DllImportGenerator/StubCodeContext.cs | 7 +++-- .../DllImportGenerator/StubCodeGenerator.cs | 4 +-- DllImportGenerator/designs/Pipeline.md | 4 +++ 4 files changed, 27 insertions(+), 14 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs index 4b6e9f6be1ac..ed8aef66186f 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Runtime.InteropServices; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -105,7 +106,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(newHandleObjectIdentifier), - IdentifierName("DangerousGetHandle")), + IdentifierName(nameof(SafeHandle.DangerousGetHandle))), ArgumentList())))))); } else if (info.IsManagedReturnPosition) @@ -129,7 +130,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(managedIdentifier), - IdentifierName("DangerousAddRef")), + IdentifierName(nameof(SafeHandle.DangerousAddRef))), ArgumentList(SingletonSeparatedList( Argument(IdentifierName(addRefdIdentifier)) .WithRefKindKeyword(Token(SyntaxKind.RefKeyword)))))); @@ -141,7 +142,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(managedIdentifier), - IdentifierName("DangerousGetHandle")), + IdentifierName(nameof(SafeHandle.DangerousGetHandle))), ArgumentList())); if (info.IsByRef && info.RefKind != RefKind.In) { @@ -156,9 +157,18 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont } } break; - case StubCodeContext.Stage.LeakSafeUnmarshal: - StatementSyntax unmarshalStatement = - ParseStatement($"{TypeNames.System_Runtime_InteropServices_MarshalEx}.SetHandle({newHandleObjectIdentifier}, {nativeIdentifier});"); + case StubCodeContext.Stage.GuaranteedUnmarshal: + StatementSyntax unmarshalStatement = ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Runtime_InteropServices_MarshalEx), + IdentifierName("SetHandle")), + ArgumentList(SeparatedList( + new [] + { + Argument(IdentifierName(newHandleObjectIdentifier)), + Argument(IdentifierName(nativeIdentifier)) + })))); if(info.IsManagedReturnPosition) { @@ -181,7 +191,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(managedIdentifier), - IdentifierName("DangerousRelease")), + IdentifierName(nameof(SafeHandle.DangerousRelease))), ArgumentList()))); // Do not unmarshal the handle if the value didn't change. @@ -206,7 +216,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(managedIdentifier), - IdentifierName("DangerousRelease")), + IdentifierName(nameof(SafeHandle.DangerousRelease))), ArgumentList()))); } break; diff --git a/DllImportGenerator/DllImportGenerator/StubCodeContext.cs b/DllImportGenerator/DllImportGenerator/StubCodeContext.cs index f43c7f278ccc..0b5f9fe0a862 100644 --- a/DllImportGenerator/DllImportGenerator/StubCodeContext.cs +++ b/DllImportGenerator/DllImportGenerator/StubCodeContext.cs @@ -49,11 +49,10 @@ public enum Stage KeepAlive, /// - /// Convert native data to managed data - /// where native values will leak if we - /// fail to unmarshal, but do run cleanup. + /// Convert native data to managed data even in the case of an exception during + /// the non-cleanup phases. /// - LeakSafeUnmarshal + GuaranteedUnmarshal } public Stage CurrentStage { get; protected set; } diff --git a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs index d6e87300a6b8..043819bc068f 100644 --- a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs @@ -141,7 +141,7 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt Stage.Invoke, Stage.KeepAlive, Stage.Unmarshal, - Stage.LeakSafeUnmarshal, + Stage.GuaranteedUnmarshal, Stage.Cleanup }; @@ -152,7 +152,7 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt int initialCount = statements.Count; context.CurrentStage = stage; - if (!invokeReturnsVoid && (stage == Stage.Setup || stage == Stage.Unmarshal || stage == Stage.LeakSafeUnmarshal)) + if (!invokeReturnsVoid && (stage == Stage.Setup || stage == Stage.Unmarshal || stage == Stage.GuaranteedUnmarshal)) { // Handle setup and unmarshalling for return var retStatements = retMarshaller.Generator.Generate(retMarshaller.TypeInfo, context); diff --git a/DllImportGenerator/designs/Pipeline.md b/DllImportGenerator/designs/Pipeline.md index d42d3a08e4d5..d34be0159c25 100644 --- a/DllImportGenerator/designs/Pipeline.md +++ b/DllImportGenerator/designs/Pipeline.md @@ -38,9 +38,13 @@ Generation of the stub code happens in stages. The marshalling generator for eac 1. `Invoke`: call to the generated P/Invoke - Call `AsArgument` on the marshalling generator for every parameter - Create invocation statement that calls the generated P/Invoke +1. `KeepAlive`: keep alive any objects who's native representation won't keep them alive across the call. + - Call `Generate` on the marshalling generator for every parameter. 1. `Unmarshal`: conversion of native to managed data - If the method has a non-void return, call `Generate` on the marshalling generator for the return - Call `Generate` on the marshalling generator for every parameter +1. `GuaranteedUnmarshal`: conversion of native to managed data even when an exception is thrown + - Call `Generate` on the marshalling generator for every parameter. 1. `Cleanup`: free any allocated resources - Call `Generate` on the marshalling generator for every parameter From b0c5fefa70b54a4befbceaacc2193bf74201c53f Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Mon, 5 Oct 2020 14:46:39 -0700 Subject: [PATCH 6/6] Fix nullability --- DllImportGenerator/DllImportGenerator/TypePositionInfo.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index 980c7555a741..e816f8457d79 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -240,7 +240,7 @@ NativeMarshallingAttributeInfo CreateNativeMarshallingInfo(AttributeData attrDat methods); } - static MarshallingInfo CreateTypeBasedMarshallingInfo(ITypeSymbol type, Compilation compilation) + static MarshallingInfo? CreateTypeBasedMarshallingInfo(ITypeSymbol type, Compilation compilation) { var conversion = compilation.ClassifyCommonConversion(type, compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle)!); if (conversion.Exists && @@ -250,7 +250,7 @@ static MarshallingInfo CreateTypeBasedMarshallingInfo(ITypeSymbol type, Compilat { return new SafeHandleMarshallingInfo(); } - return null!; + return null; } } #nullable restore