Skip to content

Commit

Permalink
[release/7.0] Ensure we cleanup the marshalling for elements of colle…
Browse files Browse the repository at this point in the history
…ctions (stateful and stateless) (#76693)

* Ensure we cleanup the marshalling for elements of collections (stateful and stateless)

* Add tests

* Fix bad stackalloc size after we moved to strongly-typed buffers

* PR feedback

* Update NonBlittable.cs

* Propagate details for types.

* Update src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs

Co-authored-by: Elinor Fung <elfung@microsoft.com>

Co-authored-by: Jeremy Koritzinsky <jekoritz@microsoft.com>
Co-authored-by: Aaron Robinson <arobins@microsoft.com>
Co-authored-by: Elinor Fung <elfung@microsoft.com>
  • Loading branch information
4 people committed Oct 7, 2022
1 parent ef70886 commit 0d944ae
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,33 @@ protected StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo
StubCodeContext.Stage.Unmarshal));
}

protected StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context)
{
string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(info, context);
StatementSyntax contentsCleanupStatements = GenerateContentsMarshallingStatement(info, context,
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(MarshallerHelpers.GetNativeSpanIdentifier(info, context)),
IdentifierName("Length")),
StubCodeContext.Stage.Cleanup);

if (contentsCleanupStatements.IsKind(SyntaxKind.EmptyStatement))
{
return EmptyStatement();
}

return Block(
LocalDeclarationStatement(VariableDeclaration(
GenericName(
Identifier(TypeNames.System_Span),
TypeArgumentList(SingletonSeparatedList(_unmanagedElementType))),
SingletonSeparatedList(
VariableDeclarator(
Identifier(nativeSpanIdentifier))
.WithInitializer(EqualsValueClause(
GetUnmanagedValuesDestination(info, context)))))),
contentsCleanupStatements);
}

protected StatementSyntax GenerateContentsMarshallingStatement(
TypePositionInfo info,
StubCodeContext context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,37 @@ public static IEnumerable<TypePositionInfo> GetDependentElementsOfMarshallingInf
}
}
}

public static StatementSyntax SkipInitOrDefaultInit(TypePositionInfo info, StubCodeContext context)
{
(TargetFramework fmk, _) = context.GetTargetFramework();
if (info.ManagedType is not PointerTypeInfo
&& info.ManagedType is not ValueTypeInfo { IsByRefLike: true }
&& fmk is TargetFramework.Net)
{
// Use the Unsafe.SkipInit<T> API when available and
// managed type is usable as a generic parameter.
return ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_CompilerServices_Unsafe),
IdentifierName("SkipInit")))
.WithArgumentList(
ArgumentList(SingletonSeparatedList(
Argument(IdentifierName(info.InstanceIdentifier))
.WithRefOrOutKeyword(Token(SyntaxKind.OutKeyword))))));
}
else
{
// Assign out params to default
return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(info.InstanceIdentifier),
LiteralExpression(
SyntaxKind.DefaultLiteralExpression,
Token(SyntaxKind.DefaultKeyword))));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,27 @@ public StatefulLinearCollectionNonBlittableElementsMarshalling(
}

public TypeSyntax AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupStatements(info, context);
public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
{
StatementSyntax elementCleanup = GenerateElementCleanupStatement(info, context);

if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement))
{
yield return elementCleanup;
}

if (!_shape.HasFlag(MarshallerShape.Free))
yield break;

string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context);
// <marshaller>.Free();
yield return ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(marshaller),
IdentifierName(ShapeMemberNames.Free)),
ArgumentList()));
}
public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);

public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ public StatelessFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller,

public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
{
foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupStatements(info, context))
{
yield return statement;
}
// <marshallerType>.Free(<nativeIdentifier>);
yield return ExpressionStatement(
InvocationExpression(
Expand Down Expand Up @@ -372,11 +376,19 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo i
public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context)
{
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
yield return LocalDeclarationStatement(
VariableDeclaration(
PredefinedType(Token(SyntaxKind.IntKeyword)),
SingletonSeparatedList(
VariableDeclarator(MarshallerHelpers.GetNumElementsIdentifier(info, context)))));
VariableDeclarator(numElementsIdentifier))));
// Use the numElements local to ensure the compiler doesn't give errors for using an uninitialized variable.
// The value will never be used unless it has been initialized, so this is safe.
yield return MarshallerHelpers.SkipInitOrDefaultInit(
new TypePositionInfo(SpecialTypeInfo.Int32, NoMarshallingInfo.Instance)
{
InstanceIdentifier = numElementsIdentifier
}, context);
}

