From 191ec61d66f4d129b42f5bdee0d80834fbb4f42d Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Tue, 8 Aug 2023 16:47:45 -0500 Subject: [PATCH] Cleanup caller allocated and callee allocated resources separately (#89982) This PR separates cleaning up caller allocated resources and callee allocated resources into separate stages in the managed to unmanaged direction. Caller allocated parameters (anything except 'out') will clean up the same way. Callee allocated parameters ('out' parameters) will be cleaned up only if the invocation succeeded. --- .../LibraryImportGenerator/Pipeline.md | 21 ++-- .../JSExportCodeGenerator.cs | 16 +-- .../JSImportCodeGenerator.cs | 12 +-- .../ManagedToNativeVTableMethodGenerator.cs | 12 +-- ...anagedHResultExceptionMarshallerFactory.cs | 2 +- .../UnmanagedToManagedStubGenerator.cs | 7 +- .../PInvokeStubCodeGenerator.cs | 13 +-- .../GeneratedStatements.cs | 9 +- .../CustomTypeMarshallingGenerator.cs | 6 +- .../Marshalling/ElementsMarshalling.cs | 19 +++- .../ICustomTypeMarshallingStrategy.cs | 4 +- .../Marshalling/MarshallerHelpers.cs | 29 ++++- .../Marshalling/SafeHandleMarshaller.cs | 2 +- .../StatefulMarshallingStrategy.cs | 77 +++++++++++-- .../StatelessMarshallingStrategy.cs | 102 ++++++++++++++++-- ...nagedToManagedOwnershipTrackingStrategy.cs | 29 +++-- .../StubCodeContext.cs | 9 +- .../RcwAroundCcwTests.cs | 76 +++++++++++-- .../IArrayOfStatelessElements.cs | 12 +++ .../IStatefulFinallyMarshalling.cs | 1 + .../ComInterfaces/IStatelessMarshalling.cs | 2 + .../ManagedComMethodFailureException.cs | 11 ++ 22 files changed, 385 insertions(+), 86 deletions(-) create mode 100644 src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/ManagedComMethodFailureException.cs diff --git a/docs/design/libraries/LibraryImportGenerator/Pipeline.md b/docs/design/libraries/LibraryImportGenerator/Pipeline.md index 9533241c9ad29..bd38953951751 100644 --- a/docs/design/libraries/LibraryImportGenerator/Pipeline.md +++ b/docs/design/libraries/LibraryImportGenerator/Pipeline.md @@ -91,8 +91,12 @@ The stub code generator itself will handle some initial setup and variable decla - Call `Generate` on the marshalling generator for every parameter 1. `GuaranteedUnmarshal`: conversion of native to managed data even when an exception is thrown - Call `Generate` on the marshalling generator for every parameter. -1. `Cleanup`: free any allocated resources + - If this stage has any statements, put them in an if statement where the condition represents whether the call succeeded +1. `CleanupCallerAllocated`: free any resources allocated by the caller - Call `Generate` on the marshalling generator for every parameter +1. `CleanupCalleeAllocated`: if the native method succeeded, free any resources allocated by the callee (`out` parameters and return values) + - Call `Generate` on the marshalling generator for every parameter + - If this stage has any statements, put them in an if statement where the condition represents whether the call succeeded Generated P/Invoke structure (if no code is generated for `GuaranteedUnmarshal` and `Cleanup`, the `try-finally` is omitted): ```C# @@ -113,7 +117,8 @@ try finally { << GuaranteedUnmarshal >> - << Cleanup >> + << CleanupCalleeAllocated >> + << CleanupCallerAllocated >> } ``` @@ -138,12 +143,12 @@ Support for these features is indicated in code by the `abstract` `SingleFrameSp The various scenarios mentioned above have different levels of support for these specialized features: -| Scenarios | Pinning and Stack allocation across the native context | Storing additional temporary state in locals | -|------|-----|-----| -| P/Invoke | supported | supported | -| Reverse P/Invoke | unsupported | supported | -| User-defined structure content marshalling | unsupported | unsupported | -| non-blittable array marshalling | unsupported | unuspported | +| Scenarios | Pinning and Stack allocation across the native context | Storing additional temporary state in locals | +|--------------------------------------------|--------------------------------------------------------|----------------------------------------------| +| P/Invoke | supported | supported | +| Reverse P/Invoke | unsupported | supported | +| User-defined structure content marshalling | unsupported | unsupported | +| non-blittable array marshalling | unsupported | unuspported | To help enable developers to use the full model described in the [Struct Marshalling design](./StructMarshalling.md), we declare that in contexts where `AdditionalTemporaryStateLivesAcrossStages` is false, developers can still assume that state declared in the `Setup` phase is valid in any phase, but any side effects in code emitted in a phase other than `Setup` will not be guaranteed to be visible in other phases. This enables developers to still use the identifiers declared in the `Setup` phase in their other phases, but they'll need to take care to design their generators to handle these rules. diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs index 9221bef508f3b..0a18241ef7c76 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs @@ -2,13 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Linq; using System.Collections.Generic; using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -using Microsoft.CodeAnalysis; namespace Microsoft.Interop.JavaScript { @@ -61,13 +61,13 @@ public BlockSyntax GenerateJSExportBody() { StatementSyntax invoke = InvokeSyntax(); GeneratedStatements statements = GeneratedStatements.Create(_marshallers, _context); - bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.Cleanup.IsEmpty; + bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.CleanupCallerAllocated.IsEmpty || !statements.CleanupCalleeAllocated.IsEmpty; VariableDeclarations declarations = VariableDeclarations.GenerateDeclarationsForUnmanagedToManaged(_marshallers, _context, shouldInitializeVariables); var setupStatements = new List(); SetupSyntax(setupStatements); - if (!statements.GuaranteedUnmarshal.IsEmpty) + if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { setupStatements.Add(MarshallerHelpers.Declare(PredefinedType(Token(SyntaxKind.BoolKeyword)), InvokeSucceededIdentifier, initializeToDefault: true)); } @@ -81,7 +81,7 @@ public BlockSyntax GenerateJSExportBody() tryStatements.Add(invoke); - if (!statements.GuaranteedUnmarshal.IsEmpty) + if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { tryStatements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(InvokeSucceededIdentifier), @@ -94,12 +94,12 @@ public BlockSyntax GenerateJSExportBody() List allStatements = setupStatements; List finallyStatements = new List(); - if (!statements.GuaranteedUnmarshal.IsEmpty) + if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { - finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal))); + finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal.Concat(statements.CleanupCalleeAllocated)))); } - finallyStatements.AddRange(statements.Cleanup); + finallyStatements.AddRange(statements.CleanupCallerAllocated); if (finallyStatements.Count > 0) { allStatements.Add( diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportCodeGenerator.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportCodeGenerator.cs index 2c1d49fd3dd12..1f415b8338008 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportCodeGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSImportCodeGenerator.cs @@ -67,14 +67,14 @@ public BlockSyntax GenerateJSImportBody() { StatementSyntax invoke = InvokeSyntax(); GeneratedStatements statements = GeneratedStatements.Create(_marshallers, _context); - bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.Cleanup.IsEmpty; + bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.CleanupCallerAllocated.IsEmpty || !statements.CleanupCalleeAllocated.IsEmpty; VariableDeclarations declarations = VariableDeclarations.GenerateDeclarationsForManagedToUnmanaged(_marshallers, _context, shouldInitializeVariables); var setupStatements = new List(); BindSyntax(setupStatements); SetupSyntax(setupStatements); - if (!statements.GuaranteedUnmarshal.IsEmpty) + if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { setupStatements.Add(MarshallerHelpers.Declare(PredefinedType(Token(SyntaxKind.BoolKeyword)), InvokeSucceededIdentifier, initializeToDefault: true)); } @@ -88,7 +88,7 @@ public BlockSyntax GenerateJSImportBody() tryStatements.AddRange(statements.PinnedMarshal); tryStatements.Add(invoke); - if (!statements.GuaranteedUnmarshal.IsEmpty) + if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { tryStatements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(InvokeSucceededIdentifier), @@ -100,12 +100,12 @@ public BlockSyntax GenerateJSImportBody() List allStatements = setupStatements; List finallyStatements = new List(); - if (!statements.GuaranteedUnmarshal.IsEmpty) + if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { - finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal))); + finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal.Concat(statements.CleanupCalleeAllocated)))); } - finallyStatements.AddRange(statements.Cleanup); + finallyStatements.AddRange(statements.CleanupCallerAllocated); if (finallyStatements.Count > 0) { // Add try-finally block if there are any statements in the finally block diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs index 883ca85161771..dd18f01bb7c88 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs @@ -131,7 +131,7 @@ public BlockSyntax GenerateStubBody(int index, ImmutableArray = true; - if (!statements.GuaranteedUnmarshal.IsEmpty) + if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { tryStatements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(InvokeSucceededIdentifier), @@ -197,12 +197,12 @@ public BlockSyntax GenerateStubBody(int index, ImmutableArray allStatements = setupStatements; List finallyStatements = new List(); - if (!statements.GuaranteedUnmarshal.IsEmpty) + if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { - finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal))); + finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal.Concat(statements.CleanupCalleeAllocated)))); } - finallyStatements.AddRange(statements.Cleanup); + finallyStatements.AddRange(statements.CleanupCallerAllocated); if (finallyStatements.Count > 0) { // Add try-finally block if there are any statements in the finally block diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ManagedHResultExceptionMarshallerFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ManagedHResultExceptionMarshallerFactory.cs index e396ea5936fcb..63017cf51265a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ManagedHResultExceptionMarshallerFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ManagedHResultExceptionMarshallerFactory.cs @@ -83,7 +83,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont { Debug.Assert(info.MarshallingAttributeInfo is ManagedHResultExceptionMarshallingInfo); - if (context.CurrentStage != StubCodeContext.Stage.Unmarshal) + if (context.CurrentStage != StubCodeContext.Stage.NotifyForSuccessfulInvoke) { yield break; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs index 64060d613e429..cc76d853e06bd 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Diagnostics; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; @@ -52,9 +53,11 @@ public BlockSyntax GenerateStubBody(ExpressionSyntax methodToInvoke) _marshallers, _context, methodToInvoke); + Debug.Assert(statements.CleanupCalleeAllocated.IsEmpty); + bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty - || !statements.Cleanup.IsEmpty + || !statements.CleanupCallerAllocated.IsEmpty || !statements.ManagedExceptionCatchClauses.IsEmpty; VariableDeclarations declarations = VariableDeclarations.GenerateDeclarationsForUnmanagedToManaged(_marshallers, _context, shouldInitializeVariables); @@ -77,7 +80,7 @@ public BlockSyntax GenerateStubBody(ExpressionSyntax methodToInvoke) SyntaxList catchClauses = List(statements.ManagedExceptionCatchClauses); - finallyStatements.AddRange(statements.Cleanup); + finallyStatements.AddRange(statements.CleanupCallerAllocated); if (finallyStatements.Count > 0) { allStatements.Add( diff --git a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/PInvokeStubCodeGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/PInvokeStubCodeGenerator.cs index 34d09a70a66e6..939a623ab7dba 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/PInvokeStubCodeGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/PInvokeStubCodeGenerator.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Linq; using System.Collections.Generic; using System.Collections.Immutable; using Microsoft.CodeAnalysis.CSharp; @@ -107,7 +108,7 @@ public PInvokeStubCodeGenerator( public BlockSyntax GeneratePInvokeBody(string dllImportName) { GeneratedStatements statements = GeneratedStatements.Create(_marshallers, _context, IdentifierName(dllImportName)); - bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.Cleanup.IsEmpty; + bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.CleanupCallerAllocated.IsEmpty || !statements.CleanupCalleeAllocated.IsEmpty; VariableDeclarations declarations = VariableDeclarations.GenerateDeclarationsForManagedToUnmanaged(_marshallers, _context, shouldInitializeVariables); var setupStatements = new List(); @@ -121,7 +122,7 @@ public BlockSyntax GeneratePInvokeBody(string dllImportName) initializeToDefault: false)); } - if (!statements.GuaranteedUnmarshal.IsEmpty) + if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { setupStatements.Add(MarshallerHelpers.Declare(PredefinedType(Token(SyntaxKind.BoolKeyword)), InvokeSucceededIdentifier, initializeToDefault: true)); } @@ -148,7 +149,7 @@ public BlockSyntax GeneratePInvokeBody(string dllImportName) } tryStatements.Add(statements.Pin.NestFixedStatements(fixedBlock)); // = true; - if (!statements.GuaranteedUnmarshal.IsEmpty) + if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { tryStatements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(InvokeSucceededIdentifier), @@ -160,12 +161,12 @@ public BlockSyntax GeneratePInvokeBody(string dllImportName) List allStatements = setupStatements; List finallyStatements = new List(); - if (!statements.GuaranteedUnmarshal.IsEmpty) + if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { - finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal))); + finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(statements.GuaranteedUnmarshal.Concat(statements.CleanupCalleeAllocated)))); } - finallyStatements.AddRange(statements.Cleanup); + finallyStatements.AddRange(statements.CleanupCallerAllocated); if (finallyStatements.Count > 0) { // Add try-finally block if there are any statements in the finally block diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/GeneratedStatements.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/GeneratedStatements.cs index e8e3de05ffff2..d611eb59cd3b4 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/GeneratedStatements.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/GeneratedStatements.cs @@ -21,7 +21,8 @@ public struct GeneratedStatements public ImmutableArray Unmarshal { get; init; } public ImmutableArray NotifyForSuccessfulInvoke { get; init; } public ImmutableArray GuaranteedUnmarshal { get; init; } - public ImmutableArray Cleanup { get; init; } + public ImmutableArray CleanupCallerAllocated { get; init; } + public ImmutableArray CleanupCalleeAllocated { get; init; } public ImmutableArray ManagedExceptionCatchClauses { get; init; } @@ -38,7 +39,8 @@ public static GeneratedStatements Create(BoundGenerators marshallers, StubCodeCo .AddRange(GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Unmarshal })), NotifyForSuccessfulInvoke = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.NotifyForSuccessfulInvoke }), GuaranteedUnmarshal = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.GuaranteedUnmarshal }), - Cleanup = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.Cleanup }), + CleanupCallerAllocated = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.CleanupCallerAllocated }), + CleanupCalleeAllocated = GenerateStatementsForStubContext(marshallers, context with { CurrentStage = StubCodeContext.Stage.CleanupCalleeAllocated }), ManagedExceptionCatchClauses = GenerateCatchClauseForManagedException(marshallers, context) }; } @@ -182,7 +184,8 @@ private static SyntaxTriviaList GenerateStageTrivia(StubCodeContext.Stage stage) StubCodeContext.Stage.Invoke => "Call the P/Invoke.", StubCodeContext.Stage.UnmarshalCapture => "Capture the native data into marshaller instances in case conversion to managed data throws an exception.", StubCodeContext.Stage.Unmarshal => "Convert native data to managed data.", - StubCodeContext.Stage.Cleanup => "Perform required cleanup.", + StubCodeContext.Stage.CleanupCallerAllocated => "Perform cleanup of caller allocated resources.", + StubCodeContext.Stage.CleanupCalleeAllocated => "Perform cleanup of callee allocated resources.", StubCodeContext.Stage.NotifyForSuccessfulInvoke => "Keep alive any managed objects that need to stay alive across the call.", StubCodeContext.Stage.GuaranteedUnmarshal => "Convert native data to managed data even in the case of an exception during the non-cleanup phases.", _ => throw new ArgumentOutOfRangeException(nameof(stage)) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs index c992f63aae700..30f333d59540e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CustomTypeMarshallingGenerator.cs @@ -97,8 +97,10 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont return _nativeTypeMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); } break; - case StubCodeContext.Stage.Cleanup: - return _nativeTypeMarshaller.GenerateCleanupStatements(info, context); + case StubCodeContext.Stage.CleanupCallerAllocated: + return _nativeTypeMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, context); + case StubCodeContext.Stage.CleanupCalleeAllocated: + return _nativeTypeMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, context); default: break; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs index 7b71e9e0e9e35..136de0e4cca53 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Linq; using Microsoft.CodeAnalysis; @@ -440,7 +441,7 @@ public override StatementSyntax GenerateElementCleanupStatement(TypePositionInfo indexConstraintName, _elementInfo, _elementMarshaller, - StubCodeContext.Stage.Cleanup); + context.CurrentStage); if (contentsCleanupStatements.IsKind(SyntaxKind.EmptyStatement)) { @@ -531,6 +532,18 @@ public override StatementSyntax GenerateUnmanagedToManagedByValueOutMarshalState .WithInitializer(EqualsValueClause( CollectionSource.GetManagedValuesDestination(info, context)))))); + StubCodeContext.Stage[] stagesToGenerate; + + // Until we separate CalleeAllocated cleanup and CallerAllocated cleanup in unmanaged to managed, we'll need this hack + if (context.Direction is MarshalDirection.UnmanagedToManaged && info.ByValueContentsMarshalKind is ByValueContentsMarshalKind.Out) + { + stagesToGenerate = new[] { StubCodeContext.Stage.Marshal, StubCodeContext.Stage.PinnedMarshal }; + } + else + { + stagesToGenerate = new[] { StubCodeContext.Stage.Marshal, StubCodeContext.Stage.PinnedMarshal, StubCodeContext.Stage.CleanupCallerAllocated, StubCodeContext.Stage.CleanupCalleeAllocated }; + } + return Block( setNumElements, unmanagedValuesSource, @@ -541,9 +554,7 @@ public override StatementSyntax GenerateUnmanagedToManagedByValueOutMarshalState IdentifierName(numElementsIdentifier), _elementInfo, new FreeAlwaysOwnedOriginalValueGenerator(_elementMarshaller), - StubCodeContext.Stage.Marshal, - StubCodeContext.Stage.PinnedMarshal, - StubCodeContext.Stage.Cleanup)); + stagesToGenerate)); } private static List GenerateElementStages( diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomTypeMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomTypeMarshallingStrategy.cs index f9da5d32b3977..53223656b7ddf 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomTypeMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomTypeMarshallingStrategy.cs @@ -13,7 +13,9 @@ internal interface ICustomTypeMarshallingStrategy { ManagedTypeInfo AsNativeType(TypePositionInfo info); - IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context); + IEnumerable GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context); + + IEnumerable GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context); IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs index ebb1f01bf8d3f..325f4741600ac 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs @@ -405,6 +405,26 @@ public static MarshalDirection GetMarshalDirection(TypePositionInfo info, StubCo throw new UnreachableException("An element is either a return value or passed by value or by ref."); } + /// + /// Returns which stage cleanup should be performed for the parameter. + /// + public static StubCodeContext.Stage GetCleanupStage(TypePositionInfo info, StubCodeContext context) + { + // Unmanaged to managed doesn't properly handle lifetimes right now and will default to the original behavior. + // Failures will only occur when marshalling fails, and would only cause leaks, not double frees. + // See https://github.com/dotnet/runtime/issues/89483 for more details + if (context.Direction is MarshalDirection.UnmanagedToManaged) + return StubCodeContext.Stage.CleanupCallerAllocated; + + return GetMarshalDirection(info, context) switch + { + MarshalDirection.UnmanagedToManaged => StubCodeContext.Stage.CleanupCalleeAllocated, + MarshalDirection.ManagedToUnmanaged => StubCodeContext.Stage.CleanupCallerAllocated, + MarshalDirection.Bidirectional => StubCodeContext.Stage.CleanupCallerAllocated, + _ => throw new UnreachableException() + }; + } + /// /// Ensure that the count of a collection is available at call time if the parameter is not an out parameter. /// It only looks at an indirection level of 0 (the size of the outer array), so there are some holes in @@ -417,10 +437,10 @@ public static void ValidateCountInfoAvailableAtCall(MarshalDirection stubDirecti if (stubDirection is MarshalDirection.ManagedToUnmanaged) return; - if (info.MarshallingAttributeInfo is NativeLinearCollectionMarshallingInfo collectionMarshallingInfo - && collectionMarshallingInfo.ElementCountInfo is CountElementCountInfo countInfo - && !(info.RefKind is RefKind.Out - || info.ManagedIndex is TypePositionInfo.ReturnIndex)) + if (!(info.RefKind is RefKind.Out + || info.ManagedIndex is TypePositionInfo.ReturnIndex) + && info.MarshallingAttributeInfo is NativeLinearCollectionMarshallingInfo collectionMarshallingInfo + && collectionMarshallingInfo.ElementCountInfo is CountElementCountInfo countInfo) { if (countInfo.ElementInfo.IsByRef && countInfo.ElementInfo.RefKind is RefKind.Out) { @@ -444,6 +464,5 @@ public static void ValidateCountInfoAvailableAtCall(MarshalDirection stubDirecti // If the parameter is multidimensional and a higher indirection level parameter is ByValue [Out], then we should warn. } } - } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/SafeHandleMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/SafeHandleMarshaller.cs index 4e0b3bcf40f92..9f4700ec00f93 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/SafeHandleMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/SafeHandleMarshaller.cs @@ -214,7 +214,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont IdentifierName(newHandleObjectIdentifier))))); } break; - case StubCodeContext.Stage.Cleanup: + case StubCodeContext.Stage.CleanupCallerAllocated: if (!info.IsManagedReturnPosition && (!info.IsByRef || info.RefKind == RefKind.In)) { yield return IfStatement( diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs index edb226fd74cac..22c1c33c9f226 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs @@ -30,8 +30,28 @@ public ManagedTypeInfo AsNativeType(TypePositionInfo info) public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true; - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + public IEnumerable GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) { + if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCallerAllocated) + yield break; + + if (!_shape.HasFlag(MarshallerShape.Free)) + yield break; + + // .Free(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(context.GetAdditionalIdentifier(info, MarshallerIdentifier)), + IdentifierName(ShapeMemberNames.Free)), + ArgumentList())); + } + + public IEnumerable GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) + { + if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCalleeAllocated) + yield break; + if (!_shape.HasFlag(MarshallerShape.Free)) yield break; @@ -213,9 +233,14 @@ public ManagedTypeInfo AsNativeType(TypePositionInfo info) return _innerMarshaller.AsNativeType(info); } - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + public IEnumerable GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) { - return _innerMarshaller.GenerateCleanupStatements(info, context); + return _innerMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, context); + } + + public IEnumerable GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) + { + return _innerMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, context); } public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) @@ -371,12 +396,26 @@ public StatefulLinearCollectionMarshalling( } public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + + public IEnumerable GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) { + // We don't have anything to cleanup specifically related to this value, just the elements. We let the element marshaller decide whether to cleanup in callee or caller cleanup stage if (!_cleanupElements) - { yield break; + + StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context); + + if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement)) + { + yield return elementCleanup; } + } + + public IEnumerable GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) + { + // We don't have anything to cleanup specifically related to this value, just the elements. We let the element marshaller decide whether to cleanup in callee or caller cleanup stage + if (!_cleanupElements) + yield break; StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context); @@ -385,6 +424,7 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i yield return elementCleanup; } } + public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) @@ -504,13 +544,36 @@ public StatefulFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller) public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + public IEnumerable GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) + { + foreach (var statement in _innerMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, context)) + { + yield return statement; + } + + if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCallerAllocated) + yield break; + + string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context); + // .Free(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(marshaller), + IdentifierName(ShapeMemberNames.Free)), + ArgumentList())); + } + + public IEnumerable GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) { - foreach (var statement in _innerMarshaller.GenerateCleanupStatements(info, context)) + foreach (var statement in _innerMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, context)) { yield return statement; } + if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCalleeAllocated) + yield break; + string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context); // .Free(); yield return ExpressionStatement( diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs index 6303abe2bfbc8..2ed2653eb96a7 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs @@ -33,7 +33,9 @@ public ManagedTypeInfo AsNativeType(TypePositionInfo info) public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true; - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); + public IEnumerable GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); + + public IEnumerable GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) { @@ -159,7 +161,8 @@ public StatelessCallerAllocatedBufferMarshalling(ICustomTypeMarshallingStrategy } public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupStatements(info, context); + public IEnumerable GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, context); + public IEnumerable GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, context); public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) @@ -266,9 +269,31 @@ public StatelessFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller, public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + public IEnumerable GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) { - foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupStatements(info, context)) + if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCallerAllocated) + yield break; + + foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, context)) + { + yield return statement; + } + // .Free(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + _marshallerType, + IdentifierName(ShapeMemberNames.Free)), + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(context.GetIdentifiers(info).native)))))); + } + + public IEnumerable GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) + { + if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCalleeAllocated) + yield break; + + foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, context)) { yield return statement; } @@ -316,8 +341,25 @@ public ManagedTypeInfo AsNativeType(TypePositionInfo info) return _unmanagedType; } - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + public IEnumerable GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) { + if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCallerAllocated) + yield break; + + string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context); + // = ; + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(numElementsIdentifier), + _numElementsExpression)); + } + + public IEnumerable GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) + { + if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCalleeAllocated) + yield break; + if (MarshallerHelpers.GetMarshalDirection(info, context) == MarshalDirection.ManagedToUnmanaged) { yield return EmptyStatement(); @@ -325,6 +367,7 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i } string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context); + // = ; yield return ExpressionStatement( AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, @@ -397,6 +440,7 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) { + // int ; string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context); yield return LocalDeclarationStatement( VariableDeclaration( @@ -554,12 +598,13 @@ public StatelessLinearCollectionMarshalling( public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _unmanagedType; - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + public IEnumerable GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) { if (!_cleanupElementsAndSpace) { yield break; } + StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context); if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement)) @@ -567,6 +612,7 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i // If we don't have the numElements variable still available from unmarshal or marshal stage, we need to reassign that again if (!context.AdditionalTemporaryStateLivesAcrossStages) { + // = ; string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context); yield return ExpressionStatement( AssignmentExpression( @@ -577,9 +623,45 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i yield return elementCleanup; } - foreach (var statement in _spaceMarshallingStrategy.GenerateCleanupStatements(info, context)) + if (MarshallerHelpers.GetCleanupStage(info, context) is StubCodeContext.Stage.CleanupCallerAllocated) { - yield return statement; + foreach (var statement in _spaceMarshallingStrategy.GenerateCleanupCallerAllocatedResourcesStatements(info, context)) + { + yield return statement; + } + } + } + + public IEnumerable GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) + { + if (!_cleanupElementsAndSpace) + { + yield break; + } + StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context); + + if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement)) + { + // If we don't have the numElements variable still available from unmarshal or marshal stage, we need to reassign that again + if (!context.AdditionalTemporaryStateLivesAcrossStages) + { + // = ; + string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context); + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(numElementsIdentifier), + _numElementsExpression)); + } + yield return elementCleanup; + } + + if (MarshallerHelpers.GetCleanupStage(info, context) is StubCodeContext.Stage.CleanupCallerAllocated) + { + foreach (var statement in _spaceMarshallingStrategy.GenerateCleanupCalleeAllocatedResourcesStatements(info, context)) + { + yield return statement; + } } } @@ -646,6 +728,10 @@ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection. // We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content. yield return _elementsMarshalling.GenerateClearManagedValuesDestination(info, context); + foreach (var statement in _spaceMarshallingStrategy.GenerateUnmarshalStatements(info, context)) + { + yield return statement; + } yield break; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/UnmanagedToManagedOwnershipTrackingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/UnmanagedToManagedOwnershipTrackingStrategy.cs index 5b9d0866f743c..ba43659536d52 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/UnmanagedToManagedOwnershipTrackingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/UnmanagedToManagedOwnershipTrackingStrategy.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; @@ -25,7 +24,8 @@ public UnmanagedToManagedOwnershipTrackingStrategy(ICustomTypeMarshallingStrateg public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupStatements(info, context); + public IEnumerable GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, context); + public IEnumerable GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, context); public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) @@ -78,7 +78,7 @@ public IEnumerable GenerateSetupStatements(TypePositionInfo inf /// /// Marshalling strategy that uses the tracking variables introduced by to cleanup the original value if the original value is owned - /// in the stage. + /// in the stage. /// internal sealed class CleanupOwnedOriginalValueMarshalling : ICustomTypeMarshallingStrategy { @@ -91,15 +91,30 @@ public CleanupOwnedOriginalValueMarshalling(ICustomTypeMarshallingStrategy inner public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); - public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + public IEnumerable GenerateCleanupCallerAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) { + if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCallerAllocated) + yield break; // if () // { // // } yield return IfStatement( IdentifierName(context.GetAdditionalIdentifier(info, OwnershipTrackingHelpers.OwnOriginalValueIdentifier)), - Block(_innerMarshaller.GenerateCleanupStatements(info, new OwnedValueCodeContext(context)))); + Block(_innerMarshaller.GenerateCleanupCallerAllocatedResourcesStatements(info, new OwnedValueCodeContext(context)))); + } + + public IEnumerable GenerateCleanupCalleeAllocatedResourcesStatements(TypePositionInfo info, StubCodeContext context) + { + if (MarshallerHelpers.GetCleanupStage(info, context) is not StubCodeContext.Stage.CleanupCalleeAllocated) + yield break; + // if () + // { + // + // } + yield return IfStatement( + IdentifierName(context.GetAdditionalIdentifier(info, OwnershipTrackingHelpers.OwnOriginalValueIdentifier)), + Block(_innerMarshaller.GenerateCleanupCalleeAllocatedResourcesStatements(info, new OwnedValueCodeContext(context)))); } public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); @@ -119,7 +134,7 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i /// /// Marshalling strategy to cache the initial value of a given in a local variable and cleanup that value in the cleanup stage. - /// Useful in scenarios where the value is always owned in all code-paths that reach the stage, so additional ownership tracking is extraneous. + /// Useful in scenarios where the value is always owned in all code-paths that reach the stage, so additional ownership tracking is extraneous. /// internal sealed class FreeAlwaysOwnedOriginalValueGenerator : IMarshallingGenerator { @@ -138,7 +153,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont return GenerateSetupStatements(); } - if (context.CurrentStage == StubCodeContext.Stage.Cleanup) + if (context.CurrentStage == StubCodeContext.Stage.CleanupCallerAllocated) { return GenerateStatementsFromInner(new OwnedValueCodeContext(context)); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StubCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StubCodeContext.cs index 35e07f1f7dc47..50af716eb3a51 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StubCodeContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StubCodeContext.cs @@ -64,9 +64,14 @@ public enum Stage NotifyForSuccessfulInvoke, /// - /// Perform any cleanup required + /// Perform any cleanup required on caller allocated resources /// - Cleanup, + CleanupCallerAllocated, + + /// + /// Perform any cleanup required on callee allocated resources + /// + CleanupCalleeAllocated, /// /// Convert native data to managed data even in the case of an exception during diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs index 5157e3ca029e8..98390d97417d7 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs @@ -84,6 +84,63 @@ public void IIntArray() Assert.True(data is [2, 4, 6]); } + [Fact] + public void IArrayOfStatelessElements() + { + var obj = CreateWrapper(); + var data = new StatelessType[10]; + + // ByValueContentsOut should only free the returned values + var oldFreeCount = StatelessTypeMarshaller.AllFreeCount; + obj.MethodContentsOut(data, data.Length); + Assert.Equal(oldFreeCount + 10, StatelessTypeMarshaller.AllFreeCount); + + // ByValueContentsOut should only free the elements after the call + oldFreeCount = StatelessTypeMarshaller.AllFreeCount; + obj.MethodContentsIn(data, data.Length); + Assert.Equal(oldFreeCount + 10, StatelessTypeMarshaller.AllFreeCount); + + // ByValueContentsInOut should free elements in both directions + oldFreeCount = StatelessTypeMarshaller.AllFreeCount; + obj.MethodContentsInOut(data, data.Length); + Assert.Equal(oldFreeCount + 20, StatelessTypeMarshaller.AllFreeCount); + } + + [Fact] + public void IArrayOfStatelessElementsThrows() + { + var obj = CreateWrapper(); + var data = new StatelessType[10]; + var oldFreeCount = StatelessTypeMarshaller.AllFreeCount; + try + { + obj.MethodContentsOut(data, 10); + } + catch (Exception) { } + Assert.Equal(oldFreeCount, StatelessTypeMarshaller.AllFreeCount); + + for (int i = 0; i < 10; i++) + { + data[i] = new StatelessType() { I = i }; + } + + oldFreeCount = StatelessTypeMarshaller.AllFreeCount; + try + { + obj.MethodContentsIn(data, 10); + } + catch (Exception) { } + Assert.Equal(oldFreeCount + 10, StatelessTypeMarshaller.AllFreeCount); + + oldFreeCount = StatelessTypeMarshaller.AllFreeCount; + try + { + obj.MethodContentsInOut(data, 10); + } + catch (Exception) { } + Assert.Equal(oldFreeCount + 10, StatelessTypeMarshaller.AllFreeCount); + } + [Fact] public void IJaggedIntArray() { @@ -138,35 +195,36 @@ public void IStatelessFinallyMarshalling() } [Fact] - [ActiveIssue("https://github.com/dotnet/runtime/issues/89747")] public void ICollectionMarshallingFails() { + Type hrExceptionType = SystemFindsComCalleeException() ? typeof(MarshallingFailureException) : typeof(Exception); + var obj = CreateWrapper(); Assert.Throws(() => - _ = obj.GetConstSize() + obj.Set(new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 0 }, 10) ); - Assert.Throws(() => - _ = obj.Get(out _) + Assert.Throws(hrExceptionType, () => + _ = obj.GetConstSize() ); - Assert.Throws(() => - obj.Set(new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 0 }, 10) + Assert.Throws(hrExceptionType, () => + _ = obj.Get(out _) ); } [Fact] - [ActiveIssue("https://github.com/dotnet/runtime/issues/89747")] public void IJaggedArrayMarshallingFails() { + Type hrExceptionType = SystemFindsComCalleeException() ? typeof(MarshallingFailureException) : typeof(Exception); var obj = CreateWrapper(); - Assert.Throws(() => + Assert.Throws(hrExceptionType, () => _ = obj.GetConstSize() ); - Assert.Throws(() => + Assert.Throws(hrExceptionType, () => _ = obj.Get(out _, out _) ); var array = new int[][] { new int[] { 1, 2, 3 }, new int[] { 4, 5, }, new int[] { 6, 7, 8, 9 } }; diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IArrayOfStatelessElements.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IArrayOfStatelessElements.cs index abb8842f1e0a2..658606d84ee3c 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IArrayOfStatelessElements.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IArrayOfStatelessElements.cs @@ -72,4 +72,16 @@ public void MethodRef(ref StatelessType[] param, int size) } } } + + [GeneratedComClass] + internal partial class ArrayOfStatelessElementsThrows : IArrayOfStatelessElements + { + public void Method(StatelessType[] param, int size) => throw new ManagedComMethodFailureException(); + public void MethodContentsIn(StatelessType[] param, int size) => throw new ManagedComMethodFailureException(); + public void MethodContentsInOut(StatelessType[] param, int size) => throw new ManagedComMethodFailureException(); + public void MethodContentsOut(StatelessType[] param, int size) => throw new ManagedComMethodFailureException(); + public void MethodIn(in StatelessType[] param, int size) => throw new ManagedComMethodFailureException(); + public void MethodOut(out StatelessType[] param, int size) => throw new ManagedComMethodFailureException(); + public void MethodRef(ref StatelessType[] param, int size) => throw new ManagedComMethodFailureException(); + } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatefulFinallyMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatefulFinallyMarshalling.cs index 37fce2297a88c..3c9951c03fafe 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatefulFinallyMarshalling.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatefulFinallyMarshalling.cs @@ -19,6 +19,7 @@ internal partial interface IStatefulFinallyMarshalling [PreserveSig] StatefulFinallyType ReturnPreserveSig(); } + [GeneratedComClass] internal partial class StatefulFinallyMarshalling : IStatefulFinallyMarshalling { diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessMarshalling.cs index 12c98b69ece31..66ccc0c3e1636 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessMarshalling.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessMarshalling.cs @@ -48,6 +48,8 @@ internal class StatelessType [CustomMarshaller(typeof(StatelessType), MarshalMode.ElementRef, typeof(Bidirectional))] internal static class StatelessTypeMarshaller { + public static int AllFreeCount => Bidirectional.FreeCount + UnmanagedToManaged.FreeCount + ManagedToUnmanaged.FreeCount; + internal static class Bidirectional { public static int FreeCount { get; private set; } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/ManagedComMethodFailureException.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/ManagedComMethodFailureException.cs new file mode 100644 index 0000000000000..de5aea6228a14 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/ManagedComMethodFailureException.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace SharedTypes.ComInterfaces +{ + internal class ManagedComMethodFailureException : Exception + { + } +}