diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs index f51e0d522a321..e75b25eab45d7 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs @@ -315,6 +315,33 @@ protected StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo StubCodeContext.Stage.Unmarshal)); } + protected StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context) + { + string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(info, context); + StatementSyntax contentsCleanupStatements = GenerateContentsMarshallingStatement(info, context, + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(MarshallerHelpers.GetNativeSpanIdentifier(info, context)), + IdentifierName("Length")), + StubCodeContext.Stage.Cleanup); + + if (contentsCleanupStatements.IsKind(SyntaxKind.EmptyStatement)) + { + return EmptyStatement(); + } + + return Block( + LocalDeclarationStatement(VariableDeclaration( + GenericName( + Identifier(TypeNames.System_Span), + TypeArgumentList(SingletonSeparatedList(_unmanagedElementType))), + SingletonSeparatedList( + VariableDeclarator( + Identifier(nativeSpanIdentifier)) + .WithInitializer(EqualsValueClause( + GetUnmanagedValuesDestination(info, context)))))), + contentsCleanupStatements); + } + protected StatementSyntax GenerateContentsMarshallingStatement( TypePositionInfo info, StubCodeContext context, diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs index c245f8db61bcb..98213f4f7dd25 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs @@ -292,5 +292,37 @@ public static IEnumerable GetDependentElementsOfMarshallingInf } } } + + public static StatementSyntax SkipInitOrDefaultInit(TypePositionInfo info, StubCodeContext context) + { + (TargetFramework fmk, _) = context.GetTargetFramework(); + if (info.ManagedType is not PointerTypeInfo + && info.ManagedType is not ValueTypeInfo { IsByRefLike: true } + && fmk is TargetFramework.Net) + { + // Use the Unsafe.SkipInit API when available and + // managed type is usable as a generic parameter. + return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ParseName(TypeNames.System_Runtime_CompilerServices_Unsafe), + IdentifierName("SkipInit"))) + .WithArgumentList( + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(info.InstanceIdentifier)) + .WithRefOrOutKeyword(Token(SyntaxKind.OutKeyword)))))); + } + else + { + // Assign out params to default + return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(info.InstanceIdentifier), + LiteralExpression( + SyntaxKind.DefaultLiteralExpression, + Token(SyntaxKind.DefaultKeyword)))); + } + } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs index ad487bf6d4015..20707af44ed66 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs @@ -459,7 +459,27 @@ public StatefulLinearCollectionNonBlittableElementsMarshalling( } public TypeSyntax AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupStatements(info, context); + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + StatementSyntax elementCleanup = GenerateElementCleanupStatement(info, context); + + if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement)) + { + yield return elementCleanup; + } + + if (!_shape.HasFlag(MarshallerShape.Free)) + yield break; + + string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context); + // .Free(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(marshaller), + IdentifierName(ShapeMemberNames.Free)), + ArgumentList())); + } public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs index 213eec0855f30..aa09b9d8d3d01 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs @@ -251,6 +251,10 @@ public StatelessFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller, public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) { + foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupStatements(info, context)) + { + yield return statement; + } // .Free(); yield return ExpressionStatement( InvocationExpression( @@ -372,11 +376,19 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) { + string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context); yield return LocalDeclarationStatement( VariableDeclaration( PredefinedType(Token(SyntaxKind.IntKeyword)), SingletonSeparatedList( - VariableDeclarator(MarshallerHelpers.GetNumElementsIdentifier(info, context))))); + VariableDeclarator(numElementsIdentifier)))); + // Use the numElements local to ensure the compiler doesn't give errors for using an uninitialized variable. + // The value will never be used unless it has been initialized, so this is safe. + yield return MarshallerHelpers.SkipInitOrDefaultInit( + new TypePositionInfo(SpecialTypeInfo.Int32, NoMarshallingInfo.Instance) + { + InstanceIdentifier = numElementsIdentifier + }, context); } public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); @@ -512,7 +524,15 @@ public StatelessLinearCollectionNonBlittableElementsMarshalling( public TypeSyntax AsNativeType(TypePositionInfo info) => _nativeTypeSyntax; - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + StatementSyntax elementCleanup = GenerateElementCleanupStatement(info, context); + + if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement)) + { + yield return elementCleanup; + } + } public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) { @@ -588,11 +608,19 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) { + string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context); yield return LocalDeclarationStatement( VariableDeclaration( PredefinedType(Token(SyntaxKind.IntKeyword)), SingletonSeparatedList( - VariableDeclarator(MarshallerHelpers.GetNumElementsIdentifier(info, context))))); + VariableDeclarator(numElementsIdentifier)))); + // Use the numElements local to ensure the compiler doesn't give errors for using an uninitialized variable. + // The value will never be used unless it has been initialized, so this is safe. + yield return MarshallerHelpers.SkipInitOrDefaultInit( + new TypePositionInfo(SpecialTypeInfo.Int32, NoMarshallingInfo.Instance) + { + InstanceIdentifier = numElementsIdentifier + }, context); } public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs index 19680588e7bad..fb4fcb4f7a049 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs @@ -593,7 +593,7 @@ private MarshallingInfo CreateNativeMarshallingInfo( } int maxIndirectionDepthUsedLocal = maxIndirectionDepthUsed; - Func getMarshallingInfoForElement = (ITypeSymbol elementType) => GetMarshallingInfo(elementType, new Dictionary(), 1, ImmutableHashSet.Empty, ref maxIndirectionDepthUsedLocal); + Func getMarshallingInfoForElement = (ITypeSymbol elementType) => GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionDepthUsedLocal); if (ManualTypeMarshallingHelper.TryGetLinearCollectionMarshallersFromEntryType(entryPointType, type, _compilation, getMarshallingInfoForElement, out CustomTypeMarshallers? collectionMarshallers)) { maxIndirectionDepthUsed = maxIndirectionDepthUsedLocal; diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/VariableDeclarations.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/VariableDeclarations.cs index 4c06dedb5eefd..dbc73df2f8d2c 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/VariableDeclarations.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/VariableDeclarations.cs @@ -29,34 +29,7 @@ public static VariableDeclarations GenerateDeclarationsForManagedToNative(BoundG if (info.RefKind == RefKind.Out) { - (TargetFramework fmk, _) = context.GetTargetFramework(); - if (info.ManagedType is not PointerTypeInfo - && info.ManagedType is not ValueTypeInfo { IsByRefLike: true } - && fmk is TargetFramework.Net) - { - // Use the Unsafe.SkipInit API when available and - // managed type is usable as a generic parameter. - initializations.Add(ExpressionStatement( - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - ParseName(TypeNames.System_Runtime_CompilerServices_Unsafe), - IdentifierName("SkipInit"))) - .WithArgumentList( - ArgumentList(SingletonSeparatedList( - Argument(IdentifierName(info.InstanceIdentifier)) - .WithRefOrOutKeyword(Token(SyntaxKind.OutKeyword))))))); - } - else - { - // Assign out params to default - initializations.Add(ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(info.InstanceIdentifier), - LiteralExpression( - SyntaxKind.DefaultLiteralExpression, - Token(SyntaxKind.DefaultKeyword))))); - } + initializations.Add(MarshallerHelpers.SkipInitOrDefaultInit(info, context)); } // Declare variables for parameters diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.cs index 48d592735280c..51214883bd4b3 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.cs @@ -25,6 +25,9 @@ public partial class Stateless [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")] public static partial int SumWithBuffer([MarshalUsing(typeof(ListMarshallerWithBuffer<,>))] List values, int numValues); + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_ptr_array")] + public static unsafe partial int SumWithFreeTracking([MarshalUsing(typeof(ListMarshaller<,>)), MarshalUsing(typeof(IntWrapperMarshallerWithFreeCounts), ElementIndirectionDepth = 1)] List values, int numValues); + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "double_values")] public static partial int DoubleValues([MarshalUsing(typeof(ListMarshallerWithPinning<,>))] List values, int length); @@ -99,6 +102,9 @@ public partial class Stateful [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")] public static partial int Sum([MarshalUsing(typeof(ListMarshallerStateful<,>))] List values, int numValues); + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_ptr_array")] + public static unsafe partial int SumWithFreeTracking([MarshalUsing(typeof(ListMarshallerStateful<,>)), MarshalUsing(typeof(IntWrapperMarshallerWithFreeCounts), ElementIndirectionDepth = 1)] List values, int numValues); + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array_ref")] public static partial int SumInArray([MarshalUsing(typeof(ListMarshallerStateful<,>))] in List values, int numValues); @@ -369,6 +375,30 @@ public void NonBlittableElementCollection_GuaranteedUnmarshal() Assert.True(NativeExportsNE.Collections.Stateful.ListGuaranteedUnmarshal.Marshaller.ToManagedFinallyCalled); } + [Fact] + public void ElementsFreed() + { + List list = new List + { + new IntWrapper { i = 1 }, + new IntWrapper { i = 10 }, + new IntWrapper { i = 24 }, + new IntWrapper { i = 30 }, + }; + + int startingCount = IntWrapperMarshallerWithFreeCounts.NumCallsToFree; + + NativeExportsNE.Collections.Stateless.SumWithFreeTracking(list, list.Count); + + Assert.Equal(startingCount + list.Count, IntWrapperMarshallerWithFreeCounts.NumCallsToFree); + + startingCount = IntWrapperMarshallerWithFreeCounts.NumCallsToFree; + + NativeExportsNE.Collections.Stateful.SumWithFreeTracking(list, list.Count); + + Assert.Equal(startingCount + list.Count, IntWrapperMarshallerWithFreeCounts.NumCallsToFree); + } + private static List GetBoolStructsToAnd(bool result) => new List { new BoolStruct diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs index 6035dc2dc5b1f..88e594d9f6a2d 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs @@ -196,6 +196,31 @@ public static void Free(int* unmanaged) } } + [CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(IntWrapperMarshallerWithFreeCounts))] + public static unsafe class IntWrapperMarshallerWithFreeCounts + { + [ThreadStatic] + public static int NumCallsToFree = 0; + + public static int* ConvertToUnmanaged(IntWrapper managed) + { + int* ret = (int*)Marshal.AllocCoTaskMem(sizeof(int)); + *ret = managed.i; + return ret; + } + + public static IntWrapper ConvertToManaged(int* unmanaged) + { + return new IntWrapper { i = *unmanaged }; + } + + public static void Free(int* unmanaged) + { + NumCallsToFree++; + Marshal.FreeCoTaskMem((IntPtr)unmanaged); + } + } + [CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(Marshaller))] public static unsafe class IntWrapperMarshallerStateful { @@ -477,14 +502,14 @@ public void FromManaged(List managed, Span buffer) _list = managed; // Always allocate at least one byte when the list is zero-length. - int spaceToAllocate = Math.Max(managed.Count * sizeof(TUnmanagedElement), 1); - if (spaceToAllocate <= buffer.Length) + int countToAllocate = Math.Max(managed.Count, 1); + if (countToAllocate <= buffer.Length) { - _span = buffer[0..spaceToAllocate]; + _span = buffer[0..countToAllocate]; } else { - _allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate); + _allocatedMemory = Marshal.AllocCoTaskMem(countToAllocate * sizeof(TUnmanagedElement)); _span = new Span((void*)_allocatedMemory, managed.Count); } }