diff --git a/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs b/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs index 61e71adcff477a..50f97ac625e3a6 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs @@ -1357,14 +1357,18 @@ internal static Exception GetHRExceptionObject(int hr) } #if FEATURE_COMINTEROP - internal static Exception GetCOMHRExceptionObject(int hr, IntPtr pCPCMD, IntPtr pUnk) + internal static unsafe Exception GetCOMHRExceptionObject(int hr, IntPtr pCPCMD, IntPtr pUnk) { - RuntimeMethodHandle handle = RuntimeMethodHandle.FromIntPtr(pCPCMD); - RuntimeType declaringType = RuntimeMethodHandle.GetDeclaringType(handle.GetMethodInfo()); + Debug.Assert(pCPCMD != IntPtr.Zero); + MethodTable* interfaceType = GetComInterfaceFromMethodDesc(pCPCMD); + RuntimeType declaringType = RuntimeTypeHandle.GetRuntimeType(interfaceType); Exception ex = Marshal.GetExceptionForHR(hr, declaringType.GUID, pUnk)!; ex.InternalPreserveStackTrace(); return ex; } + + [MethodImpl(MethodImplOptions.InternalCall)] + private static extern unsafe MethodTable* GetComInterfaceFromMethodDesc(IntPtr pCPCMD); #endif // FEATURE_COMINTEROP [ThreadStatic] diff --git a/src/coreclr/vm/ecalllist.h b/src/coreclr/vm/ecalllist.h index 502586170e843f..266db5df974044 100644 --- a/src/coreclr/vm/ecalllist.h +++ b/src/coreclr/vm/ecalllist.h @@ -352,6 +352,7 @@ FCFuncStart(gStubHelperFuncs) FCFuncElement("SetLastError", StubHelpers::SetLastError) FCFuncElement("ClearLastError", StubHelpers::ClearLastError) #ifdef FEATURE_COMINTEROP + FCFuncElement("GetComInterfaceFromMethodDesc", StubHelpers::GetComInterfaceFromMethodDesc) FCFuncElement("GetCOMIPFromRCW", StubHelpers::GetCOMIPFromRCW) #endif // FEATURE_COMINTEROP FCFuncElement("CalcVaListSize", StubHelpers::CalcVaListSize) diff --git a/src/coreclr/vm/stubhelpers.cpp b/src/coreclr/vm/stubhelpers.cpp index 4437f432123d05..b51dae748310a3 100644 --- a/src/coreclr/vm/stubhelpers.cpp +++ b/src/coreclr/vm/stubhelpers.cpp @@ -266,6 +266,14 @@ FORCEINLINE static IUnknown* GetCOMIPFromRCW_GetTargetFromRCWCache(SOleTlsData* return NULL; } +FCIMPL1(MethodTable*, StubHelpers::GetComInterfaceFromMethodDesc, MethodDesc* pMD) +{ + FCALL_CONTRACT; + _ASSERTE(pMD != NULL); + return CLRToCOMCallInfo::FromMethodDesc(pMD)->m_pInterfaceMT; +} +FCIMPLEND + //================================================================================================================== // The GetCOMIPFromRCW helper exists in four specialized versions to optimize CLR->COM perf. Please be careful when // changing this code as one of these methods is executed as part of every CLR->COM call so every instruction counts. diff --git a/src/coreclr/vm/stubhelpers.h b/src/coreclr/vm/stubhelpers.h index 00fdd97b6d1b64..35a76a654535d3 100644 --- a/src/coreclr/vm/stubhelpers.h +++ b/src/coreclr/vm/stubhelpers.h @@ -26,6 +26,7 @@ class StubHelpers //------------------------------------------------------- #ifdef FEATURE_COMINTEROP + static FCDECL1(MethodTable*, GetComInterfaceFromMethodDesc, MethodDesc* pMD); static FCDECL3(IUnknown*, GetCOMIPFromRCW, Object* pSrcUNSAFE, MethodDesc* pMD, void **ppTarget); #endif // FEATURE_COMINTEROP diff --git a/src/tests/Interop/COM/NETClients/ComDisabled/Program.cs b/src/tests/Interop/COM/NETClients/ComDisabled/Program.cs index ca74c5cb74a151..a88ed3032a0fd3 100644 --- a/src/tests/Interop/COM/NETClients/ComDisabled/Program.cs +++ b/src/tests/Interop/COM/NETClients/ComDisabled/Program.cs @@ -22,7 +22,7 @@ public static int TestEntryPoint() try { - var server = (Server.Contract.Servers.NumericTesting)new Server.Contract.Servers.NumericTestingClass(); + var server = new Server.Contract.Servers.NumericTesting(); } catch (NotSupportedException) when (OperatingSystem.IsWindows()) { diff --git a/src/tests/Interop/COM/NETClients/Events/Program.cs b/src/tests/Interop/COM/NETClients/Events/Program.cs index 3aacc923a2c73b..c3b70ac05379cd 100644 --- a/src/tests/Interop/COM/NETClients/Events/Program.cs +++ b/src/tests/Interop/COM/NETClients/Events/Program.cs @@ -20,7 +20,7 @@ static void Validate_BasicCOMEvent() { Console.WriteLine($"{nameof(Validate_BasicCOMEvent)}..."); - var eventTesting = (EventTesting)new EventTestingClass(); + var eventTesting = new EventTesting(); // Verify event handler subscription @@ -57,7 +57,7 @@ static void Validate_COMEventViaComAwareEventInfo() { Console.WriteLine($"{nameof(Validate_COMEventViaComAwareEventInfo)}..."); - var eventTesting = (EventTesting)new EventTestingClass(); + var eventTesting = new EventTesting(); // Verify event handler subscription diff --git a/src/tests/Interop/COM/NETClients/IDispatch/Program.cs b/src/tests/Interop/COM/NETClients/IDispatch/Program.cs index 4324913a5a1fb6..f76ddd27d32015 100644 --- a/src/tests/Interop/COM/NETClients/IDispatch/Program.cs +++ b/src/tests/Interop/COM/NETClients/IDispatch/Program.cs @@ -19,7 +19,7 @@ public class Program { static void Validate_Numeric_In_ReturnByRef() { - var dispatchTesting = (DispatchTesting)new DispatchTestingClass(); + var dispatchTesting = new DispatchTesting(); byte b1 = 1; byte b2 = b1; @@ -74,7 +74,7 @@ static private bool EqualByBound(double expected, double actual) static void Validate_Float_In_ReturnAndUpdateByRef() { - var dispatchTesting = (DispatchTesting)new DispatchTestingClass(); + var dispatchTesting = new DispatchTesting(); float a = .1f; float b = .2f; @@ -91,7 +91,7 @@ static void Validate_Float_In_ReturnAndUpdateByRef() static void Validate_Double_In_ReturnAndUpdateByRef() { - var dispatchTesting = (DispatchTesting)new DispatchTestingClass(); + var dispatchTesting = new DispatchTesting(); double a = .1; double b = .2; @@ -114,7 +114,7 @@ static int GetErrorCodeFromHResult(int hresult) static void Validate_Exception() { - var dispatchTesting = (DispatchTesting)new DispatchTestingClass(); + var dispatchTesting = new DispatchTesting(); int errorCode = 1127; string resultString = errorCode.ToString("x"); @@ -174,7 +174,7 @@ static void Validate_Exception() static void Validate_StructNotSupported() { Console.WriteLine($"IDispatch with structs not supported..."); - var dispatchTesting = (DispatchTesting)new DispatchTestingClass(); + var dispatchTesting = new DispatchTesting(); var input = new HFA_4() { x = 1f, y = 2f, z = 3f, w = 4f }; Assert.Throws(() => dispatchTesting.DoubleHVAValues(ref input)); @@ -182,7 +182,7 @@ static void Validate_StructNotSupported() static void Validate_LCID_Marshaled() { - var dispatchTesting = (DispatchTesting)new DispatchTestingClass(); + var dispatchTesting = new DispatchTesting(); CultureInfo oldCulture = CultureInfo.CurrentCulture; CultureInfo newCulture = new CultureInfo("es-ES", false); try @@ -200,7 +200,7 @@ static void Validate_LCID_Marshaled() static void Validate_Enumerator() { - var dispatchTesting = (DispatchTesting)new DispatchTestingClass(); + var dispatchTesting = new DispatchTesting(); var expected = System.Linq.Enumerable.Range(0, 10); { @@ -251,7 +251,7 @@ static System.Collections.Generic.IEnumerable ConvertEnumerable(System.Coll static void Validate_ValueCoerce_ReturnToManaged() { - var dispatchCoerceTesting = (DispatchCoerceTesting)new DispatchCoerceTestingClass(); + var dispatchCoerceTesting = new DispatchCoerceTesting(); Console.WriteLine($"Calling {nameof(DispatchCoerceTesting.ReturnToManaged)} ..."); diff --git a/src/tests/Interop/COM/NETClients/IInspectable/Program.cs b/src/tests/Interop/COM/NETClients/IInspectable/Program.cs index a539802ead4347..c36bcd4f8196cf 100644 --- a/src/tests/Interop/COM/NETClients/IInspectable/Program.cs +++ b/src/tests/Interop/COM/NETClients/IInspectable/Program.cs @@ -18,7 +18,7 @@ public class Program { static void Validate_IInspectable() { - var server = (InspectableTesting)new InspectableTestingClass(); + var server = new InspectableTesting(); Assert.Throws(() => _ = (IInspectableTesting2)server); } diff --git a/src/tests/Interop/COM/NETClients/Licensing/Program.cs b/src/tests/Interop/COM/NETClients/Licensing/Program.cs index 77d3d5111c94aa..e8f21273b4f76d 100644 --- a/src/tests/Interop/COM/NETClients/Licensing/Program.cs +++ b/src/tests/Interop/COM/NETClients/Licensing/Program.cs @@ -23,13 +23,13 @@ static void ActivateLicensedObject() Console.WriteLine($"Calling {nameof(ActivateLicensedObject)}..."); // Validate activation - var licenseTesting = (LicenseTesting)new LicenseTestingClass(); + var licenseTesting = new LicenseTesting(); // Validate license denial licenseTesting.SetNextDenyLicense(true); try { - var tmp = (LicenseTesting)new LicenseTestingClass(); + var tmp = new LicenseTesting(); Assert.Fail("Activation of licensed class should fail"); } catch (COMException e) @@ -86,7 +86,7 @@ static void ActivateUnderDesigntimeContext() LicenseManager.CurrentContext = new MockLicenseContext(typeof(LicenseTestingClass), LicenseUsageMode.Designtime); LicenseManager.CurrentContext.SetSavedLicenseKey(typeof(LicenseTestingClass), licKey); - var licenseTesting = (LicenseTesting)new LicenseTestingClass(); + var licenseTesting = new LicenseTesting(); // During design time the IClassFactory::CreateInstance will be called - no license Assert.Null(licenseTesting.GetLicense()); @@ -111,7 +111,7 @@ static void ActivateUnderRuntimeContext() LicenseManager.CurrentContext = new MockLicenseContext(typeof(LicenseTestingClass), LicenseUsageMode.Runtime); LicenseManager.CurrentContext.SetSavedLicenseKey(typeof(LicenseTestingClass), licKey); - var licenseTesting = (LicenseTesting)new LicenseTestingClass(); + var licenseTesting = new LicenseTesting(); // During runtime the IClassFactory::CreateInstance2 will be called with license from context Assert.Equal(licKey, licenseTesting.GetLicense()); diff --git a/src/tests/Interop/COM/NETClients/MiscTypes/Program.cs b/src/tests/Interop/COM/NETClients/MiscTypes/Program.cs index 4c578303fa3fd1..a8c0f579e50290 100644 --- a/src/tests/Interop/COM/NETClients/MiscTypes/Program.cs +++ b/src/tests/Interop/COM/NETClients/MiscTypes/Program.cs @@ -87,7 +87,7 @@ private static void ValidationTests() { Console.WriteLine($"Running {nameof(ValidationTests)} ..."); - var miscTypeTesting = (Server.Contract.Servers.MiscTypesTesting)new Server.Contract.Servers.MiscTypesTestingClass(); + var miscTypeTesting = new Server.Contract.Servers.MiscTypesTesting(); Console.WriteLine("-- Primitives <=> VARIANT..."); { @@ -226,7 +226,7 @@ private static void ValidateNegativeTests() { Console.WriteLine($"Running {nameof(ValidateNegativeTests)} ..."); - var miscTypeTesting = (Server.Contract.Servers.MiscTypesTesting)new Server.Contract.Servers.MiscTypesTestingClass(); + var miscTypeTesting = new Server.Contract.Servers.MiscTypesTesting(); Console.WriteLine("-- DispatchWrapper with non-IDispatch object <=> VARIANT..."); { diff --git a/src/tests/Interop/COM/NETClients/Primitives/ArrayTests.cs b/src/tests/Interop/COM/NETClients/Primitives/ArrayTests.cs index 2edffed11d24fc..06e9a116bf45a4 100644 --- a/src/tests/Interop/COM/NETClients/Primitives/ArrayTests.cs +++ b/src/tests/Interop/COM/NETClients/Primitives/ArrayTests.cs @@ -17,7 +17,7 @@ class ArrayTests public ArrayTests() { - this.server = (Server.Contract.Servers.ArrayTesting)new Server.Contract.Servers.ArrayTestingClass(); + this.server = new Server.Contract.Servers.ArrayTesting(); double acc = 0.0; int[] rawData = BaseData.ToArray(); @@ -31,6 +31,7 @@ public ArrayTests() public void Run() { + Console.WriteLine(nameof(ArrayTests)); this.Marshal_ByteArray(); this.Marshal_ShortArray(); this.Marshal_UShortArray(); diff --git a/src/tests/Interop/COM/NETClients/Primitives/CallViaReflectionTests.cs b/src/tests/Interop/COM/NETClients/Primitives/CallViaReflectionTests.cs index 561f6c685abb8e..eefd5a3d50ae3a 100644 --- a/src/tests/Interop/COM/NETClients/Primitives/CallViaReflectionTests.cs +++ b/src/tests/Interop/COM/NETClients/Primitives/CallViaReflectionTests.cs @@ -13,7 +13,7 @@ class CallViaReflectionTests public CallViaReflectionTests() { - this.server = (Server.Contract.Servers.NumericTesting)new Server.Contract.Servers.NumericTestingClass(); + this.server = new Server.Contract.Servers.NumericTesting(); } public void Run() diff --git a/src/tests/Interop/COM/NETClients/Primitives/ColorTests.cs b/src/tests/Interop/COM/NETClients/Primitives/ColorTests.cs index 0ecea298de08fd..81afd3291122f9 100644 --- a/src/tests/Interop/COM/NETClients/Primitives/ColorTests.cs +++ b/src/tests/Interop/COM/NETClients/Primitives/ColorTests.cs @@ -13,11 +13,12 @@ class ColorTests private readonly Server.Contract.Servers.ColorTesting server; public ColorTests() { - this.server = (Server.Contract.Servers.ColorTesting)new Server.Contract.Servers.ColorTestingClass(); + this.server = new Server.Contract.Servers.ColorTesting(); } public void Run() { + Console.WriteLine(nameof(ColorTests)); this.VerifyColorMarshalling(); this.VerifyGetRed(); } diff --git a/src/tests/Interop/COM/NETClients/Primitives/ErrorTests.cs b/src/tests/Interop/COM/NETClients/Primitives/ErrorTests.cs index 17c5a41674f8a0..58ac77aa1822b2 100644 --- a/src/tests/Interop/COM/NETClients/Primitives/ErrorTests.cs +++ b/src/tests/Interop/COM/NETClients/Primitives/ErrorTests.cs @@ -12,11 +12,12 @@ class ErrorTests private readonly Server.Contract.Servers.ErrorMarshalTesting server; public ErrorTests() { - this.server = (Server.Contract.Servers.ErrorMarshalTesting)new Server.Contract.Servers.ErrorMarshalTestingClass(); + this.server = new Server.Contract.Servers.ErrorMarshalTesting(); } public void Run() { + Console.WriteLine(nameof(ErrorTests)); this.VerifyExpectedException(); this.VerifyReturnHResult(); this.VerifyHelpLink(); diff --git a/src/tests/Interop/COM/NETClients/Primitives/NumericTests.cs b/src/tests/Interop/COM/NETClients/Primitives/NumericTests.cs index 6517191f53cc58..9a6530e10955a8 100644 --- a/src/tests/Interop/COM/NETClients/Primitives/NumericTests.cs +++ b/src/tests/Interop/COM/NETClients/Primitives/NumericTests.cs @@ -18,11 +18,12 @@ public NumericTests(int seed = 37) Console.WriteLine($"Numeric RNG seed: {this.seed}"); this.rng = new Random(this.seed); - this.server = (Server.Contract.Servers.NumericTesting)new Server.Contract.Servers.NumericTestingClass(); + this.server = new Server.Contract.Servers.NumericTesting(); } public void Run() { + Console.WriteLine(nameof(NumericTests)); int a = this.rng.Next(); int b = this.rng.Next(); diff --git a/src/tests/Interop/COM/NETClients/Primitives/StringTests.cs b/src/tests/Interop/COM/NETClients/Primitives/StringTests.cs index d88b84070afdbc..1de229f27cc284 100644 --- a/src/tests/Interop/COM/NETClients/Primitives/StringTests.cs +++ b/src/tests/Interop/COM/NETClients/Primitives/StringTests.cs @@ -46,11 +46,12 @@ class StringTests public StringTests() { - this.server = (Server.Contract.Servers.StringTesting)new Server.Contract.Servers.StringTestingClass(); + this.server = new Server.Contract.Servers.StringTesting(); } public void Run() { + Console.WriteLine(nameof(StringTests)); this.Marshal_LPString(); this.Marshal_LPWString(); this.Marshal_BStrString();