Skip to content

Commit

Permalink
Make SafeHandle parameters accept null
Browse files Browse the repository at this point in the history
This also:
* adds support for SafeHandles that are 32-bits long even in 64-bit processes (e.g. `MSIHANDLE`).
* removes `SafeHandle` from all `extern` methods. They only appear on helper methods now.
* removes the `NullSafeHandle` static class.

Closes #129
  • Loading branch information
AArnott committed Feb 23, 2021
1 parent c1e002e commit 23e317c
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 65 deletions.
213 changes: 166 additions & 47 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,6 @@ public class Generator : IDisposable
private readonly CSharpCompilation? compilation;
private readonly CSharpParseOptions? parseOptions;

private bool nullSafeHandleGenerated;

/// <summary>
/// Initializes a new instance of the <see cref="Generator"/> class.
/// </summary>
Expand Down Expand Up @@ -872,7 +870,6 @@ internal void GenerateConstant(FieldDefinitionHandle fieldDefHandle)

if (BclInteropSafeHandles.TryGetValue(releaseMethod, out TypeSyntax? bclType))
{
this.GenerateNullSafeHandleHelper();
return bclType;
}

Expand All @@ -897,13 +894,6 @@ internal void GenerateConstant(FieldDefinitionHandle fieldDefHandle)
safeHandleType = null;
}

// Do NOT generate a SafeHandle type for typedefs that don't use IntPtr as their field type.
// Otherwise when used, .NET will use a pointer-sized value where another size was appropriate.
if (!this.handleTypeStructsWithIntPtrSizeFields.Contains(releaseMethodParameterType.ToString()))
{
safeHandleType = null;
}

this.releaseMethodsWithSafeHandleTypesGenerating.Add(releaseMethod, safeHandleType);

if (safeHandleType is null)
Expand All @@ -916,7 +906,6 @@ internal void GenerateConstant(FieldDefinitionHandle fieldDefHandle)
return safeHandleType;
}

this.GenerateNullSafeHandleHelper();
this.GenerateExternMethod(releaseMethodHandle);

TypeSyntax releaseMethodReturnType = this.GetReturnTypeCustomAttributes(releaseMethodDef) is { } atts
Expand Down Expand Up @@ -1475,39 +1464,6 @@ private static string FetchTemplate(string name)
return sr.ReadToEnd();
}

private void GenerateNullSafeHandleHelper()
{
if (this.nullSafeHandleGenerated)
{
return;
}

this.nullSafeHandleGenerated = true;

const string className = "NullSafeHandle";
string fullName = $"{this.Namespace}.{className}";
if (this.FindSymbolIfAlreadyAvailable(fullName) is object)
{
return;
}

// static readonly SafeHandle NullHandle = new SafeFileHandle(IntPtr.Zero, ownsHandle: false);
FieldDeclarationSyntax nullHandle = FieldDeclaration(
VariableDeclaration(SafeHandleTypeSyntax)
.AddVariables(VariableDeclarator("NullHandle").WithInitializer(EqualsValueClause(
ObjectCreationExpression(ParseTypeName("Microsoft.Win32.SafeHandles.SafeFileHandle")).AddArgumentListArguments(
Argument(DefaultExpression(IntPtrTypeSyntax)),
Argument(LiteralExpression(SyntaxKind.FalseLiteralExpression)))))))
.AddModifiers(Token(this.Visibility), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword));

// static class NullSafeHandle { ... }
ClassDeclarationSyntax helper = ClassDeclaration(className)
.AddModifiers(Token(this.Visibility), Token(SyntaxKind.StaticKeyword))
.AddMembers(nullHandle);

this.safeHandleTypes.Add(helper);
}

private FunctionPointerTypeSyntax FunctionPointer(CallingConvention callingConvention, MethodSignature<TypeSyntax> signature, string delegateName)
=> FunctionPointerType(
FunctionPointerCallingConvention(Token(SyntaxKind.UnmanagedKeyword), FunctionPointerUnmanagedCallingConventionList(SingletonSeparatedList(ToUnmanagedCallingConventionSyntax(callingConvention)))),
Expand Down Expand Up @@ -2182,6 +2138,28 @@ private StructDeclarationSyntax CreateTypeDefStruct(TypeDefinition typeDef)
.AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)) // operators MUST be public
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken)));

