Skip to content

Commit

Permalink
SafeHandle marshalling (dotnet/runtimelab#133)
Browse files Browse the repository at this point in the history
Co-authored-by: Jan Kotas <jkotas@microsoft.com>

Commit migrated from dotnet/runtimelab@2cde5aa
  • Loading branch information
jkoritzinsky authored Oct 5, 2020
1 parent ff96442 commit ea4e979
Show file tree
Hide file tree
Showing 18 changed files with 460 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

<PropertyGroup>
<CompilerPlatformVersion>3.8.0-3.final</CompilerPlatformVersion>
<XunitVersion>2.4.1</XunitVersion>
</PropertyGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public static IEnumerable<object[]> CodeSnippetsToCompile()
yield return new[] { CodeSnippets.DelegateMarshalAsParametersAndModifiers };
yield return new[] { CodeSnippets.BlittableStructParametersAndModifiers };
yield return new[] { CodeSnippets.GenericBlittableStructParametersAndModifiers };
yield return new[] { CodeSnippets.BasicParametersAndModifiers("Microsoft.Win32.SafeHandles.SafeFileHandle") };
}

[Theory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.6.1" />
<PackageReference Include="xunit" Version="2.4.1" />
<PackageReference Include="xunit" Version="$(XunitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.2">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ internal class MarshallingGenerators
public static readonly Forwarder Forwarder = new Forwarder();
public static readonly BlittableMarshaller Blittable = new BlittableMarshaller();
public static readonly DelegateMarshaller Delegate = new DelegateMarshaller();
public static readonly SafeHandleMarshaller SafeHandle = new SafeHandleMarshaller();

public static bool TryCreate(TypePositionInfo info, StubCodeContext context, out IMarshallingGenerator generator)
{
Expand Down Expand Up @@ -126,6 +127,10 @@ public static bool TryCreate(TypePositionInfo info, StubCodeContext context, out
generator = Forwarder;
return false;

case { MarshallingAttributeInfo: SafeHandleMarshallingInfo _}:
generator = SafeHandle;
return true;

default:
generator = Forwarder;
return false;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace Microsoft.Interop
{
class SafeHandleMarshaller : IMarshallingGenerator
{
public TypeSyntax AsNativeType(TypePositionInfo info)
{
return ParseTypeName("global::System.IntPtr");
}

public ParameterSyntax AsParameter(TypePositionInfo info)
{
var type = info.IsByRef
? PointerType(AsNativeType(info))
: AsNativeType(info);
return Parameter(Identifier(info.InstanceIdentifier))
.WithType(type);
}

public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context)
{
string identifier = context.GetIdentifiers(info).native;
if (info.IsByRef)
{
return Argument(
PrefixUnaryExpression(
SyntaxKind.AddressOfExpression,
IdentifierName(identifier)));
}

return Argument(IdentifierName(identifier));
}

public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeContext context)
{
// The high level logic (note that the parameter may be in, out or both):
// 1) If this is an input parameter we need to AddRef the SafeHandle.
// 2) If this is an output parameter we need to preallocate a SafeHandle to wrap the new native handle value. We
// must allocate this before the native call to avoid a failure point when we already have a native resource
// allocated. We must allocate a new SafeHandle even if we have one on input since both input and output native
// handles need to be tracked and released by a SafeHandle.
// 3) Initialize a local IntPtr that will be passed to the native call. If we have an input SafeHandle the value
// comes from there otherwise we get it from the new SafeHandle (which is guaranteed to be initialized to an
// invalid handle value).
// 4) If this is a out parameter we also store the original handle value (that we just computed above) in a local
// variable.
// 5) If we successfully AddRef'd the incoming SafeHandle, we need to Release it before we return.
// 6) After the native call, if this is an output parameter and the handle value we passed to native differs from
// the local copy we made then the new handle value is written into the output SafeHandle and that SafeHandle
// is propagated back to the caller.

(string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info);
string addRefdIdentifier = $"{managedIdentifier}__addRefd";
string newHandleObjectIdentifier = info.IsManagedReturnPosition
? managedIdentifier
: $"{managedIdentifier}__newHandle";
string handleValueBackupIdentifier = $"{nativeIdentifier}__original";
switch (context.CurrentStage)
{
case StubCodeContext.Stage.Setup:
yield return LocalDeclarationStatement(
VariableDeclaration(
AsNativeType(info),
SingletonSeparatedList(
VariableDeclarator(nativeIdentifier))));
if (!info.IsManagedReturnPosition && info.RefKind != RefKind.Out)
{
yield return LocalDeclarationStatement(
VariableDeclaration(
PredefinedType(Token(SyntaxKind.BoolKeyword)),
SingletonSeparatedList(
VariableDeclarator(addRefdIdentifier)
.WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.FalseLiteralExpression))))));

}
if (info.IsByRef && info.RefKind != RefKind.In)
{
// We create the new handle in the Setup phase
// so we eliminate the possible failure points during unmarshalling, where we would
// leak the handle if we failed to create the handle.
yield return LocalDeclarationStatement(
VariableDeclaration(
info.ManagedType.AsTypeSyntax(),
SingletonSeparatedList(
VariableDeclarator(newHandleObjectIdentifier)
.WithInitializer(EqualsValueClause(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx),
GenericName(Identifier("CreateSafeHandle"),
TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))),
ArgumentList()))))));
yield return LocalDeclarationStatement(
VariableDeclaration(
AsNativeType(info),
SingletonSeparatedList(
VariableDeclarator(handleValueBackupIdentifier)
.WithInitializer(EqualsValueClause(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(newHandleObjectIdentifier),
IdentifierName(nameof(SafeHandle.DangerousGetHandle))),
ArgumentList()))))));
}
else if (info.IsManagedReturnPosition)
{
yield return ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(managedIdentifier),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx),
GenericName(Identifier("CreateSafeHandle"),
TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))),
ArgumentList())));
}
break;
case StubCodeContext.Stage.Marshal:
if (info.RefKind != RefKind.Out)
{
// <managedIdentifier>.DangerousAddRef(ref <addRefdIdentifier>);
yield return ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(managedIdentifier),
IdentifierName(nameof(SafeHandle.DangerousAddRef))),
ArgumentList(SingletonSeparatedList(
Argument(IdentifierName(addRefdIdentifier))
.WithRefKindKeyword(Token(SyntaxKind.RefKeyword))))));


