Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle SetLastError=true #360

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<LangVersion>8.0</LangVersion>
<RootNamespace>System.Runtime.InteropServices</RootNamespace>
<Nullable>enable</Nullable>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

</Project>
74 changes: 74 additions & 0 deletions DllImportGenerator/Ancillary.Interop/MarshalEx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,79 @@ public static void SetHandle(SafeHandle safeHandle, IntPtr handle)
{
typeof(SafeHandle).GetMethod("SetHandle", BindingFlags.NonPublic | BindingFlags.Instance)!.Invoke(safeHandle, new object[] { handle });
}

/// <summary>
/// Set the last platform invoke error on the thread
/// </summary>
public static void SetLastWin32Error(int error)
{
typeof(Marshal).GetMethod("SetLastWin32Error", BindingFlags.NonPublic | BindingFlags.Static)!.Invoke(null, new object[] { error });
}

/// <summary>
/// Get the last system error on the current thread (errno on Unix, GetLastError on Windows)
/// </summary>
public static unsafe int GetLastSystemError()
{
// Would be internal call that handles getting the last error for the thread using the PAL

if (OperatingSystem.IsWindows())
{
return Kernel32.GetLastError();
}
else if (OperatingSystem.IsMacOS())
{
return *libc.__error();
}
else if (OperatingSystem.IsLinux())
{
return *libc.__errno_location();
}

throw new NotImplementedException();
elinor-fung marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
/// Set the last system error on the current thread (errno on Unix, SetLastError on Windows)
/// </summary>
public static unsafe void SetLastSystemError(int error)
{
// Would be internal call that handles setting the last error for the thread using the PAL

if (OperatingSystem.IsWindows())
{
Kernel32.SetLastError(error);
}
else if (OperatingSystem.IsMacOS())
{
*libc.__error() = error;
}
else if (OperatingSystem.IsLinux())
{
*libc.__errno_location() = error;
}
else
{
throw new NotImplementedException();
}
}

private class Kernel32
{
[DllImport(nameof(Kernel32))]
public static extern void SetLastError(int error);

[DllImport(nameof(Kernel32))]
public static extern int GetLastError();
}

private class libc
{
[DllImport(nameof(libc))]
internal static unsafe extern int* __errno_location();

[DllImport(nameof(libc))]
internal static unsafe extern int* __error();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
using System;
using System.Runtime.InteropServices;

using Xunit;

namespace DllImportGenerator.IntegrationTests
{
[BlittableType]
public struct SetLastErrorMarshaller
{
public int val;

public SetLastErrorMarshaller(int i)
{
val = i;
}

public int ToManaged()
{
// Explicity set the last error to something else on unmarshalling
MarshalEx.SetLastWin32Error(val * 2);
return val;
}
}

partial class NativeExportsNE
{
public partial class SetLastError
{
[GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "set_error", SetLastError = true)]
public static partial int SetError(int error, byte shouldSetError);

[GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "set_error_return_string", SetLastError = true)]
[return: MarshalUsing(typeof(SetLastErrorMarshaller))]
public static partial int SetError_CustomMarshallingSetsError(int error, byte shouldSetError);

[GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "set_error_return_string", SetLastError = true)]
[return: MarshalAs(UnmanagedType.LPWStr)]
public static partial string SetError_NonBlittableSignature(int error, [MarshalAs(UnmanagedType.U1)] bool shouldSetError, [MarshalAs(UnmanagedType.LPWStr)] string errorString);
}
}