public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
Expand Down Expand Up @@ -512,7 +524,15 @@ public StatelessLinearCollectionNonBlittableElementsMarshalling(

public TypeSyntax AsNativeType(TypePositionInfo info) => _nativeTypeSyntax;

public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
{
StatementSyntax elementCleanup = GenerateElementCleanupStatement(info, context);

if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement))
{
yield return elementCleanup;
}
}

public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context)
{
Expand Down Expand Up @@ -588,11 +608,19 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo i

public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context)
{
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
yield return LocalDeclarationStatement(
VariableDeclaration(
PredefinedType(Token(SyntaxKind.IntKeyword)),
SingletonSeparatedList(
VariableDeclarator(MarshallerHelpers.GetNumElementsIdentifier(info, context)))));
VariableDeclarator(numElementsIdentifier))));
// Use the numElements local to ensure the compiler doesn't give errors for using an uninitialized variable.
// The value will never be used unless it has been initialized, so this is safe.
yield return MarshallerHelpers.SkipInitOrDefaultInit(
new TypePositionInfo(SpecialTypeInfo.Int32, NoMarshallingInfo.Instance)
{
InstanceIdentifier = numElementsIdentifier
}, context);
}

public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ private MarshallingInfo CreateNativeMarshallingInfo(
}

int maxIndirectionDepthUsedLocal = maxIndirectionDepthUsed;
Func<ITypeSymbol, MarshallingInfo> getMarshallingInfoForElement = (ITypeSymbol elementType) => GetMarshallingInfo(elementType, new Dictionary<int, AttributeData>(), 1, ImmutableHashSet<string>.Empty, ref maxIndirectionDepthUsedLocal);
Func<ITypeSymbol, MarshallingInfo> getMarshallingInfoForElement = (ITypeSymbol elementType) => GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionDepthUsedLocal);
if (ManualTypeMarshallingHelper.TryGetLinearCollectionMarshallersFromEntryType(entryPointType, type, _compilation, getMarshallingInfoForElement, out CustomTypeMarshallers? collectionMarshallers))
{
maxIndirectionDepthUsed = maxIndirectionDepthUsedLocal;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,7 @@ public static VariableDeclarations GenerateDeclarationsForManagedToNative(BoundG

if (info.RefKind == RefKind.Out)
{
(TargetFramework fmk, _) = context.GetTargetFramework();
if (info.ManagedType is not PointerTypeInfo
&& info.ManagedType is not ValueTypeInfo { IsByRefLike: true }
&& fmk is TargetFramework.Net)
{
// Use the Unsafe.SkipInit<T> API when available and
// managed type is usable as a generic parameter.
initializations.Add(ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_CompilerServices_Unsafe),
IdentifierName("SkipInit")))
.WithArgumentList(
ArgumentList(SingletonSeparatedList(
Argument(IdentifierName(info.InstanceIdentifier))
.WithRefOrOutKeyword(Token(SyntaxKind.OutKeyword)))))));
}
else
{
// Assign out params to default
initializations.Add(ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(info.InstanceIdentifier),
LiteralExpression(
SyntaxKind.DefaultLiteralExpression,
Token(SyntaxKind.DefaultKeyword)))));
}
initializations.Add(MarshallerHelpers.SkipInitOrDefaultInit(info, context));
}

// Declare variables for parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ public partial class Stateless
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")]
public static partial int SumWithBuffer([MarshalUsing(typeof(ListMarshallerWithBuffer<,>))] List<int> values, int numValues);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_ptr_array")]
public static unsafe partial int SumWithFreeTracking([MarshalUsing(typeof(ListMarshaller<,>)), MarshalUsing(typeof(IntWrapperMarshallerWithFreeCounts), ElementIndirectionDepth = 1)] List<IntWrapper> values, int numValues);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "double_values")]
public static partial int DoubleValues([MarshalUsing(typeof(ListMarshallerWithPinning<,>))] List<BlittableIntWrapper> values, int length);