if (isHandle && fieldInfo.FieldType is not IdentifierNameSyntax { Identifier: { ValueText: nameof(IntPtr) } })
{
// Handle types must interop with IntPtr for SafeHandle support, so if IntPtr isn't the field type,
// we need to create new conversion operators.

// public static implicit operator IntPtr(MSIHANDLE value) => new IntPtr(value.Value);
members = members.Add(ConversionOperatorDeclaration(Token(SyntaxKind.ImplicitKeyword), IntPtrTypeSyntax)
.AddParameterListParameters(Parameter(valueParameter.Identifier).WithType(IdentifierName(name)))
.WithExpressionBody(ArrowExpressionClause(
ObjectCreationExpression(IntPtrTypeSyntax).AddArgumentListArguments(Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, valueParameter, fieldIdentifierName)))))
.AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)) // operators MUST be public
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken)));

// public static explicit operator MSIHANDLE(IntPtr value) => new MSIHANDLE((uint)value.ToInt32());
members = members.Add(ConversionOperatorDeclaration(Token(SyntaxKind.ExplicitKeyword), IdentifierName(name))
.AddParameterListParameters(Parameter(valueParameter.Identifier).WithType(IntPtrTypeSyntax))
.WithExpressionBody(ArrowExpressionClause(ObjectCreationExpression(IdentifierName(name)).AddArgumentListArguments(
Argument(CastExpression(fieldInfo.FieldType, InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, valueParameter, IdentifierName(nameof(IntPtr.ToInt32)))))))))
.AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)) // operators MUST be public
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken)));
}

// public bool Equals(HWND other) => this.Value == other.Value;
IdentifierNameSyntax other = IdentifierName("other");
members = members.Add(MethodDeclaration(PredefinedType(Token(SyntaxKind.BoolKeyword)), nameof(IEquatable<int>.Equals))
Expand Down Expand Up @@ -2490,15 +2468,20 @@ private EnumDeclarationSyntax CreateInteropEnum(TypeDefinition typeDef)

private IEnumerable<MethodDeclarationSyntax> CreateFriendlyOverloads(MethodDefinition methodDefinition, MethodDeclarationSyntax externMethodDeclaration, string declaringTypeName, bool isStatic)
{
#pragma warning disable SA1114 // Parameter list should follow declaration
static ParameterSyntax StripAttributes(ParameterSyntax parameter) => parameter.WithAttributeLists(List<AttributeListSyntax>());
bool IsInterface(string name) => this.typesByName.TryGetValue(name, out TypeDefinitionHandle tdh) && (this.mr.GetTypeDefinition(tdh).Attributes & TypeAttributes.Interface) == TypeAttributes.Interface;
static ExpressionSyntax GetSpanLength(ExpressionSyntax span) => MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, span, IdentifierName(nameof(Span<int>.Length)));
bool isReleaseMethod = this.releaseMethods.Contains(externMethodDeclaration.Identifier.ValueText);

var parameters = externMethodDeclaration.ParameterList.Parameters.Select(StripAttributes).ToList();
var lengthParamUsedBy = new Dictionary<int, int>();
var arguments = externMethodDeclaration.ParameterList.Parameters.Select(p => Argument(IdentifierName(p.Identifier.Text))).ToList();
var fixedBlocks = new List<VariableDeclarationSyntax>();
var leadingOutsideTryStatements = new List<StatementSyntax>();
var leadingStatements = new List<StatementSyntax>();
var trailingStatements = new List<StatementSyntax>();
var finallyStatements = new List<StatementSyntax>();
bool signatureChanged = false;
foreach (ParameterHandle paramHandle in methodDefinition.GetParameters())
{
Expand All @@ -2523,6 +2506,82 @@ private IEnumerable<MethodDeclarationSyntax> CreateFriendlyOverloads(MethodDefin
bool hasOut = externParam.Modifiers.Any(SyntaxKind.OutKeyword);
arguments[param.SequenceNumber - 1] = arguments[param.SequenceNumber - 1].WithRefKindKeyword(Token(hasOut ? SyntaxKind.OutKeyword : SyntaxKind.RefKeyword));
}
else if (isOut && !isIn && !isReleaseMethod && externParam.Type is PointerTypeSyntax { ElementType: IdentifierNameSyntax outTypeId } && this.TryGetHandleReleaseMethod(outTypeId.Identifier.ValueText, out string? outReleaseMethod) && !this.mr.StringComparer.Equals(methodDefinition.Name, outReleaseMethod))
{
if (this.GenerateSafeHandle(outReleaseMethod) is TypeSyntax safeHandleType)
{
signatureChanged = true;

IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
IdentifierNameSyntax typeDefHandleName = IdentifierName(externParam.Identifier.ValueText + "Local");

// out SafeHandle
parameters[param.SequenceNumber - 1] = externParam
.WithType(safeHandleType)
.WithModifiers(TokenList(Token(SyntaxKind.OutKeyword)));

// HANDLE SomeLocal;
leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(outTypeId).AddVariables(
VariableDeclarator(typeDefHandleName.Identifier))));

// Argument: &SomeLocal
arguments[param.SequenceNumber - 1] = Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, typeDefHandleName));