public class SetLastErrorTests
{
[Theory]
[InlineData(0)]
[InlineData(2)]
[InlineData(-5)]
public void LastWin32Error_HasExpectedValue(int error)
{
string errorString = error.ToString();
string ret = NativeExportsNE.SetLastError.SetError_NonBlittableSignature(error, shouldSetError: true, errorString);
Assert.Equal(error, Marshal.GetLastWin32Error());
Assert.Equal(errorString, ret);

// Clear the last error
MarshalEx.SetLastWin32Error(0);

NativeExportsNE.SetLastError.SetError(error, shouldSetError: 1);
Assert.Equal(error, Marshal.GetLastWin32Error());

MarshalEx.SetLastWin32Error(0);

// Custom marshalling sets the last error on unmarshalling.
// Last error should reflect error from native call, not unmarshalling.
NativeExportsNE.SetLastError.SetError_CustomMarshallingSetsError(error, shouldSetError: 1);
Assert.Equal(error, Marshal.GetLastWin32Error());
}

[Fact]
public void ClearPreviousError()
{
int error = 100;
MarshalEx.SetLastWin32Error(error);

// Don't actually set the error in the native call. SetLastError=true should clear any existing error.
string errorString = error.ToString();
string ret = NativeExportsNE.SetLastError.SetError_NonBlittableSignature(error, shouldSetError: false, errorString);
Assert.Equal(0, Marshal.GetLastWin32Error());
Assert.Equal(errorString, ret);

MarshalEx.SetLastWin32Error(error);

// Don't actually set the error in the native call. SetLastError=true should clear any existing error.
NativeExportsNE.SetLastError.SetError(error, shouldSetError: 0);
Assert.Equal(0, Marshal.GetLastWin32Error());

// Don't actually set the error in the native call. Custom marshalling still sets the last error.
// SetLastError=true should clear any existing error and ignore error set by custom marshalling.
NativeExportsNE.SetLastError.SetError_CustomMarshallingSetsError(error, shouldSetError: 0);
Assert.Equal(0, Marshal.GetLastWin32Error());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ public static IEnumerable<object[]> CodeSnippetsToCompile()

// Unsupported named arguments
// * BestFitMapping, ThrowOnUnmappableChar
// [TODO]: Expected diagnostic count should be 2 once we support SetLastError
yield return new object[] { CodeSnippets.AllDllImportNamedArguments, 3, 0 };
yield return new object[] { CodeSnippets.AllDllImportNamedArguments, 2, 0 };

// LCIDConversion
yield return new object[] { CodeSnippets.LCIDConversionAttribute, 1, 0 };
Expand Down
28 changes: 4 additions & 24 deletions DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ namespace DllImportGenerator.UnitTests
{
public class Compiles
{
public static IEnumerable<object[]> CodeSnippetsToCompile_NoDiagnostics()
public static IEnumerable<object[]> CodeSnippetsToCompile()
{
yield return new[] { CodeSnippets.TrivialClassDeclarations };
yield return new[] { CodeSnippets.TrivialStructDeclarations };
yield return new[] { CodeSnippets.MultipleAttributes };
yield return new[] { CodeSnippets.NestedNamespace };
yield return new[] { CodeSnippets.NestedTypes };
yield return new[] { CodeSnippets.UserDefinedEntryPoint };
//yield return new[] { CodeSnippets.AllSupportedDllImportNamedArguments };
yield return new[] { CodeSnippets.AllSupportedDllImportNamedArguments };
yield return new[] { CodeSnippets.DefaultParameters };
yield return new[] { CodeSnippets.UseCSharpFeaturesForConstants };

Expand Down Expand Up @@ -161,14 +161,9 @@ public static IEnumerable<object[]> CodeSnippetsToCompile_NoDiagnostics()
yield return new[] { CodeSnippets.CustomStructMarshallingMarshalUsingParametersAndModifiers };
}

public static IEnumerable<object[]> CodeSnippetsToCompile_WithDiagnostics()
{
yield return new[] { CodeSnippets.AllSupportedDllImportNamedArguments };
}

[Theory]
[MemberData(nameof(CodeSnippetsToCompile_NoDiagnostics))]
public async Task ValidateSnippets_NoDiagnostics(string source)
[MemberData(nameof(CodeSnippetsToCompile))]
public async Task ValidateSnippets(string source)
{
Compilation comp = await TestUtils.CreateCompilation(source);
TestUtils.AssertPreSourceGeneratorCompilation(comp);
Expand All @@ -179,20 +174,5 @@ public async Task ValidateSnippets_NoDiagnostics(string source)
var newCompDiags = newComp.GetDiagnostics();
Assert.Empty(newCompDiags);
}

[Theory]
[MemberData(nameof(CodeSnippetsToCompile_WithDiagnostics))]
public async Task ValidateSnippets_WithDiagnostics(string source)
{
Compilation comp = await TestUtils.CreateCompilation(source);
TestUtils.AssertPreSourceGeneratorCompilation(comp);

var newComp = TestUtils.RunGenerators(comp, out var generatorDiags, new Microsoft.Interop.DllImportGenerator());
Assert.NotEmpty(generatorDiags);
Assert.All(generatorDiags, d => Assert.StartsWith(Microsoft.Interop.GeneratorDiagnostics.Ids.Prefix, d.Id));

var newCompDiags = newComp.GetDiagnostics();
Assert.Empty(newCompDiags);
}
}
}
6 changes: 0 additions & 6 deletions DllImportGenerator/DllImportGenerator/DllImportGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,6 @@ public void Execute(GeneratorExecutionContext context)
generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr, nameof(DllImportStub.GeneratedDllImportData.ThrowOnUnmappableChar));
}