Expand Down Expand Up @@ -99,6 +102,9 @@ public partial class Stateful
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")]
public static partial int Sum([MarshalUsing(typeof(ListMarshallerStateful<,>))] List<int> values, int numValues);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_ptr_array")]
public static unsafe partial int SumWithFreeTracking([MarshalUsing(typeof(ListMarshallerStateful<,>)), MarshalUsing(typeof(IntWrapperMarshallerWithFreeCounts), ElementIndirectionDepth = 1)] List<IntWrapper> values, int numValues);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array_ref")]
public static partial int SumInArray([MarshalUsing(typeof(ListMarshallerStateful<,>))] in List<int> values, int numValues);

Expand Down Expand Up @@ -369,6 +375,30 @@ public void NonBlittableElementCollection_GuaranteedUnmarshal()
Assert.True(NativeExportsNE.Collections.Stateful.ListGuaranteedUnmarshal<BoolStruct, BoolStructMarshaller.BoolStructNative>.Marshaller.ToManagedFinallyCalled);
}

[Fact]
public void ElementsFreed()
{
List<IntWrapper> list = new List<IntWrapper>
{
new IntWrapper { i = 1 },
new IntWrapper { i = 10 },
new IntWrapper { i = 24 },
new IntWrapper { i = 30 },
};

int startingCount = IntWrapperMarshallerWithFreeCounts.NumCallsToFree;

NativeExportsNE.Collections.Stateless.SumWithFreeTracking(list, list.Count);

Assert.Equal(startingCount + list.Count, IntWrapperMarshallerWithFreeCounts.NumCallsToFree);

startingCount = IntWrapperMarshallerWithFreeCounts.NumCallsToFree;

NativeExportsNE.Collections.Stateful.SumWithFreeTracking(list, list.Count);

Assert.Equal(startingCount + list.Count, IntWrapperMarshallerWithFreeCounts.NumCallsToFree);
}

private static List<BoolStruct> GetBoolStructsToAnd(bool result) => new List<BoolStruct>
{
new BoolStruct
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,31 @@ public static void Free(int* unmanaged)
}
}

[CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(IntWrapperMarshallerWithFreeCounts))]
public static unsafe class IntWrapperMarshallerWithFreeCounts
{
[ThreadStatic]
public static int NumCallsToFree = 0;

public static int* ConvertToUnmanaged(IntWrapper managed)
{
int* ret = (int*)Marshal.AllocCoTaskMem(sizeof(int));
*ret = managed.i;
return ret;
}

public static IntWrapper ConvertToManaged(int* unmanaged)
{
return new IntWrapper { i = *unmanaged };
}

public static void Free(int* unmanaged)
{
NumCallsToFree++;
Marshal.FreeCoTaskMem((IntPtr)unmanaged);
}
}

[CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(Marshaller))]
public static unsafe class IntWrapperMarshallerStateful
{
Expand Down Expand Up @@ -477,14 +502,14 @@ public void FromManaged(List<T> managed, Span<TUnmanagedElement> buffer)

_list = managed;
// Always allocate at least one byte when the list is zero-length.
int spaceToAllocate = Math.Max(managed.Count * sizeof(TUnmanagedElement), 1);
if (spaceToAllocate <= buffer.Length)
int countToAllocate = Math.Max(managed.Count, 1);
if (countToAllocate <= buffer.Length)
{
_span = buffer[0..spaceToAllocate];
_span = buffer[0..countToAllocate];
}
else
{
_allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate);
_allocatedMemory = Marshal.AllocCoTaskMem(countToAllocate * sizeof(TUnmanagedElement));
_span = new Span<TUnmanagedElement>((void*)_allocatedMemory, managed.Count);
}
}
Expand Down

0 comments on commit 0d944ae

Please sign in to comment.