// Some = new SafeHandle(SomeLocal, ownsHandle: true);
trailingStatements.Add(ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
origName,
ObjectCreationExpression(safeHandleType).AddArgumentListArguments(
Argument(typeDefHandleName),
Argument(LiteralExpression(SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon("ownsHandle"))))));
}
}
else if (isIn && !isOut && !isReleaseMethod && externParam.Type is IdentifierNameSyntax typeId && this.TryGetHandleReleaseMethod(typeId.Identifier.ValueText, out string? releaseMethod) && !this.mr.StringComparer.Equals(methodDefinition.Name, releaseMethod))
{
IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
IdentifierNameSyntax typeDefHandleName = IdentifierName(externParam.Identifier.ValueText + "Local");
signatureChanged = true;

IdentifierNameSyntax refAddedName = IdentifierName(externParam.Identifier.ValueText + "AddRef");

// bool hParamNameAddRef = false;
leadingOutsideTryStatements.Add(LocalDeclarationStatement(
VariableDeclaration(PredefinedType(Token(SyntaxKind.BoolKeyword))).AddVariables(
VariableDeclarator(refAddedName.Identifier).WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.FalseLiteralExpression))))));

// HANDLE hTemplateFileLocal;
leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(externParam.Type).AddVariables(
VariableDeclarator(typeDefHandleName.Identifier))));

// if (hTemplateFile is object)
leadingStatements.Add(IfStatement(
BinaryExpression(SyntaxKind.IsExpression, origName, PredefinedType(Token(SyntaxKind.ObjectKeyword))),
Block().AddStatements(
//// hTemplateFile.DangerousAddRef(ref hTemplateFileAddRef);
ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName(nameof(SafeHandle.DangerousAddRef))))
.AddArgumentListArguments(Argument(refAddedName).WithRefKindKeyword(Token(SyntaxKind.RefKeyword)))),
//// hTemplateFileLocal = (HANDLE)hTemplateFile.DangerousGetHandle();
ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
typeDefHandleName,
CastExpression(
externParam.Type,
InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName(nameof(SafeHandle.DangerousGetHandle))), ArgumentList()))))),
//// else hTemplateFileLocal = default;
ElseClause(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, typeDefHandleName, DefaultExpression(externParam.Type))))));

// if (hTemplateFileAddRef) hTemplateFile.DangerousRelease();
finallyStatements.Add(IfStatement(
refAddedName,
ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName(nameof(SafeHandle.DangerousRelease))), ArgumentList()))));

// Accept the SafeHandle instead.
parameters[param.SequenceNumber - 1] = externParam
.WithType(IdentifierName(nameof(SafeHandle)));

// hParamNameLocal;
arguments[param.SequenceNumber - 1] = Argument(typeDefHandleName);
}
else if (externParam.Type is PointerTypeSyntax ptrType
&& !IsVoid(ptrType.ElementType)
&& !(ptrType.ElementType is IdentifierNameSyntax id && IsInterface(id.Identifier.ValueText)))
Expand Down Expand Up @@ -2703,6 +2762,19 @@ private IEnumerable<MethodDeclarationSyntax> CreateFriendlyOverloads(MethodDefin
}
}