ExpressionSyntax assignHandleToNativeExpression =
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(nativeIdentifier),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(managedIdentifier),
IdentifierName(nameof(SafeHandle.DangerousGetHandle))),
ArgumentList()));
if (info.IsByRef && info.RefKind != RefKind.In)
{
yield return ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(handleValueBackupIdentifier),
assignHandleToNativeExpression));
}
else
{
yield return ExpressionStatement(assignHandleToNativeExpression);
}
}
break;
case StubCodeContext.Stage.GuaranteedUnmarshal:
StatementSyntax unmarshalStatement = ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.System_Runtime_InteropServices_MarshalEx),
IdentifierName("SetHandle")),
ArgumentList(SeparatedList(
new []
{
Argument(IdentifierName(newHandleObjectIdentifier)),
Argument(IdentifierName(nativeIdentifier))
}))));

if(info.IsManagedReturnPosition)
{
yield return unmarshalStatement;
}
else if (info.RefKind == RefKind.Out)
{
yield return unmarshalStatement;
yield return ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(managedIdentifier),
IdentifierName(newHandleObjectIdentifier)));
}
else if (info.RefKind == RefKind.Ref)
{
// Decrement refcount on original SafeHandle if we addrefd
yield return IfStatement(
IdentifierName(addRefdIdentifier),
ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(managedIdentifier),
IdentifierName(nameof(SafeHandle.DangerousRelease))),
ArgumentList())));

// Do not unmarshal the handle if the value didn't change.
yield return IfStatement(
BinaryExpression(SyntaxKind.NotEqualsExpression,
IdentifierName(handleValueBackupIdentifier),
IdentifierName(nativeIdentifier)),
Block(
unmarshalStatement,
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(managedIdentifier),
IdentifierName(newHandleObjectIdentifier)))));
}
break;
case StubCodeContext.Stage.Cleanup:
if (!info.IsByRef || info.RefKind == RefKind.In)
{
yield return IfStatement(
IdentifierName(addRefdIdentifier),
ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(managedIdentifier),
IdentifierName(nameof(SafeHandle.DangerousRelease))),
ArgumentList())));
}
break;
default:
break;
}
}