// [TODO] Remove once we support SetLastError=true
if (dllImportData.SetLastError)
{
generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr, nameof(DllImportStub.GeneratedDllImportData.SetLastError), "true");
}

if (lcidConversionAttr != null)
{
// Using LCIDConversion with GeneratedDllImport is not supported
Expand Down
53 changes: 52 additions & 1 deletion DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ internal sealed class StubCodeGenerator : StubCodeContext
public string ReturnNativeIdentifier { get; private set; } = ReturnIdentifier;

private const string InvokeReturnIdentifier = "__invokeRetVal";
private const string LastErrorIdentifier = "__lastError";

// Error code representing success. This maps to S_OK for Windows HRESULT semantics and 0 for POSIX errno semantics.
private const int SuccessErrorCode = 0;

private static readonly Stage[] Stages = new Stage[]
{
Expand Down Expand Up @@ -170,6 +174,14 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo
AppendVariableDeclations(setupStatements, retMarshaller.TypeInfo, retMarshaller.Generator);
}

if (this.dllImportData.SetLastError)
{
// Declare variable for last error
setupStatements.Add(MarshallerHelpers.DeclareWithDefault(
PredefinedType(Token(SyntaxKind.IntKeyword)),
LastErrorIdentifier));
}

var tryStatements = new List<StatementSyntax>();
var finallyStatements = new List<StatementSyntax>();
var invoke = InvocationExpression(IdentifierName(dllImportName));
Expand Down Expand Up @@ -235,11 +247,37 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo
invoke));
}

if (this.dllImportData.SetLastError)
{
// Marshal.SetLastSystemError(0);
elinor-fung marked this conversation as resolved.
Show resolved Hide resolved
var clearLastError = ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx),
IdentifierName("SetLastSystemError")),
ArgumentList(SingletonSeparatedList(
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(SuccessErrorCode)))))));

// <lastError> = Marshal.GetLastSystemError();
var getLastError = ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(LastErrorIdentifier),
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx),
IdentifierName("GetLastSystemError")))));

invokeStatement = Block(clearLastError, invokeStatement, getLastError);
}

// Nest invocation in fixed statements
if (fixedStatements.Any())
{
fixedStatements.Reverse();
invokeStatement = fixedStatements.First().WithStatement(Block(invokeStatement));
invokeStatement = fixedStatements.First().WithStatement(invokeStatement);
foreach (var fixedStatement in fixedStatements.Skip(1))
{
invokeStatement = fixedStatement.WithStatement(Block(invokeStatement));
Expand Down Expand Up @@ -274,6 +312,19 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo
allStatements.AddRange(tryStatements);
}

if (this.dllImportData.SetLastError)
{
// Marshal.SetLastWin32Error(<lastError>);
allStatements.Add(ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx),
IdentifierName("SetLastWin32Error")),
ArgumentList(SingletonSeparatedList(
Argument(IdentifierName(LastErrorIdentifier)))))));
}

// Return
if (!stubReturnsVoid)
allStatements.Add(ReturnStatement(IdentifierName(ReturnIdentifier)));
Expand Down
Loading