diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/TypePreinit.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/TypePreinit.cs index d511fea16e8021..847485fdfb0fe5 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/TypePreinit.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/TypePreinit.cs @@ -60,7 +60,7 @@ private TypePreinit(MetadataType owningType, CompilationModuleGroup compilationG if (!field.IsStatic || field.IsLiteral || field.IsThreadStatic || field.HasRva) continue; - _fieldValues.Add(field, NewUninitializedLocationValue(field.FieldType)); + _fieldValues.Add(field, NewUninitializedLocationValue(field.FieldType, field)); } } @@ -195,7 +195,7 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack new ValueTypeValue(BitConverter.GetBytes(value)); } + private sealed class ComInterfaceEntryArrayValue : BaseValueTypeValue + { + private readonly FieldDesc[] _targetFields; + private readonly byte[][] _guidBytes; + private readonly MetadataType _entryType; + + public override int Size => _entryType.InstanceFieldSize.AsInt * _targetFields.Length; + + public ComInterfaceEntryArrayValue(TypeDesc type, TypeDesc entryType) + { + Debug.Assert(IsCompatible(type, out _)); + Debug.Assert(IsComInterfaceEntryType(entryType)); + Debug.Assert(((MetadataType)type).InstanceFieldSize.AsInt % ((MetadataType)entryType).InstanceFieldSize.AsInt == 0); + + _entryType = (MetadataType)entryType; + + int numFields = ((MetadataType)type).InstanceFieldSize.AsInt / _entryType.InstanceFieldSize.AsInt; + _targetFields = new FieldDesc[numFields]; + _guidBytes = new byte[numFields][]; + for (int i = 0; i < numFields; i++) + _guidBytes[i] = new byte[16]; + } + + private static bool IsComInterfaceEntryType(TypeDesc type) + => type is MetadataType mdType + && mdType.Name == "ComInterfaceEntry" + && mdType.ContainingType is MetadataType { Name: "ComWrappers", Namespace: "System.Runtime.InteropServices" } comWrappersType + && comWrappersType.Module == comWrappersType.Context.SystemModule; + + public static bool IsCompatible(TypeDesc type, out TypeDesc entryType) + { + entryType = null; + + if (!type.IsValueType + || type.HasInstantiation + || type is not MetadataType mdType + || !mdType.IsSequentialLayout + || mdType.GetClassLayout() is not { PackingSize: 0, Size: 0 } + || mdType.IsInlineArray) + { + return false; + } + + foreach (FieldDesc field in type.GetFields()) + { + if (field.IsStatic) + continue; + + entryType = field.FieldType; + + if (!IsComInterfaceEntryType(entryType)) + return false; + } + + return entryType != null; + } + + public override bool TryCompareEquality(Value value, out bool result) + { + result = false; + return false; + } + + public override void WriteFieldData(ref ObjectDataBuilder builder, NodeFactory factory) + { + for (int i = 0; i < _targetFields.Length; i++) + { + Debug.Assert(_entryType.GetField("IID").Offset.AsInt == 0); + builder.EmitBytes(_guidBytes[i]); + + Debug.Assert(_entryType.GetField("Vtable").Offset.AsInt == _guidBytes[i].Length); + if (_targetFields[i] is not FieldDesc targetField) + { + builder.EmitZeroPointer(); + } + else + { + Debug.Assert(targetField.IsStatic && !targetField.HasGCStaticBase && !targetField.IsThreadStatic && !targetField.HasRva); + ISymbolNode nonGcStaticBase = factory.TypeNonGCStaticsSymbol((MetadataType)targetField.OwningType); + builder.EmitPointerReloc(nonGcStaticBase, targetField.Offset.AsInt); + } + } + } + + public override bool GetRawData(NodeFactory factory, out object data) + { + data = null; + return false; + } + + public override bool TryCreateByRef(out Value value) + { + value = new ComInterfaceEntrySlotReference(this, 0); + return true; + } + + private sealed class ComInterfaceEntrySlotReference : ByRefValueBase, IHasInstanceFields + { + private readonly ComInterfaceEntryArrayValue _parent; + private readonly int _index; + + public ComInterfaceEntrySlotReference(ComInterfaceEntryArrayValue parent, int index) + => (_parent, _index) = (parent, index); + + public override bool TryCompareEquality(Value value, out bool result) + { + result = false; + return false; + } + + public override bool GetRawData(NodeFactory factory, out object data) + { + data = null; + return false; + } + + public override void WriteFieldData(ref ObjectDataBuilder builder, NodeFactory factory) + { + throw new NotSupportedException(); + } + + bool IHasInstanceFields.TrySetField(FieldDesc field, Value value) + { + if (field.OwningType != _parent._entryType) + return false; + + if (field.Name == "IID" + && value is ValueTypeValue guidValue + && guidValue.Size == _parent._guidBytes[_index].Length) + { + Array.Copy(guidValue.InstanceBytes, _parent._guidBytes[_index], _parent._guidBytes[_index].Length); + return true; + } + else if (field.Name == "Vtable" + && value is ByRefValueBase byrefValue + && byrefValue.BackingField != null) + { + _parent._targetFields[_index] = byrefValue.BackingField; + return true; + } + + return false; + } + + Value IHasInstanceFields.GetField(FieldDesc field) + { + // Not actually invalid, but we don't need this. + ThrowHelper.ThrowInvalidProgramException(); + return null; // unreached + } + + ByRefValueBase IHasInstanceFields.GetFieldAddress(FieldDesc field) + { + if (field.OwningType == _parent._entryType) + { + // Get address of IID or Vtable field on ComInterfaceEntry this ref points to. + // Not actually invalid, but we don't need this. + ThrowHelper.ThrowInvalidProgramException(); + } + else if (field.FieldType == _parent._entryType + && _index == 0 + && field.Offset.AsInt % _parent._entryType.InstanceFieldSize.AsInt == 0 + && field.Offset.AsInt < _parent._entryType.InstanceFieldSize.AsInt * _parent._targetFields.Length) + { + // Get address of a field within an array of ComInterfaceEntry. + int index = field.Offset.AsInt / _parent._entryType.InstanceFieldSize.AsInt; + return new ComInterfaceEntrySlotReference(_parent, index); + } + + ThrowHelper.ThrowInvalidProgramException(); + return null; // unreached + } + } + } + private sealed class VTableLikeStructValue : BaseValueTypeValue, IAssignableValue { private readonly MetadataType _type; private readonly MethodDesc[] _methods; + private readonly FieldDesc _fieldThatOwnsMemory; - public VTableLikeStructValue(MetadataType type) - : this(type, new MethodDesc[GetFieldCount(type)]) + public VTableLikeStructValue(MetadataType type, FieldDesc fieldThatOwnsMemory) + : this(type, new MethodDesc[GetFieldCount(type)], fieldThatOwnsMemory) { } - private VTableLikeStructValue(MetadataType type, MethodDesc[] methods) - => (_type, _methods) = (type, methods); + private VTableLikeStructValue(MetadataType type, MethodDesc[] methods, FieldDesc fieldThatOwnsMemory) + => (_type, _methods, _fieldThatOwnsMemory) = (type, methods, fieldThatOwnsMemory); private static int GetFieldCount(MetadataType type) { @@ -2435,7 +2615,8 @@ public static bool IsCompatible(TypeDesc type) || type.HasInstantiation || type is not MetadataType mdType || !mdType.IsSequentialLayout - || mdType.GetClassLayout() is not { PackingSize: 0, Size: 0 }) + || mdType.GetClassLayout() is not { PackingSize: 0, Size: 0 } + || mdType.IsInlineArray) { return false; } @@ -2480,13 +2661,13 @@ public override void WriteFieldData(ref ObjectDataBuilder builder, NodeFactory f public override bool TryCreateByRef(out Value value) { - value = new VTableLikeSlotReferenceValue(_methods, index: 0, _type.Context.Target.PointerSize); + value = new VTableLikeSlotReferenceValue(this, index: 0); return true; } public override Value Clone() { - return new VTableLikeStructValue(_type, (MethodDesc[])_methods.Clone()); + return new VTableLikeStructValue(_type, (MethodDesc[])_methods.Clone(), fieldThatOwnsMemory: null); } bool IAssignableValue.TryAssign(Value value) @@ -2503,12 +2684,13 @@ bool IAssignableValue.TryAssign(Value value) private sealed class VTableLikeSlotReferenceValue : ByRefValueBase, IHasInstanceFields { - private readonly MethodDesc[] _methods; + private readonly VTableLikeStructValue _parent; private readonly int _index; - private readonly int _pointerSize; - public VTableLikeSlotReferenceValue(MethodDesc[] methods, int index, int pointerSize) - => (_methods, _index, _pointerSize) = (methods, index, pointerSize); + public override FieldDesc BackingField => _index == 0 ? _parent._fieldThatOwnsMemory : null; + + public VTableLikeSlotReferenceValue(VTableLikeStructValue parent, int index) + => (_parent, _index) = (parent, index); public override bool TryCompareEquality(Value value, out bool result) { @@ -2531,13 +2713,13 @@ public override bool TryStore(Value value) { if (value is MethodPointerValue methodPointer) { - _methods[_index] = methodPointer.PointedToMethod; + _parent._methods[_index] = methodPointer.PointedToMethod; return true; } else if (value is VTableLikeStructValue otherStruct - && _methods.Length - _index >= otherStruct._methods.Length) + && _parent._methods.Length - _index >= otherStruct._methods.Length) { - Array.Copy(otherStruct._methods, 0, _methods, _index, otherStruct._methods.Length); + Array.Copy(otherStruct._methods, 0, _parent._methods, _index, otherStruct._methods.Length); return true; } @@ -2548,15 +2730,15 @@ public override bool TryLoad(TypeDesc type, out Value value) { if (!VTableLikeStructValue.IsCompatible(type) || type is not MetadataType mdType - || mdType.InstanceFieldSize.AsInt > (_methods.Length - _index) * _pointerSize) + || mdType.InstanceFieldSize.AsInt > (_parent._methods.Length - _index) * _parent._type.Context.Target.PointerSize) { value = null; return false; } MethodDesc[] slots = new MethodDesc[GetFieldCount(mdType)]; - Array.Copy(_methods, _index, slots, 0, slots.Length); - value = new VTableLikeStructValue(mdType, slots); + Array.Copy(_parent._methods, _index, slots, 0, slots.Length); + value = new VTableLikeStructValue(mdType, slots, fieldThatOwnsMemory: null); return true; } @@ -2568,10 +2750,10 @@ private int GetFieldIndex(FieldDesc field) if (!VTableLikeStructValue.IsCompatible(field.OwningType)) ThrowHelper.ThrowInvalidProgramException(); - Debug.Assert(field.Offset.AsInt % _pointerSize == 0 && field.FieldType.IsFunctionPointer); + Debug.Assert(field.Offset.AsInt % _parent._type.Context.Target.PointerSize == 0 && field.FieldType.IsFunctionPointer); - int index = (field.Offset.AsInt / _pointerSize) + _index; - if (index >= _methods.Length) + int index = (field.Offset.AsInt / _parent._type.Context.Target.PointerSize) + _index; + if (index >= _parent._methods.Length) ThrowHelper.ThrowInvalidProgramException(); return index; @@ -2582,36 +2764,36 @@ bool IHasInstanceFields.TrySetField(FieldDesc field, Value value) if (value is not MethodPointerValue methodPtr) return false; - _methods[GetFieldIndex(field)] = methodPtr.PointedToMethod; + _parent._methods[GetFieldIndex(field)] = methodPtr.PointedToMethod; return true; } Value IHasInstanceFields.GetField(FieldDesc field) { - MethodDesc method = _methods[GetFieldIndex(field)]; + MethodDesc method = _parent._methods[GetFieldIndex(field)]; if (method is not null) return new MethodPointerValue(method); else - return _pointerSize == 8 ? ValueTypeValue.FromInt64(0) : ValueTypeValue.FromInt32(0); + return _parent._type.Context.Target.PointerSize == 8 ? ValueTypeValue.FromInt64(0) : ValueTypeValue.FromInt32(0); } ByRefValueBase IHasInstanceFields.GetFieldAddress(FieldDesc field) { - return new VTableLikeSlotReferenceValue(_methods, GetFieldIndex(field), _pointerSize); + return new VTableLikeSlotReferenceValue(_parent, GetFieldIndex(field)); } public override bool TryInitialize(int size) { - if (size % _pointerSize != 0) + if (size % _parent._type.Context.Target.PointerSize != 0) return false; - int numSlots = size / _pointerSize; - if (_index + numSlots > _methods.Length) + int numSlots = size / _parent._type.Context.Target.PointerSize; + if (_index + numSlots > _parent._methods.Length) return false; for (int i = _index; i < numSlots; i++) - _methods[i] = null; + _parent._methods[i] = null; return true; } @@ -2889,6 +3071,8 @@ public virtual bool TryLoad(TypeDesc type, out Value value) return false; } public virtual bool TryInitialize(int size) => false; + + public virtual FieldDesc BackingField => null; } private sealed class ByRefValue : ByRefValueBase, IHasInstanceFields diff --git a/src/tests/nativeaot/SmokeTests/Preinitialization/Preinitialization.cs b/src/tests/nativeaot/SmokeTests/Preinitialization/Preinitialization.cs index 562c78d8aea586..8bf32c97f6a347 100644 --- a/src/tests/nativeaot/SmokeTests/Preinitialization/Preinitialization.cs +++ b/src/tests/nativeaot/SmokeTests/Preinitialization/Preinitialization.cs @@ -67,6 +67,7 @@ private static int Main() TestVTableManipulation.Run(); TestVTableNegativeScenarios.Run(); TestByRefFieldAddressEquality.Run(); + TestComInterfaceEntry.Run(); #else Console.WriteLine("Preinitialization is disabled in multimodule builds for now. Skipping test."); #endif @@ -2033,6 +2034,79 @@ public static void Run() } } +unsafe class TestComInterfaceEntry +{ + struct MyVTableEntries + { + public ComWrappers.ComInterfaceEntry TinyImpl; + public ComWrappers.ComInterfaceEntry SmallImpl; + } + + class VtableEntries + { + [FixedAddressValueType] + public static MyVTableEntries Entries; + + static VtableEntries() + { + Entries.TinyImpl.IID = new Guid(0x1234, 0x4567, 0x789A, 0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89); + Entries.TinyImpl.Vtable = ITinyVtableImpl.VftablePtr; + Entries.SmallImpl.IID = new Guid(0x4321, 0x7654, 0xA987, 0x21, 0x32, 0x43, 0x54, 0x65, 0x76, 0x87, 0x98); + Entries.SmallImpl.Vtable = ISmallVtableImpl.VftablePtr; + } + } + + class ITinyVtableImpl + { + [FixedAddressValueType] + private static readonly ITinyVtable Vtbl; + + public static nint VftablePtr => (nint)Unsafe.AsPointer(ref Unsafe.AsRef(in Vtbl)); + + static ITinyVtableImpl() + { + Vtbl.Method = &Method; + } + } + + class ISmallVtableImpl + { + [FixedAddressValueType] + private static readonly ISmallVtable Vtbl; + + public static nint VftablePtr => (nint)Unsafe.AsPointer(ref Unsafe.AsRef(in Vtbl)); + + static ISmallVtableImpl() + { + Vtbl.Method1 = &Method; + Vtbl.Method2 = &OtherMethod; + } + } + + public unsafe struct ITinyVtable + { + public delegate* Method; + } + + public unsafe struct ISmallVtable + { + public delegate* Method1; + public delegate* Method2; + } + + static void Method() { } + static void OtherMethod() { } + + public static void Run() + { + Assert.IsPreinitialized(typeof(VtableEntries)); + Assert.AreEqual(ITinyVtableImpl.VftablePtr, VtableEntries.Entries.TinyImpl.Vtable); + Assert.AreEqual(new Guid(0x1234, 0x4567, 0x789A, 0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89), VtableEntries.Entries.TinyImpl.IID); + Assert.AreEqual(ISmallVtableImpl.VftablePtr, VtableEntries.Entries.SmallImpl.Vtable); + Assert.AreEqual(new Guid(0x4321, 0x7654, 0xA987, 0x21, 0x32, 0x43, 0x54, 0x65, 0x76, 0x87, 0x98), VtableEntries.Entries.SmallImpl.IID); + } +} + static class Assert { [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070:UnrecognizedReflectionPattern",