public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace Microsoft.Interop
// for C# 10 discriminated unions. Once discriminated unions are released,
// these should be updated to be implemented as a discriminated union.

internal abstract record MarshallingAttributeInfo {}
internal abstract record MarshallingInfo {}

/// <summary>
/// User-applied System.Runtime.InteropServices.MarshalAsAttribute
Expand All @@ -23,14 +23,14 @@ internal sealed record MarshalAsInfo(
string? CustomMarshallerCookie,
UnmanagedType UnmanagedArraySubType,
int ArraySizeConst,
short ArraySizeParamIndex) : MarshallingAttributeInfo;
short ArraySizeParamIndex) : MarshallingInfo;

/// <summary>
/// User-applied System.Runtime.InteropServices.BlittableTypeAttribute
/// or System.Runtime.InteropServices.GeneratedMarshallingAttribute on a blittable type
/// in source in this compilation.
/// </summary>
internal sealed record BlittableTypeAttributeInfo : MarshallingAttributeInfo;
internal sealed record BlittableTypeAttributeInfo : MarshallingInfo;

[Flags]
internal enum SupportedMarshallingMethods
Expand All @@ -47,12 +47,18 @@ internal enum SupportedMarshallingMethods
internal sealed record NativeMarshallingAttributeInfo(
ITypeSymbol NativeMarshallingType,
ITypeSymbol? ValuePropertyType,
SupportedMarshallingMethods MarshallingMethods) : MarshallingAttributeInfo;
SupportedMarshallingMethods MarshallingMethods) : MarshallingInfo;

/// <summary>
/// User-applied System.Runtime.InteropServices.GeneratedMarshallingAttribute
/// on a non-blittable type in source in this compilation.
/// </summary>
internal sealed record GeneratedNativeMarshallingAttributeInfo(
string NativeMarshallingFullyQualifiedTypeName) : MarshallingAttributeInfo;
string NativeMarshallingFullyQualifiedTypeName) : MarshallingInfo;

/// <summary>
/// The type of the element is a SafeHandle-derived type with no marshalling attributes.
/// </summary>
internal sealed record SafeHandleMarshallingInfo : MarshallingInfo;

}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ public enum Stage
/// <summary>
/// Keep alive any managed objects that need to stay alive across the call.
/// </summary>
KeepAlive
KeepAlive,

/// <summary>
/// Convert native data to managed data even in the case of an exception during
/// the non-cleanup phases.
/// </summary>
GuaranteedUnmarshal
}

public Stage CurrentStage { get; protected set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt
Stage.Invoke,
Stage.KeepAlive,
Stage.Unmarshal,
Stage.GuaranteedUnmarshal,
Stage.Cleanup
};

Expand All @@ -151,7 +152,7 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt
int initialCount = statements.Count;
context.CurrentStage = stage;

if (!invokeReturnsVoid && (stage == Stage.Setup || stage == Stage.Unmarshal))
if (!invokeReturnsVoid && (stage == Stage.Setup || stage == Stage.Unmarshal || stage == Stage.GuaranteedUnmarshal))
{
// Handle setup and unmarshalling for return
var retStatements = retMarshaller.Generator.Generate(retMarshaller.TypeInfo, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using System.Collections.Generic;
using System.Text;

namespace DllImportGenerator
namespace Microsoft.Interop
{
static class TypeNames
{
Expand All @@ -19,5 +19,9 @@ static class TypeNames
public const string System_Runtime_InteropServices_StructLayoutAttribute = "System.Runtime.InteropServices.StructLayoutAttribute";

public const string System_Runtime_InteropServices_MarshalAsAttribute = "System.Runtime.InteropServices.MarshalAsAttribute";

public const string System_Runtime_InteropServices_MarshalEx = "System.Runtime.InteropServices.MarshalEx";

public const string System_Runtime_InteropServices_SafeHandle = "System.Runtime.InteropServices.SafeHandle";
}
}
Loading

0 comments on commit ea4e979

Please sign in to comment.