diff --git a/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs b/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs index a6494bf27883e7..4a932f6dc7637e 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs @@ -1351,26 +1351,20 @@ internal static void DestroyCleanupList(ref CleanupWorkListElement? pCleanupWork internal static Exception GetHRExceptionObject(int hr) { - Exception? ex = null; - GetHRExceptionObject(hr, ObjectHandleOnStack.Create(ref ex)); - ex!.InternalPreserveStackTrace(); - return ex!; + Exception ex = Marshal.GetExceptionForHR(hr)!; + ex.InternalPreserveStackTrace(); + return ex; } - [LibraryImport(RuntimeHelpers.QCall, EntryPoint = "StubHelpers_GetHRExceptionObject")] - private static partial void GetHRExceptionObject(int hr, ObjectHandleOnStack throwable); - #if FEATURE_COMINTEROP - internal static Exception GetCOMHRExceptionObject(int hr, IntPtr pCPCMD, object pThis) + internal static Exception GetCOMHRExceptionObject(int hr, IntPtr pCPCMD, IntPtr pUnk) { - Exception? ex = null; - GetCOMHRExceptionObject(hr, pCPCMD, ObjectHandleOnStack.Create(ref pThis), ObjectHandleOnStack.Create(ref ex)); - ex!.InternalPreserveStackTrace(); - return ex!; + RuntimeMethodHandle handle = RuntimeMethodHandle.FromIntPtr(pCPCMD); + RuntimeType declaringType = RuntimeMethodHandle.GetDeclaringType(handle.GetMethodInfo()); + Exception ex = Marshal.GetExceptionForHR(hr, declaringType.GUID, pUnk)!; + ex.InternalPreserveStackTrace(); + return ex; } - - [LibraryImport(RuntimeHelpers.QCall, EntryPoint = "StubHelpers_GetCOMHRExceptionObject")] - private static partial void GetCOMHRExceptionObject(int hr, IntPtr pCPCMD, ObjectHandleOnStack pThis, ObjectHandleOnStack throwable); #endif // FEATURE_COMINTEROP [ThreadStatic] diff --git a/src/coreclr/vm/corelib.h b/src/coreclr/vm/corelib.h index 5b894a5312d97c..f9d65b28a35ac5 100644 --- a/src/coreclr/vm/corelib.h +++ b/src/coreclr/vm/corelib.h @@ -1039,7 +1039,7 @@ DEFINE_METHOD(BUFFER, MEMCOPYGC, BulkMoveWithWriteBar DEFINE_CLASS(STUBHELPERS, StubHelpers, StubHelpers) DEFINE_METHOD(STUBHELPERS, GET_DELEGATE_TARGET, GetDelegateTarget, SM_Delegate_RetIntPtr) #ifdef FEATURE_COMINTEROP -DEFINE_METHOD(STUBHELPERS, GET_COM_HR_EXCEPTION_OBJECT, GetCOMHRExceptionObject, SM_Int_IntPtr_Obj_RetException) +DEFINE_METHOD(STUBHELPERS, GET_COM_HR_EXCEPTION_OBJECT, GetCOMHRExceptionObject, SM_Int_IntPtr_IntPtr_RetException) DEFINE_METHOD(STUBHELPERS, GET_COM_IP_FROM_RCW, GetCOMIPFromRCW, SM_Obj_IntPtr_RefIntPtr_RefBool_RetIntPtr) #endif // FEATURE_COMINTEROP DEFINE_METHOD(STUBHELPERS, SET_LAST_ERROR, SetLastError, SM_RetVoid) diff --git a/src/coreclr/vm/dllimport.cpp b/src/coreclr/vm/dllimport.cpp index ace24da99e6849..05a5c538e59e4f 100644 --- a/src/coreclr/vm/dllimport.cpp +++ b/src/coreclr/vm/dllimport.cpp @@ -736,14 +736,9 @@ class ILStubState : public StubState #ifdef FEATURE_COMINTEROP if (SF_IsForwardCOMStub(m_dwStubFlags)) { - // Make sure that the RCW stays alive for the duration of the call. Note that if we do HRESULT - // swapping, we'll pass 'this' to GetCOMHRExceptionObject after returning from the target so - // GC.KeepAlive is not necessary. - if (!SF_IsHRESULTSwapping(m_dwStubFlags)) - { - m_slIL.EmitLoadRCWThis(pcsDispatch, m_dwStubFlags); - pcsDispatch->EmitCALL(METHOD__GC__KEEP_ALIVE, 1, 0); - } + // Make sure that the RCW stays alive for the duration of the call. + m_slIL.EmitLoadRCWThis(pcsDispatch, m_dwStubFlags); + pcsDispatch->EmitCALL(METHOD__GC__KEEP_ALIVE, 1, 0); } #endif // FEATURE_COMINTEROP @@ -761,7 +756,7 @@ class ILStubState : public StubState if (SF_IsCOMStub(m_dwStubFlags)) { m_slIL.EmitLoadStubContext(pcsDispatch, m_dwStubFlags); - m_slIL.EmitLoadRCWThis(pcsDispatch, m_dwStubFlags); + pcsDispatch->EmitLDLOC(m_slIL.GetTargetInterfacePointerLocalNum()); pcsDispatch->EmitCALL(METHOD__STUBHELPERS__GET_COM_HR_EXCEPTION_OBJECT, 3, 1); } diff --git a/src/coreclr/vm/metasig.h b/src/coreclr/vm/metasig.h index 1dcfce6a42ba39..3a7e24b2fc6bcb 100644 --- a/src/coreclr/vm/metasig.h +++ b/src/coreclr/vm/metasig.h @@ -169,7 +169,7 @@ // static methods: -DEFINE_METASIG_T(SM(Int_IntPtr_Obj_RetException, i I j, C(EXCEPTION))) +DEFINE_METASIG_T(SM(Int_IntPtr_IntPtr_RetException, i I I, C(EXCEPTION))) DEFINE_METASIG_T(SM(Type_CharPtr_RuntimeAssembly_Bool_Bool_IntPtr_RetRuntimeType, P(u) C(ASSEMBLY) F F I, C(CLASS))) DEFINE_METASIG_T(SM(Type_RetIntPtr, C(TYPE), I)) DEFINE_METASIG(SM(RefIntPtr_IntPtr_IntPtr_Int_RetObj, r(I) I I i, j)) diff --git a/src/coreclr/vm/qcallentrypoints.cpp b/src/coreclr/vm/qcallentrypoints.cpp index f8238e0811025e..f3343f5555a1f2 100644 --- a/src/coreclr/vm/qcallentrypoints.cpp +++ b/src/coreclr/vm/qcallentrypoints.cpp @@ -518,9 +518,7 @@ static const Entry s_QCall[] = DllImportEntry(StubHelpers_ProfilerBeginTransitionCallback) DllImportEntry(StubHelpers_ProfilerEndTransitionCallback) #endif - DllImportEntry(StubHelpers_GetHRExceptionObject) #if defined(FEATURE_COMINTEROP) - DllImportEntry(StubHelpers_GetCOMHRExceptionObject) DllImportEntry(StubHelpers_GetCOMIPFromRCWSlow) DllImportEntry(ObjectMarshaler_ConvertToNative) DllImportEntry(ObjectMarshaler_ConvertToManaged) diff --git a/src/coreclr/vm/stubhelpers.cpp b/src/coreclr/vm/stubhelpers.cpp index afcab14de162d7..2ac34dbfa30b25 100644 --- a/src/coreclr/vm/stubhelpers.cpp +++ b/src/coreclr/vm/stubhelpers.cpp @@ -543,81 +543,6 @@ extern "C" void QCALLTYPE StubHelpers_ProfilerEndTransitionCallback(MethodDesc* } #endif // PROFILING_SUPPORTED -extern "C" void QCALLTYPE StubHelpers_GetHRExceptionObject(HRESULT hr, QCall::ObjectHandleOnStack result) -{ - QCALL_CONTRACT; - - BEGIN_QCALL; - - GCX_COOP(); - - OBJECTREF oThrowable = NULL; - GCPROTECT_BEGIN(oThrowable); - - // GetExceptionForHR uses equivalant logic as COMPlusThrowHR - GetExceptionForHR(hr, &oThrowable); - result.Set(oThrowable); - - GCPROTECT_END(); - - END_QCALL; -} - -#ifdef FEATURE_COMINTEROP -extern "C" void QCALLTYPE StubHelpers_GetCOMHRExceptionObject( - HRESULT hr, - MethodDesc* pMD, - QCall::ObjectHandleOnStack pThis, - QCall::ObjectHandleOnStack result) -{ - QCALL_CONTRACT; - - BEGIN_QCALL; - - GCX_COOP(); - - struct - { - OBJECTREF oThrowable; - OBJECTREF oref; - } gc; - gc.oThrowable = NULL; - gc.oref = NULL; - GCPROTECT_BEGIN(gc); - - IErrorInfo* pErrorInfo = NULL; - if (pMD != NULL) - { - // Retrieve the interface method table. - MethodTable* pItfMT = CLRToCOMCallInfo::FromMethodDesc(pMD)->m_pInterfaceMT; - - // get 'this' - gc.oref = ObjectToOBJECTREF(pThis.Get()); - - // Get IUnknown pointer for this interface on this object - IUnknown* pUnk = ComObject::GetComIPFromRCW(&gc.oref, pItfMT); - if (pUnk != NULL) - { - // Check to see if the component supports error information for this interface. - IID ItfIID; - pItfMT->GetGuid(&ItfIID, TRUE); - pErrorInfo = GetSupportedErrorInfo(pUnk, ItfIID); - - DWORD cbRef = SafeRelease(pUnk); - LogInteropRelease(pUnk, cbRef, "IUnk to QI for ISupportsErrorInfo"); - } - } - - // GetExceptionForHR will handle lifetime of IErrorInfo. - GetExceptionForHR(hr, pErrorInfo, &gc.oThrowable); - result.Set(gc.oThrowable); - - GCPROTECT_END(); - - END_QCALL; -} -#endif // FEATURE_COMINTEROP - extern "C" void QCALLTYPE StubHelpers_MarshalToManagedVaList(va_list va, VARARGS* pArgIterator) { QCALL_CONTRACT; diff --git a/src/coreclr/vm/stubhelpers.h b/src/coreclr/vm/stubhelpers.h index e77ce5f2e7f46e..acce48a552b831 100644 --- a/src/coreclr/vm/stubhelpers.h +++ b/src/coreclr/vm/stubhelpers.h @@ -49,11 +49,7 @@ extern "C" void* QCALLTYPE StubHelpers_ProfilerBeginTransitionCallback(MethodDes extern "C" void QCALLTYPE StubHelpers_ProfilerEndTransitionCallback(MethodDesc* pTargetMD); #endif -extern "C" void QCALLTYPE StubHelpers_GetHRExceptionObject(HRESULT hr, QCall::ObjectHandleOnStack result); - #ifdef FEATURE_COMINTEROP -extern "C" void QCALLTYPE StubHelpers_GetCOMHRExceptionObject(HRESULT hr, MethodDesc *pMD, QCall::ObjectHandleOnStack pThis, QCall::ObjectHandleOnStack result); - extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHandleOnStack pSrc, MethodDesc* pMD, void** ppTarget); extern "C" void QCALLTYPE ObjectMarshaler_ConvertToNative(QCall::ObjectHandleOnStack pSrcUNSAFE, VARIANT* pDest); diff --git a/src/libraries/Common/src/Interop/Windows/OleAut32/Interop.GetErrorInfo.cs b/src/libraries/Common/src/Interop/Windows/OleAut32/Interop.GetErrorInfo.cs new file mode 100644 index 00000000000000..eb450acf651a36 --- /dev/null +++ b/src/libraries/Common/src/Interop/Windows/OleAut32/Interop.GetErrorInfo.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; + +internal static partial class Interop +{ + internal static partial class OleAut32 + { + [LibraryImport(Libraries.OleAut32)] + internal static partial int GetErrorInfo(uint dwReserved, out IntPtr ppErrorInfo); + } +} diff --git a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems index 150b1d6a8e11cd..a9aba695d9b215 100644 --- a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems +++ b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems @@ -2156,6 +2156,9 @@ Common\Interop\Windows\Ole32\Interop.PropVariantClear.cs + + Common\Interop\Windows\OleAut32\Interop.GetErrorInfo.cs + Common\Interop\Windows\Secur32\Interop.GetUserNameExW.cs diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.cs index a2b30029b0216f..6a0c4026bca245 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.cs @@ -653,6 +653,63 @@ public static IntPtr GetHINSTANCE(Module m) return GetExceptionForHRInternal(errorCode, errorInfo); } + public static Exception? GetExceptionForHR(int errorCode, in Guid iid, IntPtr pUnk) + { + if (errorCode >= 0) + { + return null; + } + + return GetExceptionForHRInternal(errorCode, in iid, pUnk); + } + + private static unsafe Exception? GetExceptionForHRInternal(int errorCode, in Guid iid, IntPtr pUnk) + { + const IntPtr NoErrorInfo = -1; // Use -1 to indicate no error info available + + // Normally, we would check if the interface supports IErrorInfo first. However, + // built-in COM calls GetErrorInfo first to clear the error info, so we follow + // that pattern here. + IntPtr errorInfo = NoErrorInfo; + +#if TARGET_WINDOWS + Interop.OleAut32.GetErrorInfo(0, out errorInfo); + if (errorInfo == IntPtr.Zero) + { + errorInfo = NoErrorInfo; + } + + // If there is error info and we have a pointer to the interface, + // we check if it supports ISupportErrorInfo. + if (errorInfo != NoErrorInfo && pUnk != IntPtr.Zero) + { + Guid IID_ISupportErrorInfo = new(0xDF0B3D60, 0x548F, 0x101B, 0x8E, 0x65, 0x08, 0x00, 0x2B, 0x2B, 0xD1, 0x19); + int hr = QueryInterface(pUnk, in IID_ISupportErrorInfo, out IntPtr supportErrorInfo); + if (hr == 0) + { + // Check if the target interface is supported. + // ISupportErrorInfo.InterfaceSupportsErrorInfo slot + fixed (Guid* piid = &iid) + { + hr = ((delegate* unmanaged[MemberFunction])(*(*(void***)supportErrorInfo + 3)))(supportErrorInfo, piid); + } + Release(supportErrorInfo); + } + + // If ISupportErrorInfo isn't supported or the target interface doesn't support IErrorInfo, + // release the error info and mark it as NoErrorInfo to avoid querying for IErrorInfo again. + if (hr != 0) + { + Release(errorInfo); + errorInfo = NoErrorInfo; + } + } +#endif + + // If the error info is valid, its lifetime will be handled by GetExceptionForHRInternal(). + return GetExceptionForHRInternal(errorCode, errorInfo); + } + #if !CORECLR #pragma warning disable IDE0060 private static Exception? GetExceptionForHRInternal(int errorCode, IntPtr errorInfo) @@ -865,6 +922,14 @@ public static void ThrowExceptionForHR(int errorCode, IntPtr errorInfo) } } + public static void ThrowExceptionForHR(int errorCode, in Guid iid, IntPtr pUnk) + { + if (errorCode < 0) + { + throw GetExceptionForHR(errorCode, in iid, pUnk)!; + } + } + public static IntPtr SecureStringToBSTR(SecureString s) { if (s is null) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index b4bc49672185bf..2d94a3bd4a1bdd 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -107,7 +107,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return new ComMethodContext( data.Method, data.OwningInterface, - CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.OwningInterface.Info.Type, ct)); + CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.OwningInterface.Info, ct)); }).WithTrackingName(StepNames.CalculateStubInformation); var interfaceAndMethodsContexts = comMethodContexts @@ -256,7 +256,7 @@ private static bool IsHResultLikeType(ManagedTypeInfo type) || typeName.Equals("hresult", StringComparison.OrdinalIgnoreCase); } - private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ManagedTypeInfo owningInterface, CancellationToken ct) + private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterfaceInfo, CancellationToken ct) { ct.ThrowIfCancellationRequested(); INamedTypeSymbol? lcidConversionAttrType = environment.LcidConversionAttrType; @@ -349,7 +349,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M // Add the HRESULT return value in the native signature. // This element does not have any influence on the managed signature, so don't assign a managed index. ElementTypeInformation = returnSwappedSignatureElements.Add( - new TypePositionInfo(SpecialTypeInfo.Int32, new ManagedHResultExceptionMarshallingInfo()) + new TypePositionInfo(SpecialTypeInfo.Int32, new ManagedHResultExceptionMarshallingInfo(owningInterfaceInfo.InterfaceId)) { NativeIndex = TypePositionInfo.ReturnIndex }) @@ -425,7 +425,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M virtualMethodIndexData, exceptionMarshallingInfo, environment.EnvironmentFlags, - owningInterface, + owningInterfaceInfo.Type, declaringType, generatorDiagnostics.Diagnostics.ToSequenceEqualImmutableArray(), ComInterfaceDispatchMarshallingInfo.Instance); @@ -826,7 +826,7 @@ private static ClassDeclarationSyntax GenerateInterfaceInformation(ComInterfaceI EqualsValueClause( ImplicitObjectCreationExpression() .AddArgumentListArguments( - Argument(CreateEmbeddedDataBlobCreationStatement(context.InterfaceId.ToByteArray()))))) + Argument(ComInterfaceGeneratorHelpers.CreateEmbeddedDataBlobCreationStatement(context.InterfaceId.ToByteArray()))))) .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))); if (context.Options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper)) @@ -858,20 +858,6 @@ private static ClassDeclarationSyntax GenerateInterfaceInformation(ComInterfaceI .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)) .WithExpressionBody(ArrowExpressionClause(LiteralExpression(SyntaxKind.NullLiteralExpression))) .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))); - - - static ExpressionSyntax CreateEmbeddedDataBlobCreationStatement(ReadOnlySpan bytes) - { - var literals = new CollectionElementSyntax[bytes.Length]; - - for (int i = 0; i < bytes.Length; i++) - { - literals[i] = ExpressionElement(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(bytes[i]))); - } - - // [ ] - return CollectionExpression(SeparatedList(literals)); - } } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs index 91a3a956a024bf..04516de9cccaf4 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs @@ -5,6 +5,9 @@ using System.Collections.Generic; using System.Linq; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Microsoft.Interop { @@ -32,5 +35,18 @@ public static IMarshallingGeneratorResolver GetGeneratorResolver(EnvironmentFlag (false, MarshalDirection.UnmanagedToManaged) => s_unmanagedToManagedEnabledMarshallingGeneratorResolver, _ => throw new UnreachableException(), }; + + public static ExpressionSyntax CreateEmbeddedDataBlobCreationStatement(ReadOnlySpan bytes) + { + var literals = new CollectionElementSyntax[bytes.Length]; + + for (int i = 0; i < bytes.Length; i++) + { + literals[i] = ExpressionElement(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(bytes[i]))); + } + + // [ ] + return CollectionExpression(SeparatedList(literals)); + } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ManagedHResultExceptionGeneratorResolver.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ManagedHResultExceptionGeneratorResolver.cs index 8f8254fa6f6251..583b1380bb901a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ManagedHResultExceptionGeneratorResolver.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ManagedHResultExceptionGeneratorResolver.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Linq; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; @@ -11,7 +12,7 @@ namespace Microsoft.Interop { - internal sealed record ManagedHResultExceptionMarshallingInfo : MarshallingInfo; + internal sealed record ManagedHResultExceptionMarshallingInfo(Guid InterfaceId) : MarshallingInfo; internal sealed class ManagedHResultExceptionGeneratorResolver : IMarshallingGeneratorResolver { @@ -37,7 +38,7 @@ private sealed class ManagedToUnmanagedMarshaller : IUnboundMarshallingGenerator public ManagedTypeInfo AsNativeType(TypePositionInfo info) => info.ManagedType; public IEnumerable Generate(TypePositionInfo info, StubCodeContext codeContext, StubIdentifierContext context) { - Debug.Assert(info.MarshallingAttributeInfo is ManagedHResultExceptionMarshallingInfo); + ManagedHResultExceptionMarshallingInfo marshallingInfo = (ManagedHResultExceptionMarshallingInfo)info.MarshallingAttributeInfo; if (context.CurrentStage != StubIdentifierContext.Stage.NotifyForSuccessfulInvoke) { @@ -46,11 +47,17 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont (string managedIdentifier, _) = context.GetIdentifiers(info); - // Marshal.ThrowExceptionForHR(); + // Marshal.ThrowExceptionForHR(, new(), ); yield return MethodInvocationStatement( TypeSyntaxes.System_Runtime_InteropServices_Marshal, IdentifierName("ThrowExceptionForHR"), - Argument(IdentifierName(managedIdentifier))); + Argument(IdentifierName(managedIdentifier)), + Argument(ImplicitObjectCreationExpression( + ArgumentList( + SingletonSeparatedList( + Argument(ComInterfaceGeneratorHelpers.CreateEmbeddedDataBlobCreationStatement(marshallingInfo.InterfaceId.ToByteArray())))), + initializer: null)), + Argument(CastExpression(TypeSyntaxes.System_IntPtr, IdentifierName(VirtualMethodPointerStubGenerator.NativeThisParameterIdentifier)))); } public SignatureBehavior GetNativeSignatureBehavior(TypePositionInfo info) => SignatureBehavior.NativeType; diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VirtualMethodPointerStubGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VirtualMethodPointerStubGenerator.cs index a56644cef13f58..54ab6543c8a373 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VirtualMethodPointerStubGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VirtualMethodPointerStubGenerator.cs @@ -16,9 +16,9 @@ namespace Microsoft.Interop { internal static class VirtualMethodPointerStubGenerator { - private const string NativeThisParameterIdentifier = "__this"; - private const string VirtualMethodTableIdentifier = "__vtable"; - private const string VirtualMethodTarget = "__target"; + internal const string NativeThisParameterIdentifier = "__this"; + internal const string VirtualMethodTableIdentifier = "__vtable"; + internal const string VirtualMethodTarget = "__target"; public static (MethodDeclarationSyntax, ImmutableArray) GenerateManagedToNativeStub( IncrementalMethodStubGenerationContext methodStub, diff --git a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs index 2a1f06362dc16e..593a42224208d8 100644 --- a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs +++ b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs @@ -1068,6 +1068,7 @@ public static void FreeHGlobal(System.IntPtr hglobal) { } public static int GetExceptionCode() { throw null; } public static System.Exception? GetExceptionForHR(int errorCode) { throw null; } public static System.Exception? GetExceptionForHR(int errorCode, System.IntPtr errorInfo) { throw null; } + public static System.Exception? GetExceptionForHR(int errorCode, in System.Guid iid, System.IntPtr pUnk) { throw null; } public static System.IntPtr GetExceptionPointers() { throw null; } [System.Diagnostics.CodeAnalysis.RequiresDynamicCode("Marshalling code for the delegate might not be available. Use the GetFunctionPointerForDelegate overload instead.")] [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] @@ -1208,6 +1209,7 @@ public static void StructureToPtr(object structure, System.IntPtr ptr, bool fDel public static void StructureToPtr([System.Diagnostics.CodeAnalysis.DisallowNullAttribute] T structure, System.IntPtr ptr, bool fDeleteOld) { } public static void ThrowExceptionForHR(int errorCode) { } public static void ThrowExceptionForHR(int errorCode, System.IntPtr errorInfo) { } + public static void ThrowExceptionForHR(int errorCode, in System.Guid iid, System.IntPtr pUnk) { } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] public static System.IntPtr UnsafeAddrOfPinnedArrayElement(System.Array arr, int index) { throw null; } public static System.IntPtr UnsafeAddrOfPinnedArrayElement(T[] arr, int index) { throw null; } diff --git a/src/libraries/System.Runtime.InteropServices/tests/System.Runtime.InteropServices.UnitTests/System/Runtime/InteropServices/Marshal/ThrowExceptionForHRTests.cs b/src/libraries/System.Runtime.InteropServices/tests/System.Runtime.InteropServices.UnitTests/System/Runtime/InteropServices/Marshal/ThrowExceptionForHRTests.cs index 68c87a4b8effb3..faa057cdc9d7ef 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/System.Runtime.InteropServices.UnitTests/System/Runtime/InteropServices/Marshal/ThrowExceptionForHRTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/System.Runtime.InteropServices.UnitTests/System/Runtime/InteropServices/Marshal/ThrowExceptionForHRTests.cs @@ -3,12 +3,13 @@ using System.Collections.Generic; using System.Reflection; +using System.Runtime.InteropServices.Marshalling; using Xunit; namespace System.Runtime.InteropServices.Tests { [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.IsNotNativeAot))] - public class ThrowExceptionForHRTests + public partial class ThrowExceptionForHRTests { [Theory] [ActiveIssue("https://github.com/mono/mono/issues/15093", TestRuntimes.Mono)] @@ -36,7 +37,7 @@ public void ThrowExceptionForHR_NoErrorInfo_ReturnsValidException(int errorCode) string sourceMaybe = "System.Private.CoreLib"; // If the ThrowExceptionForHR is inlined by the JIT, the source could be the test assembly - Assert.Contains(ex.Source, new string[]{ sourceMaybe, Assembly.GetExecutingAssembly().GetName().Name }); + Assert.Contains(ex.Source, new string[] { sourceMaybe, Assembly.GetExecutingAssembly().GetName().Name }); Assert.Contains(nameof(ThrowExceptionForHR_NoErrorInfo_ReturnsValidException), ex.StackTrace); Assert.Contains(nameof(Marshal.ThrowExceptionForHR), ex.TargetSite.Name); } @@ -77,7 +78,7 @@ public void ThrowExceptionForHR_ErrorInfo_ReturnsValidException(int errorCode, I string sourceMaybe = "System.Private.CoreLib"; // If the ThrowExceptionForHR is inlined by the JIT, the source could be the test assembly - Assert.Contains(ex.Source, new string[]{ sourceMaybe, Assembly.GetExecutingAssembly().GetName().Name }); + Assert.Contains(ex.Source, new string[] { sourceMaybe, Assembly.GetExecutingAssembly().GetName().Name }); Assert.Contains(nameof(ThrowExceptionForHR_ErrorInfo_ReturnsValidException), ex.StackTrace); Assert.Contains(nameof(Marshal.ThrowExceptionForHR), ex.TargetSite.Name); } @@ -94,6 +95,38 @@ public void ThrowExceptionForHR_InvalidHR_Nop(int errorCode) Marshal.ThrowExceptionForHR(errorCode, IntPtr.Zero); } + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsBuiltInComEnabled))] + public void ThrowExceptionForHR_BasedOnISupportErrorInfo() + { + var comWrappers = new StrategyBasedComWrappers(); + Guid iid = new Guid("999b8152-1e6f-4166-8b35-ac89475e96fa"); + var obj = new ConditionallySupportErrorInfo(iid); + IntPtr pUnk = comWrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None); + try + { + var exception = new InvalidOperationException(); + ClearCurrentIErrorInfo(); + + // Set the error info for the current thread to the exception. + _ = Marshal.GetHRForException(exception); + + // The HResult from the IErrorInfo is used because the ISupportErrorInfo interface returned S_OK for the provided iid. + Assert.IsType(Marshal.GetExceptionForHR(new ArgumentException().HResult, iid, pUnk)); + + // Set the error info for the current thread to the exception. + _ = Marshal.GetHRForException(exception); + + var otherIid = new Guid("65af44f4-fd4f-4a35-a6f5-a0c66878fa75"); + + // The HResult from the IErrorInfo is ignored because the ISupportErrorInfo interface returned S_FALSE for the provided otherIid. + Assert.IsType(Marshal.GetExceptionForHR(new ArgumentException().HResult, otherIid, pUnk)); + } + finally + { + Marshal.Release(pUnk); + } + } + private static void ClearCurrentIErrorInfo() { // Ensure that if the thread's current IErrorInfo @@ -101,5 +134,22 @@ private static void ClearCurrentIErrorInfo() // to interpreting the HRESULT. Marshal.GetExceptionForHR(unchecked((int)0x80040001)); } + + [GeneratedComClass] + internal sealed partial class ConditionallySupportErrorInfo(Guid iid) : ISupportErrorInfo + { + public int InterfaceSupportsErrorInfo(in Guid riid) + { + return iid == riid ? 0 : 1; // S_OK or S_FALSE + } + } + + [GeneratedComInterface] + [Guid("DF0B3D60-548F-101B-8E65-08002B2BD119")] + internal partial interface ISupportErrorInfo + { + [PreserveSig] + int InterfaceSupportsErrorInfo(in Guid riid); + } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/ICollectionMarshallingFails.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/ICollectionMarshallingFails.cs index 1c6513607d9056..0d331441e7dd9d 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/ICollectionMarshallingFails.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/ICollectionMarshallingFails.cs @@ -26,7 +26,7 @@ public void Set( } [GeneratedComClass] - internal partial class ICollectionMarshallingFailsImpl : ICollectionMarshallingFails + internal partial class ICollectionMarshallingFailsImpl : ICollectionMarshallingFails, ISupportErrorInfo { int[] _data = new[] { 1, 2, 3, 4, 5, 6, 7, 8 }; public int[] Get(out int size) @@ -43,5 +43,14 @@ public void Set(int[] value, int size) _data = new int[size]; value.CopyTo(_data, 0); } + + int ISupportErrorInfo.InterfaceSupportsErrorInfo(in Guid riid) + { + if (riid == typeof(ICollectionMarshallingFails).GUID) + { + return 0; // S_OK + } + return 1; // S_FALSE + } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/IJaggedIntArrayMarshallingFails.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/IJaggedIntArrayMarshallingFails.cs index 4da8802dff4079..1a4eda0dca222f 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/IJaggedIntArrayMarshallingFails.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/IJaggedIntArrayMarshallingFails.cs @@ -43,7 +43,7 @@ void Set( } [GeneratedComClass] - internal partial class IJaggedIntArrayMarshallingFailsImpl : IJaggedIntArrayMarshallingFails + internal partial class IJaggedIntArrayMarshallingFailsImpl : IJaggedIntArrayMarshallingFails, ISupportErrorInfo { int[][] _data = new int[][] { new int[] { 1, 2, 3 }, new int[] { 4, 5 }, new int[] { 6, 7, 8, 9 } }; int[] _widths = new int[] { 3, 2, 4 }; @@ -76,5 +76,14 @@ public void Set(int[][] array, int[] widths, int length) _data = array; _widths = widths; } + + int ISupportErrorInfo.InterfaceSupportsErrorInfo(in Guid riid) + { + if (riid == typeof(IJaggedIntArrayMarshallingFails).GUID) + { + return 0; // S_OK + } + return 1; // S_FALSE + } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/IStringArrayMarshallingFails.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/IStringArrayMarshallingFails.cs index 77afe03c85fddb..5492f147b137e2 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/IStringArrayMarshallingFails.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/IStringArrayMarshallingFails.cs @@ -28,7 +28,7 @@ internal partial interface IStringArrayMarshallingFails /// Implements IStringArrayMarshallingFails. /// [GeneratedComClass] - internal partial class IStringArrayMarshallingFailsImpl : IStringArrayMarshallingFails + internal partial class IStringArrayMarshallingFailsImpl : IStringArrayMarshallingFails, ISupportErrorInfo { public static string[] StartingStrings { get; } = new string[] { "Hello", "World", "Lorem", "Ipsum", "Dolor", "Sample", "Text", ".Net", "Interop", "string" }; private string[] _strings = StartingStrings; @@ -40,6 +40,15 @@ internal partial class IStringArrayMarshallingFailsImpl : IStringArrayMarshallin public void RefParam([MarshalUsing(ConstantElementCount = 10)] ref string[] value) => value[0] = _strings[0]; [return: MarshalUsing(ConstantElementCount = 10)] public string[] ReturnValue() => _strings; + + int ISupportErrorInfo.InterfaceSupportsErrorInfo(in Guid riid) + { + if (riid == typeof(IStringArrayMarshallingFails).GUID) + { + return 0; // S_OK + } + return 1; // S_FALSE + } } /// diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/ISupportErrorInfo.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/ISupportErrorInfo.cs new file mode 100644 index 00000000000000..b34a824ab394b3 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/MarshallingFails/ISupportErrorInfo.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; + +namespace SharedTypes.ComInterfaces.MarshallingFails +{ + [GeneratedComInterface] + [Guid("DF0B3D60-548F-101B-8E65-08002B2BD119")] + internal partial interface ISupportErrorInfo + { + [PreserveSig] + int InterfaceSupportsErrorInfo(in Guid riid); + } +}