diff --git a/docs/design/features/unsafeaccessors.md b/docs/design/features/unsafeaccessors.md new file mode 100644 index 0000000000000..8ce4d2a22ed26 --- /dev/null +++ b/docs/design/features/unsafeaccessors.md @@ -0,0 +1,137 @@ +# `UnsafeAccessorAttribute` + +## Background and motivation + +Number of existing .NET serializers depend on skipping member visibility checks for data serialization. Examples include System.Text.Json or EF Core. In order to skip the visibility checks, the serializers typically use dynamically emitted code (Reflection.Emit or Linq.Expressions) and classic reflection APIs as slow fallback. Neither of these two options are great for source generated serializers and native AOT compilation. This API proposal introduces a first class zero-overhead mechanism for skipping visibility checks. + +## Semantics + +This attribute will be applied to an `extern static` method. The implementation of the `extern static` method annotated with this attribute will be provided by the runtime based on the information in the attribute and the signature of the method that the attribute is applied to. The runtime will try to find the matching method or field and forward the call to it. If the matching method or field is not found, the body of the `extern static` method will throw `MissingFieldException` or `MissingMethodException`. + +For `Method`, `StaticMethod`, `Field`, and `StaticField`, the type of the first argument of the annotated `extern static` method identifies the owning type. Only the specific type defined will be examined for inaccessible members. The type hierarchy is not walked looking for a match. + +The value of the first argument is treated as `this` pointer for instance fields and methods. + +The first argument must be passed as `ref` for instance fields and methods on structs. + +The value of the first argument is not used by the implementation for static fields and methods. + +The return value for an accessor to a field can be `ref` if setting of the field is desired. + +Constructors can be accessed using Constructor or Method. + +The return type is considered for the signature match. Modreqs and modopts are initially not considered for the signature match. However, if an ambiguity exists ignoring modreqs and modopts, a precise match is attempted. If an ambiguity still exists, `AmbiguousMatchException` is thrown. + +By default, the attributed method's name dictates the name of the method/field. This can cause confusion in some cases since language abstractions, like C# local functions, generate mangled IL names. The solution to this is to use the `nameof` mechanism and define the `Name` property. + +Scenarios involving generics may require creating new generic types to contain the `extern static` method definition. The decision was made to require all `ELEMENT_TYPE_VAR` and `ELEMENT_TYPE_MVAR` instances to match identically type and generic parameter index. This means if the target method for access uses an `ELEMENT_TYPE_VAR`, the `extern static` method must also use an `ELEMENT_TYPE_VAR`. For example: + +```csharp +class C +{ + T M(U u) => default; +} + +class Accessor +{ + // Correct - V is an ELEMENT_TYPE_VAR and W is ELEMENT_TYPE_VAR, + // respectively the same as T and U in the definition of C::M(). + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + extern static void CallM(C c, W w); + + // Incorrect - Since Y must be an ELEMENT_TYPE_VAR, but is ELEMENT_TYPE_MVAR below. + // [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + // extern static void CallM(C c, Z z); +} +``` + +Methods with the `UnsafeAccessorAttribute` that access members with generic parameters are expected to have the same declared constraints with the target member. Failure to do so results in unspecified behavior. For example: + +```csharp +class C +{ + T M(U u) where U: Base => default; +} + +class Accessor +{ + // Correct - Constraints match the target member. + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + extern static void CallM(C c, W w) where W: Base; + + // Incorrect - Constraints do not match target member. + // [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + // extern static void CallM(C c, W w); +} +``` + +## API + +```csharp +namespace System.Runtime.CompilerServices; + +[AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = false)] +public class UnsafeAccessorAttribute : Attribute +{ + public UnsafeAccessorAttribute(UnsafeAccessorKind kind); + + public UnsafeAccessorKind Kind { get; } + + // The name defaults to the annotated method name if not specified. + // The name must be null for constructors + public string? Name { get; set; } +} + +public enum UnsafeAccessorKind +{ + Constructor, // call instance constructor (`newobj` in IL) + Method, // call instance method (`callvirt` in IL) + StaticMethod, // call static method (`call` in IL) + Field, // address of instance field (`ldflda` in IL) + StaticField // address of static field (`ldsflda` in IL) +}; +``` + +## API Usage + +```csharp +class UserData +{ + private UserData() { } + public string Name { get; set; } +} + +[UnsafeAccessor(UnsafeAccessorKind.Constructor)] +extern static UserData CallPrivateConstructor(); + +// This API allows accessing backing fields for auto-implemented properties with unspeakable names. +[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "k__BackingField")] +extern static ref string GetName(UserData userData); + +UserData ud = CallPrivateConstructor(); +GetName(ud) = "Joe"; +``` + +Using generics + +```csharp +class UserData +{ + private T _field; + private UserData(T t) { _field = t; } + private U ConvertFieldToT() => (U)_field; +} + +// The Accessors class provides the generic Type parameter for the method definitions. +class Accessors +{ + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + extern static UserData CallPrivateConstructor(V v); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "ConvertFieldToT")] + extern static U CallConvertFieldToT(UserData userData); +} + +UserData ud = Accessors.CallPrivateConstructor("Joe"); +Accessors.CallPrivateConstructor(ud); +``` \ No newline at end of file diff --git a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ILEmitter.cs b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ILEmitter.cs index b6f31e130142e..4a4723d2b7afd 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ILEmitter.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ILEmitter.cs @@ -686,27 +686,27 @@ private ILToken NewToken(object value, int tokenType) public ILToken NewToken(TypeDesc value) { - return NewToken(value, 0x01000000); + return NewToken(value, 0x01000000); // mdtTypeRef } public ILToken NewToken(MethodDesc value) { - return NewToken(value, 0x0a000000); + return NewToken(value, 0x0a000000); // mdtMemberRef } public ILToken NewToken(FieldDesc value) { - return NewToken(value, 0x0a000000); + return NewToken(value, 0x0a000000); // mdtMemberRef } public ILToken NewToken(string value) { - return NewToken(value, 0x70000000); + return NewToken(value, 0x70000000); // mdtString } public ILToken NewToken(MethodSignature value) { - return NewToken(value, 0x11000000); + return NewToken(value, 0x11000000); // mdtSignature } public ILLocalVariable NewLocal(TypeDesc localType, bool isPinned = false) diff --git a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs index 6338f725be223..e8e97f2eb3197 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs @@ -29,12 +29,6 @@ public static MethodIL TryGetIL(EcmaMethod method) return GenerateAccessorBadImageFailure(method); } - // Block generic support early - if (method.HasInstantiation || method.OwningType.HasInstantiation) - { - return GenerateAccessorBadImageFailure(method); - } - if (!TryParseUnsafeAccessorAttribute(method, decodedAttribute.Value, out UnsafeAccessorKind kind, out string name)) { return GenerateAccessorBadImageFailure(method); @@ -54,7 +48,7 @@ public static MethodIL TryGetIL(EcmaMethod method) firstArgType = sig[0]; } - bool isAmbiguous = false; + SetTargetResult result; // Using the kind type, perform the following: // 1) Validate the basic type information from the signature. @@ -77,9 +71,10 @@ public static MethodIL TryGetIL(EcmaMethod method) } const string ctorName = ".ctor"; - if (!TrySetTargetMethod(ref context, ctorName, out isAmbiguous)) + result = TrySetTargetMethod(ref context, ctorName); + if (result is not SetTargetResult.Success) { - return GenerateAccessorSpecificFailure(ref context, ctorName, isAmbiguous); + return GenerateAccessorSpecificFailure(ref context, ctorName, result); } break; case UnsafeAccessorKind.Method: @@ -105,9 +100,10 @@ public static MethodIL TryGetIL(EcmaMethod method) } context.IsTargetStatic = kind == UnsafeAccessorKind.StaticMethod; - if (!TrySetTargetMethod(ref context, name, out isAmbiguous)) + result = TrySetTargetMethod(ref context, name); + if (result is not SetTargetResult.Success) { - return GenerateAccessorSpecificFailure(ref context, name, isAmbiguous); + return GenerateAccessorSpecificFailure(ref context, name, result); } break; @@ -136,9 +132,10 @@ public static MethodIL TryGetIL(EcmaMethod method) } context.IsTargetStatic = kind == UnsafeAccessorKind.StaticField; - if (!TrySetTargetField(ref context, name, ((ParameterizedType)retType).GetParameterType())) + result = TrySetTargetField(ref context, name, ((ParameterizedType)retType).GetParameterType()); + if (result is not SetTargetResult.Success) { - return GenerateAccessorSpecificFailure(ref context, name, isAmbiguous); + return GenerateAccessorSpecificFailure(ref context, name, result); } break; @@ -232,6 +229,12 @@ private static bool ValidateTargetType(TypeDesc targetTypeMaybe, out TypeDesc va targetType = null; } + // We do not support signature variables as a target (for example, VAR and MVAR). + if (targetType is SignatureVariable) + { + targetType = null; + } + validated = targetType; return validated != null; } @@ -366,7 +369,45 @@ private static bool DoesMethodMatchUnsafeAccessorDeclaration(ref GenerationConte return true; } - private static bool TrySetTargetMethod(ref GenerationContext context, string name, out bool isAmbiguous, bool ignoreCustomModifiers = true) + private static bool VerifyDeclarationSatisfiesTargetConstraints(MethodDesc declaration, TypeDesc targetType, MethodDesc targetMethod) + { + Debug.Assert(declaration != null); + Debug.Assert(targetType != null); + Debug.Assert(targetMethod != null); + + if (targetType.HasInstantiation) + { + Instantiation declClassInst = declaration.OwningType.Instantiation; + var instType = targetType.Context.GetInstantiatedType((MetadataType)targetType.GetTypeDefinition(), declClassInst); + if (!instType.CheckConstraints()) + { + return false; + } + + targetMethod = instType.FindMethodOnExactTypeWithMatchingTypicalMethod(targetMethod); + } + + if (targetMethod.HasInstantiation) + { + Instantiation declMethodInst = declaration.Instantiation; + var instMethod = targetType.Context.GetInstantiatedMethod(targetMethod, declMethodInst); + if (!instMethod.CheckConstraints()) + { + return false; + } + } + return true; + } + + private enum SetTargetResult + { + Success, + Missing, + Ambiguous, + Invalid, + } + + private static SetTargetResult TrySetTargetMethod(ref GenerationContext context, string name, bool ignoreCustomModifiers = true) { TypeDesc targetType = context.TargetType; @@ -399,23 +440,39 @@ private static bool TrySetTargetMethod(ref GenerationContext context, string nam // We have detected ambiguity when ignoring custom modifiers. // Start over, but look for a match requiring custom modifiers // to match precisely. - if (TrySetTargetMethod(ref context, name, out isAmbiguous, ignoreCustomModifiers: false)) - return true; + if (SetTargetResult.Success == TrySetTargetMethod(ref context, name, ignoreCustomModifiers: false)) + return SetTargetResult.Success; } - - isAmbiguous = true; - return false; + return SetTargetResult.Ambiguous; } targetMaybe = md; } - isAmbiguous = false; + if (targetMaybe != null) + { + if (!VerifyDeclarationSatisfiesTargetConstraints(context.Declaration, targetType, targetMaybe)) + { + return SetTargetResult.Invalid; + } + + if (targetMaybe.HasInstantiation) + { + TypeDesc[] methodInstantiation = new TypeDesc[targetMaybe.Instantiation.Length]; + for (int i = 0; i < methodInstantiation.Length; ++i) + { + methodInstantiation[i] = targetMaybe.Context.GetSignatureVariable(i, true); + } + targetMaybe = targetMaybe.Context.GetInstantiatedMethod(targetMaybe, new Instantiation(methodInstantiation)); + } + Debug.Assert(targetMaybe is not null); + } + context.TargetMethod = targetMaybe; - return context.TargetMethod != null; + return context.TargetMethod != null ? SetTargetResult.Success : SetTargetResult.Missing; } - private static bool TrySetTargetField(ref GenerationContext context, string name, TypeDesc fieldType) + private static SetTargetResult TrySetTargetField(ref GenerationContext context, string name, TypeDesc fieldType) { TypeDesc targetType = context.TargetType; @@ -431,10 +488,10 @@ private static bool TrySetTargetField(ref GenerationContext context, string name && fieldType == fd.FieldType) { context.TargetField = fd; - return true; + return SetTargetResult.Success; } } - return false; + return SetTargetResult.Missing; } private static MethodIL GenerateAccessor(ref GenerationContext context) @@ -486,7 +543,7 @@ private static MethodIL GenerateAccessor(ref GenerationContext context) return emit.Link(context.Declaration); } - private static MethodIL GenerateAccessorSpecificFailure(ref GenerationContext context, string name, bool ambiguous) + private static MethodIL GenerateAccessorSpecificFailure(ref GenerationContext context, string name, SetTargetResult result) { ILEmitter emit = new ILEmitter(); ILCodeStream codeStream = emit.NewCodeStream(); @@ -496,14 +553,19 @@ private static MethodIL GenerateAccessorSpecificFailure(ref GenerationContext co MethodDesc thrower; TypeSystemContext typeSysContext = context.Declaration.Context; - if (ambiguous) + if (result is SetTargetResult.Ambiguous) { codeStream.EmitLdc((int)ExceptionStringID.AmbiguousMatchUnsafeAccessor); thrower = typeSysContext.GetHelperEntryPoint("ThrowHelpers", "ThrowAmbiguousMatchException"); } + else if (result is SetTargetResult.Invalid) + { + codeStream.EmitLdc((int)ExceptionStringID.InvalidProgramDefault); + thrower = typeSysContext.GetHelperEntryPoint("ThrowHelpers", "ThrowInvalidProgramException"); + } else { - + Debug.Assert(result is SetTargetResult.Missing); ExceptionStringID id; if (context.Kind == UnsafeAccessorKind.Field || context.Kind == UnsafeAccessorKind.StaticField) { diff --git a/src/coreclr/vm/callconvbuilder.cpp b/src/coreclr/vm/callconvbuilder.cpp index 20f95f1222410..3075087ee7b83 100644 --- a/src/coreclr/vm/callconvbuilder.cpp +++ b/src/coreclr/vm/callconvbuilder.cpp @@ -298,15 +298,12 @@ namespace { STANDARD_VM_CONTRACT; - TypeHandle type; - MethodDesc* pMD; - FieldDesc* pFD; + ResolvedToken resolved{}; + pResolver->ResolveToken(token, &resolved); - pResolver->ResolveToken(token, &type, &pMD, &pFD); + _ASSERTE(!resolved.TypeHandle.IsNull()); - _ASSERTE(!type.IsNull()); - - *nameOut = type.GetMethodTable()->GetFullyQualifiedNameInfo(namespaceOut); + *nameOut = resolved.TypeHandle.GetMethodTable()->GetFullyQualifiedNameInfo(namespaceOut); return S_OK; } diff --git a/src/coreclr/vm/dynamicmethod.cpp b/src/coreclr/vm/dynamicmethod.cpp index bd5bebcce50f2..065d80d57fcc1 100644 --- a/src/coreclr/vm/dynamicmethod.cpp +++ b/src/coreclr/vm/dynamicmethod.cpp @@ -1325,7 +1325,7 @@ void LCGMethodResolver::AddToUsedIndCellList(BYTE * indcell) } -void LCGMethodResolver::ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** ppMD, FieldDesc ** ppFD) +void LCGMethodResolver::ResolveToken(mdToken token, ResolvedToken* resolvedToken) { STANDARD_VM_CONTRACT; @@ -1335,24 +1335,35 @@ void LCGMethodResolver::ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc DECLARE_ARGHOLDER_ARRAY(args, 5); + TypeHandle handle; + MethodDesc* pMD = NULL; + FieldDesc* pFD = NULL; args[ARGNUM_0] = OBJECTREF_TO_ARGHOLDER(ObjectFromHandle(m_managedResolver)); args[ARGNUM_1] = DWORD_TO_ARGHOLDER(token); - args[ARGNUM_2] = pTH; - args[ARGNUM_3] = ppMD; - args[ARGNUM_4] = ppFD; + args[ARGNUM_2] = &handle; + args[ARGNUM_3] = &pMD; + args[ARGNUM_4] = &pFD; CALL_MANAGED_METHOD_NORET(args); - _ASSERTE(*ppMD == NULL || *ppFD == NULL); + _ASSERTE(pMD == NULL || pFD == NULL); - if (pTH->IsNull()) + if (handle.IsNull()) { - if (*ppMD != NULL) *pTH = (*ppMD)->GetMethodTable(); - else - if (*ppFD != NULL) *pTH = (*ppFD)->GetEnclosingMethodTable(); + if (pMD != NULL) + { + handle = pMD->GetMethodTable(); + } + else if (pFD != NULL) + { + handle = pFD->GetEnclosingMethodTable(); + } } - _ASSERTE(!pTH->IsNull()); + _ASSERTE(!handle.IsNull()); + resolvedToken->TypeHandle = handle; + resolvedToken->Method = pMD; + resolvedToken->Field = pFD; } //--------------------------------------------------------------------------------------- diff --git a/src/coreclr/vm/dynamicmethod.h b/src/coreclr/vm/dynamicmethod.h index ddbe3c795cfe3..a26a241006113 100644 --- a/src/coreclr/vm/dynamicmethod.h +++ b/src/coreclr/vm/dynamicmethod.h @@ -37,6 +37,15 @@ class ChunkAllocator void Delete(); }; +struct ResolvedToken final +{ + TypeHandle TypeHandle; + SigPointer TypeSignature; + SigPointer MethodSignature; + MethodDesc* Method; + FieldDesc* Field; +}; + //--------------------------------------------------------------------------------------- // class DynamicResolver @@ -90,7 +99,7 @@ class DynamicResolver virtual OBJECTHANDLE ConstructStringLiteral(mdToken metaTok) = 0; virtual BOOL IsValidStringRef(mdToken metaTok) = 0; virtual STRINGREF GetStringLiteral(mdToken metaTok) = 0; - virtual void ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** ppMD, FieldDesc ** ppFD) = 0; + virtual void ResolveToken(mdToken token, ResolvedToken* resolvedToken) = 0; virtual SigPointer ResolveSignature(mdToken token) = 0; virtual SigPointer ResolveSignatureForVarArg(mdToken token) = 0; virtual void GetEHInfo(unsigned EHnumber, CORINFO_EH_CLAUSE* clause) = 0; @@ -141,7 +150,7 @@ class LCGMethodResolver : public DynamicResolver OBJECTHANDLE ConstructStringLiteral(mdToken metaTok); BOOL IsValidStringRef(mdToken metaTok); - void ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** ppMD, FieldDesc ** ppFD); + void ResolveToken(mdToken token, ResolvedToken* resolvedToken); SigPointer ResolveSignature(mdToken token); SigPointer ResolveSignatureForVarArg(mdToken token); void GetEHInfo(unsigned EHnumber, CORINFO_EH_CLAUSE* clause); diff --git a/src/coreclr/vm/ilstubresolver.cpp b/src/coreclr/vm/ilstubresolver.cpp index c24be260c692e..1efb9c2975e16 100644 --- a/src/coreclr/vm/ilstubresolver.cpp +++ b/src/coreclr/vm/ilstubresolver.cpp @@ -133,13 +133,10 @@ STRINGREF ILStubResolver::GetStringLiteral(mdToken metaTok) return NULL; } -void ILStubResolver::ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** ppMD, FieldDesc ** ppFD) +void ILStubResolver::ResolveToken(mdToken token, ResolvedToken* resolvedToken) { STANDARD_VM_CONTRACT; - - *pTH = NULL; - *ppMD = NULL; - *ppFD = NULL; + _ASSERTE(resolvedToken != NULL); switch (TypeFromToken(token)) { @@ -147,8 +144,8 @@ void ILStubResolver::ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** { MethodDesc* pMD = m_pCompileTimeState->m_tokenLookupMap.LookupMethodDef(token); _ASSERTE(pMD); - *ppMD = pMD; - *pTH = TypeHandle(pMD->GetMethodTable()); + resolvedToken->Method = pMD; + resolvedToken->TypeHandle = TypeHandle(pMD->GetMethodTable()); } break; @@ -156,7 +153,7 @@ void ILStubResolver::ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** { TypeHandle typeHnd = m_pCompileTimeState->m_tokenLookupMap.LookupTypeDef(token); _ASSERTE(!typeHnd.IsNull()); - *pTH = typeHnd; + resolvedToken->TypeHandle = typeHnd; } break; @@ -164,10 +161,59 @@ void ILStubResolver::ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** { FieldDesc* pFD = m_pCompileTimeState->m_tokenLookupMap.LookupFieldDef(token); _ASSERTE(pFD); - *ppFD = pFD; - *pTH = TypeHandle(pFD->GetEnclosingMethodTable()); + resolvedToken->Field = pFD; + resolvedToken->TypeHandle = TypeHandle(pFD->GetEnclosingMethodTable()); + } + break; + +#if !defined(DACCESS_COMPILE) + case mdtMemberRef: + { + TokenLookupMap::MemberRefEntry entry = m_pCompileTimeState->m_tokenLookupMap.LookupMemberRef(token); + if (entry.Type == mdtFieldDef) + { + _ASSERTE(entry.Entry.Field != NULL); + + if (entry.ClassSignatureToken != mdTokenNil) + resolvedToken->TypeSignature = m_pCompileTimeState->m_tokenLookupMap.LookupSig(entry.ClassSignatureToken); + + resolvedToken->Field = entry.Entry.Field; + resolvedToken->TypeHandle = TypeHandle(entry.Entry.Field->GetApproxEnclosingMethodTable()); + } + else + { + _ASSERTE(entry.Type == mdtMethodDef); + _ASSERTE(entry.Entry.Method != NULL); + + if (entry.ClassSignatureToken != mdTokenNil) + resolvedToken->TypeSignature = m_pCompileTimeState->m_tokenLookupMap.LookupSig(entry.ClassSignatureToken); + + resolvedToken->Method = entry.Entry.Method; + MethodTable* pMT = entry.Entry.Method->GetMethodTable(); + _ASSERTE(!pMT->ContainsGenericVariables()); + resolvedToken->TypeHandle = TypeHandle(pMT); + } + } + break; + + case mdtMethodSpec: + { + TokenLookupMap::MethodSpecEntry entry = m_pCompileTimeState->m_tokenLookupMap.LookupMethodSpec(token); + _ASSERTE(entry.Method != NULL); + + if (entry.ClassSignatureToken != mdTokenNil) + resolvedToken->TypeSignature = m_pCompileTimeState->m_tokenLookupMap.LookupSig(entry.ClassSignatureToken); + + if (entry.MethodSignatureToken != mdTokenNil) + resolvedToken->MethodSignature = m_pCompileTimeState->m_tokenLookupMap.LookupSig(entry.MethodSignatureToken); + + resolvedToken->Method = entry.Method; + MethodTable* pMT = entry.Method->GetMethodTable(); + _ASSERTE(!pMT->ContainsGenericVariables()); + resolvedToken->TypeHandle = TypeHandle(pMT); } break; +#endif // !defined(DACCESS_COMPILE) default: UNREACHABLE_MSG("unexpected metadata token type"); diff --git a/src/coreclr/vm/ilstubresolver.h b/src/coreclr/vm/ilstubresolver.h index 82a1217d79c7e..ea823e7f77380 100644 --- a/src/coreclr/vm/ilstubresolver.h +++ b/src/coreclr/vm/ilstubresolver.h @@ -35,7 +35,7 @@ class ILStubResolver : DynamicResolver OBJECTHANDLE ConstructStringLiteral(mdToken metaTok); BOOL IsValidStringRef(mdToken metaTok); STRINGREF GetStringLiteral(mdToken metaTok); - void ResolveToken(mdToken token, TypeHandle * pTH, MethodDesc ** ppMD, FieldDesc ** ppFD); + void ResolveToken(mdToken token, ResolvedToken* resolvedToken); SigPointer ResolveSignature(mdToken token); SigPointer ResolveSignatureForVarArg(mdToken token); void GetEHInfo(unsigned EHnumber, CORINFO_EH_CLAUSE* clause); diff --git a/src/coreclr/vm/jitinterface.cpp b/src/coreclr/vm/jitinterface.cpp index 5e6b0cbeeafdd..06f205428f0d2 100644 --- a/src/coreclr/vm/jitinterface.cpp +++ b/src/coreclr/vm/jitinterface.cpp @@ -156,15 +156,13 @@ inline CORINFO_MODULE_HANDLE GetScopeHandle(MethodDesc* method) //This is common refactored code from within several of the access check functions. static BOOL ModifyCheckForDynamicMethod(DynamicResolver *pResolver, TypeHandle *pOwnerTypeForSecurity, - AccessCheckOptions::AccessCheckType *pAccessCheckType, - DynamicResolver** ppAccessContext) + AccessCheckOptions::AccessCheckType *pAccessCheckType) { CONTRACTL { STANDARD_VM_CHECK; PRECONDITION(CheckPointer(pResolver)); PRECONDITION(CheckPointer(pOwnerTypeForSecurity)); PRECONDITION(CheckPointer(pAccessCheckType)); - PRECONDITION(CheckPointer(ppAccessContext)); PRECONDITION(*pAccessCheckType == AccessCheckOptions::kNormalAccessibilityChecks); } CONTRACTL_END; @@ -883,7 +881,18 @@ void CEEInfo::resolveToken(/* IN, OUT */ CORINFO_RESOLVED_TOKEN * pResolvedToken if (IsDynamicScope(pResolvedToken->tokenScope)) { - GetDynamicResolver(pResolvedToken->tokenScope)->ResolveToken(pResolvedToken->token, &th, &pMD, &pFD); + ResolvedToken resolved{}; + GetDynamicResolver(pResolvedToken->tokenScope)->ResolveToken(pResolvedToken->token, &resolved); + + th = resolved.TypeHandle; + pMD = resolved.Method; + pFD = resolved.Field; + + // Record supplied signatures. + if (!resolved.TypeSignature.IsNull()) + resolved.TypeSignature.GetSignature(&pResolvedToken->pTypeSpec, &pResolvedToken->cbTypeSpec); + if (!resolved.MethodSignature.IsNull()) + resolved.MethodSignature.GetSignature(&pResolvedToken->pMethodSpec, &pResolvedToken->cbMethodSpec); // // Check that we got the expected handles and fill in missing data if necessary @@ -893,18 +902,10 @@ void CEEInfo::resolveToken(/* IN, OUT */ CORINFO_RESOLVED_TOKEN * pResolvedToken if (pMD != NULL) { - if ((tkType != mdtMethodDef) && (tkType != mdtMemberRef)) + if ((tkType != mdtMethodDef) && (tkType != mdtMemberRef) && (tkType != mdtMethodSpec)) ThrowBadTokenException(pResolvedToken); if ((tokenType & CORINFO_TOKENKIND_Method) == 0) ThrowBadTokenException(pResolvedToken); - if (th.IsNull()) - th = pMD->GetMethodTable(); - - // "PermitUninstDefOrRef" check - if ((tokenType != CORINFO_TOKENKIND_Ldtoken) && pMD->ContainsGenericVariables()) - { - COMPlusThrow(kInvalidProgramException); - } // if this is a BoxedEntryPointStub get the UnboxedEntryPoint one if (pMD->IsUnboxingStub()) @@ -924,8 +925,6 @@ void CEEInfo::resolveToken(/* IN, OUT */ CORINFO_RESOLVED_TOKEN * pResolvedToken ThrowBadTokenException(pResolvedToken); if ((tokenType & CORINFO_TOKENKIND_Field) == 0) ThrowBadTokenException(pResolvedToken); - if (th.IsNull()) - th = pFD->GetApproxEnclosingMethodTable(); if (pFD->IsStatic() && (tokenType != CORINFO_TOKENKIND_Ldtoken)) { @@ -959,7 +958,7 @@ void CEEInfo::resolveToken(/* IN, OUT */ CORINFO_RESOLVED_TOKEN * pResolvedToken else { mdToken metaTOK = pResolvedToken->token; - Module * pModule = (Module *)pResolvedToken->tokenScope; + Module * pModule = GetModule(pResolvedToken->tokenScope); switch (TypeFromToken(metaTOK)) { @@ -1705,7 +1704,9 @@ void CEEInfo::getFieldInfo (CORINFO_RESOLVED_TOKEN * pResolvedToken, SigTypeContext::InitTypeContext(pCallerForSecurity, &typeContext); SigPointer sigptr(pResolvedToken->pTypeSpec, pResolvedToken->cbTypeSpec); - fieldTypeForSecurity = sigptr.GetTypeHandleThrowing((Module *)pResolvedToken->tokenScope, &typeContext); + + Module* targetModule = GetModule(pResolvedToken->tokenScope); + fieldTypeForSecurity = sigptr.GetTypeHandleThrowing(targetModule, &typeContext); // typeHnd can be a variable type if (fieldTypeForSecurity.GetMethodTable() == NULL) @@ -1717,15 +1718,13 @@ void CEEInfo::getFieldInfo (CORINFO_RESOLVED_TOKEN * pResolvedToken, BOOL doAccessCheck = TRUE; AccessCheckOptions::AccessCheckType accessCheckType = AccessCheckOptions::kNormalAccessibilityChecks; - DynamicResolver * pAccessContext = NULL; - //More in code:CEEInfo::getCallInfo, but the short version is that the caller and callee Descs do //not completely describe the type. TypeHandle callerTypeForSecurity = TypeHandle(pCallerForSecurity->GetMethodTable()); if (IsDynamicScope(pResolvedToken->tokenScope)) { doAccessCheck = ModifyCheckForDynamicMethod(GetDynamicResolver(pResolvedToken->tokenScope), &callerTypeForSecurity, - &accessCheckType, &pAccessContext); + &accessCheckType); } //Now for some link time checks. @@ -1737,7 +1736,7 @@ void CEEInfo::getFieldInfo (CORINFO_RESOLVED_TOKEN * pResolvedToken, { //Well, let's check some visibility at least. AccessCheckOptions accessCheckOptions(accessCheckType, - pAccessContext, + NULL, FALSE, pField); @@ -1851,22 +1850,19 @@ CEEInfo::findCallSiteSig( { _ASSERTE(TypeFromToken(sigMethTok) == mdtMethodDef); - TypeHandle classHandle; - MethodDesc * pMD = NULL; - FieldDesc * pFD = NULL; - // in this case a method is asked for its sig. Resolve the method token and get the sig - pResolver->ResolveToken(sigMethTok, &classHandle, &pMD, &pFD); - if (pMD == NULL) + ResolvedToken resolved{}; + pResolver->ResolveToken(sigMethTok, &resolved); + if (resolved.Method == NULL) COMPlusThrow(kInvalidProgramException); PCCOR_SIGNATURE pSig = NULL; DWORD cbSig; - pMD->GetSig(&pSig, &cbSig); + resolved.Method->GetSig(&pSig, &cbSig); sig = SigPointer(pSig, cbSig); - context = MAKE_METHODCONTEXT(pMD); - scopeHnd = GetScopeHandle(pMD->GetModule()); + context = MAKE_METHODCONTEXT(resolved.Method); + scopeHnd = GetScopeHandle(resolved.Method->GetModule()); } sig.GetSignature(&pSig, &cbSig); @@ -3250,7 +3246,7 @@ void CEEInfo::ComputeRuntimeLookupForSharedGenericToken(DictionaryEntryKind entr sigBuilder.AppendData(pContextMT->GetNumDicts() - 1); } - Module * pModule = (Module *)pResolvedToken->tokenScope; + Module * pModule = GetModule(pResolvedToken->tokenScope); switch (entryKind) { @@ -4931,7 +4927,6 @@ CorInfoIsAccessAllowedResult CEEInfo::canAccessClass( BOOL doAccessCheck = TRUE; AccessCheckOptions::AccessCheckType accessCheckType = AccessCheckOptions::kNormalAccessibilityChecks; - DynamicResolver * pAccessContext = NULL; //All access checks must be done on the open instantiation. MethodDesc * pCallerForSecurity = GetMethodForSecurity(callerHandle); @@ -4944,7 +4939,7 @@ CorInfoIsAccessAllowedResult CEEInfo::canAccessClass( SigTypeContext::InitTypeContext(pCallerForSecurity, &typeContext); SigPointer sigptr(pResolvedToken->pTypeSpec, pResolvedToken->cbTypeSpec); - pCalleeForSecurity = sigptr.GetTypeHandleThrowing((Module *)pResolvedToken->tokenScope, &typeContext); + pCalleeForSecurity = sigptr.GetTypeHandleThrowing(GetModule(pResolvedToken->tokenScope), &typeContext); } while (pCalleeForSecurity.HasTypeParam()) @@ -4955,8 +4950,7 @@ CorInfoIsAccessAllowedResult CEEInfo::canAccessClass( if (IsDynamicScope(pResolvedToken->tokenScope)) { doAccessCheck = ModifyCheckForDynamicMethod(GetDynamicResolver(pResolvedToken->tokenScope), - &callerTypeForSecurity, &accessCheckType, - &pAccessContext); + &callerTypeForSecurity, &accessCheckType); } //Since this is a check against a TypeHandle, there are some things we can stick in a TypeHandle that @@ -4971,7 +4965,7 @@ CorInfoIsAccessAllowedResult CEEInfo::canAccessClass( if (doAccessCheck) { AccessCheckOptions accessCheckOptions(accessCheckType, - pAccessContext, + NULL, FALSE /*throw on error*/, pCalleeForSecurity.GetMethodTable()); @@ -5543,7 +5537,7 @@ void CEEInfo::getCallInfo( if (pResolvedToken->pTypeSpec != NULL) { SigPointer sigptr(pResolvedToken->pTypeSpec, pResolvedToken->cbTypeSpec); - calleeTypeForSecurity = sigptr.GetTypeHandleThrowing((Module *)pResolvedToken->tokenScope, &typeContext); + calleeTypeForSecurity = sigptr.GetTypeHandleThrowing(GetModule(pResolvedToken->tokenScope), &typeContext); // typeHnd can be a variable type if (calleeTypeForSecurity.GetMethodTable() == NULL) @@ -5570,7 +5564,7 @@ void CEEInfo::getCallInfo( IfFailThrow(sp.GetByte(&etype)); // Load the generic method instantiation - THROW_BAD_FORMAT_MAYBE(etype == (BYTE)IMAGE_CEE_CS_CALLCONV_GENERICINST, 0, (Module *)pResolvedToken->tokenScope); + THROW_BAD_FORMAT_MAYBE(etype == (BYTE)IMAGE_CEE_CS_CALLCONV_GENERICINST, 0, GetModule(pResolvedToken->tokenScope)); IfFailThrow(sp.GetData(&nGenericMethodArgs)); @@ -5584,7 +5578,7 @@ void CEEInfo::getCallInfo( for (uint32_t i = 0; i < nGenericMethodArgs; i++) { - genericMethodArgs[i] = sp.GetTypeHandleThrowing((Module *)pResolvedToken->tokenScope, &typeContext); + genericMethodArgs[i] = sp.GetTypeHandleThrowing(GetModule(pResolvedToken->tokenScope), &typeContext); _ASSERTE (!genericMethodArgs[i].IsNull()); IfFailThrow(sp.SkipExactlyOne()); } @@ -5604,14 +5598,13 @@ void CEEInfo::getCallInfo( BOOL doAccessCheck = TRUE; BOOL canAccessMethod = TRUE; AccessCheckOptions::AccessCheckType accessCheckType = AccessCheckOptions::kNormalAccessibilityChecks; - DynamicResolver * pAccessContext = NULL; callerTypeForSecurity = TypeHandle(pCallerForSecurity->GetMethodTable()); if (pCallerForSecurity->IsDynamicMethod()) { doAccessCheck = ModifyCheckForDynamicMethod(pCallerForSecurity->AsDynamicMethodDesc()->GetResolver(), &callerTypeForSecurity, - &accessCheckType, &pAccessContext); + &accessCheckType); } pResult->accessAllowed = CORINFO_ACCESS_ALLOWED; @@ -5619,7 +5612,7 @@ void CEEInfo::getCallInfo( if (doAccessCheck) { AccessCheckOptions accessCheckOptions(accessCheckType, - pAccessContext, + NULL, FALSE, pCalleeForSecurity); @@ -12375,10 +12368,11 @@ void CEEJitInfo::setEHinfo ( ((pEHClause->Flags & COR_ILEXCEPTION_CLAUSE_FILTER) == 0) && (clause->ClassToken != NULL)) { - MethodDesc * pMD; FieldDesc * pFD; - m_pMethodBeingCompiled->AsDynamicMethodDesc()->GetResolver()->ResolveToken(clause->ClassToken, (TypeHandle *)&pEHClause->TypeHandle, &pMD, &pFD); + ResolvedToken resolved{}; + m_pMethodBeingCompiled->AsDynamicMethodDesc()->GetResolver()->ResolveToken(clause->ClassToken, &resolved); + pEHClause->TypeHandle = (void*)resolved.TypeHandle.AsPtr(); SetHasCachedTypeHandle(pEHClause); - LOG((LF_EH, LL_INFO1000000, " CachedTypeHandle: 0x%08lx -> 0x%08lx\n", clause->ClassToken, pEHClause->TypeHandle)); + LOG((LF_EH, LL_INFO1000000, " CachedTypeHandle: 0x%08x -> %p\n", clause->ClassToken, pEHClause->TypeHandle)); } EE_TO_JIT_TRANSITION(); @@ -12982,18 +12976,17 @@ PCODE UnsafeJitFunction(PrepareCodeConfig* config, //and its return type. AccessCheckOptions::AccessCheckType accessCheckType = AccessCheckOptions::kNormalAccessibilityChecks; TypeHandle ownerTypeForSecurity = TypeHandle(pMethodForSecurity->GetMethodTable()); - DynamicResolver *pAccessContext = NULL; BOOL doAccessCheck = TRUE; if (pMethodForSecurity->IsDynamicMethod()) { doAccessCheck = ModifyCheckForDynamicMethod(pMethodForSecurity->AsDynamicMethodDesc()->GetResolver(), &ownerTypeForSecurity, - &accessCheckType, &pAccessContext); + &accessCheckType); } if (doAccessCheck) { AccessCheckOptions accessCheckOptions(accessCheckType, - pAccessContext, + NULL, TRUE /*Throw on error*/, pMethodForSecurity); diff --git a/src/coreclr/vm/methodtable.cpp b/src/coreclr/vm/methodtable.cpp index b59a2d23d7d03..08ad1eede6468 100644 --- a/src/coreclr/vm/methodtable.cpp +++ b/src/coreclr/vm/methodtable.cpp @@ -472,6 +472,17 @@ WORD MethodTable::GetNumMethods() return GetClass()->GetNumMethods(); } +PTR_MethodTable MethodTable::GetTypicalMethodTable() +{ + LIMITED_METHOD_DAC_CONTRACT; + if (IsArray()) + return (PTR_MethodTable)this; + + PTR_MethodTable methodTableMaybe = GetModule()->LookupTypeDef(GetCl()).AsMethodTable(); + _ASSERTE(methodTableMaybe->IsTypicalTypeDefinition()); + return methodTableMaybe; +} + //========================================================================================== BOOL MethodTable::HasSameTypeDefAs(MethodTable *pMT) { diff --git a/src/coreclr/vm/methodtable.h b/src/coreclr/vm/methodtable.h index 83057c623fba9..6e4f68f29bce7 100644 --- a/src/coreclr/vm/methodtable.h +++ b/src/coreclr/vm/methodtable.h @@ -1183,6 +1183,8 @@ class MethodTable return !HasInstantiation() || IsGenericTypeDefinition(); } + PTR_MethodTable GetTypicalMethodTable(); + BOOL HasSameTypeDefAs(MethodTable *pMT); //------------------------------------------------------------------- diff --git a/src/coreclr/vm/prestub.cpp b/src/coreclr/vm/prestub.cpp index 7df5865af7920..ff447af8ab46e 100644 --- a/src/coreclr/vm/prestub.cpp +++ b/src/coreclr/vm/prestub.cpp @@ -1116,6 +1116,7 @@ namespace : Kind{ kind } , Declaration{ pMD } , DeclarationSig{ pMD } + , TargetTypeSig{} , TargetType{} , IsTargetStatic{ false } , TargetMethod{} @@ -1125,13 +1126,14 @@ namespace UnsafeAccessorKind Kind; MethodDesc* Declaration; MetaSig DeclarationSig; + SigPointer TargetTypeSig; TypeHandle TargetType; bool IsTargetStatic; MethodDesc* TargetMethod; FieldDesc* TargetField; }; - TypeHandle ValidateTargetType(TypeHandle targetTypeMaybe) + TypeHandle ValidateTargetType(TypeHandle targetTypeMaybe, CorElementType targetFromSig) { TypeHandle targetType = targetTypeMaybe.IsByRef() ? targetTypeMaybe.GetTypeParam() @@ -1142,6 +1144,12 @@ namespace if (targetType.IsTypeDesc()) ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); + // We do not support generic signature types as valid targets. + if (targetFromSig == ELEMENT_TYPE_VAR || targetFromSig == ELEMENT_TYPE_MVAR) + { + ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); + } + return targetType; } @@ -1167,16 +1175,29 @@ namespace ModuleBase* pModule2 = method->GetModule(); const Substitution* pSubst2 = NULL; + // + // Parsing the signature follows details defined in ECMA-335 - II.23.2.1 + // + // Validate calling convention if ((*pSig1 & IMAGE_CEE_CS_CALLCONV_MASK) != (*pSig2 & IMAGE_CEE_CS_CALLCONV_MASK)) { return false; } - BYTE callConv = *pSig1; + BYTE callConvDecl = *pSig1; + BYTE callConvMethod = *pSig2; pSig1++; pSig2++; + // Handle generic param count + DWORD declGenericCount = 0; + DWORD methodGenericCount = 0; + if (callConvDecl & IMAGE_CEE_CS_CALLCONV_GENERIC) + IfFailThrow(CorSigUncompressData_EndPtr(pSig1, pEndSig1, &declGenericCount)); + if (callConvMethod & IMAGE_CEE_CS_CALLCONV_GENERIC) + IfFailThrow(CorSigUncompressData_EndPtr(pSig2, pEndSig2, &methodGenericCount)); + DWORD declArgCount; DWORD methodArgCount; IfFailThrow(CorSigUncompressData_EndPtr(pSig1, pEndSig1, &declArgCount)); @@ -1250,6 +1271,74 @@ namespace return true; } + void VerifyDeclarationSatisfiesTargetConstraints(MethodDesc* declaration, MethodTable* targetType, MethodDesc* targetMethod) + { + CONTRACTL + { + STANDARD_VM_CHECK; + PRECONDITION(declaration != NULL); + PRECONDITION(targetType != NULL); + PRECONDITION(targetMethod != NULL); + } + CONTRACTL_END; + + // If the target method has no generic parameters there is nothing to verify + if (!targetMethod->HasClassOrMethodInstantiation()) + return; + + // Construct a context for verifying target's constraints are + // satisfied by the declaration. + Instantiation declClassInst; + Instantiation declMethodInst; + Instantiation targetClassInst; + Instantiation targetMethodInst; + if (targetType->HasInstantiation()) + { + declClassInst = declaration->GetMethodTable()->GetInstantiation(); + targetClassInst = targetType->GetTypicalMethodTable()->GetInstantiation(); + } + if (targetMethod->HasMethodInstantiation()) + { + declMethodInst = declaration->LoadTypicalMethodDefinition()->GetMethodInstantiation(); + targetMethodInst = targetMethod->LoadTypicalMethodDefinition()->GetMethodInstantiation(); + } + + SigTypeContext typeContext; + SigTypeContext::InitTypeContext(declClassInst, declMethodInst, &typeContext); + + InstantiationContext instContext{ &typeContext }; + + // + // Validate constraints on Type parameters + // + DWORD typeParamCount = targetClassInst.GetNumArgs(); + if (typeParamCount != declClassInst.GetNumArgs()) + COMPlusThrow(kInvalidProgramException, W("Argument_GenTypeConstraintsNotEqual")); + + for (DWORD i = 0; i < typeParamCount; ++i) + { + TypeHandle arg = declClassInst[i]; + TypeVarTypeDesc* param = targetClassInst[i].AsGenericVariable(); + if (!param->SatisfiesConstraints(&typeContext, arg, &instContext)) + COMPlusThrow(kInvalidProgramException, W("Argument_GenTypeConstraintsNotEqual")); + } + + // + // Validate constraints on Method parameters + // + DWORD methodParamCount = targetMethodInst.GetNumArgs(); + if (methodParamCount != declMethodInst.GetNumArgs()) + COMPlusThrow(kInvalidProgramException, W("Argument_GenMethodConstraintsNotEqual")); + + for (DWORD i = 0; i < methodParamCount; ++i) + { + TypeHandle arg = declMethodInst[i]; + TypeVarTypeDesc* param = targetMethodInst[i].AsGenericVariable(); + if (!param->SatisfiesConstraints(&typeContext, arg, &instContext)) + COMPlusThrow(kInvalidProgramException, W("Argument_GenMethodConstraintsNotEqual")); + } + } + bool TrySetTargetMethod( GenerationContext& cxt, LPCUTF8 methodName, @@ -1264,11 +1353,13 @@ namespace TypeHandle targetType = cxt.TargetType; _ASSERTE(!targetType.IsTypeDesc()); + MethodTable* pMT = targetType.AsMethodTable(); + MethodDesc* targetMaybe = NULL; // Following a similar iteration pattern found in MemberLoader::FindMethod(). // However, we are only operating on the current type not walking the type hierarchy. - MethodTable::IntroducedMethodIterator iter(targetType.AsMethodTable()); + MethodTable::IntroducedMethodIterator iter(pMT); for (; iter.IsValid(); iter.Next()) { MethodDesc* curr = iter.GetMethodDesc(); @@ -1304,6 +1395,9 @@ namespace targetMaybe = curr; } + if (targetMaybe != NULL) + VerifyDeclarationSatisfiesTargetConstraints(cxt.Declaration, pMT, targetMaybe); + cxt.TargetMethod = targetMaybe; return cxt.TargetMethod != NULL; } @@ -1321,19 +1415,47 @@ namespace TypeHandle targetType = cxt.TargetType; _ASSERTE(!targetType.IsTypeDesc()); + MethodTable* pMT = targetType.AsMethodTable(); + + CorElementType elemType = fieldType.GetSignatureCorElementType(); ApproxFieldDescIterator fdIterator( - targetType.AsMethodTable(), + pMT, (cxt.IsTargetStatic ? ApproxFieldDescIterator::STATIC_FIELDS : ApproxFieldDescIterator::INSTANCE_FIELDS)); PTR_FieldDesc pField; while ((pField = fdIterator.Next()) != NULL) { // Validate the name and target type match. - if (strcmp(fieldName, pField->GetName()) == 0 - && fieldType == pField->LookupFieldTypeHandle()) + if (strcmp(fieldName, pField->GetName()) != 0) + continue; + + // We check if the possible field is class or valuetype + // since generic fields need resolution. + CorElementType fieldTypeMaybe = pField->GetFieldType(); + if (fieldTypeMaybe == ELEMENT_TYPE_CLASS + || fieldTypeMaybe == ELEMENT_TYPE_VALUETYPE) + { + if (fieldType != pField->LookupFieldTypeHandle()) + continue; + } + else + { + if (elemType != fieldTypeMaybe) + continue; + } + + if (cxt.Kind == UnsafeAccessorKind::StaticField && pMT->HasGenericsStaticsInfo()) { - cxt.TargetField = pField; - return true; + // Statics require the exact typed field as opposed to the canonically + // typed field. In order to do that we lookup the current index of the + // approx field and then use that index to get the precise field from + // the approx field. + MethodTable* pFieldMT = pField->GetApproxEnclosingMethodTable(); + DWORD index = pFieldMT->GetIndexForFieldDesc(pField); + pField = pMT->GetFieldDescByIndex(index); } + + cxt.TargetField = pField; + return true; } return false; } @@ -1351,12 +1473,14 @@ namespace ilResolver->SetStubMethodDesc(cxt.Declaration); ilResolver->SetStubTargetMethodDesc(cxt.TargetMethod); - // [TODO] Handle generics - SigTypeContext emptyContext; + SigTypeContext genericContext; + if (cxt.Declaration->GetClassification() == mcInstantiated) + SigTypeContext::InitTypeContext(cxt.Declaration, &genericContext); + ILStubLinker sl( cxt.Declaration->GetModule(), cxt.Declaration->GetSignature(), - &emptyContext, + &genericContext, cxt.TargetMethod, (ILStubLinkerFlags)ILSTUB_LINKER_FLAG_NONE); @@ -1377,24 +1501,126 @@ namespace switch (cxt.Kind) { case UnsafeAccessorKind::Constructor: + { _ASSERTE(cxt.TargetMethod != NULL); - pCode->EmitNEWOBJ(pCode->GetToken(cxt.TargetMethod), targetArgCount); + mdToken target; + if (!cxt.TargetType.HasInstantiation()) + { + target = pCode->GetToken(cxt.TargetMethod); + } + else + { + PCCOR_SIGNATURE sig; + uint32_t sigLen; + cxt.TargetTypeSig.GetSignature(&sig, &sigLen); + mdToken targetTypeSigToken = pCode->GetSigToken(sig, sigLen); + target = pCode->GetToken(cxt.TargetMethod, targetTypeSigToken); + } + pCode->EmitNEWOBJ(target, targetArgCount); break; + } case UnsafeAccessorKind::Method: - _ASSERTE(cxt.TargetMethod != NULL); - pCode->EmitCALLVIRT(pCode->GetToken(cxt.TargetMethod), targetArgCount, targetRetCount); - break; case UnsafeAccessorKind::StaticMethod: + { _ASSERTE(cxt.TargetMethod != NULL); - pCode->EmitCALL(pCode->GetToken(cxt.TargetMethod), targetArgCount, targetRetCount); + mdToken target; + if (!cxt.TargetMethod->HasClassOrMethodInstantiation()) + { + target = pCode->GetToken(cxt.TargetMethod); + } + else + { + DWORD targetGenericCount = cxt.TargetMethod->GetNumGenericMethodArgs(); + + mdToken methodSpecSigToken = mdTokenNil; + SigBuilder sigBuilder; + uint32_t sigLen; + PCCOR_SIGNATURE sig; + if (targetGenericCount != 0) + { + // Create signature for the MethodSpec. See ECMA-335 - II.23.2.15 + sigBuilder.AppendByte(IMAGE_CEE_CS_CALLCONV_GENERICINST); + sigBuilder.AppendData(targetGenericCount); + for (DWORD i = 0; i < targetGenericCount; ++i) + { + sigBuilder.AppendElementType(ELEMENT_TYPE_MVAR); + sigBuilder.AppendData(i); + } + sigLen; + sig = (PCCOR_SIGNATURE)sigBuilder.GetSignature((DWORD*)&sigLen); + methodSpecSigToken = pCode->GetSigToken(sig, sigLen); + } + + cxt.TargetTypeSig.GetSignature(&sig, &sigLen); + mdToken targetTypeSigToken = pCode->GetSigToken(sig, sigLen); + + if (methodSpecSigToken == mdTokenNil) + { + // Create a MemberRef + target = pCode->GetToken(cxt.TargetMethod, targetTypeSigToken); + _ASSERTE(TypeFromToken(target) == mdtMemberRef); + } + else + { + // Use the method declaration Instantiation to find the instantiated MethodDesc target. + Instantiation methodInst = cxt.Declaration->GetMethodInstantiation(); + MethodDesc* instantiatedTarget = MethodDesc::FindOrCreateAssociatedMethodDesc(cxt.TargetMethod, cxt.TargetType.GetMethodTable(), FALSE, methodInst, TRUE); + + // Create a MethodSpec + target = pCode->GetToken(instantiatedTarget, targetTypeSigToken, methodSpecSigToken); + _ASSERTE(TypeFromToken(target) == mdtMethodSpec); + } + } + + if (cxt.Kind == UnsafeAccessorKind::StaticMethod) + { + pCode->EmitCALL(target, targetArgCount, targetRetCount); + } + else + { + pCode->EmitCALLVIRT(target, targetArgCount, targetRetCount); + } break; + } case UnsafeAccessorKind::Field: + { _ASSERTE(cxt.TargetField != NULL); - pCode->EmitLDFLDA(pCode->GetToken(cxt.TargetField)); + mdToken target; + if (!cxt.TargetType.HasInstantiation()) + { + target = pCode->GetToken(cxt.TargetField); + } + else + { + // See the static field case for why this can be mdTokenNil. + mdToken targetTypeSigToken = mdTokenNil; + target = pCode->GetToken(cxt.TargetField, targetTypeSigToken); + } + pCode->EmitLDFLDA(target); break; + } case UnsafeAccessorKind::StaticField: _ASSERTE(cxt.TargetField != NULL); - pCode->EmitLDSFLDA(pCode->GetToken(cxt.TargetField)); + mdToken target; + if (!cxt.TargetType.HasInstantiation()) + { + target = pCode->GetToken(cxt.TargetField); + } + else + { + // For accessing a generic instance field, every instantiation will + // be at the same offset, and be the same size, with the same GC layout, + // as long as the generic is canonically equivalent. However, for static fields, + // while the offset, size and GC layout remain the same, the address of the + // field is different, and needs to be found by a lookup of some form. The + // current form of lookup means the exact type isn't with a type signature. + PCCOR_SIGNATURE sig; + uint32_t sigLen; + cxt.TargetTypeSig.GetSignature(&sig, &sigLen); + mdToken targetTypeSigToken = pCode->GetSigToken(sig, sigLen); + target = pCode->GetToken(cxt.TargetField, targetTypeSigToken); + } + pCode->EmitLDSFLDA(target); break; default: _ASSERTE(!"Unknown UnsafeAccessorKind"); @@ -1449,10 +1675,6 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET if (!IsStatic()) ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); - // Block generic support early - if (HasClassOrMethodInstantiation()) - ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); - UnsafeAccessorKind kind; SString name; @@ -1467,12 +1689,19 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET // * Instance member access - examine type of first parameter // * Static member access - examine type of first parameter TypeHandle retType; + CorElementType retCorType; TypeHandle firstArgType; + CorElementType firstArgCorType = ELEMENT_TYPE_END; + retCorType = context.DeclarationSig.GetReturnType(); retType = context.DeclarationSig.GetRetTypeHandleThrowing(); UINT argCount = context.DeclarationSig.NumFixedArgs(); if (argCount > 0) { context.DeclarationSig.NextArg(); + + // Get the target type signature and resolve to a type handle. + context.TargetTypeSig = context.DeclarationSig.GetArgProps(); + (void)context.TargetTypeSig.PeekElemType(&firstArgCorType); firstArgType = context.DeclarationSig.GetLastTypeHandleThrowing(); } @@ -1491,7 +1720,9 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); } - context.TargetType = ValidateTargetType(retType); + // Get the target type signature from the return type. + context.TargetTypeSig = context.DeclarationSig.GetReturnProps(); + context.TargetType = ValidateTargetType(retType, retCorType); if (!TrySetTargetMethod(context, ".ctor")) MemberLoader::ThrowMissingMethodException(context.TargetType.AsMethodTable(), ".ctor"); break; @@ -1511,7 +1742,7 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); } - context.TargetType = ValidateTargetType(firstArgType); + context.TargetType = ValidateTargetType(firstArgType, firstArgCorType); context.IsTargetStatic = kind == UnsafeAccessorKind::StaticMethod; if (!TrySetTargetMethod(context, name.GetUTF8())) MemberLoader::ThrowMissingMethodException(context.TargetType.AsMethodTable(), name.GetUTF8()); @@ -1536,7 +1767,7 @@ bool MethodDesc::TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMET ThrowHR(COR_E_BADIMAGEFORMAT, BFA_INVALID_UNSAFEACCESSOR); } - context.TargetType = ValidateTargetType(firstArgType); + context.TargetType = ValidateTargetType(firstArgType, firstArgCorType); context.IsTargetStatic = kind == UnsafeAccessorKind::StaticField; if (!TrySetTargetField(context, name.GetUTF8(), retType.GetTypeParam())) MemberLoader::ThrowMissingFieldException(context.TargetType.AsMethodTable(), name.GetUTF8()); diff --git a/src/coreclr/vm/siginfo.hpp b/src/coreclr/vm/siginfo.hpp index a0ec6b3d4a26c..fab9a79260d2d 100644 --- a/src/coreclr/vm/siginfo.hpp +++ b/src/coreclr/vm/siginfo.hpp @@ -394,7 +394,7 @@ class Substitution Substitution( ModuleBase * pModuleArg, - const SigPointer & sigInst, + SigPointer sigInst, const Substitution * pNextSubstitution) { LIMITED_METHOD_CONTRACT; diff --git a/src/coreclr/vm/stubgen.cpp b/src/coreclr/vm/stubgen.cpp index fcb0bf05bdc6e..5ecb723b68c5d 100644 --- a/src/coreclr/vm/stubgen.cpp +++ b/src/coreclr/vm/stubgen.cpp @@ -3127,6 +3127,18 @@ int ILStubLinker::GetToken(MethodDesc* pMD) return m_tokenMap.GetToken(pMD); } +int ILStubLinker::GetToken(MethodDesc* pMD, mdToken typeSignature) +{ + STANDARD_VM_CONTRACT; + return m_tokenMap.GetToken(pMD, typeSignature); +} + +int ILStubLinker::GetToken(MethodDesc* pMD, mdToken typeSignature, mdToken methodSignature) +{ + STANDARD_VM_CONTRACT; + return m_tokenMap.GetToken(pMD, typeSignature, methodSignature); +} + int ILStubLinker::GetToken(MethodTable* pMT) { STANDARD_VM_CONTRACT; @@ -3145,6 +3157,12 @@ int ILStubLinker::GetToken(FieldDesc* pFD) return m_tokenMap.GetToken(pFD); } +int ILStubLinker::GetToken(FieldDesc* pFD, mdToken typeSignature) +{ + STANDARD_VM_CONTRACT; + return m_tokenMap.GetToken(pFD, typeSignature); +} + int ILStubLinker::GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig) { STANDARD_VM_CONTRACT; @@ -3221,6 +3239,16 @@ int ILCodeStream::GetToken(MethodDesc* pMD) STANDARD_VM_CONTRACT; return m_pOwner->GetToken(pMD); } +int ILCodeStream::GetToken(MethodDesc* pMD, mdToken typeSignature) +{ + STANDARD_VM_CONTRACT; + return m_pOwner->GetToken(pMD, typeSignature); +} +int ILCodeStream::GetToken(MethodDesc* pMD, mdToken typeSignature, mdToken methodSignature) +{ + STANDARD_VM_CONTRACT; + return m_pOwner->GetToken(pMD, typeSignature, methodSignature); +} int ILCodeStream::GetToken(MethodTable* pMT) { STANDARD_VM_CONTRACT; @@ -3236,6 +3264,11 @@ int ILCodeStream::GetToken(FieldDesc* pFD) STANDARD_VM_CONTRACT; return m_pOwner->GetToken(pFD); } +int ILCodeStream::GetToken(FieldDesc* pFD, mdToken typeSignature) +{ + STANDARD_VM_CONTRACT; + return m_pOwner->GetToken(pFD, typeSignature); +} int ILCodeStream::GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig) { STANDARD_VM_CONTRACT; diff --git a/src/coreclr/vm/stubgen.h b/src/coreclr/vm/stubgen.h index 56b9b6f6cdcc1..968e5f9b48291 100644 --- a/src/coreclr/vm/stubgen.h +++ b/src/coreclr/vm/stubgen.h @@ -295,10 +295,13 @@ class TokenLookupMap for (COUNT_T i = 0; i < pSrc->m_signatures.GetCount(); i++) { const CQuickBytesSpecifySize<16>& src = pSrc->m_signatures[i]; - CQuickBytesSpecifySize<16>& dst = *m_signatures.Append(); - dst.AllocThrows(src.Size()); - memcpy(dst.Ptr(), src.Ptr(), src.Size()); + auto dst = m_signatures.Append(); + dst->AllocThrows(src.Size()); + memcpy(dst->Ptr(), src.Ptr(), src.Size()); } + + m_memberRefs.Set(pSrc->m_memberRefs); + m_methodSpecs.Set(pSrc->m_methodSpecs); } TypeHandle LookupTypeDef(mdToken token) @@ -316,6 +319,55 @@ class TokenLookupMap WRAPPER_NO_CONTRACT; return LookupTokenWorker(token); } + + struct MemberRefEntry final + { + CorTokenType Type; + mdToken ClassSignatureToken; + union + { + FieldDesc* Field; + MethodDesc* Method; + } Entry; + }; + MemberRefEntry LookupMemberRef(mdToken token) + { + CONTRACTL + { + NOTHROW; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(RidFromToken(token) - 1 < m_memberRefs.GetCount()); + PRECONDITION(RidFromToken(token) != 0); + PRECONDITION(TypeFromToken(token) == mdtMemberRef); + } + CONTRACTL_END; + + return m_memberRefs[static_cast(RidFromToken(token) - 1)]; + } + + struct MethodSpecEntry final + { + mdToken ClassSignatureToken; + mdToken MethodSignatureToken; + MethodDesc* Method; + }; + MethodSpecEntry LookupMethodSpec(mdToken token) + { + CONTRACTL + { + NOTHROW; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(RidFromToken(token) - 1 < m_methodSpecs.GetCount()); + PRECONDITION(RidFromToken(token) != 0); + PRECONDITION(TypeFromToken(token) == mdtMethodSpec); + } + CONTRACTL_END; + + return m_methodSpecs[static_cast(RidFromToken(token) - 1)]; + } + SigPointer LookupSig(mdToken token) { CONTRACTL @@ -345,11 +397,67 @@ class TokenLookupMap WRAPPER_NO_CONTRACT; return GetTokenWorker(pMD); } + mdToken GetToken(MethodDesc* pMD, mdToken typeSignature) + { + CONTRACTL + { + THROWS; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(pMD != NULL); + } + CONTRACTL_END; + + MemberRefEntry* entry; + mdToken token = GetMemberRefWorker(&entry); + entry->Type = mdtMethodDef; + entry->ClassSignatureToken = typeSignature; + entry->Entry.Method = pMD; + return token; + } + mdToken GetToken(MethodDesc* pMD, mdToken typeSignature, mdToken methodSignature) + { + CONTRACTL + { + THROWS; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(pMD != NULL); + PRECONDITION(typeSignature != mdTokenNil); + PRECONDITION(methodSignature != mdTokenNil); + } + CONTRACTL_END; + + MethodSpecEntry* entry; + mdToken token = GetMethodSpecWorker(&entry); + entry->ClassSignatureToken = typeSignature; + entry->MethodSignatureToken = methodSignature; + entry->Method = pMD; + return token; + } mdToken GetToken(FieldDesc* pFieldDesc) { WRAPPER_NO_CONTRACT; return GetTokenWorker(pFieldDesc); } + mdToken GetToken(FieldDesc* pFieldDesc, mdToken typeSignature) + { + CONTRACTL + { + THROWS; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(pFieldDesc != NULL); + } + CONTRACTL_END; + + MemberRefEntry* entry; + mdToken token = GetMemberRefWorker(&entry); + entry->Type = mdtFieldDef; + entry->ClassSignatureToken = typeSignature; + entry->Entry.Field = pFieldDesc; + return token; + } mdToken GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig) { @@ -370,6 +478,38 @@ class TokenLookupMap } protected: + mdToken GetMemberRefWorker(MemberRefEntry** entry) + { + CONTRACTL + { + THROWS; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(entry != NULL); + } + CONTRACTL_END; + + mdToken token = TokenFromRid(m_memberRefs.GetCount(), mdtMemberRef) + 1; + *entry = &*m_memberRefs.Append(); // Dereference the iterator and then take the address + return token; + } + + mdToken GetMethodSpecWorker(MethodSpecEntry** entry) + { + CONTRACTL + { + THROWS; + MODE_ANY; + GC_NOTRIGGER; + PRECONDITION(entry != NULL); + } + CONTRACTL_END; + + mdToken token = TokenFromRid(m_methodSpecs.GetCount(), mdtMethodSpec) + 1; + *entry = &*m_methodSpecs.Append(); // Dereference the iterator and then take the address + return token; + } + template HandleType LookupTokenWorker(mdToken token) { @@ -411,9 +551,11 @@ class TokenLookupMap return token; } - unsigned int m_nextAvailableRid; + uint32_t m_nextAvailableRid; CQuickBytesSpecifySize m_qbEntries; SArray, FALSE> m_signatures; + SArray m_memberRefs; + SArray m_methodSpecs; }; class ILCodeLabel; @@ -580,9 +722,12 @@ class ILStubLinker // ILCodeLabel* NewCodeLabel(); int GetToken(MethodDesc* pMD); + int GetToken(MethodDesc* pMD, mdToken typeSignature); + int GetToken(MethodDesc* pMD, mdToken typeSignature, mdToken methodSignature); int GetToken(MethodTable* pMT); int GetToken(TypeHandle th); int GetToken(FieldDesc* pFD); + int GetToken(FieldDesc* pFD, mdToken typeSignature); int GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig); DWORD NewLocal(CorElementType typ = ELEMENT_TYPE_I); DWORD NewLocal(LocalDesc loc); @@ -809,9 +954,12 @@ class ILCodeStream // int GetToken(MethodDesc* pMD); + int GetToken(MethodDesc* pMD, mdToken typeSignature); + int GetToken(MethodDesc* pMD, mdToken typeSignature, mdToken methodSignature); int GetToken(MethodTable* pMT); int GetToken(TypeHandle th); int GetToken(FieldDesc* pFD); + int GetToken(FieldDesc* pFD, mdToken typeSignature); int GetSigToken(PCCOR_SIGNATURE pSig, DWORD cbSig); DWORD NewLocal(CorElementType typ = ELEMENT_TYPE_I); diff --git a/src/coreclr/vm/typehandle.h b/src/coreclr/vm/typehandle.h index 8483a935af613..f0f5a4604ab22 100644 --- a/src/coreclr/vm/typehandle.h +++ b/src/coreclr/vm/typehandle.h @@ -647,9 +647,7 @@ inline CHECK CheckPointer(TypeHandle th, IsNullOK ok = NULL_NOT_OK) /*************************************************************************/ // Instantiation is representation of generic instantiation. -// It is simple read-only array of TypeHandles. In NGen, the type handles -// may be encoded using indirections. That's one reason why it is convenient -// to have wrapper class that performs the decoding. +// It is simple read-only array of TypeHandles. class Instantiation { public: @@ -695,6 +693,14 @@ class Instantiation } #endif + Instantiation& operator=(const Instantiation& inst) + { + _ASSERTE(this != &inst); + m_pArgs = inst.m_pArgs; + m_nArgs = inst.m_nArgs; + return *this; + } + // Return i-th instantiation argument TypeHandle operator[](DWORD iArg) const { diff --git a/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx b/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx index b6d8fd00aa404..f5777fc29413a 100644 --- a/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx +++ b/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx @@ -1101,6 +1101,12 @@ GenericArguments[{0}], '{1}', on '{2}' violates the constraint of type '{3}'. + + Generic type constraints do not match. + + + Generic method constraints do not match. + The number of generic arguments provided doesn't equal the arity of the generic type definition. @@ -3346,7 +3352,7 @@ Object type {0} does not match target type {1}. - + Non-static field requires a target. diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs new file mode 100644 index 0000000000000..e1029797bf12c --- /dev/null +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs @@ -0,0 +1,460 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +using Xunit; + +struct Struct { } + +public static unsafe class UnsafeAccessorsTestsGenerics +{ + class MyList + { + public const string StaticGenericFieldName = nameof(_GF); + public const string StaticFieldName = nameof(_F); + public const string GenericFieldName = nameof(_list); + + static MyList() + { + _F = typeof(T).ToString(); + } + + public static void SetStaticGenericField(T val) => _GF = val; + private static T _GF; + private static string _F; + + private List _list; + + public MyList() => _list = new(); + + private MyList(int i) => _list = new(i); + + private MyList(List list) => _list = list; + + private void Clear() => _list.Clear(); + + private void Add(T t) => _list.Add(t); + + private void AddWithIgnore(T t, U _) => _list.Add(t); + + private bool CanCastToElementType(U t) => t is T; + + private static bool CanUseElementType(U t) => t is T; + + private static Type ElementType() => typeof(T); + + private void Add(int a) => + Unsafe.As>(_list).Add(a); + + private void Add(string a) => + Unsafe.As>(_list).Add(a); + + private void Add(Struct a) => + Unsafe.As>(_list).Add(a); + + public int Count => _list.Count; + + public int Capacity => _list.Capacity; + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] + public static void Verify_Generic_AccessStaticFieldClass() + { + Console.WriteLine($"Running {nameof(Verify_Generic_AccessStaticFieldClass)}"); + + Assert.Equal(typeof(int).ToString(), GetPrivateStaticFieldInt((MyList)null)); + + Assert.Equal(typeof(string).ToString(), GetPrivateStaticFieldString((MyList)null)); + + Assert.Equal(typeof(Struct).ToString(), GetPrivateStaticFieldStruct((MyList)null)); + + { + int expected = 10; + MyList.SetStaticGenericField(expected); + Assert.Equal(expected, GetPrivateStaticField((MyList)null)); + } + { + string expected = "abc"; + MyList.SetStaticGenericField(expected); + Assert.Equal(expected, GetPrivateStaticField((MyList)null)); + } + + [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=MyList.StaticFieldName)] + extern static ref string GetPrivateStaticFieldInt(MyList d); + + [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=MyList.StaticFieldName)] + extern static ref string GetPrivateStaticFieldString(MyList d); + + [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=MyList.StaticFieldName)] + extern static ref string GetPrivateStaticFieldStruct(MyList d); + + [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=MyList.StaticGenericFieldName)] + extern static ref V GetPrivateStaticField(MyList d); + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] + public static void Verify_Generic_AccessFieldClass() + { + Console.WriteLine($"Running {nameof(Verify_Generic_AccessFieldClass)}"); + { + MyList a = new(); + Assert.NotNull(GetPrivateField(a)); + } + { + MyList a = new(); + Assert.NotNull(GetPrivateField(a)); + } + { + MyList a = new(); + Assert.NotNull(GetPrivateField(a)); + } + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name=MyList.GenericFieldName)] + extern static ref List GetPrivateField(MyList a); + } + + class Base + { + protected virtual string CreateMessageGeneric(T t) => $"{nameof(Base)}:{t}"; + } + + class GenericBase : Base + { + protected virtual string CreateMessage(T t) => $"{nameof(GenericBase)}:{t}"; + protected override string CreateMessageGeneric(U u) => $"{nameof(GenericBase)}:{u}"; + } + + sealed class Derived1 : GenericBase + { + protected override string CreateMessage(string u) => $"{nameof(Derived1)}:{u}"; + protected override string CreateMessageGeneric(U t) => $"{nameof(Derived1)}:{t}"; + } + + sealed class Derived2 : GenericBase + { + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] + public static void Verify_Generic_InheritanceMethodResolution() + { + string expect = "abc"; + Console.WriteLine($"Running {nameof(Verify_Generic_InheritanceMethodResolution)}"); + { + Base a = new(); + Assert.Equal($"{nameof(Base)}:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(Base)}:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(Base)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + } + { + GenericBase a = new(); + Assert.Equal($"{nameof(GenericBase)}:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(GenericBase)}:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(GenericBase)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + } + { + GenericBase a = new(); + Assert.Equal($"{nameof(GenericBase)}:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(GenericBase)}:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(GenericBase)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + } + { + GenericBase a = new(); + Assert.Equal($"{nameof(GenericBase)}:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(GenericBase)}:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(GenericBase)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + } + { + Derived1 a = new(); + Assert.Equal($"{nameof(Derived1)}:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(Derived1)}:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(Derived1)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + } + { + // Verify resolution of generic override logic. + Derived1 a1 = new(); + Derived2 a2 = new(); + Assert.Equal($"{nameof(Derived1)}:{expect}", Accessors.CreateMessage(a1, expect)); + Assert.Equal($"{nameof(GenericBase)}:{expect}", Accessors.CreateMessage(a2, expect)); + } + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CreateMessageGeneric")] + extern static string CreateMessage(Base b, W w); + } + + sealed class Accessors + { + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + public extern static MyList Create(int a); + + [UnsafeAccessor(UnsafeAccessorKind.Constructor)] + public extern static MyList CreateWithList(List a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = ".ctor")] + public extern static void CallCtorAsMethod(MyList l, List a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] + public extern static void AddInt(MyList l, int a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] + public extern static void AddString(MyList l, string a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] + public extern static void AddStruct(MyList l, Struct a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Clear")] + public extern static void Clear(MyList l); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "Add")] + public extern static void Add(MyList l, V element); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "AddWithIgnore")] + public extern static void AddWithIgnore(MyList l, V element, W ignore); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CanCastToElementType")] + public extern static bool CanCastToElementType(MyList l, W element); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CreateMessage")] + public extern static string CreateMessage(GenericBase b, V v); + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "ElementType")] + public extern static Type ElementType(MyList l); + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "CanUseElementType")] + public extern static bool CanUseElementType(MyList l, W element); + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] + public static void Verify_Generic_CallCtor() + { + Console.WriteLine($"Running {nameof(Verify_Generic_CallCtor)}"); + + // Call constructor with non-generic parameter + { + MyList a = Accessors.Create(1); + Assert.Equal(1, a.Capacity); + } + { + MyList a = Accessors.Create(2); + Assert.Equal(2, a.Capacity); + } + { + MyList a = Accessors.Create(3); + Assert.Equal(3, a.Capacity); + } + + // Call constructor using generic parameter + { + MyList a = Accessors.CreateWithList([ 1 ]); + Assert.Equal(1, a.Count); + } + { + MyList a = Accessors.CreateWithList([ "1", "2" ]); + Assert.Equal(2, a.Count); + } + { + MyList a = Accessors.CreateWithList([new Struct(), new Struct(), new Struct()]); + Assert.Equal(3, a.Count); + } + + // Call constructors as methods + { + MyList a = (MyList)RuntimeHelpers.GetUninitializedObject(typeof(MyList)); + Accessors.CallCtorAsMethod(a, [1]); + Assert.Equal(1, a.Count); + } + { + MyList a = (MyList)RuntimeHelpers.GetUninitializedObject(typeof(MyList)); + Accessors.CallCtorAsMethod(a, ["1", "2"]); + Assert.Equal(2, a.Count); + } + { + MyList a = (MyList)RuntimeHelpers.GetUninitializedObject(typeof(MyList)); + Accessors.CallCtorAsMethod(a, [new Struct(), new Struct(), new Struct()]); + Assert.Equal(3, a.Count); + } + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] + public static void Verify_Generic_GenericTypeNonGenericInstanceMethod() + { + Console.WriteLine($"Running {nameof(Verify_Generic_GenericTypeNonGenericInstanceMethod)}"); + { + MyList a = new(); + Accessors.AddInt(a, 1); + Assert.Equal(1, a.Count); + Accessors.Clear(a); + Assert.Equal(0, a.Count); + } + { + MyList a = new(); + Accessors.AddString(a, "1"); + Accessors.AddString(a, "2"); + Assert.Equal(2, a.Count); + Accessors.Clear(a); + Assert.Equal(0, a.Count); + } + { + MyList a = new(); + Accessors.AddStruct(a, new Struct()); + Accessors.AddStruct(a, new Struct()); + Accessors.AddStruct(a, new Struct()); + Assert.Equal(3, a.Count); + Accessors.Clear(a); + Assert.Equal(0, a.Count); + } + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] + public static void Verify_Generic_GenericTypeGenericInstanceMethod() + { + Console.WriteLine($"Running {nameof(Verify_Generic_GenericTypeGenericInstanceMethod)}"); + { + MyList a = new(); + Assert.True(Accessors.CanCastToElementType(a, 1)); + Assert.False(Accessors.CanCastToElementType(a, string.Empty)); + Assert.False(Accessors.CanCastToElementType(a, new Struct())); + Assert.Equal(0, a.Count); + Accessors.Add(a, 1); + Accessors.AddWithIgnore(a, 1, 1); + Accessors.AddWithIgnore(a, 1, string.Empty); + Accessors.AddWithIgnore(a, 1, new Struct()); + Assert.Equal(4, a.Count); + } + { + MyList a = new(); + Assert.False(Accessors.CanCastToElementType(a, 1)); + Assert.True(Accessors.CanCastToElementType(a, string.Empty)); + Assert.False(Accessors.CanCastToElementType(a, new Struct())); + Assert.Equal(0, a.Count); + Accessors.Add(a, string.Empty); + Accessors.AddWithIgnore(a, string.Empty, 1); + Accessors.AddWithIgnore(a, string.Empty, string.Empty); + Accessors.AddWithIgnore(a, string.Empty, new Struct()); + Assert.Equal(4, a.Count); + } + { + MyList a = new(); + Assert.False(Accessors.CanCastToElementType(a, 1)); + Assert.False(Accessors.CanCastToElementType(a, string.Empty)); + Assert.True(Accessors.CanCastToElementType(a, new Struct())); + Assert.Equal(0, a.Count); + Accessors.Add(a, new Struct()); + Accessors.AddWithIgnore(a, new Struct(), 1); + Accessors.AddWithIgnore(a, new Struct(), string.Empty); + Accessors.AddWithIgnore(a, new Struct(), new Struct()); + Assert.Equal(4, a.Count); + } + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] + public static void Verify_Generic_GenericTypeNonGenericStaticMethod() + { + Console.WriteLine($"Running {nameof(Verify_Generic_GenericTypeNonGenericStaticMethod)}"); + { + Assert.Equal(typeof(int), Accessors.ElementType(null)); + Assert.Equal(typeof(string), Accessors.ElementType(null)); + Assert.Equal(typeof(Struct), Accessors.ElementType(null)); + } + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] + public static void Verify_Generic_GenericTypeGenericStaticMethod() + { + Console.WriteLine($"Running {nameof(Verify_Generic_GenericTypeGenericStaticMethod)}"); + { + Assert.True(Accessors.CanUseElementType(null, 1)); + Assert.False(Accessors.CanUseElementType(null, string.Empty)); + Assert.False(Accessors.CanUseElementType(null, new Struct())); + } + { + Assert.False(Accessors.CanUseElementType(null, 1)); + Assert.True(Accessors.CanUseElementType(null, string.Empty)); + Assert.False(Accessors.CanUseElementType(null, new Struct())); + } + { + Assert.False(Accessors.CanUseElementType(null, 1)); + Assert.False(Accessors.CanUseElementType(null, string.Empty)); + Assert.True(Accessors.CanUseElementType(null, new Struct())); + } + } + + class ClassWithConstraints + { + private string M() where T : U, IEquatable + => $"{typeof(T)}|{typeof(U)}"; + + private static string SM() where T : U, IEquatable + => $"{typeof(T)}|{typeof(U)}"; + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] + public static void Verify_Generic_ConstraintEnforcement() + { + Console.WriteLine($"Running {nameof(Verify_Generic_ConstraintEnforcement)}"); + + Assert.Equal($"{typeof(string)}|{typeof(object)}", CallMethod(new ClassWithConstraints())); + Assert.Equal($"{typeof(string)}|{typeof(object)}", CallStaticMethod(null)); + Assert.Throws(() => CallMethod_NoConstraints(new ClassWithConstraints())); + Assert.Throws(() => CallMethod_MissingConstraint(new ClassWithConstraints())); + Assert.Throws(() => CallStaticMethod_NoConstraints(null)); + Assert.Throws(() => CallStaticMethod_MissingConstraint(null)); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + extern static string CallMethod(ClassWithConstraints c) where V : W, IEquatable; + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + extern static string CallMethod_NoConstraints(ClassWithConstraints c); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + extern static string CallMethod_MissingConstraint(ClassWithConstraints c) where V : W; + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "SM")] + extern static string CallStaticMethod(ClassWithConstraints c) where V : W, IEquatable; + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "SM")] + extern static string CallStaticMethod_NoConstraints(ClassWithConstraints c); + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "SM")] + extern static string CallStaticMethod_MissingConstraint(ClassWithConstraints c) where V : W; + } + + class Invalid + { + [UnsafeAccessor(UnsafeAccessorKind.Method, Name=nameof(ToString))] + public static extern string CallToString(U a); + } + + class Invalid + { + [UnsafeAccessor(UnsafeAccessorKind.Method, Name=nameof(ToString))] + public static extern string CallToString(T a); + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/89439", TestRuntimes.Mono)] + public static void Verify_Generic_InvalidUseUnsafeAccessor() + { + Console.WriteLine($"Running {nameof(Verify_Generic_InvalidUseUnsafeAccessor)}"); + + Assert.Throws(() => Invalid.CallToString(0)); + Assert.Throws(() => Invalid.CallToString(0)); + Assert.Throws(() => Invalid.CallToString(string.Empty)); + Assert.Throws(() => Invalid.CallToString(string.Empty)); + Assert.Throws(() => Invalid.CallToString(new Struct())); + Assert.Throws(() => Invalid.CallToString(new Struct())); + } +} diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.cs b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.cs index 6e0a562f32a9b..30f65993da6cc 100644 --- a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.cs +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.cs @@ -85,33 +85,6 @@ struct UserDataValue public string GetFieldValue() => _f; } - class UserDataGenericClass - { - public const string StaticGenericFieldName = nameof(_GF); - public const string GenericFieldName = nameof(_gf); - public const string StaticGenericMethodName = nameof(_GM); - public const string GenericMethodName = nameof(_gm); - - public const string StaticFieldName = nameof(_F); - public const string FieldName = nameof(_f); - public const string StaticMethodName = nameof(_M); - public const string MethodName = nameof(_m); - - private static T _GF; - private T _gf; - - private static string _F = PrivateStatic; - private string _f; - - public UserDataGenericClass() { _f = Private; } - - private static string _GM(T s, ref T sr, in T si) => typeof(T).ToString(); - private string _gm(T s, ref T sr, in T si) => typeof(T).ToString(); - - private static string _M(string s, ref string sr, in string si) => s; - private string _m(string s, ref string sr, in string si) => s; - } - [UnsafeAccessor(UnsafeAccessorKind.Constructor)] extern static UserDataClass CallPrivateConstructorClass(); @@ -215,23 +188,6 @@ public static void Verify_AccessFieldClass() extern static ref string GetPrivateField(UserDataClass d); } - [Fact] - [ActiveIssue("https://github.com/dotnet/runtime/issues/92633")] - public static void Verify_AccessStaticFieldGenericClass() - { - Console.WriteLine($"Running {nameof(Verify_AccessStaticFieldGenericClass)}"); - - Assert.Equal(PrivateStatic, GetPrivateStaticFieldInt((UserDataGenericClass)null)); - - Assert.Equal(PrivateStatic, GetPrivateStaticFieldString((UserDataGenericClass)null)); - - [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=UserDataGenericClass.StaticFieldName)] - extern static ref string GetPrivateStaticFieldInt(UserDataGenericClass d); - - [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name=UserDataGenericClass.StaticFieldName)] - extern static ref string GetPrivateStaticFieldString(UserDataGenericClass d); - } - [Fact] public static void Verify_AccessStaticFieldValue() { @@ -259,23 +215,6 @@ public static void Verify_AccessFieldValue() extern static ref string GetPrivateField(ref UserDataValue d); } - [Fact] - [ActiveIssue("https://github.com/dotnet/runtime/issues/92633")] - public static void Verify_AccessFieldGenericClass() - { - Console.WriteLine($"Running {nameof(Verify_AccessFieldGenericClass)}"); - - Assert.Equal(Private, GetPrivateFieldInt(new UserDataGenericClass())); - - Assert.Equal(Private, GetPrivateFieldString(new UserDataGenericClass())); - - [UnsafeAccessor(UnsafeAccessorKind.Field, Name=UserDataGenericClass.FieldName)] - extern static ref string GetPrivateFieldInt(UserDataGenericClass d); - - [UnsafeAccessor(UnsafeAccessorKind.Field, Name=UserDataGenericClass.FieldName)] - extern static ref string GetPrivateFieldString(UserDataGenericClass d); - } - [Fact] public static void Verify_AccessStaticMethodClass() { @@ -587,15 +526,6 @@ class Invalid { [UnsafeAccessor(UnsafeAccessorKind.Method, Name=nameof(ToString))] public extern string NonStatic(string a); - - [UnsafeAccessor(UnsafeAccessorKind.Method, Name=nameof(ToString))] - public static extern string CallToString(U a); - } - - class Invalid - { - [UnsafeAccessor(UnsafeAccessorKind.Method, Name=nameof(ToString))] - public static extern string CallToString(T a); } [Fact] @@ -620,8 +550,6 @@ public static void Verify_InvalidUseUnsafeAccessor() Assert.Throws(() => LookUpFailsOnPointers(null)); Assert.Throws(() => LookUpFailsOnFunctionPointers(null)); Assert.Throws(() => new Invalid().NonStatic(string.Empty)); - Assert.Throws(() => Invalid.CallToString(string.Empty)); - Assert.Throws(() => Invalid.CallToString(string.Empty)); Assert.Throws(() => { string str = string.Empty; diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.csproj b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.csproj index 876d006ea96eb..f551f9b48c249 100644 --- a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.csproj +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.csproj @@ -6,6 +6,7 @@ +