Skip to content

Commit

Permalink
DataContractSerialization doesn't work with TrimMode - link (#42824)
Browse files Browse the repository at this point in the history
DataContractSerialization has some Reflection "shim" methods that the ILLinker can't see through. This causes critical methods to be trimmed and applications to fail. These methods were put in place in .NET Core 1.0 when the full Reflection API wasn't available.

The fix is to remove these "shim" Reflection APIs and use Reflection directly.

Fix #41525
Fix #42754
  • Loading branch information
eerhardt authored Sep 30, 2020
1 parent 957fdb9 commit 118eee9
Show file tree
Hide file tree
Showing 16 changed files with 124 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ internal static bool IsNonAttributedTypeValidForSerialization(Type type)
else
{
return (type.IsVisible &&
type.GetConstructor(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public, Array.Empty<Type>()) != null);
type.GetConstructor(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public, null, Array.Empty<Type>(), null) != null);
}
}

Expand Down Expand Up @@ -1368,7 +1368,7 @@ private void SetKeyValuePairAdapterFlags(Type type)
{
_isKeyValuePairAdapter = true;
_keyValuePairGenericArguments = type.GetGenericArguments();
_keyValuePairCtorInfo = type.GetConstructor(Globals.ScanAllMembers, new Type[] { Globals.TypeOfKeyValuePair.MakeGenericType(_keyValuePairGenericArguments) });
_keyValuePairCtorInfo = type.GetConstructor(Globals.ScanAllMembers, null, new Type[] { Globals.TypeOfKeyValuePair.MakeGenericType(_keyValuePairGenericArguments) }, null);
_getKeyValuePairMethodInfo = type.GetMethod("GetKeyValuePair", Globals.ScanAllMembers)!;
}
}
Expand Down Expand Up @@ -1425,7 +1425,7 @@ internal MethodInfo? GetKeyValuePairMethodInfo
if (type.IsValueType)
return null;

ConstructorInfo? ctor = type.GetConstructor(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public, Array.Empty<Type>());
ConstructorInfo? ctor = type.GetConstructor(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public, null, Array.Empty<Type>(), null);
if (ctor == null)
throw System.Runtime.Serialization.DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidDataContractException(SR.Format(SR.NonAttributedSerializableTypesMustHaveDefaultConstructor, DataContract.GetClrTypeFullName(type))));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private static MethodInfo ObjectEquals
{
if (s_objectEquals == null)
{
s_objectEquals = Globals.TypeOfObject.GetMethod("Equals", BindingFlags.Public | BindingFlags.Static);
s_objectEquals = typeof(object).GetMethod("Equals", BindingFlags.Public | BindingFlags.Static);
Debug.Assert(s_objectEquals != null);
}
return s_objectEquals;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ internal Type GetCollectionElementType()
enumeratorType = GetEnumeratorMethod.ReturnType;
}

MethodInfo? getCurrentMethod = enumeratorType.GetMethod(Globals.GetCurrentMethodName, BindingFlags.Instance | BindingFlags.Public, Array.Empty<Type>());
MethodInfo? getCurrentMethod = enumeratorType.GetMethod(Globals.GetCurrentMethodName, BindingFlags.Instance | BindingFlags.Public, null, Array.Empty<Type>(), null);
if (getCurrentMethod == null)
{
if (enumeratorType.IsInterface)
Expand Down Expand Up @@ -1134,7 +1134,7 @@ private static bool IsCollectionOrTryCreate(Type type, bool tryCreate, out DataC
}
}

