diff --git a/TUnit.Core.SourceGenerator.Tests/GenericMethodWithDataSourceTests.Generic_Method_With_MethodDataSource_Should_Generate_Tests.verified.txt b/TUnit.Core.SourceGenerator.Tests/GenericMethodWithDataSourceTests.Generic_Method_With_MethodDataSource_Should_Generate_Tests.verified.txt index d41e5d35c6..47d4802429 100644 --- a/TUnit.Core.SourceGenerator.Tests/GenericMethodWithDataSourceTests.Generic_Method_With_MethodDataSource_Should_Generate_Tests.verified.txt +++ b/TUnit.Core.SourceGenerator.Tests/GenericMethodWithDataSourceTests.Generic_Method_With_MethodDataSource_Should_Generate_Tests.verified.txt @@ -1,4 +1,4 @@ -// +// #pragma warning disable #nullable enable @@ -2262,9 +2262,19 @@ internal sealed class TUnit_TestProject_Bugs__4431_GenericClassWithClassDataSour { PropertyName = "DataSource", PropertyType = typeof(global::TUnit.TestProject.Bugs._4431.TestDataSource), - Setter = (instance, value) => throw new global::System.NotSupportedException( - "Init-only property 'DataSource' on generic type 'global::TUnit.TestProject.Bugs._4431.GenericClassWithClassDataSource<>' cannot be set. " + - "Use a regular settable property or constructor injection instead."), + Setter = (instance, value) => + { + var backingField = typeof(global::TUnit.TestProject.Bugs._4431.GenericClassWithClassDataSource<>).GetField("k__BackingField", + global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.NonPublic); + if (backingField != null) + { + backingField.SetValue(instance, value); + } + else + { + throw new global::System.InvalidOperationException("Could not find backing field for property DataSource on type global::TUnit.TestProject.Bugs._4431.GenericClassWithClassDataSource<>"); + } + }, ValueFactory = () => throw new global::System.InvalidOperationException("ValueFactory should be provided by TestDataCombination"), NestedPropertyInjections = global::System.Array.Empty(), NestedPropertyValueFactory = obj => @@ -2354,9 +2364,19 @@ internal sealed class TUnit_TestProject_Bugs__4431_GenericClassWithClassDataSour { PropertyName = "DataSource", PropertyType = typeof(global::TUnit.TestProject.Bugs._4431.TestDataSource), - Setter = (instance, value) => throw new global::System.NotSupportedException( - "Init-only property 'DataSource' on generic type 'global::TUnit.TestProject.Bugs._4431.GenericClassWithClassDataSource' cannot be set. " + - "Use a regular settable property or constructor injection instead."), + Setter = (instance, value) => + { + var backingField = typeof(global::TUnit.TestProject.Bugs._4431.GenericClassWithClassDataSource).GetField("k__BackingField", + global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.NonPublic); + if (backingField != null) + { + backingField.SetValue(instance, value); + } + else + { + throw new global::System.InvalidOperationException("Could not find backing field for property DataSource on type global::TUnit.TestProject.Bugs._4431.GenericClassWithClassDataSource"); + } + }, ValueFactory = () => throw new global::System.InvalidOperationException("ValueFactory should be provided by TestDataCombination"), NestedPropertyInjections = global::System.Array.Empty(), NestedPropertyValueFactory = obj => @@ -2453,9 +2473,19 @@ internal sealed class TUnit_TestProject_Bugs__4431_GenericClassWithClassDataSour { PropertyName = "DataSource", PropertyType = typeof(global::TUnit.TestProject.Bugs._4431.TestDataSource), - Setter = (instance, value) => throw new global::System.NotSupportedException( - "Init-only property 'DataSource' on generic type 'global::TUnit.TestProject.Bugs._4431.GenericClassWithClassDataSource' cannot be set. " + - "Use a regular settable property or constructor injection instead."), + Setter = (instance, value) => + { + var backingField = typeof(global::TUnit.TestProject.Bugs._4431.GenericClassWithClassDataSource).GetField("k__BackingField", + global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.NonPublic); + if (backingField != null) + { + backingField.SetValue(instance, value); + } + else + { + throw new global::System.InvalidOperationException("Could not find backing field for property DataSource on type global::TUnit.TestProject.Bugs._4431.GenericClassWithClassDataSource"); + } + }, ValueFactory = () => throw new global::System.InvalidOperationException("ValueFactory should be provided by TestDataCombination"), NestedPropertyInjections = global::System.Array.Empty(), NestedPropertyValueFactory = obj => @@ -2600,9 +2630,19 @@ internal sealed class TUnit_TestProject_Bugs__4431_GenericClassGenericMethodWith { PropertyName = "DataSource", PropertyType = typeof(global::TUnit.TestProject.Bugs._4431.TestDataSource), - Setter = (instance, value) => throw new global::System.NotSupportedException( - "Init-only property 'DataSource' on generic type 'global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources<>' cannot be set. " + - "Use a regular settable property or constructor injection instead."), + Setter = (instance, value) => + { + var backingField = typeof(global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources<>).GetField("k__BackingField", + global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.NonPublic); + if (backingField != null) + { + backingField.SetValue(instance, value); + } + else + { + throw new global::System.InvalidOperationException("Could not find backing field for property DataSource on type global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources<>"); + } + }, ValueFactory = () => throw new global::System.InvalidOperationException("ValueFactory should be provided by TestDataCombination"), NestedPropertyInjections = global::System.Array.Empty(), NestedPropertyValueFactory = obj => @@ -2729,9 +2769,19 @@ internal sealed class TUnit_TestProject_Bugs__4431_GenericClassGenericMethodWith { PropertyName = "DataSource", PropertyType = typeof(global::TUnit.TestProject.Bugs._4431.TestDataSource), - Setter = (instance, value) => throw new global::System.NotSupportedException( - "Init-only property 'DataSource' on generic type 'global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources' cannot be set. " + - "Use a regular settable property or constructor injection instead."), + Setter = (instance, value) => + { + var backingField = typeof(global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources).GetField("k__BackingField", + global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.NonPublic); + if (backingField != null) + { + backingField.SetValue(instance, value); + } + else + { + throw new global::System.InvalidOperationException("Could not find backing field for property DataSource on type global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources"); + } + }, ValueFactory = () => throw new global::System.InvalidOperationException("ValueFactory should be provided by TestDataCombination"), NestedPropertyInjections = global::System.Array.Empty(), NestedPropertyValueFactory = obj => @@ -2865,9 +2915,19 @@ internal sealed class TUnit_TestProject_Bugs__4431_GenericClassGenericMethodWith { PropertyName = "DataSource", PropertyType = typeof(global::TUnit.TestProject.Bugs._4431.TestDataSource), - Setter = (instance, value) => throw new global::System.NotSupportedException( - "Init-only property 'DataSource' on generic type 'global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources' cannot be set. " + - "Use a regular settable property or constructor injection instead."), + Setter = (instance, value) => + { + var backingField = typeof(global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources).GetField("k__BackingField", + global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.NonPublic); + if (backingField != null) + { + backingField.SetValue(instance, value); + } + else + { + throw new global::System.InvalidOperationException("Could not find backing field for property DataSource on type global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources"); + } + }, ValueFactory = () => throw new global::System.InvalidOperationException("ValueFactory should be provided by TestDataCombination"), NestedPropertyInjections = global::System.Array.Empty(), NestedPropertyValueFactory = obj => @@ -3001,9 +3061,19 @@ internal sealed class TUnit_TestProject_Bugs__4431_GenericClassGenericMethodWith { PropertyName = "DataSource", PropertyType = typeof(global::TUnit.TestProject.Bugs._4431.TestDataSource), - Setter = (instance, value) => throw new global::System.NotSupportedException( - "Init-only property 'DataSource' on generic type 'global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources' cannot be set. " + - "Use a regular settable property or constructor injection instead."), + Setter = (instance, value) => + { + var backingField = typeof(global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources).GetField("k__BackingField", + global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.NonPublic); + if (backingField != null) + { + backingField.SetValue(instance, value); + } + else + { + throw new global::System.InvalidOperationException("Could not find backing field for property DataSource on type global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources"); + } + }, ValueFactory = () => throw new global::System.InvalidOperationException("ValueFactory should be provided by TestDataCombination"), NestedPropertyInjections = global::System.Array.Empty(), NestedPropertyValueFactory = obj => @@ -3137,9 +3207,19 @@ internal sealed class TUnit_TestProject_Bugs__4431_GenericClassGenericMethodWith { PropertyName = "DataSource", PropertyType = typeof(global::TUnit.TestProject.Bugs._4431.TestDataSource), - Setter = (instance, value) => throw new global::System.NotSupportedException( - "Init-only property 'DataSource' on generic type 'global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources' cannot be set. " + - "Use a regular settable property or constructor injection instead."), + Setter = (instance, value) => + { + var backingField = typeof(global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources).GetField("k__BackingField", + global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.NonPublic); + if (backingField != null) + { + backingField.SetValue(instance, value); + } + else + { + throw new global::System.InvalidOperationException("Could not find backing field for property DataSource on type global::TUnit.TestProject.Bugs._4431.GenericClassGenericMethodWithDataSources"); + } + }, ValueFactory = () => throw new global::System.InvalidOperationException("ValueFactory should be provided by TestDataCombination"), NestedPropertyInjections = global::System.Array.Empty(), NestedPropertyValueFactory = obj => diff --git a/TUnit.Core.SourceGenerator/Generators/PropertyInjectionSourceGenerator.cs b/TUnit.Core.SourceGenerator/Generators/PropertyInjectionSourceGenerator.cs index 47200a9693..9bc3728075 100644 --- a/TUnit.Core.SourceGenerator/Generators/PropertyInjectionSourceGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/PropertyInjectionSourceGenerator.cs @@ -167,7 +167,7 @@ private sealed record PropertyWithClass( Property: propertyModel); } - private static PropertyDataSourceModel ExtractPropertyModel(IPropertySymbol property, AttributeData attribute) + private static PropertyDataSourceModel ExtractPropertyModel(IPropertySymbol property, AttributeData attribute, INamedTypeSymbol? containingTypeOverride = null) { var propertyType = property.Type; var isNullableValueType = propertyType is INamedTypeSymbol @@ -176,6 +176,14 @@ private static PropertyDataSourceModel ExtractPropertyModel(IPropertySymbol prop ConstructedFrom.SpecialType: SpecialType.System_Nullable_T }; + // Check if the original property type is a type parameter (e.g., T Provider { get; }) + // We need to use the type parameter name in UnsafeAccessor for generic types + string? propertyTypeAsTypeParameter = null; + if (property.OriginalDefinition.Type is ITypeParameterSymbol typeParam) + { + propertyTypeAsTypeParameter = typeParam.Name; + } + // Format constructor arguments var ctorArgs = attribute.ConstructorArguments .Select(FormatTypedConstant) @@ -190,14 +198,39 @@ private static PropertyDataSourceModel ExtractPropertyModel(IPropertySymbol prop }) .ToArray(); + // Use the override if provided (for closed generic types), otherwise use the declaring type + var containingType = containingTypeOverride ?? property.ContainingType; + + // For generic types, extract the open generic type definition and type parameters + string? openGenericType = null; + string? typeParameters = null; + string? typeArguments = null; + string? typeConstraints = null; + + if (containingType.IsGenericType) + { + var originalDefinition = containingType.OriginalDefinition; + openGenericType = originalDefinition.ToDisplayString(); + typeParameters = string.Join(", ", originalDefinition.TypeParameters.Select(tp => tp.Name)); + typeArguments = string.Join(", ", containingType.TypeArguments.Select(ta => ta.ToDisplayString())); + typeConstraints = GetTypeParameterConstraints(originalDefinition.TypeParameters); + } + return new PropertyDataSourceModel { PropertyName = property.Name, PropertyTypeFullyQualified = GetNonNullableTypeName(propertyType), PropertyTypeForTypeof = GetNonNullableTypeString(propertyType), - ContainingTypeFullyQualified = property.ContainingType.ToDisplayString(), + ContainingTypeFullyQualified = containingType.ToDisplayString(), + ContainingTypeClrName = GetClrTypeName(containingType), + ContainingTypeOpenGeneric = openGenericType, + GenericTypeParameters = typeParameters, + GenericTypeArguments = typeArguments, + GenericTypeConstraints = typeConstraints, IsInitOnly = property.SetMethod?.IsInitOnly == true, + IsContainingTypeGeneric = containingType.IsGenericType, IsStatic = property.IsStatic, + PropertyTypeAsTypeParameter = propertyTypeAsTypeParameter, IsValueType = propertyType.IsValueType, IsNullableValueType = isNullableValueType, AttributeTypeName = attribute.AttributeClass!.ToDisplayString(), @@ -476,7 +509,8 @@ private static bool IsConcreteGenericType(INamedTypeSymbol type) if (attr.AttributeClass != null && attr.AttributeClass.AllInterfaces.Contains(dataSourceInterface, SymbolEqualityComparer.Default)) { - dataSourceProperties.Add(ExtractPropertyModel(property, attr)); + // Pass currentType as the containing type override for closed generic types + dataSourceProperties.Add(ExtractPropertyModel(property, attr, currentType)); break; } } @@ -567,12 +601,14 @@ private static void GeneratePropertyInjectionSource(SourceProductionContext cont sb.AppendLine(" public bool ShouldInitialize => true;"); sb.AppendLine(); - // Generate UnsafeAccessor methods for init-only properties + // Generate UnsafeAccessor methods for init-only properties on non-generic types foreach (var prop in model.Properties) { - if (prop.IsInitOnly) + if (prop.IsInitOnly && !prop.IsContainingTypeGeneric) { var backingFieldName = $"<{prop.PropertyName}>k__BackingField"; + + // For non-generic types: use regular UnsafeAccessor on .NET 8+ sb.AppendLine("#if NET8_0_OR_GREATER"); sb.AppendLine($" [global::System.Runtime.CompilerServices.UnsafeAccessor(global::System.Runtime.CompilerServices.UnsafeAccessorKind.Field, Name = \"{backingFieldName}\")]"); sb.AppendLine($" private static extern ref {prop.PropertyTypeFullyQualified} Get{prop.PropertyName}BackingField({prop.ContainingTypeFullyQualified} instance);"); @@ -587,16 +623,39 @@ private static void GeneratePropertyInjectionSource(SourceProductionContext cont foreach (var prop in model.Properties) { - GeneratePropertyMetadata(sb, prop, model.ClassFullyQualifiedName); + GeneratePropertyMetadata(sb, prop, model.ClassFullyQualifiedName, model.SafeClassName); } sb.AppendLine(" }"); sb.AppendLine("}"); + // Generate generic accessor classes for init-only properties on generic types + // These must be outside the property source class and be generic themselves + foreach (var prop in model.Properties) + { + if (prop.IsInitOnly && prop.IsContainingTypeGeneric && prop.GenericTypeParameters != null && prop.ContainingTypeOpenGeneric != null) + { + var backingFieldName = $"<{prop.PropertyName}>k__BackingField"; + var accessorClassName = $"{model.SafeClassName}_{prop.PropertyName}_GenericAccessor"; + var constraintsClause = prop.GenericTypeConstraints != null ? $" {prop.GenericTypeConstraints}" : ""; + // Use type parameter name if property type is a type parameter (e.g., T), otherwise use concrete type + var returnType = prop.PropertyTypeAsTypeParameter ?? prop.PropertyTypeFullyQualified; + + sb.AppendLine(); + sb.AppendLine("#if NET9_0_OR_GREATER"); + sb.AppendLine($"internal static class {accessorClassName}<{prop.GenericTypeParameters}>{constraintsClause}"); + sb.AppendLine("{"); + sb.AppendLine($" [global::System.Runtime.CompilerServices.UnsafeAccessor(global::System.Runtime.CompilerServices.UnsafeAccessorKind.Field, Name = \"{backingFieldName}\")]"); + sb.AppendLine($" public static extern ref {returnType} GetBackingField({prop.ContainingTypeOpenGeneric} instance);"); + sb.AppendLine("}"); + sb.AppendLine("#endif"); + } + } + context.AddSource(fileName, sb.ToString()); } - private static void GeneratePropertyMetadata(StringBuilder sb, PropertyDataSourceModel prop, string classTypeName) + private static void GeneratePropertyMetadata(StringBuilder sb, PropertyDataSourceModel prop, string classTypeName, string safeClassName) { var ctorArgsStr = string.Join(", ", prop.ConstructorArgs); @@ -628,20 +687,37 @@ private static void GeneratePropertyMetadata(StringBuilder sb, PropertyDataSourc if (prop.IsInitOnly) { - sb.AppendLine("#if NET8_0_OR_GREATER"); - if (prop.ContainingTypeFullyQualified != classTypeName) + if (prop.IsContainingTypeGeneric && prop.GenericTypeArguments != null) { - sb.AppendLine($" Get{prop.PropertyName}BackingField(({prop.ContainingTypeFullyQualified})typedInstance) = {castExpression};"); + // For generic types: .NET 9+ uses generic accessor class, older versions use reflection + var accessorClassName = $"{safeClassName}_{prop.PropertyName}_GenericAccessor"; + + sb.AppendLine("#if NET9_0_OR_GREATER"); + sb.AppendLine($" {accessorClassName}<{prop.GenericTypeArguments}>.GetBackingField(typedInstance) = {castExpression};"); + sb.AppendLine("#else"); + sb.AppendLine($" var backingField = typeof({prop.ContainingTypeFullyQualified}).GetField(\"<{prop.PropertyName}>k__BackingField\","); + sb.AppendLine(" global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.NonPublic);"); + sb.AppendLine(" backingField.SetValue(typedInstance, value);"); + sb.AppendLine("#endif"); } else { - sb.AppendLine($" Get{prop.PropertyName}BackingField(typedInstance) = {castExpression};"); + // For non-generic types: .NET 8+ uses UnsafeAccessor + sb.AppendLine("#if NET8_0_OR_GREATER"); + if (prop.ContainingTypeFullyQualified != classTypeName) + { + sb.AppendLine($" Get{prop.PropertyName}BackingField(({prop.ContainingTypeFullyQualified})typedInstance) = {castExpression};"); + } + else + { + sb.AppendLine($" Get{prop.PropertyName}BackingField(typedInstance) = {castExpression};"); + } + sb.AppendLine("#else"); + sb.AppendLine($" var backingField = typeof({prop.ContainingTypeFullyQualified}).GetField(\"<{prop.PropertyName}>k__BackingField\","); + sb.AppendLine(" global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.NonPublic);"); + sb.AppendLine(" backingField.SetValue(typedInstance, value);"); + sb.AppendLine("#endif"); } - sb.AppendLine("#else"); - sb.AppendLine($" var backingField = typeof({prop.ContainingTypeFullyQualified}).GetField(\"<{prop.PropertyName}>k__BackingField\","); - sb.AppendLine(" global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.NonPublic);"); - sb.AppendLine(" backingField.SetValue(typedInstance, value);"); - sb.AppendLine("#endif"); } else if (prop.IsStatic) { @@ -723,12 +799,14 @@ private static void GenerateGenericPropertyInjectionSource(SourceProductionConte sb.AppendLine(" public bool ShouldInitialize => true;"); sb.AppendLine(); - // Generate UnsafeAccessor methods for init-only properties + // Generate UnsafeAccessor methods for init-only properties on non-generic types foreach (var prop in model.DataSourceProperties) { - if (prop.IsInitOnly) + if (prop.IsInitOnly && !prop.IsContainingTypeGeneric) { var backingFieldName = $"<{prop.PropertyName}>k__BackingField"; + + // For non-generic types: use regular UnsafeAccessor on .NET 8+ sb.AppendLine("#if NET8_0_OR_GREATER"); sb.AppendLine($" [global::System.Runtime.CompilerServices.UnsafeAccessor(global::System.Runtime.CompilerServices.UnsafeAccessorKind.Field, Name = \"{backingFieldName}\")]"); sb.AppendLine($" private static extern ref {prop.PropertyTypeFullyQualified} Get{prop.PropertyName}BackingField({prop.ContainingTypeFullyQualified} instance);"); @@ -743,12 +821,34 @@ private static void GenerateGenericPropertyInjectionSource(SourceProductionConte foreach (var prop in model.DataSourceProperties) { - GeneratePropertyMetadata(sb, prop, model.ConcreteTypeFullyQualified); + GeneratePropertyMetadata(sb, prop, model.ConcreteTypeFullyQualified, model.SafeTypeName); } sb.AppendLine(" }"); sb.AppendLine("}"); + // Generate generic accessor classes for init-only properties on generic types + foreach (var prop in model.DataSourceProperties) + { + if (prop.IsInitOnly && prop.IsContainingTypeGeneric && prop.GenericTypeParameters != null && prop.ContainingTypeOpenGeneric != null) + { + var backingFieldName = $"<{prop.PropertyName}>k__BackingField"; + var accessorClassName = $"{model.SafeTypeName}_{prop.PropertyName}_GenericAccessor"; + var constraintsClause = prop.GenericTypeConstraints != null ? $" {prop.GenericTypeConstraints}" : ""; + // Use type parameter name if property type is a type parameter (e.g., T), otherwise use concrete type + var returnType = prop.PropertyTypeAsTypeParameter ?? prop.PropertyTypeFullyQualified; + + sb.AppendLine(); + sb.AppendLine("#if NET9_0_OR_GREATER"); + sb.AppendLine($"internal static class {accessorClassName}<{prop.GenericTypeParameters}>{constraintsClause}"); + sb.AppendLine("{"); + sb.AppendLine($" [global::System.Runtime.CompilerServices.UnsafeAccessor(global::System.Runtime.CompilerServices.UnsafeAccessorKind.Field, Name = \"{backingFieldName}\")]"); + sb.AppendLine($" public static extern ref {returnType} GetBackingField({prop.ContainingTypeOpenGeneric} instance);"); + sb.AppendLine("}"); + sb.AppendLine("#endif"); + } + } + context.AddSource(fileName, sb.ToString()); } @@ -958,5 +1058,208 @@ private static string GetNonNullableTypeString(ITypeSymbol typeSymbol) private static string GetNonNullableTypeName(ITypeSymbol typeSymbol) => GetNonNullableTypeString(typeSymbol); + /// + /// Converts a type symbol to CLR type name format suitable for Type.GetType() and UnsafeAccessorType. + /// For generic types, produces format like: "Namespace.Type`1[[TypeArg, Assembly]]" + /// + private static string? GetClrTypeName(INamedTypeSymbol typeSymbol) + { + if (!typeSymbol.IsGenericType) + { + return null; // Not needed for non-generic types + } + + var sb = new StringBuilder(); + + // Build the namespace and containing types + if (typeSymbol.ContainingNamespace != null && !typeSymbol.ContainingNamespace.IsGlobalNamespace) + { + sb.Append(typeSymbol.ContainingNamespace.ToDisplayString()); + sb.Append('.'); + } + + // Handle nested types + var containingTypes = new Stack(); + var current = typeSymbol.ContainingType; + while (current != null) + { + containingTypes.Push(current); + current = current.ContainingType; + } + + foreach (var containingType in containingTypes) + { + sb.Append(containingType.MetadataName); + sb.Append('+'); + } + + // Add the type name with generic arity (e.g., "GenericType`1") + sb.Append(typeSymbol.MetadataName); + + // Add type arguments in CLR format: [[TypeArg1, Assembly], [TypeArg2, Assembly]] + if (typeSymbol.TypeArguments.Length > 0) + { + sb.Append('['); + for (int i = 0; i < typeSymbol.TypeArguments.Length; i++) + { + if (i > 0) sb.Append(','); + sb.Append('['); + sb.Append(GetAssemblyQualifiedTypeName(typeSymbol.TypeArguments[i])); + sb.Append(']'); + } + sb.Append(']'); + } + + // Add assembly name for the containing type + if (typeSymbol.ContainingAssembly != null) + { + sb.Append(", "); + sb.Append(typeSymbol.ContainingAssembly.Name); + } + + return sb.ToString(); + } + + /// + /// Gets the assembly-qualified type name for a type symbol. + /// + private static string GetAssemblyQualifiedTypeName(ITypeSymbol typeSymbol) + { + var sb = new StringBuilder(); + + // Handle generic types recursively + if (typeSymbol is INamedTypeSymbol { IsGenericType: true } namedType) + { + // Build namespace + if (namedType.ContainingNamespace != null && !namedType.ContainingNamespace.IsGlobalNamespace) + { + sb.Append(namedType.ContainingNamespace.ToDisplayString()); + sb.Append('.'); + } + + // Handle nested types + var containingTypes = new Stack(); + var current = namedType.ContainingType; + while (current != null) + { + containingTypes.Push(current); + current = current.ContainingType; + } + + foreach (var containingType in containingTypes) + { + sb.Append(containingType.MetadataName); + sb.Append('+'); + } + + sb.Append(namedType.MetadataName); + + // Add type arguments recursively + if (namedType.TypeArguments.Length > 0) + { + sb.Append('['); + for (int i = 0; i < namedType.TypeArguments.Length; i++) + { + if (i > 0) sb.Append(','); + sb.Append('['); + sb.Append(GetAssemblyQualifiedTypeName(namedType.TypeArguments[i])); + sb.Append(']'); + } + sb.Append(']'); + } + } + else if (typeSymbol is INamedTypeSymbol simpleNamedType) + { + // Build namespace + if (simpleNamedType.ContainingNamespace != null && !simpleNamedType.ContainingNamespace.IsGlobalNamespace) + { + sb.Append(simpleNamedType.ContainingNamespace.ToDisplayString()); + sb.Append('.'); + } + + // Handle nested types + var containingTypes = new Stack(); + var current = simpleNamedType.ContainingType; + while (current != null) + { + containingTypes.Push(current); + current = current.ContainingType; + } + + foreach (var containingType in containingTypes) + { + sb.Append(containingType.MetadataName); + sb.Append('+'); + } + + sb.Append(simpleNamedType.MetadataName); + } + else + { + // Fallback for other type kinds (arrays, pointers, etc.) + sb.Append(typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + .Replace("global::", "")); + } + + // Add assembly name + if (typeSymbol.ContainingAssembly != null) + { + sb.Append(", "); + sb.Append(typeSymbol.ContainingAssembly.Name); + } + + return sb.ToString(); + } + + /// + /// Generates type parameter constraints string (e.g., "where T : class" or "where T1 : class where T2 : struct, new()"). + /// + private static string? GetTypeParameterConstraints(ImmutableArray typeParameters) + { + var constraintParts = new List(); + + foreach (var tp in typeParameters) + { + var constraints = new List(); + + // Primary constraints (must come first) + if (tp.HasReferenceTypeConstraint) + { + constraints.Add("class"); + } + else if (tp.HasValueTypeConstraint) + { + constraints.Add("struct"); + } + else if (tp.HasNotNullConstraint) + { + constraints.Add("notnull"); + } + else if (tp.HasUnmanagedTypeConstraint) + { + constraints.Add("unmanaged"); + } + + // Type constraints (base class and interfaces) + foreach (var constraintType in tp.ConstraintTypes) + { + constraints.Add(constraintType.ToDisplayString()); + } + + // Constructor constraint (must come last) + if (tp.HasConstructorConstraint) + { + constraints.Add("new()"); + } + + if (constraints.Count > 0) + { + constraintParts.Add($"where {tp.Name} : {string.Join(", ", constraints)}"); + } + } + + return constraintParts.Count > 0 ? string.Join(" ", constraintParts) : null; + } + #endregion } diff --git a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs index 3916bebdfd..ff515d62e6 100644 --- a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs @@ -1538,16 +1538,35 @@ private static void GeneratePropertyInjections(CodeWriter writer, INamedTypeSymb { // For init-only properties, use UnsafeAccessor on .NET 8+ (but not for generic types) // UnsafeAccessor doesn't work with open generic types - var containingTypeName = property.ContainingType.GloballyQualified(); - var isGenericContainingType = property.ContainingType.IsGenericType; + // IMPORTANT: Use currentType (which is the closed generic type from the inheritance chain) + // instead of property.ContainingType (which is the open generic type definition) + var containingTypeName = currentType.GloballyQualified(); + var isGenericContainingType = currentType.IsGenericType; if (isGenericContainingType) { - // For generic types, init-only properties with data sources are not supported - // UnsafeAccessor doesn't work with open generic types and reflection is not AOT-compatible - writer.AppendLine($"Setter = (instance, value) => throw new global::System.NotSupportedException("); - writer.AppendLine($" \"Init-only property '{property.Name}' on generic type '{containingTypeName}' cannot be set. \" +"); - writer.AppendLine($" \"Use a regular settable property or constructor injection instead.\"),"); + // For init-only properties on generic types, use reflection with the closed generic type. + // UnsafeAccessor doesn't work with generic base classes, but reflection does. + // This is AOT-compatible because we use the closed generic type known at compile time. + writer.AppendLine("Setter = (instance, value) =>"); + writer.AppendLine("{"); + writer.Indent(); + writer.AppendLine($"var backingField = typeof({containingTypeName}).GetField(\"<{property.Name}>k__BackingField\","); + writer.AppendLine(" global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.NonPublic);"); + writer.AppendLine("if (backingField != null)"); + writer.AppendLine("{"); + writer.Indent(); + writer.AppendLine("backingField.SetValue(instance, value);"); + writer.Unindent(); + writer.AppendLine("}"); + writer.AppendLine("else"); + writer.AppendLine("{"); + writer.Indent(); + writer.AppendLine($"throw new global::System.InvalidOperationException(\"Could not find backing field for property {property.Name} on type {containingTypeName}\");"); + writer.Unindent(); + writer.AppendLine("}"); + writer.Unindent(); + writer.AppendLine("},"); } else { diff --git a/TUnit.Core.SourceGenerator/Models/Extracted/PropertyInjectionModel.cs b/TUnit.Core.SourceGenerator/Models/Extracted/PropertyInjectionModel.cs index 71af56a4bd..28fe2326b7 100644 --- a/TUnit.Core.SourceGenerator/Models/Extracted/PropertyInjectionModel.cs +++ b/TUnit.Core.SourceGenerator/Models/Extracted/PropertyInjectionModel.cs @@ -69,16 +69,58 @@ internal sealed record PropertyDataSourceModel : IEquatable public required string ContainingTypeFullyQualified { get; init; } + /// + /// CLR type name format for UnsafeAccessorType attribute (e.g., "Namespace.GenericType`1[[Namespace.TypeArg, Assembly]]") + /// Only populated for generic containing types. + /// + public required string? ContainingTypeClrName { get; init; } + + /// + /// The open generic type definition with type parameters (e.g., "global::NS.GenericBase<T>") + /// Only populated for generic containing types. + /// + public required string? ContainingTypeOpenGeneric { get; init; } + + /// + /// Comma-separated list of type parameter names (e.g., "T" or "T1, T2") + /// Only populated for generic containing types. + /// + public required string? GenericTypeParameters { get; init; } + + /// + /// Comma-separated list of concrete type arguments (e.g., "global::NS.ProviderType") + /// Only populated for generic containing types. + /// + public required string? GenericTypeArguments { get; init; } + + /// + /// Type parameter constraints (e.g., "where T : class" or "where T1 : class where T2 : struct") + /// Only populated for generic containing types that have constraints. + /// + public required string? GenericTypeConstraints { get; init; } + /// /// Whether the property has an init-only setter /// public required bool IsInitOnly { get; init; } + /// + /// Whether the containing type (where the property is declared) is a generic type + /// + public required bool IsContainingTypeGeneric { get; init; } + /// /// Whether the property is static /// public required bool IsStatic { get; init; } + /// + /// If the property type is a type parameter in the original definition (e.g., "T"), + /// this contains the type parameter name. Otherwise null. + /// Used for UnsafeAccessor generation on generic types. + /// + public required string? PropertyTypeAsTypeParameter { get; init; } + /// /// Whether the property type is a value type /// @@ -112,8 +154,15 @@ public bool Equals(PropertyDataSourceModel? other) && PropertyTypeFullyQualified == other.PropertyTypeFullyQualified && PropertyTypeForTypeof == other.PropertyTypeForTypeof && ContainingTypeFullyQualified == other.ContainingTypeFullyQualified + && ContainingTypeClrName == other.ContainingTypeClrName + && ContainingTypeOpenGeneric == other.ContainingTypeOpenGeneric + && GenericTypeParameters == other.GenericTypeParameters + && GenericTypeArguments == other.GenericTypeArguments + && GenericTypeConstraints == other.GenericTypeConstraints && IsInitOnly == other.IsInitOnly + && IsContainingTypeGeneric == other.IsContainingTypeGeneric && IsStatic == other.IsStatic + && PropertyTypeAsTypeParameter == other.PropertyTypeAsTypeParameter && IsValueType == other.IsValueType && IsNullableValueType == other.IsNullableValueType && AttributeTypeName == other.AttributeTypeName @@ -129,8 +178,15 @@ public override int GetHashCode() hash = (hash * 397) ^ PropertyTypeFullyQualified.GetHashCode(); hash = (hash * 397) ^ PropertyTypeForTypeof.GetHashCode(); hash = (hash * 397) ^ ContainingTypeFullyQualified.GetHashCode(); + hash = (hash * 397) ^ (ContainingTypeClrName?.GetHashCode() ?? 0); + hash = (hash * 397) ^ (ContainingTypeOpenGeneric?.GetHashCode() ?? 0); + hash = (hash * 397) ^ (GenericTypeParameters?.GetHashCode() ?? 0); + hash = (hash * 397) ^ (GenericTypeArguments?.GetHashCode() ?? 0); + hash = (hash * 397) ^ (GenericTypeConstraints?.GetHashCode() ?? 0); hash = (hash * 397) ^ IsInitOnly.GetHashCode(); + hash = (hash * 397) ^ IsContainingTypeGeneric.GetHashCode(); hash = (hash * 397) ^ IsStatic.GetHashCode(); + hash = (hash * 397) ^ (PropertyTypeAsTypeParameter?.GetHashCode() ?? 0); hash = (hash * 397) ^ IsValueType.GetHashCode(); hash = (hash * 397) ^ IsNullableValueType.GetHashCode(); hash = (hash * 397) ^ AttributeTypeName.GetHashCode(); diff --git a/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs b/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs index 9c950da9c5..bf11d1706a 100644 --- a/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs +++ b/TUnit.Core/PropertyInjection/PropertyInjectionPlanBuilder.cs @@ -39,6 +39,7 @@ public static PropertyInjectionPlan BuildSourceGeneratedPlan(Type type) WalkInheritanceChain(type, currentType => { var propertySource = PropertySourceRegistry.GetSource(currentType); + if (propertySource?.ShouldInitialize == true) { foreach (var metadata in propertySource.GetPropertyMetadata()) diff --git a/TUnit.Core/PropertyInjection/PropertySetterFactory.cs b/TUnit.Core/PropertyInjection/PropertySetterFactory.cs index 8bc1662cdf..030dfa9fd7 100644 --- a/TUnit.Core/PropertyInjection/PropertySetterFactory.cs +++ b/TUnit.Core/PropertyInjection/PropertySetterFactory.cs @@ -67,10 +67,34 @@ internal static class PropertySetterFactory #endif } - var backingField = GetBackingField(property); - if (backingField != null) + // Check if the declaring type is an open generic type definition + // In this case, we need to resolve the backing field at runtime using the instance's actual type + var declaringType = property.DeclaringType; + if (declaringType != null && declaringType.IsGenericTypeDefinition) + { + // For open generic types, we must resolve the backing field at runtime + // because we don't know the closed generic type until we have an instance + return (instance, value) => + { + var instanceType = instance.GetType(); + var backingField = GetBackingField(property, instanceType); + if (backingField != null) + { + backingField.SetValue(instance, value); + } + else + { + throw new InvalidOperationException( + $"Property '{property.Name}' on type '{declaringType.Name}' " + + $"is not writable and no backing field was found for instance type '{instanceType.Name}'."); + } + }; + } + + var backingFieldStatic = GetBackingField(property); + if (backingFieldStatic != null) { - return (instance, value) => backingField.SetValue(instance, value); + return (instance, value) => backingFieldStatic.SetValue(instance, value); } throw new InvalidOperationException( @@ -84,7 +108,7 @@ internal static class PropertySetterFactory #if NET6_0_OR_GREATER [RequiresUnreferencedCode("Backing field access for init-only properties requires reflection")] #endif - private static FieldInfo? GetBackingField(PropertyInfo property) + private static FieldInfo? GetBackingField(PropertyInfo property, Type? instanceType = null) { var declaringType = property.DeclaringType; if (declaringType == null) @@ -92,6 +116,17 @@ internal static class PropertySetterFactory return null; } + // If the declaring type is an open generic type definition (e.g., GenericBase), + // we need to find the closed generic type from the instance type's hierarchy + if (declaringType.IsGenericTypeDefinition && instanceType != null) + { + declaringType = FindClosedGenericType(instanceType, declaringType); + if (declaringType == null) + { + return null; + } + } + var backingFieldFlags = BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.FlattenHierarchy; // Try compiler-generated backing field name @@ -128,6 +163,24 @@ internal static class PropertySetterFactory return null; } + /// + /// Finds the closed generic type in the inheritance hierarchy that matches the open generic type definition. + /// + private static Type? FindClosedGenericType(Type instanceType, Type openGenericTypeDefinition) + { + var currentType = instanceType; + while (currentType != null && currentType != typeof(object)) + { + if (currentType.IsGenericType && + currentType.GetGenericTypeDefinition() == openGenericTypeDefinition) + { + return currentType; + } + currentType = currentType.BaseType; + } + return null; + } + /// /// Helper method to get field with proper trimming suppression. /// diff --git a/TUnit.Core/PropertySourceRegistry.cs b/TUnit.Core/PropertySourceRegistry.cs index a95239b460..476114bc37 100644 --- a/TUnit.Core/PropertySourceRegistry.cs +++ b/TUnit.Core/PropertySourceRegistry.cs @@ -102,7 +102,7 @@ public static PropertyInjectionData[] DiscoverInjectableProperties([DynamicallyA { try { - var injection = CreatePropertyInjection(property); + var injection = CreatePropertyInjection(property, type); injectableProperties.Add(injection); } catch (Exception ex) @@ -150,9 +150,9 @@ private static PropertyDataSource ConvertToPropertyDataSource(PropertyInjectionM #if NET6_0_OR_GREATER [RequiresUnreferencedCode("Backing field access for init-only properties requires reflection")] #endif - private static PropertyInjectionData CreatePropertyInjection(System.Reflection.PropertyInfo property) + private static PropertyInjectionData CreatePropertyInjection(System.Reflection.PropertyInfo property, Type? testClassType = null) { - var setter = CreatePropertySetter(property); + var setter = CreatePropertySetter(property, testClassType); return new PropertyInjectionData { @@ -170,7 +170,7 @@ private static PropertyInjectionData CreatePropertyInjection(System.Reflection.P #if NET6_0_OR_GREATER [RequiresUnreferencedCode("Backing field access for init-only properties requires reflection")] #endif - private static Action CreatePropertySetter(System.Reflection.PropertyInfo property) + private static Action CreatePropertySetter(System.Reflection.PropertyInfo property, Type? testClassType = null) { if (property.CanWrite && property.SetMethod != null) { @@ -187,7 +187,7 @@ private static PropertyInjectionData CreatePropertyInjection(System.Reflection.P #endif } - var backingField = GetBackingField(property); + var backingField = GetBackingField(property, testClassType); if (backingField != null) { return (instance, value) => backingField.SetValue(instance, value); @@ -204,7 +204,7 @@ private static PropertyInjectionData CreatePropertyInjection(System.Reflection.P #if NET6_0_OR_GREATER [RequiresUnreferencedCode("Backing field discovery needed for init-only properties in reflection mode")] #endif - private static System.Reflection.FieldInfo? GetBackingField(System.Reflection.PropertyInfo property) + private static System.Reflection.FieldInfo? GetBackingField(System.Reflection.PropertyInfo property, Type? testClassType = null) { var declaringType = property.DeclaringType; if (declaringType == null) @@ -212,6 +212,17 @@ private static PropertyInjectionData CreatePropertyInjection(System.Reflection.P return null; } + // If the declaring type is an open generic type definition (e.g., GenericBase), + // we need to find the closed generic type from the test class hierarchy + if (declaringType.IsGenericTypeDefinition && testClassType != null) + { + declaringType = FindClosedGenericType(testClassType, declaringType); + if (declaringType == null) + { + return null; + } + } + var backingFieldFlags = System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.FlattenHierarchy; var backingFieldName = $"<{property.Name}>k__BackingField"; @@ -248,6 +259,24 @@ private static PropertyInjectionData CreatePropertyInjection(System.Reflection.P return null; } + /// + /// Finds the closed generic type in the inheritance hierarchy that matches the open generic type definition + /// + private static Type? FindClosedGenericType(Type testClassType, Type openGenericTypeDefinition) + { + var currentType = testClassType; + while (currentType != null && currentType != typeof(object)) + { + if (currentType.IsGenericType && + currentType.GetGenericTypeDefinition() == openGenericTypeDefinition) + { + return currentType; + } + currentType = currentType.BaseType; + } + return null; + } + /// /// Helper method to get field with proper trimming suppression /// diff --git a/TUnit.TestProject/Bugs/4431/CompositionPatternTests.cs b/TUnit.TestProject/Bugs/4431/CompositionPatternTests.cs new file mode 100644 index 0000000000..9c15837e7b --- /dev/null +++ b/TUnit.TestProject/Bugs/4431/CompositionPatternTests.cs @@ -0,0 +1,300 @@ +using TUnit.Core.Interfaces; +using TUnit.TestProject.Attributes; + +namespace TUnit.TestProject.Bugs._4431; + +/// +/// Tests that replicate the user's pattern from issue #4431 comments. +/// The user is trying to use composition (new T()) instead of inheritance, +/// which means TUnit's DI system doesn't process the ClassDataSource attributes. +/// + +#region User's Pattern - Composition (This Pattern Does NOT Work) + +/// +/// Interface for database providers (matches user's IDbProvider). +/// +public interface IDbProvider4431 +{ + string GetConnectionString(); +} + +/// +/// Data source that implements IAsyncInitializer. +/// In the user's case, this would be a container like Postgres. +/// +public class DatabaseContainer4431 : IAsyncInitializer +{ + public bool IsInitialized { get; private set; } + public string? ConnectionString { get; private set; } + + public Task InitializeAsync() + { + Console.WriteLine("DatabaseContainer4431.InitializeAsync starting"); + IsInitialized = true; + ConnectionString = "Server=container;Database=test"; + Console.WriteLine("DatabaseContainer4431.InitializeAsync completed"); + return Task.CompletedTask; + } +} + +/// +/// User's ParentTest1 pattern - a sealed class with ClassDataSource. +/// This class is NOT in the test's inheritance chain. +/// +public sealed class ProviderWithClassDataSource4431 : IDbProvider4431 +{ + [ClassDataSource(Shared = SharedType.PerTestSession)] + public DatabaseContainer4431 Database { get; init; } = null!; + + public string GetConnectionString() + { + // This will throw NullReferenceException because Database is not injected + // when this class is created via new T() + return Database.ConnectionString!; + } +} + +/// +/// User's TestBaseDatabase pattern - creates provider via new T(). +/// This is the problematic pattern because new T() doesn't go through TUnit's DI. +/// +public abstract class TestBaseDatabaseWithComposition4431 where T : IDbProvider4431, new() +{ + // This creates T via plain constructor - TUnit doesn't process ClassDataSource attributes + protected readonly T Provider = new T(); +} + +/// +/// This test demonstrates the user's exact pattern. +/// It is expected to FAIL because the composition pattern doesn't work with TUnit's DI. +/// +[EngineTest(ExpectedResult.Failure)] +public class CompositionPatternDoesNotWork_4431 : TestBaseDatabaseWithComposition4431 +{ + [Test] + public async Task Composition_DoesNotInjectDataSources() + { + // This WILL FAIL with NullReferenceException + // Because Provider was created with new T(), not by TUnit's DI + // The ClassDataSource attribute on ProviderWithClassDataSource4431.Database is never processed + var connectionString = Provider.GetConnectionString(); + await Assert.That(connectionString).IsNotNull(); + } +} + +#endregion + +#region Correct Pattern - Using Inheritance + +/// +/// The CORRECT way to achieve the user's goal: use inheritance instead of composition. +/// The base class has the ClassDataSource, and child tests inherit from it. +/// +public abstract class TestBaseDatabaseWithInheritance4431 +{ + [ClassDataSource(Shared = SharedType.PerTestSession)] + public DatabaseContainer4431 Database { get; init; } = null!; + + public string GetConnectionString() => Database.ConnectionString!; +} + +/// +/// Test using the correct inheritance pattern. +/// +[EngineTest(ExpectedResult.Pass)] +[NotInParallel(nameof(InheritancePatternWorks_4431))] +public class InheritancePatternWorks_4431 : TestBaseDatabaseWithInheritance4431 +{ + [Test] + public async Task Inheritance_InjectsDataSources() + { + // This WILL PASS because Database is in the inheritance chain + // TUnit processes ClassDataSource on base classes + await Assert.That(Database).IsNotNull(); + await Assert.That(Database.IsInitialized).IsTrue(); + await Assert.That(GetConnectionString()).IsEqualTo("Server=container;Database=test"); + } +} + +#endregion + +#region Alternative Pattern - ClassDataSource on Property Using Interface Type + +/// +/// Alternative pattern: Use ClassDataSource directly on the test class, +/// referencing the provider type that implements IDbProvider. +/// +[EngineTest(ExpectedResult.Pass)] +[NotInParallel(nameof(DirectClassDataSourcePattern_4431))] +public class DirectClassDataSourcePattern_4431 +{ + // Use ClassDataSource to get an instance that TUnit manages + [ClassDataSource(Shared = SharedType.PerTestSession)] + public ProviderWithClassDataSource4431 Provider { get; init; } = null!; + + [Test] + public async Task DirectClassDataSource_InjectsProviderAndDependencies() + { + // TUnit will: + // 1. Create ProviderWithClassDataSource4431 + // 2. Inject its ClassDataSource property + // 3. Initialize DatabaseContainer4431 (IAsyncInitializer) + // 4. Inject the fully initialized provider here + await Assert.That(Provider).IsNotNull(); + await Assert.That(Provider.Database).IsNotNull(); + await Assert.That(Provider.Database.IsInitialized).IsTrue(); + await Assert.That(Provider.GetConnectionString()).IsEqualTo("Server=container;Database=test"); + } +} + +#endregion + +#region Option 3: Non-Generic ClassDataSource with Type Inference + +/// +/// Test using the non-generic ClassDataSource on the concrete class. +/// The type is inferred from the property type. +/// +[EngineTest(ExpectedResult.Pass)] +[NotInParallel(nameof(NonGenericClassDataSourceWithTypeInference_4431))] +public class NonGenericClassDataSourceWithTypeInference_4431 +{ + // The non-generic [ClassDataSource] infers the type from the property type + [ClassDataSource(Shared = [SharedType.PerTestSession])] + public ProviderWithClassDataSource4431 Provider { get; init; } = default!; + + [Test] + public async Task NonGenericClassDataSource_InfersTypeFromProperty() + { + // TUnit infers the type from the property type (ProviderWithClassDataSource4431) + // and properly injects it with all nested dependencies + await Assert.That(Provider).IsNotNull(); + await Assert.That(Provider.Database).IsNotNull(); + await Assert.That(Provider.Database.IsInitialized).IsTrue(); + await Assert.That(Provider.GetConnectionString()).IsEqualTo("Server=container;Database=test"); + } +} + +/// +/// Generic base class that uses [ClassDataSource] (non-generic) to infer type from property. +/// This allows the property type to be a generic type parameter T. +/// +public abstract class GenericBaseWithInferredClassDataSource where T : class, IDbProvider4431 +{ + // The non-generic [ClassDataSource] infers the type from the property type T + [ClassDataSource(Shared = [SharedType.PerTestSession])] + public T Provider { get; init; } = default!; +} + +/// +/// Test using the non-generic ClassDataSource with a generic base class. +/// This is the cleanest solution for the user's scenario. +/// Currently fails due to backing field lookup issue with generic types. +/// +[EngineTest(ExpectedResult.Pass)] +[NotInParallel(nameof(NonGenericClassDataSourceWithGenericBase_First_4431))] +public class NonGenericClassDataSourceWithGenericBase_First_4431 + : GenericBaseWithInferredClassDataSource +{ + [Test] + public async Task NonGenericClassDataSource_OnGenericBase_InfersTypeFromProperty() + { + // TUnit should infer the type from the property type (ProviderWithClassDataSource4431) + // and properly inject it with all nested dependencies + await Assert.That(Provider).IsNotNull(); + await Assert.That(Provider.Database).IsNotNull(); + await Assert.That(Provider.Database.IsInitialized).IsTrue(); + await Assert.That(Provider.GetConnectionString()).IsEqualTo("Server=container;Database=test"); + } +} + +/// +/// Second test using a different provider type with the same generic base. +/// +[EngineTest(ExpectedResult.Pass)] +[NotInParallel(nameof(NonGenericClassDataSourceWithGenericBase_Second_4431))] +public class NonGenericClassDataSourceWithGenericBase_Second_4431 + : GenericBaseWithInferredClassDataSource +{ + [Test] + public async Task NonGenericClassDataSource_OnGenericBase_WorksWithDifferentTypes() + { + await Assert.That(Provider).IsNotNull(); + await Assert.That(Provider.Database).IsNotNull(); + await Assert.That(Provider.Database.IsInitialized).IsTrue(); + await Assert.That(Provider.GetConnectionString()).IsEqualTo("Server=secondary;Database=other"); + } +} + +#endregion + +#region Additional Test Scenarios + +/// +/// Second data source for testing with different databases. +/// +public class SecondDatabaseContainer4431 : IAsyncInitializer +{ + public bool IsInitialized { get; private set; } + public string? ConnectionString { get; private set; } + + public Task InitializeAsync() + { + IsInitialized = true; + ConnectionString = "Server=secondary;Database=other"; + return Task.CompletedTask; + } +} + +/// +/// Provider that uses the second database. +/// +public sealed class SecondProviderWithClassDataSource4431 : IDbProvider4431 +{ + [ClassDataSource(Shared = SharedType.PerTestSession)] + public SecondDatabaseContainer4431 Database { get; init; } = null!; + + public string GetConnectionString() => Database.ConnectionString!; +} + +/// +/// Concrete test class that directly uses ClassDataSource for the first provider. +/// +[EngineTest(ExpectedResult.Pass)] +[NotInParallel(nameof(GenericBaseWithConcreteProvider_First_4431))] +public class GenericBaseWithConcreteProvider_First_4431 +{ + [ClassDataSource(Shared = SharedType.PerTestSession)] + public ProviderWithClassDataSource4431 Provider { get; init; } = null!; + + [Test] + public async Task DirectClassDataSource_WithNestedInjection_Works() + { + await Assert.That(Provider).IsNotNull(); + await Assert.That(Provider.Database).IsNotNull(); + await Assert.That(Provider.Database.IsInitialized).IsTrue(); + } +} + +/// +/// Second concrete test class using a different provider. +/// +[EngineTest(ExpectedResult.Pass)] +[NotInParallel(nameof(GenericBaseWithConcreteProvider_Second_4431))] +public class GenericBaseWithConcreteProvider_Second_4431 +{ + [ClassDataSource(Shared = SharedType.PerTestSession)] + public SecondProviderWithClassDataSource4431 Provider { get; init; } = null!; + + [Test] + public async Task DirectClassDataSource_WithDifferentProvider_Works() + { + await Assert.That(Provider).IsNotNull(); + await Assert.That(Provider.Database).IsNotNull(); + await Assert.That(Provider.Database.IsInitialized).IsTrue(); + await Assert.That(Provider.GetConnectionString()).IsEqualTo("Server=secondary;Database=other"); + } +} + +#endregion diff --git a/TUnit.TestProject/Bugs/4431/UnsafeAccessorGenericTest.cs b/TUnit.TestProject/Bugs/4431/UnsafeAccessorGenericTest.cs new file mode 100644 index 0000000000..c405f7aebf --- /dev/null +++ b/TUnit.TestProject/Bugs/4431/UnsafeAccessorGenericTest.cs @@ -0,0 +1,65 @@ +#if NET8_0_OR_GREATER +using System.Runtime.CompilerServices; + +namespace TUnit.TestProject.Bugs._4431; + +/// +/// Minimal reproduction test for UnsafeAccessor with generic types. +/// This test demonstrates that UnsafeAccessor does NOT work with fields on generic base classes. +/// The reflection-based approach works correctly. +/// +public class UnsafeAccessorGenericTest +{ + [Test] + [Category("KnownLimitation")] + public async Task UnsafeAccessor_FailsWithGenericBaseClass() + { + // Create an instance of the derived class + var derivedInstance = new NonGenericClassDataSourceWithGenericBase_First_4431(); + var providerValue = new ProviderWithClassDataSource4431(); + + // UnsafeAccessor does NOT work with fields on generic base classes + // This is a known .NET limitation + var ex = Assert.Throws(() => + { + UnsafeAccessorHelper.SetProviderField(derivedInstance, providerValue); + }); + + // The error message shows the open generic type (with backtick notation) + await Assert.That(ex!.Message).Contains("GenericBaseWithInferredClassDataSource"); + } + + [Test] + public async Task Reflection_ShouldWork_WithGenericBaseClass() + { + // Create an instance of the derived class + var derivedInstance = new NonGenericClassDataSourceWithGenericBase_First_4431(); + var providerValue = new ProviderWithClassDataSource4431(); + + // Reflection works correctly with the closed generic type + var closedGenericBaseType = typeof(GenericBaseWithInferredClassDataSource); + var backingField = closedGenericBaseType.GetField("k__BackingField", + System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic); + + await Assert.That(backingField).IsNotNull(); + + backingField!.SetValue(derivedInstance, providerValue); + + await Assert.That(derivedInstance.Provider).IsNotNull(); + await Assert.That(derivedInstance.Provider).IsSameReferenceAs(providerValue); + } +} + +internal static class UnsafeAccessorHelper +{ + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "k__BackingField")] + internal static extern ref ProviderWithClassDataSource4431 GetProviderBackingField( + GenericBaseWithInferredClassDataSource instance); + + public static void SetProviderField(object instance, ProviderWithClassDataSource4431 value) + { + var typedInstance = (GenericBaseWithInferredClassDataSource)instance; + GetProviderBackingField(typedInstance) = value; + } +} +#endif