Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions TUnit.Assertions.Tests/WeakReferenceAssertionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ public async Task Test_WeakReference_IsAlive()
var target = new object();
var weakRef = new WeakReference(target);
await Assert.That(weakRef).IsAlive();
GC.KeepAlive(target);
}

[Test]
Expand All @@ -28,6 +29,7 @@ public async Task Test_WeakReference_DoesNotTrackResurrection()
var target = new object();
var weakRef = new WeakReference(target, trackResurrection: false);
await Assert.That(weakRef).DoesNotTrackResurrection();
GC.KeepAlive(target);
}

[Test]
Expand All @@ -36,6 +38,7 @@ public async Task Test_WeakReference_TrackResurrection()
var target = new object();
var weakRef = new WeakReference(target, trackResurrection: true);
await Assert.That(weakRef).TrackResurrection();
GC.KeepAlive(target);
}

private static WeakReference CreateWeakReferenceToCollectedObject()
Expand Down
86 changes: 77 additions & 9 deletions TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3919,29 +3919,32 @@ private static bool ValidateClassTypeConstraints(INamedTypeSymbol classSymbol, I
// Check specific type constraints
foreach (var constraintType in typeParam.ConstraintTypes)
{
// Substitute type parameters in the constraint type with the actual type arguments
var substitutedConstraint = SubstituteTypeParameters(constraintType, typeParams, typeArguments);

// For interface constraints, check if the type implements the interface
if (constraintType.TypeKind == TypeKind.Interface)
if (substitutedConstraint.TypeKind == TypeKind.Interface)
{
if (!typeArg.AllInterfaces.Any(i => SymbolEqualityComparer.Default.Equals(i, constraintType)))
if (!TypeImplementsInterface(typeArg, substitutedConstraint))
{
return false;
}
}
// For base class constraints, check if the type derives from the class
else if (constraintType.TypeKind == TypeKind.Class)
else if (substitutedConstraint.TypeKind == TypeKind.Class)
{
var baseType = typeArg.BaseType;
var found = false;
while (baseType != null)
{
if (SymbolEqualityComparer.Default.Equals(baseType, constraintType))
if (SymbolEqualityComparer.Default.Equals(baseType, substitutedConstraint))
{
found = true;
break;
}
baseType = baseType.BaseType;
}
if (!found && !SymbolEqualityComparer.Default.Equals(typeArg, constraintType))
if (!found && !SymbolEqualityComparer.Default.Equals(typeArg, substitutedConstraint))
{
return false;
}
Expand All @@ -3952,6 +3955,67 @@ private static bool ValidateClassTypeConstraints(INamedTypeSymbol classSymbol, I
return true;
}

private static ITypeSymbol SubstituteTypeParameters(ITypeSymbol type, ImmutableArray<ITypeParameterSymbol> typeParams, ITypeSymbol[] typeArguments)
{
// If the type is a type parameter, substitute it with the corresponding type argument
if (type is ITypeParameterSymbol typeParam)
{
for (var i = 0; i < typeParams.Length; i++)
{
if (SymbolEqualityComparer.Default.Equals(typeParams[i], typeParam))
{
return typeArguments[i];
}
}
return type;
}

// If the type is a named type with type arguments (e.g., IComparable<T>), substitute recursively
if (type is INamedTypeSymbol { IsGenericType: true } namedType)
{
var originalTypeArgs = namedType.TypeArguments;
var newTypeArgs = new ITypeSymbol[originalTypeArgs.Length];
var anyChanged = false;

for (var i = 0; i < originalTypeArgs.Length; i++)
{
newTypeArgs[i] = SubstituteTypeParameters(originalTypeArgs[i], typeParams, typeArguments);
if (!SymbolEqualityComparer.Default.Equals(newTypeArgs[i], originalTypeArgs[i]))
{
anyChanged = true;
}
}

if (anyChanged)
{
return namedType.OriginalDefinition.Construct(newTypeArgs);
}
}

return type;
}

private static bool TypeImplementsInterface(ITypeSymbol type, ITypeSymbol interfaceType)
{
// Check if the type directly implements the interface
if (type.AllInterfaces.Any(i => SymbolEqualityComparer.Default.Equals(i, interfaceType)))
{
return true;
}

// For generic interfaces, also check if the type implements a constructed version
if (interfaceType is INamedTypeSymbol { IsGenericType: true } genericInterface)
{
var originalDef = genericInterface.OriginalDefinition;
return type.AllInterfaces.Any(i =>
i.IsGenericType &&
SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, originalDef) &&
((IEnumerable<ITypeSymbol>)i.TypeArguments).SequenceEqual(genericInterface.TypeArguments, SymbolEqualityComparer.Default));
}

return false;
}

private static ITypeSymbol[]? InferClassTypesFromMethodArguments(INamedTypeSymbol classSymbol, IMethodSymbol methodSymbol, AttributeData argAttr, Compilation compilation)
{
if (argAttr.ConstructorArguments.Length == 0)
Expand Down Expand Up @@ -4710,6 +4774,7 @@ private static bool ValidateTypeConstraints(IMethodSymbol method, ITypeSymbol[]
private static bool ValidateTypeParameterConstraints(IEnumerable<ITypeParameterSymbol> typeParams, ITypeSymbol[] typeArguments)
{
var typeParamsList = typeParams.ToList();
var typeParamsArray = typeParamsList.ToImmutableArray();

for (var i = 0; i < typeParamsList.Count; i++)
{
Expand Down Expand Up @@ -4737,22 +4802,25 @@ private static bool ValidateTypeParameterConstraints(IEnumerable<ITypeParameterS
// Check interface constraints
foreach (var constraintType in typeParam.ConstraintTypes)
{
if (constraintType.TypeKind == TypeKind.Interface)
// Substitute type parameters in the constraint type with the actual type arguments
var substitutedConstraint = SubstituteTypeParameters(constraintType, typeParamsArray, typeArguments);

if (substitutedConstraint.TypeKind == TypeKind.Interface)
{
// Check if the type argument implements the interface
if (!typeArg.AllInterfaces.Any(i => SymbolEqualityComparer.Default.Equals(i, constraintType)))
if (!TypeImplementsInterface(typeArg, substitutedConstraint))
{
return false;
}
}
else if (constraintType.TypeKind == TypeKind.Class)
else if (substitutedConstraint.TypeKind == TypeKind.Class)
{
// Check if the type argument derives from the base class
var baseType = typeArg.BaseType;
var found = false;
while (baseType != null)
{
if (SymbolEqualityComparer.Default.Equals(baseType, constraintType))
if (SymbolEqualityComparer.Default.Equals(baseType, substitutedConstraint))
{
found = true;
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ public sealed class CombinedDataSourcesAttribute : AsyncUntypedDataSourceGenerat
Type = dataGeneratorMetadata.Type,
TestSessionId = dataGeneratorMetadata.TestSessionId,
TestClassInstance = dataGeneratorMetadata.TestClassInstance,
ClassInstanceArguments = dataGeneratorMetadata.ClassInstanceArguments
ClassInstanceArguments = dataGeneratorMetadata.ClassInstanceArguments,
InstanceFactory = dataGeneratorMetadata.InstanceFactory
};

// Get data rows from this data source (need to await async enumerable)
Expand Down
60 changes: 38 additions & 22 deletions TUnit.Core/Attributes/TestData/MethodDataSourceAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ public MethodDataSourceAttribute(

// If the target type is abstract or interface, we can't create an instance of it.
// Fall back to the test class type which should be concrete.
if (targetType != null && (targetType.IsAbstract || targetType.IsInterface))
// BUT: Don't override if ClassProvidingDataSource was explicitly provided, even if it's a static class
// (static classes are abstract in IL but contain static members we can invoke)
if (ClassProvidingDataSource == null && targetType != null && (targetType.IsAbstract || targetType.IsInterface))
{
var testClassType = TestClassTypeHelper.GetTestClassType(dataGeneratorMetadata);
if (testClassType != null && !testClassType.IsAbstract && !testClassType.IsInterface)
Expand All @@ -119,13 +121,7 @@ public MethodDataSourceAttribute(
object? instance = null;
if (!methodInfo.IsStatic)
{
// Skip PlaceholderInstance as it's a sentinel value, not a real instance
var testClassInstance = dataGeneratorMetadata.TestClassInstance;
if (testClassInstance is PlaceholderInstance)
{
testClassInstance = null;
}
instance = testClassInstance ?? Activator.CreateInstance(targetType);
instance = await GetOrCreateInstanceAsync(dataGeneratorMetadata, targetType);
}

methodResult = methodInfo.Invoke(instance, Arguments);
Expand All @@ -142,13 +138,7 @@ public MethodDataSourceAttribute(
object? instance = null;
if (propertyInfo.GetMethod?.IsStatic != true)
{
// Skip PlaceholderInstance as it's a sentinel value, not a real instance
var testClassInstance = dataGeneratorMetadata.TestClassInstance;
if (testClassInstance is PlaceholderInstance)
{
testClassInstance = null;
}
instance = testClassInstance ?? Activator.CreateInstance(targetType);
instance = await GetOrCreateInstanceAsync(dataGeneratorMetadata, targetType);
}

methodResult = propertyInfo.GetValue(instance);
Expand All @@ -159,13 +149,7 @@ public MethodDataSourceAttribute(
object? instance = null;
if (!fieldInfo.IsStatic)
{
// Skip PlaceholderInstance as it's a sentinel value, not a real instance
var testClassInstance = dataGeneratorMetadata.TestClassInstance;
if (testClassInstance is PlaceholderInstance)
{
testClassInstance = null;
}
instance = testClassInstance ?? Activator.CreateInstance(targetType);
instance = await GetOrCreateInstanceAsync(dataGeneratorMetadata, targetType);
}

methodResult = fieldInfo.GetValue(instance);
Expand Down Expand Up @@ -384,4 +368,36 @@ private static bool IsAsyncEnumerable([DynamicallyAccessedMembers(DynamicallyAcc

return null;
}

/// <summary>
/// Gets an existing test class instance or creates a new one.
/// Uses InstanceFactory if available (which can perform property injection),
/// otherwise falls back to Activator.CreateInstance.
/// </summary>
[UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Reflection usage is documented. AOT-safe path available via Factory property")]
[UnconditionalSuppressMessage("Trimming", "IL2067", Justification = "Reflection usage is documented. AOT-safe path available via Factory property")]
[UnconditionalSuppressMessage("AOT", "IL3050", Justification = "Dynamic code usage is documented. AOT-safe path available via Factory property")]
private static async Task<object?> GetOrCreateInstanceAsync(DataGeneratorMetadata metadata, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] Type targetType)
{
// First check if we have a valid test class instance
var testClassInstance = metadata.TestClassInstance;
if (testClassInstance is PlaceholderInstance)
{
testClassInstance = null;
}

if (testClassInstance != null)
{
return testClassInstance;
}

// Try to use the InstanceFactory if available (which can perform property injection)
if (metadata.InstanceFactory != null)
{
return await metadata.InstanceFactory(targetType);
}

// Fall back to creating a bare instance
return Activator.CreateInstance(targetType);
}
}
16 changes: 12 additions & 4 deletions TUnit.Core/DataGeneratorMetadataCreator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ public static DataGeneratorMetadata CreateDataGeneratorMetadata(
DataGeneratorType generatorType,
object? testClassInstance,
object?[]? classInstanceArguments,
TestBuilderContextAccessor contextAccessor)
TestBuilderContextAccessor contextAccessor,
Func<Type, Task<object?>>? instanceFactory = null)
{
// Determine which parameters we're generating for
var parametersToGenerate = generatorType == DataGeneratorType.ClassParameters
Expand Down Expand Up @@ -69,7 +70,8 @@ public static DataGeneratorMetadata CreateDataGeneratorMetadata(
Type = generatorType,
TestSessionId = testSessionId,
TestClassInstance = testClassInstance,
ClassInstanceArguments = classInstanceArguments
ClassInstanceArguments = classInstanceArguments,
InstanceFactory = instanceFactory
};
}

Expand Down Expand Up @@ -112,9 +114,14 @@ public static DataGeneratorMetadata CreateForReflectionDiscovery(
/// Creates minimal DataGeneratorMetadata for discovery phase when inferring generic types.
/// This is used when we need to get data from sources to determine generic type arguments.
/// </summary>
/// <param name="dataSource">The data source attribute.</param>
/// <param name="existingMethodMetadata">Optional method metadata if available.</param>
/// <param name="instanceFactory">Optional factory for creating instances with property injection.
/// Used in reflection mode when instance data sources depend on property injection.</param>
public static DataGeneratorMetadata CreateForGenericTypeDiscovery(
IDataSourceAttribute dataSource,
MethodMetadata? existingMethodMetadata = null)
MethodMetadata? existingMethodMetadata = null,
Func<Type, Task<object?>>? instanceFactory = null)
{
var dummyParameter = new ParameterMetadata(typeof(object))
{
Expand Down Expand Up @@ -159,7 +166,8 @@ public static DataGeneratorMetadata CreateForGenericTypeDiscovery(
Type = DataGeneratorType.ClassParameters,
TestSessionId = "discovery",
TestClassInstance = null,
ClassInstanceArguments = null
ClassInstanceArguments = null,
InstanceFactory = instanceFactory
};
}

Expand Down
7 changes: 7 additions & 0 deletions TUnit.Core/Models/DataGeneratorMetadata.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,11 @@ public record DataGeneratorMetadata
public required string TestSessionId { get; init; }
public required object? TestClassInstance { get; init; }
public required object?[]? ClassInstanceArguments { get; init; }

/// <summary>
/// Optional factory for creating and initializing test class instances.
/// Used in reflection mode when instance data sources depend on property injection.
/// The factory should create an instance and perform property injection before returning it.
/// </summary>
public Func<Type, Task<object?>>? InstanceFactory { get; init; }
}
Loading
Loading