getEnumeratorMethod = Globals.TypeOfIEnumerable.GetMethod(Globals.GetEnumeratorMethodName)!;
getEnumeratorMethod = typeof(IEnumerable).GetMethod(Globals.GetEnumeratorMethodName)!;
}
if (tryCreate)
dataContract = new CollectionDataContract(type, (CollectionKind)(i + 1), itemType, getEnumeratorMethod, addMethod, null/*defaultCtor*/);
Expand Down Expand Up @@ -1364,7 +1364,7 @@ private static void GetCollectionMethods(Type type, Type interfaceType, Type[] a

if (getEnumeratorMethod == null)
{
getEnumeratorMethod = type.GetMethod(Globals.GetEnumeratorMethodName, BindingFlags.Instance | BindingFlags.Public, Array.Empty<Type>());
getEnumeratorMethod = type.GetMethod(Globals.GetEnumeratorMethodName, BindingFlags.Instance | BindingFlags.Public, null, Array.Empty<Type>(), null);
if (getEnumeratorMethod == null || !Globals.TypeOfIEnumerator.IsAssignableFrom(getEnumeratorMethod.ReturnType))
{
Type? ienumerableInterface = interfaceType.GetInterfaces().Where(t => t.FullName!.StartsWith("System.Collections.Generic.IEnumerable")).FirstOrDefault();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ internal MethodInfo? ParseMethod
{
if (!_parseMethodSet)
{
MethodInfo? method = UnderlyingType.GetMethod(Globals.ParseMethodName, BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(string) });
MethodInfo? method = UnderlyingType.GetMethod(Globals.ParseMethodName, BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null);

if (method != null && method.ReturnType == UnderlyingType)
{
Expand Down Expand Up @@ -1978,7 +1978,7 @@ private static void ImportKnownTypeAttributes(Type? type, Dictionary<Type, Type>
if (methodName.Length == 0)
DataContract.ThrowInvalidDataContractException(SR.Format(SR.KnownTypeAttributeEmptyString, DataContract.GetClrTypeFullName(type)), type);

MethodInfo? method = type.GetMethod(methodName, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public, Array.Empty<Type>());
MethodInfo? method = type.GetMethod(methodName, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public, null, Array.Empty<Type>(), null);
if (method == null)
DataContract.ThrowInvalidDataContractException(SR.Format(SR.KnownTypeAttributeUnknownMethod, methodName, DataContract.GetClrTypeFullName(type)), type);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -785,46 +785,6 @@ internal static Uri DataContractXsdBaseNamespaceUri
}
}

#region Contract compliance for System.Type

private static bool TypeSequenceEqual(Type[] seq1, Type[] seq2)
{
if (seq1 == null || seq2 == null || seq1.Length != seq2.Length)
return false;
for (int i = 0; i < seq1.Length; i++)
{
if (!seq1[i].Equals(seq2[i]) && !seq1[i].IsAssignableFrom(seq2[i]))
return false;
}
return true;
}

private static MethodBase? FilterMethodBases(MethodBase[]? methodBases, Type[] parameterTypes, string methodName)
{
if (methodBases == null || string.IsNullOrEmpty(methodName))
return null;

var matchedMethods = methodBases.Where(method => method.Name.Equals(methodName));
matchedMethods = matchedMethods.Where(method => TypeSequenceEqual(method.GetParameters().Select(param => param.ParameterType).ToArray(), parameterTypes));
return matchedMethods.FirstOrDefault();
}

internal static ConstructorInfo? GetConstructor(this Type type, BindingFlags bindingFlags, Type[] parameterTypes)
{
ConstructorInfo[] constructorInfos = type.GetConstructors(bindingFlags);
var constructorInfo = FilterMethodBases(constructorInfos.Cast<MethodBase>().ToArray()!, parameterTypes, ".ctor");
return constructorInfo != null ? (ConstructorInfo)constructorInfo : null;
}

internal static MethodInfo? GetMethod(this Type type, string methodName, BindingFlags bindingFlags, Type[] parameterTypes)
{
var methodInfos = type.GetMethods(bindingFlags);
var methodInfo = FilterMethodBases(methodInfos.Cast<MethodBase>().ToArray()!, parameterTypes, methodName);
return methodInfo != null ? (MethodInfo)methodInfo : null;
}

#endregion

private static readonly Type? s_typeOfScriptObject;

internal static ClassDataContract CreateScriptObjectClassDataContract()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ public static MethodInfo GetUninitializedObjectMethod
{
if (s_getUninitializedObjectMethod == null)
{
s_getUninitializedObjectMethod = typeof(XmlFormatReaderGenerator).GetMethod("UnsafeGetUninitializedObject", Globals.ScanAllMembers, new Type[] { typeof(Type) });
s_getUninitializedObjectMethod = typeof(XmlFormatReaderGenerator).GetMethod("UnsafeGetUninitializedObject", Globals.ScanAllMembers, null, new Type[] { typeof(Type) }, null);
Debug.Assert(s_getUninitializedObjectMethod != null);
}
return s_getUninitializedObjectMethod;
Expand All @@ -180,7 +180,7 @@ public static MethodInfo IsStartElementMethod0
{
if (s_isStartElementMethod0 == null)
{
s_isStartElementMethod0 = typeof(XmlReaderDelegator).GetMethod("IsStartElement", Globals.ScanAllMembers, Array.Empty<Type>());
s_isStartElementMethod0 = typeof(XmlReaderDelegator).GetMethod("IsStartElement", Globals.ScanAllMembers, null, Array.Empty<Type>(), null);
Debug.Assert(s_isStartElementMethod0 != null);
}
return s_isStartElementMethod0;
Expand All @@ -192,7 +192,7 @@ public static MethodInfo IsStartElementMethod2
{
if (s_isStartElementMethod2 == null)
{
s_isStartElementMethod2 = typeof(XmlReaderDelegator).GetMethod("IsStartElement", Globals.ScanAllMembers, new Type[] { typeof(XmlDictionaryString), typeof(XmlDictionaryString) });
s_isStartElementMethod2 = typeof(XmlReaderDelegator).GetMethod("IsStartElement", Globals.ScanAllMembers, null, new Type[] { typeof(XmlDictionaryString), typeof(XmlDictionaryString) }, null);
Debug.Assert(s_isStartElementMethod2 != null);
}
return s_isStartElementMethod2;
Expand Down Expand Up @@ -371,7 +371,7 @@ public static MethodInfo WriteAttributeStringMethod
{
if (s_writeAttributeStringMethod == null)
{
s_writeAttributeStringMethod = typeof(XmlWriterDelegator).GetMethod("WriteAttributeString", Globals.ScanAllMembers, new Type[] { typeof(string), typeof(string), typeof(string), typeof(string) });
s_writeAttributeStringMethod = typeof(XmlWriterDelegator).GetMethod("WriteAttributeString", Globals.ScanAllMembers, null, new Type[] { typeof(string), typeof(string), typeof(string), typeof(string) }, null);
Debug.Assert(s_writeAttributeStringMethod != null);
}
return s_writeAttributeStringMethod;
Expand All @@ -383,7 +383,7 @@ public static MethodInfo WriteEndElementMethod
{
if (s_writeEndElementMethod == null)
{
s_writeEndElementMethod = typeof(XmlWriterDelegator).GetMethod("WriteEndElement", Globals.ScanAllMembers, Array.Empty<Type>());
s_writeEndElementMethod = typeof(XmlWriterDelegator).GetMethod("WriteEndElement", Globals.ScanAllMembers, null, Array.Empty<Type>(), null);
Debug.Assert(s_writeEndElementMethod != null);
}
return s_writeEndElementMethod;
Expand Down Expand Up @@ -431,7 +431,7 @@ public static MethodInfo WriteStartElementMethod
{
if (s_writeStartElementMethod == null)
{
s_writeStartElementMethod = typeof(XmlWriterDelegator).GetMethod("WriteStartElement", Globals.ScanAllMembers, new Type[] { typeof(XmlDictionaryString), typeof(XmlDictionaryString) });
s_writeStartElementMethod = typeof(XmlWriterDelegator).GetMethod("WriteStartElement", Globals.ScanAllMembers, null, new Type[] { typeof(XmlDictionaryString), typeof(XmlDictionaryString) }, null);
Debug.Assert(s_writeStartElementMethod != null);
}
return s_writeStartElementMethod;
Expand All @@ -444,7 +444,7 @@ public static MethodInfo WriteStartElementStringMethod
{
if (s_writeStartElementStringMethod == null)
{
s_writeStartElementStringMethod = typeof(XmlWriterDelegator).GetMethod("WriteStartElement", Globals.ScanAllMembers, new Type[] { typeof(string), typeof(string) });
s_writeStartElementStringMethod = typeof(XmlWriterDelegator).GetMethod("WriteStartElement", Globals.ScanAllMembers, null, new Type[] { typeof(string), typeof(string) }, null);
Debug.Assert(s_writeStartElementStringMethod != null);
}
return s_writeStartElementStringMethod;
Expand All @@ -457,7 +457,7 @@ public static MethodInfo ParseEnumMethod
{
if (s_parseEnumMethod == null)
{
s_parseEnumMethod = typeof(Enum).GetMethod("Parse", BindingFlags.Static | BindingFlags.Public, new Type[] { typeof(Type), typeof(string) });
s_parseEnumMethod = typeof(Enum).GetMethod("Parse", BindingFlags.Static | BindingFlags.Public, null, new Type[] { typeof(Type), typeof(string) }, null);
Debug.Assert(s_parseEnumMethod != null);
}
return s_parseEnumMethod;
Expand All @@ -470,7 +470,7 @@ public static MethodInfo GetJsonMemberNameMethod
{
if (s_getJsonMemberNameMethod == null)
{
s_getJsonMemberNameMethod = typeof(XmlObjectSerializerReadContextComplexJson).GetMethod("GetJsonMemberName", Globals.ScanAllMembers, new Type[] { typeof(XmlReaderDelegator) });
s_getJsonMemberNameMethod = typeof(XmlObjectSerializerReadContextComplexJson).GetMethod("GetJsonMemberName", Globals.ScanAllMembers, null, new Type[] { typeof(XmlReaderDelegator) }, null);
Debug.Assert(s_getJsonMemberNameMethod != null);
}
return s_getJsonMemberNameMethod;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,11 +573,11 @@ private void ReadCollection(CollectionDataContract collectionContract)
{
case CollectionKind.GenericDictionary:
type = Globals.TypeOfDictionaryGeneric.MakeGenericType(itemType.GetGenericArguments());
constructor = type.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, Array.Empty<Type>())!;
constructor = type.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, Array.Empty<Type>(), null)!;
break;
case CollectionKind.Dictionary:
type = Globals.TypeOfHashtable;
constructor = type.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, Array.Empty<Type>())!;
constructor = type.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, Array.Empty<Type>(), null)!;
break;
case CollectionKind.Collection:
case CollectionKind.GenericCollection:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ private void WriteCollection(CollectionDataContract collectionContract)
{
enumeratorType = collectionContract.GetEnumeratorMethod.ReturnType;
}
MethodInfo? moveNextMethod = enumeratorType.GetMethod(Globals.MoveNextMethodName, BindingFlags.Instance | BindingFlags.Public, Array.Empty<Type>());
MethodInfo? getCurrentMethod = enumeratorType.GetMethod(Globals.GetCurrentMethodName, BindingFlags.Instance | BindingFlags.Public, Array.Empty<Type>());
MethodInfo? moveNextMethod = enumeratorType.GetMethod(Globals.MoveNextMethodName, BindingFlags.Instance | BindingFlags.Public, null, Array.Empty<Type>(), null);
MethodInfo? getCurrentMethod = enumeratorType.GetMethod(Globals.GetCurrentMethodName, BindingFlags.Instance | BindingFlags.Public, null, Array.Empty<Type>(), null);
if (moveNextMethod == null || getCurrentMethod == null)
{
if (enumeratorType.IsInterface)
Expand Down Expand Up @@ -388,15 +388,15 @@ private void WriteCollection(CollectionDataContract collectionContract)
_ilg.Call(_objectLocal, collectionContract.GetEnumeratorMethod);
if (isDictionary)
{
ConstructorInfo dictEnumCtor = enumeratorType.GetConstructor(Globals.ScanAllMembers, new Type[] { Globals.TypeOfIDictionaryEnumerator })!;
ConstructorInfo dictEnumCtor = enumeratorType.GetConstructor(Globals.ScanAllMembers, null, new Type[] { Globals.TypeOfIDictionaryEnumerator }, null)!;
_ilg.ConvertValue(collectionContract.GetEnumeratorMethod.ReturnType, Globals.TypeOfIDictionaryEnumerator);
_ilg.New(dictEnumCtor);
}
else if (isGenericDictionary)
{
Debug.Assert(keyValueTypes != null);
Type ctorParam = Globals.TypeOfIEnumeratorGeneric.MakeGenericType(Globals.TypeOfKeyValuePair.MakeGenericType(keyValueTypes));
ConstructorInfo dictEnumCtor = enumeratorType.GetConstructor(Globals.ScanAllMembers, new Type[] { ctorParam })!;
ConstructorInfo dictEnumCtor = enumeratorType.GetConstructor(Globals.ScanAllMembers, null, new Type[] { ctorParam }, null)!;
_ilg.ConvertValue(collectionContract.GetEnumeratorMethod.ReturnType, ctorParam);
_ilg.New(dictEnumCtor);
}
Expand Down Expand Up @@ -560,7 +560,9 @@ private bool TryWritePrimitiveArray(Type type, Type itemType, LocalBuilder value
MethodInfo writeArrayMethodInfo = typeof(JsonWriterDelegator).GetMethod(
writeArrayMethod,
Globals.ScanAllMembers,
new Type[] { type, typeof(XmlDictionaryString), typeof(XmlDictionaryString) })!;
null,
new Type[] { type, typeof(XmlDictionaryString), typeof(XmlDictionaryString) },
null)!;
_ilg.Call(_xmlWriterArg, writeArrayMethodInfo, value, itemName, null);
return true;
}
Expand Down
Loading

0 comments on commit 118eee9

Please sign in to comment.