TypeSyntax? returnSafeHandleType = externMethodDeclaration.ReturnType is IdentifierNameSyntax returnType
&& this.TryGetHandleReleaseMethod(returnType.Identifier.ValueText, out string? returnReleaseMethod)
? this.GenerateSafeHandle(returnReleaseMethod) : null;
SyntaxToken friendlyMethodName = externMethodDeclaration.Identifier;

if (returnSafeHandleType is object && !signatureChanged)
{
// The parameter types are all the same, but we need a friendly overload with a different return type.
// Our only choice is to rename the friendly overload.
friendlyMethodName = Identifier(externMethodDeclaration.Identifier.ValueText + "_SafeHandle");
signatureChanged = true;
}

if (signatureChanged)
{
if (lengthParamUsedBy.Count > 0)
Expand All @@ -2727,30 +2799,77 @@ private IEnumerable<MethodDeclarationSyntax> CreateFriendlyOverloads(MethodDefin
: MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), IdentifierName(externMethodDeclaration.Identifier.Text)))
.AddArgumentListArguments(arguments.ToArray());
bool hasVoidReturn = externMethodDeclaration.ReturnType is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.VoidKeyword } };
var body = Block()
.AddStatements(leadingStatements.ToArray())
.AddStatements(hasVoidReturn ? (StatementSyntax)ExpressionStatement(externInvocation) : ReturnStatement(externInvocation));
var body = Block().AddStatements(leadingStatements.ToArray());
IdentifierNameSyntax resultLocal = IdentifierName("__result");
if (returnSafeHandleType is object)
{
//// HANDLE result = invocation();
body = body.AddStatements(LocalDeclarationStatement(VariableDeclaration(externMethodDeclaration.ReturnType)
.AddVariables(VariableDeclarator(resultLocal.Identifier).WithInitializer(EqualsValueClause(externInvocation)))));

body = body.AddStatements(trailingStatements.ToArray());

//// return new SafeHandle(result, ownsHandle: true);
body = body.AddStatements(ReturnStatement(ObjectCreationExpression(returnSafeHandleType).AddArgumentListArguments(
Argument(resultLocal),
Argument(LiteralExpression(SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon("ownsHandle")))));
}
else if (hasVoidReturn)
{
body = body.AddStatements(ExpressionStatement(externInvocation));
body = body.AddStatements(trailingStatements.ToArray());
}
else
{
// var result = externInvocation();
body = body.AddStatements(LocalDeclarationStatement(VariableDeclaration(externMethodDeclaration.ReturnType)
.AddVariables(VariableDeclarator(resultLocal.Identifier).WithInitializer(EqualsValueClause(externInvocation)))));

body = body.AddStatements(trailingStatements.ToArray());

// return result;
body = body.AddStatements(ReturnStatement(resultLocal));
}

foreach (var fixedExpression in fixedBlocks)
{
body = Block(FixedStatement(fixedExpression, body));
}

if (finallyStatements.Count > 0)
{
body = Block()
.AddStatements(leadingOutsideTryStatements.ToArray())
.AddStatements(TryStatement(body, default, FinallyClause(Block().AddStatements(finallyStatements.ToArray()))));
}
else if (leadingOutsideTryStatements.Count > 0)
{
body = body.WithStatements(body.Statements.InsertRange(0, leadingOutsideTryStatements));
}

var modifiers = TokenList(Token(this.Visibility), Token(SyntaxKind.UnsafeKeyword));
if (isStatic)
{
modifiers = modifiers.Insert(1, Token(SyntaxKind.StaticKeyword));
}

MethodDeclarationSyntax friendlyDeclaration = externMethodDeclaration
.WithIdentifier(friendlyMethodName)
.WithModifiers(modifiers)
.WithAttributeLists(List<AttributeListSyntax>())
.WithParameterList(ParameterList().AddParameters(parameters.ToArray()))
.WithLeadingTrivia(leadingTrivia)
.WithBody(body)
.WithSemicolonToken(default);

if (returnSafeHandleType is object)
{
friendlyDeclaration = friendlyDeclaration.WithReturnType(returnSafeHandleType);
}

yield return friendlyDeclaration;
}
#pragma warning restore SA1114 // Parameter list should follow declaration
}

private bool IsAttribute(CustomAttribute attribute, string ns, string name)
Expand Down
Loading

0 comments on commit 23e317c

Please sign in to comment.