Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release/5.0] When marshalling a layout class, fall-back to dynamically marshalling the type if it doesn't match the static type in the signature. #50138

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/coreclr/src/vm/corelib.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,11 @@ DEFINE_METHOD(MARSHAL, ALLOC_CO_TASK_MEM, AllocCoTa
DEFINE_METHOD(MARSHAL, FREE_CO_TASK_MEM, FreeCoTaskMem, SM_IntPtr_RetVoid)
DEFINE_FIELD(MARSHAL, SYSTEM_MAX_DBCS_CHAR_SIZE, SystemMaxDBCSCharSize)

DEFINE_METHOD(MARSHAL, STRUCTURE_TO_PTR, StructureToPtr, SM_Obj_IntPtr_Bool_RetVoid)
DEFINE_METHOD(MARSHAL, PTR_TO_STRUCTURE, PtrToStructure, SM_IntPtr_Obj_RetVoid)
DEFINE_METHOD(MARSHAL, DESTROY_STRUCTURE, DestroyStructure, SM_IntPtr_Type_RetVoid)
DEFINE_METHOD(MARSHAL, SIZEOF_TYPE, SizeOf, SM_Type_RetInt)

DEFINE_CLASS(NATIVELIBRARY, Interop, NativeLibrary)
DEFINE_METHOD(NATIVELIBRARY, LOADLIBRARYCALLBACKSTUB, LoadLibraryCallbackStub, SM_Str_AssemblyBase_Bool_UInt_RetIntPtr)

Expand Down
129 changes: 126 additions & 3 deletions src/coreclr/src/vm/ilmarshalers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2149,14 +2149,30 @@ void ILLayoutClassPtrMarshalerBase::EmitConvertSpaceCLRToNative(ILCodeStream* ps

EmitLoadManagedValue(pslILEmit);
pslILEmit->EmitBRFALSE(pNullRefLabel);
ILCodeLabel* pTypeMismatchedLabel = pslILEmit->NewCodeLabel();
bool emittedTypeCheck = EmitExactTypeCheck(pslILEmit, pTypeMismatchedLabel);
DWORD sizeLocal = pslILEmit->NewLocal(LocalDesc(ELEMENT_TYPE_I4));

pslILEmit->EmitLDC(uNativeSize);
if (emittedTypeCheck)
{
ILCodeLabel* pHaveSizeLabel = pslILEmit->NewCodeLabel();
pslILEmit->EmitBR(pHaveSizeLabel);
pslILEmit->EmitLabel(pTypeMismatchedLabel);
EmitLoadManagedValue(pslILEmit);
pslILEmit->EmitCALL(METHOD__OBJECT__GET_TYPE, 1, 1);
pslILEmit->EmitCALL(METHOD__MARSHAL__SIZEOF_TYPE, 1, 1);
pslILEmit->EmitLabel(pHaveSizeLabel);
}
pslILEmit->EmitSTLOC(sizeLocal);
pslILEmit->EmitLDLOC(sizeLocal);
pslILEmit->EmitCALL(METHOD__MARSHAL__ALLOC_CO_TASK_MEM, 1, 1);
pslILEmit->EmitDUP(); // for INITBLK
EmitStoreNativeValue(pslILEmit);

// initialize local block we just allocated
pslILEmit->EmitLDC(0);
pslILEmit->EmitLDC(uNativeSize);
pslILEmit->EmitLDLOC(sizeLocal);
pslILEmit->EmitINITBLK();

pslILEmit->EmitLabel(pNullRefLabel);
Expand All @@ -2180,15 +2196,30 @@ void ILLayoutClassPtrMarshalerBase::EmitConvertSpaceCLRToNativeTemp(ILCodeStream

EmitLoadManagedValue(pslILEmit);
pslILEmit->EmitBRFALSE(pNullRefLabel);
ILCodeLabel* pTypeMismatchedLabel = pslILEmit->NewCodeLabel();
bool emittedTypeCheck = EmitExactTypeCheck(pslILEmit, pTypeMismatchedLabel);
DWORD sizeLocal = pslILEmit->NewLocal(LocalDesc(ELEMENT_TYPE_I4));

pslILEmit->EmitLDC(uNativeSize);
if (emittedTypeCheck)
{
ILCodeLabel* pHaveSizeLabel = pslILEmit->NewCodeLabel();
pslILEmit->EmitBR(pHaveSizeLabel);
pslILEmit->EmitLabel(pTypeMismatchedLabel);
EmitLoadManagedValue(pslILEmit);
pslILEmit->EmitCALL(METHOD__OBJECT__GET_TYPE, 1, 1);
pslILEmit->EmitCALL(METHOD__MARSHAL__SIZEOF_TYPE, 1, 1);
pslILEmit->EmitLabel(pHaveSizeLabel);
}
pslILEmit->EmitSTLOC(sizeLocal);
pslILEmit->EmitLDLOC(sizeLocal);
pslILEmit->EmitLOCALLOC();
pslILEmit->EmitDUP(); // for INITBLK
EmitStoreNativeValue(pslILEmit);

// initialize local block we just allocated
pslILEmit->EmitLDC(0);
pslILEmit->EmitLDC(uNativeSize);
pslILEmit->EmitLDLOC(sizeLocal);
pslILEmit->EmitINITBLK();

pslILEmit->EmitLabel(pNullRefLabel);
Expand Down Expand Up @@ -2264,7 +2295,24 @@ void ILLayoutClassPtrMarshalerBase::EmitClearNativeTemp(ILCodeStream* pslILEmit)
}
}

bool ILLayoutClassPtrMarshalerBase::EmitExactTypeCheck(ILCodeStream* pslILEmit, ILCodeLabel* isNotMatchingTypeLabel)
{
STANDARD_VM_CONTRACT;

if (m_pargs->m_pMT->IsSealed())
{
// If the provided type cannot be derived from, then we don't need to emit the type check.
return false;
}
EmitLoadManagedValue(pslILEmit);
pslILEmit->EmitCALL(METHOD__OBJECT__GET_TYPE, 1, 1);
pslILEmit->EmitLDTOKEN(pslILEmit->GetToken(m_pargs->m_pMT));
pslILEmit->EmitCALL(METHOD__TYPE__GET_TYPE_FROM_HANDLE, 1, 1);
pslILEmit->EmitCALLVIRT(pslILEmit->GetToken(CoreLibBinder::GetMethod(METHOD__OBJECT__EQUALS)), 1, 1);
pslILEmit->EmitBRFALSE(isNotMatchingTypeLabel);

return true;
}

void ILLayoutClassPtrMarshaler::EmitConvertContentsCLRToNative(ILCodeStream* pslILEmit)
{
Expand All @@ -2281,6 +2329,9 @@ void ILLayoutClassPtrMarshaler::EmitConvertContentsCLRToNative(ILCodeStream* psl
pslILEmit->EmitLDC(uNativeSize);
pslILEmit->EmitINITBLK();

ILCodeLabel* isNotMatchingTypeLabel = pslILEmit->NewCodeLabel();
bool emittedTypeCheck = EmitExactTypeCheck(pslILEmit, isNotMatchingTypeLabel);

MethodDesc* pStructMarshalStub = NDirect::CreateStructMarshalILStub(m_pargs->m_pMT);

EmitLoadManagedValue(pslILEmit);
Expand All @@ -2290,6 +2341,18 @@ void ILLayoutClassPtrMarshaler::EmitConvertContentsCLRToNative(ILCodeStream* psl
EmitLoadCleanupWorkList(pslILEmit);

pslILEmit->EmitCALL(pslILEmit->GetToken(pStructMarshalStub), 4, 0);

if (emittedTypeCheck)
{
pslILEmit->EmitBR(pNullRefLabel);

pslILEmit->EmitLabel(isNotMatchingTypeLabel);
EmitLoadManagedValue(pslILEmit);
EmitLoadNativeValue(pslILEmit);
pslILEmit->EmitLDC(0);
pslILEmit->EmitCALL(METHOD__MARSHAL__STRUCTURE_TO_PTR, 3, 0);
}

pslILEmit->EmitLabel(pNullRefLabel);
}

Expand All @@ -2302,6 +2365,9 @@ void ILLayoutClassPtrMarshaler::EmitConvertContentsNativeToCLR(ILCodeStream* psl
EmitLoadManagedValue(pslILEmit);
pslILEmit->EmitBRFALSE(pNullRefLabel);

ILCodeLabel* isNotMatchingTypeLabel = pslILEmit->NewCodeLabel();
bool emittedTypeCheck = EmitExactTypeCheck(pslILEmit, isNotMatchingTypeLabel);

MethodDesc* pStructMarshalStub = NDirect::CreateStructMarshalILStub(m_pargs->m_pMT);

EmitLoadManagedValue(pslILEmit);
Expand All @@ -2311,13 +2377,26 @@ void ILLayoutClassPtrMarshaler::EmitConvertContentsNativeToCLR(ILCodeStream* psl
EmitLoadCleanupWorkList(pslILEmit);

pslILEmit->EmitCALL(pslILEmit->GetToken(pStructMarshalStub), 4, 0);
if (emittedTypeCheck)
{
pslILEmit->EmitBR(pNullRefLabel);

pslILEmit->EmitLabel(isNotMatchingTypeLabel);
EmitLoadNativeValue(pslILEmit);
EmitLoadManagedValue(pslILEmit);
pslILEmit->EmitCALL(METHOD__MARSHAL__PTR_TO_STRUCTURE, 2, 0);
}
pslILEmit->EmitLabel(pNullRefLabel);
}

void ILLayoutClassPtrMarshaler::EmitClearNativeContents(ILCodeStream * pslILEmit)
{
STANDARD_VM_CONTRACT;

ILCodeLabel* isNotMatchingTypeLabel = pslILEmit->NewCodeLabel();
ILCodeLabel* cleanedUpLabel = pslILEmit->NewCodeLabel();
bool emittedTypeCheck = EmitExactTypeCheck(pslILEmit, isNotMatchingTypeLabel);

MethodDesc* pStructMarshalStub = NDirect::CreateStructMarshalILStub(m_pargs->m_pMT);

EmitLoadManagedValue(pslILEmit);
Expand All @@ -2327,6 +2406,19 @@ void ILLayoutClassPtrMarshaler::EmitClearNativeContents(ILCodeStream * pslILEmit
EmitLoadCleanupWorkList(pslILEmit);

pslILEmit->EmitCALL(pslILEmit->GetToken(pStructMarshalStub), 4, 0);

if (emittedTypeCheck)
{
pslILEmit->EmitBR(cleanedUpLabel);

pslILEmit->EmitLabel(isNotMatchingTypeLabel);
EmitLoadNativeValue(pslILEmit);
EmitLoadManagedValue(pslILEmit);
pslILEmit->EmitCALL(METHOD__OBJECT__GET_TYPE, 1, 1);
pslILEmit->EmitCALL(METHOD__MARSHAL__DESTROY_STRUCTURE, 2, 0);
}

pslILEmit->EmitLabel(cleanedUpLabel);
}


Expand All @@ -2341,6 +2433,9 @@ void ILBlittablePtrMarshaler::EmitConvertContentsCLRToNative(ILCodeStream* pslIL
EmitLoadNativeValue(pslILEmit);
pslILEmit->EmitBRFALSE(pNullRefLabel);

ILCodeLabel* isNotMatchingTypeLabel = pslILEmit->NewCodeLabel();
bool emittedTypeCheck = EmitExactTypeCheck(pslILEmit, isNotMatchingTypeLabel);

EmitLoadNativeValue(pslILEmit); // dest

EmitLoadManagedValue(pslILEmit);
Expand All @@ -2349,6 +2444,17 @@ void ILBlittablePtrMarshaler::EmitConvertContentsCLRToNative(ILCodeStream* pslIL
pslILEmit->EmitLDC(uNativeSize); // size

pslILEmit->EmitCPBLK();

if (emittedTypeCheck)
{
pslILEmit->EmitBR(pNullRefLabel);

pslILEmit->EmitLabel(isNotMatchingTypeLabel);
EmitLoadManagedValue(pslILEmit);
EmitLoadNativeValue(pslILEmit);
pslILEmit->EmitLDC(0);
pslILEmit->EmitCALL(METHOD__MARSHAL__STRUCTURE_TO_PTR, 3, 0);
}
pslILEmit->EmitLabel(pNullRefLabel);
}

Expand All @@ -2360,6 +2466,9 @@ void ILBlittablePtrMarshaler::EmitConvertContentsNativeToCLR(ILCodeStream* pslIL
UINT uNativeSize = m_pargs->m_pMT->GetNativeSize();
int fieldDef = pslILEmit->GetToken(CoreLibBinder::GetField(FIELD__RAW_DATA__DATA));

ILCodeLabel* isNotMatchingTypeLabel = pslILEmit->NewCodeLabel();
bool emittedTypeCheck = EmitExactTypeCheck(pslILEmit, isNotMatchingTypeLabel);

EmitLoadManagedValue(pslILEmit);
pslILEmit->EmitBRFALSE(pNullRefLabel);

Expand All @@ -2371,12 +2480,26 @@ void ILBlittablePtrMarshaler::EmitConvertContentsNativeToCLR(ILCodeStream* pslIL
pslILEmit->EmitLDC(uNativeSize); // size

pslILEmit->EmitCPBLK();

if (emittedTypeCheck)
{
pslILEmit->EmitBR(pNullRefLabel);

pslILEmit->EmitLabel(isNotMatchingTypeLabel);
EmitLoadNativeValue(pslILEmit);
EmitLoadManagedValue(pslILEmit);
pslILEmit->EmitCALL(METHOD__MARSHAL__PTR_TO_STRUCTURE, 2, 0);
}

pslILEmit->EmitLabel(pNullRefLabel);
}

bool ILBlittablePtrMarshaler::CanMarshalViaPinning()
{
return IsCLRToNative(m_dwMarshalFlags) && !IsByref(m_dwMarshalFlags) && !IsFieldMarshal(m_dwMarshalFlags);
return IsCLRToNative(m_dwMarshalFlags) &&
!IsByref(m_dwMarshalFlags) &&
!IsFieldMarshal(m_dwMarshalFlags) &&
m_pargs->m_pMT->IsSealed(); // We can't marshal via pinning if we might need to marshal differently at runtime. See calls to EmitExactTypeCheck where we check the runtime type of the object being marshalled.
}

void ILBlittablePtrMarshaler::EmitMarshalViaPinning(ILCodeStream* pslILEmit)
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/src/vm/ilmarshalers.h
Original file line number Diff line number Diff line change
Expand Up @@ -2915,6 +2915,7 @@ class ILLayoutClassPtrMarshalerBase : public ILMarshaler
bool NeedsClearNative() override;
void EmitClearNative(ILCodeStream* pslILEmit) override;
void EmitClearNativeTemp(ILCodeStream* pslILEmit) override;
bool EmitExactTypeCheck(ILCodeStream* pslILEmit, ILCodeLabel* isNotMatchingTypeLabel);
};

class ILLayoutClassPtrMarshaler : public ILLayoutClassPtrMarshalerBase
Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/src/vm/metasig.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,10 @@ DEFINE_METASIG_T(SM(Array_Int_Array_Int_Int_RetVoid, C(ARRAY) i C(ARRAY) i i, v)
DEFINE_METASIG_T(SM(Array_Int_Obj_RetVoid, C(ARRAY) i j, v))
DEFINE_METASIG_T(SM(Array_Int_PtrVoid_RetRefObj, C(ARRAY) i P(v), r(j)))

DEFINE_METASIG(SM(Obj_IntPtr_Bool_RetVoid, j I F, v))
DEFINE_METASIG(SM(IntPtr_Obj_RetVoid, I j, v))
DEFINE_METASIG_T(SM(IntPtr_Type_RetVoid, I C(TYPE), v))

// Undefine macros in case we include the file again in the compilation unit

#undef DEFINE_METASIG
Expand Down
15 changes: 12 additions & 3 deletions src/coreclr/src/vm/mlinfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,8 @@ MarshalInfo::MarshalInfo(Module* pModule,
m_pMT = NULL;
m_pMD = pMD;
m_onInstanceMethod = onInstanceMethod;
// [Compat] For backward compatibility reasons, some marshalers imply [In, Out] behavior when marked as [In], [Out], or not marked with either.
BOOL byValAlwaysInOut = FALSE;

#ifdef FEATURE_COMINTEROP
m_fDispItf = FALSE;
Expand Down Expand Up @@ -2007,6 +2009,7 @@ MarshalInfo::MarshalInfo(Module* pModule,
}
m_type = IsFieldScenario() ? MARSHAL_TYPE_BLITTABLE_LAYOUTCLASS : MARSHAL_TYPE_BLITTABLEPTR;
m_args.m_pMT = m_pMT;
byValAlwaysInOut = TRUE;
}
else if (m_pMT->HasLayout())
{
Expand Down Expand Up @@ -2514,10 +2517,16 @@ MarshalInfo::MarshalInfo(Module* pModule,
}
}

// If neither IN nor OUT are true, this signals the URT to use the default
// rules.
if (!m_in && !m_out)
if (!m_byref && byValAlwaysInOut)
{
// Some marshalers expect [In, Out] behavior with [In], [Out], or no directional attributes.
m_in = TRUE;
m_out = TRUE;
}
else if (!m_in && !m_out)
{
// If neither IN nor OUT are true, this signals the URT to use the default
// rules.
if (m_byref ||
(mtype == ELEMENT_TYPE_CLASS
&& !(sig.IsStringType(pModule, pTypeContext))
Expand Down
32 changes: 22 additions & 10 deletions src/tests/Interop/LayoutClass/LayoutClassNative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
#include <xplatform.h>

typedef void *voidPtr;


struct EmptyBase
{
};

struct DerivedSeqClass : public EmptyBase
{
int a;
};

struct SeqClass
{
int a;
Expand All @@ -14,7 +23,7 @@ struct SeqClass
};

struct ExpClass
{
{
int a;
int padding; //padding needs to be added here as we have added 8 byte offset.
union
Expand Down Expand Up @@ -47,32 +56,35 @@ DLL_EXPORT BOOL STDMETHODCALLTYPE SimpleSeqLayoutClassByRef(SeqClass* p)
}

extern "C"
DLL_EXPORT BOOL STDMETHODCALLTYPE SimpleExpLayoutClassByRef(ExpClass* p)
DLL_EXPORT BOOL STDMETHODCALLTYPE DerivedSeqLayoutClassByRef(EmptyBase* p, int expected)
{
if((p->a != 0) || (p->udata.i != 10))
if(((DerivedSeqClass*)p)->a != expected)
{
printf("FAIL: p->a=%d, p->udata.i=%d\n",p->a,p->udata.i);
printf("FAIL: p->a=%d, expected %d\n", ((DerivedSeqClass*)p)->a, expected);
return FALSE;
}
return TRUE;
}

extern "C"
DLL_EXPORT BOOL STDMETHODCALLTYPE SimpleBlittableSeqLayoutClassByRef(BlittableClass* p)
DLL_EXPORT BOOL STDMETHODCALLTYPE SimpleExpLayoutClassByRef(ExpClass* p)
{
if(p->a != 10)
if((p->a != 0) || (p->udata.i != 10))
{
printf("FAIL: p->a=%d\n", p->a);
printf("FAIL: p->a=%d, p->udata.i=%d\n",p->a,p->udata.i);
return FALSE;
}
return TRUE;
}

extern "C"
DLL_EXPORT BOOL STDMETHODCALLTYPE SimpleBlittableSeqLayoutClassByOutAttr(BlittableClass* p)
DLL_EXPORT BOOL STDMETHODCALLTYPE SimpleBlittableSeqLayoutClass_UpdateField(BlittableClass* p)
{
if(!SimpleBlittableSeqLayoutClassByRef(p))
if(p->a != 10)
{
printf("FAIL: p->a=%d\n", p->a);
return FALSE;
}

p->a++;
return TRUE;
Expand Down
Loading