diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/InstanceFactoryGenerator.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/InstanceFactoryGenerator.cs index daff8f2a26..ccc4f8d02d 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/InstanceFactoryGenerator.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/InstanceFactoryGenerator.cs @@ -6,6 +6,37 @@ namespace TUnit.Core.SourceGenerator.CodeGenerators.Helpers; public static class InstanceFactoryGenerator { + /// + /// Checks if the given type has a ClassConstructor attribute on the class/base types OR at the assembly level. + /// + public static bool HasClassConstructorAttribute(INamedTypeSymbol namedTypeSymbol) + { + var hasOnClass = namedTypeSymbol.GetAttributesIncludingBaseTypes() + .Any(a => a.AttributeClass?.GloballyQualifiedNonGeneric() == WellKnownFullyQualifiedClassNames.ClassConstructorAttribute.WithGlobalPrefix); + + if (hasOnClass) + { + return true; + } + + return namedTypeSymbol.ContainingAssembly.GetAttributes() + .Any(a => a.AttributeClass?.GloballyQualifiedNonGeneric() == WellKnownFullyQualifiedClassNames.ClassConstructorAttribute.WithGlobalPrefix); + } + + /// + /// Generates the ClassConstructor throw-stub InstanceFactory. + /// + public static void GenerateClassConstructorStub(CodeWriter writer) + { + writer.AppendLine("InstanceFactory = (typeArgs, args) =>"); + writer.AppendLine("{"); + writer.Indent(); + writer.AppendLine("// ClassConstructor attribute is present - instance creation handled at runtime"); + writer.AppendLine("throw new global::System.NotSupportedException(\"Instance creation for classes with ClassConstructor attribute is handled at runtime\");"); + writer.Unindent(); + writer.AppendLine("},"); + } + /// /// Generates code to create an instance of a type with proper required property handling. /// This handles required properties that don't have data sources by initializing them with defaults. @@ -46,26 +77,11 @@ public static void GenerateInstanceFactory(CodeWriter writer, ITypeSymbol typeSy { var className = typeSymbol.GloballyQualified(); - // Check if the class has a ClassConstructor attribute first, before any other checks - if (typeSymbol is INamedTypeSymbol namedTypeSymbol) + // Check if the class has a ClassConstructor attribute first (class, base types, or assembly level) + if (typeSymbol is INamedTypeSymbol namedTypeSymbol && HasClassConstructorAttribute(namedTypeSymbol)) { - var hasClassConstructor = namedTypeSymbol.GetAttributesIncludingBaseTypes() - .Any(a => a.AttributeClass?.GloballyQualifiedNonGeneric() == WellKnownFullyQualifiedClassNames.ClassConstructorAttribute.WithGlobalPrefix); - - if (hasClassConstructor) - { - // If class has ClassConstructor attribute, generate a factory that throws - // The actual instance creation will be handled by ClassConstructorHelper at runtime - // This applies to both generic and non-generic classes - writer.AppendLine("InstanceFactory = (typeArgs, args) =>"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("// ClassConstructor attribute is present - instance creation handled at runtime"); - writer.AppendLine("throw new global::System.NotSupportedException(\"Instance creation for classes with ClassConstructor attribute is handled at runtime\");"); - writer.Unindent(); - writer.AppendLine("},"); - return; - } + GenerateClassConstructorStub(writer); + return; } // Check if this is a generic type definition diff --git a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs index c8d0d8022f..60ffbb62b1 100644 --- a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs @@ -2979,31 +2979,38 @@ private static void GenerateGenericTestWithConcreteTypes( GenerateMetadataForConcreteInstantiation(writer, testMethod); // Generate instance factory that works with generic types - writer.AppendLine("InstanceFactory = static (typeArgs, args) =>"); - writer.AppendLine("{"); - writer.Indent(); - - if (testMethod.IsGenericType) + if (InstanceFactoryGenerator.HasClassConstructorAttribute(testMethod.TypeSymbol)) { - // For generic classes, we need to use runtime type construction - var openGenericTypeName = GetOpenGenericTypeName(testMethod.TypeSymbol); - writer.AppendLine($"var genericType = typeof({openGenericTypeName});"); - writer.AppendLine("if (typeArgs.Length > 0)"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("var closedType = genericType.MakeGenericType(typeArgs);"); - writer.AppendLine("return global::System.Activator.CreateInstance(closedType, args)!;"); - writer.Unindent(); - writer.AppendLine("}"); - writer.AppendLine("throw new global::System.InvalidOperationException(\"No type arguments provided for generic class\");"); + InstanceFactoryGenerator.GenerateClassConstructorStub(writer); } else { - writer.AppendLine($"return new {className}();"); - } + writer.AppendLine("InstanceFactory = static (typeArgs, args) =>"); + writer.AppendLine("{"); + writer.Indent(); - writer.Unindent(); - writer.AppendLine("},"); + if (testMethod.IsGenericType) + { + // For generic classes, we need to use runtime type construction + var openGenericTypeName = GetOpenGenericTypeName(testMethod.TypeSymbol); + writer.AppendLine($"var genericType = typeof({openGenericTypeName});"); + writer.AppendLine("if (typeArgs.Length > 0)"); + writer.AppendLine("{"); + writer.Indent(); + writer.AppendLine("var closedType = genericType.MakeGenericType(typeArgs);"); + writer.AppendLine("return global::System.Activator.CreateInstance(closedType, args)!;"); + writer.Unindent(); + writer.AppendLine("}"); + writer.AppendLine("throw new global::System.InvalidOperationException(\"No type arguments provided for generic class\");"); + } + else + { + writer.AppendLine($"return new {className}();"); + } + + writer.Unindent(); + writer.AppendLine("},"); + } // Generate concrete instantiations dictionary writer.AppendLine("ConcreteInstantiations = new global::System.Collections.Generic.Dictionary"); @@ -4490,53 +4497,60 @@ private static void GenerateConcreteTestMetadata( GenerateConcreteMetadataWithFilteredDataSources(writer, testMethod, specificArgumentsAttribute, typeArguments); // Generate instance factory - writer.AppendLine("InstanceFactory = static (typeArgs, args) =>"); - writer.AppendLine("{"); - writer.Indent(); - - // Check if the class has a constructor that requires arguments - var hasParameterizedConstructor = false; - var constructorParamCount = 0; - - if (testMethod.IsGenericType) + if (InstanceFactoryGenerator.HasClassConstructorAttribute(testMethod.TypeSymbol)) { - // Find the primary constructor or first public constructor - var constructor = testMethod.TypeSymbol.Constructors - .Where(c => !c.IsStatic && c.DeclaredAccessibility == Accessibility.Public) - .OrderByDescending(c => c.Parameters.Length) - .FirstOrDefault(); + InstanceFactoryGenerator.GenerateClassConstructorStub(writer); + } + else + { + writer.AppendLine("InstanceFactory = static (typeArgs, args) =>"); + writer.AppendLine("{"); + writer.Indent(); - if (constructor is { Parameters.Length: > 0 }) + // Check if the class has a constructor that requires arguments + var hasParameterizedConstructor = false; + var constructorParamCount = 0; + + if (testMethod.IsGenericType) { - hasParameterizedConstructor = true; - constructorParamCount = constructor.Parameters.Length; + // Find the primary constructor or first public constructor + var constructor = testMethod.TypeSymbol.Constructors + .Where(c => !c.IsStatic && c.DeclaredAccessibility == Accessibility.Public) + .OrderByDescending(c => c.Parameters.Length) + .FirstOrDefault(); + + if (constructor is { Parameters.Length: > 0 }) + { + hasParameterizedConstructor = true; + constructorParamCount = constructor.Parameters.Length; + } } - } - if (hasParameterizedConstructor) - { - // For classes with constructor parameters, use the specific constructor arguments from the Arguments attribute - if (specificArgumentsAttribute is { ConstructorArguments.Length: > 0 } && - specificArgumentsAttribute.ConstructorArguments[0].Kind == TypedConstantKind.Array) + if (hasParameterizedConstructor) { - var argumentValues = specificArgumentsAttribute.ConstructorArguments[0].Values; - var constructorArgs = string.Join(", ", argumentValues.Select(arg => TypedConstantParser.GetRawTypedConstantValue(arg))); + // For classes with constructor parameters, use the specific constructor arguments from the Arguments attribute + if (specificArgumentsAttribute is { ConstructorArguments.Length: > 0 } && + specificArgumentsAttribute.ConstructorArguments[0].Kind == TypedConstantKind.Array) + { + var argumentValues = specificArgumentsAttribute.ConstructorArguments[0].Values; + var constructorArgs = string.Join(", ", argumentValues.Select(arg => TypedConstantParser.GetRawTypedConstantValue(arg))); - writer.AppendLine($"return ({concreteClassName})global::System.Activator.CreateInstance(typeof({concreteClassName}), new object[] {{ {constructorArgs} }})!;"); + writer.AppendLine($"return ({concreteClassName})global::System.Activator.CreateInstance(typeof({concreteClassName}), new object[] {{ {constructorArgs} }})!;"); + } + else + { + // Fallback to using args if no specific Arguments attribute + writer.AppendLine($"return ({concreteClassName})global::System.Activator.CreateInstance(typeof({concreteClassName}), args)!;"); + } } else { - // Fallback to using args if no specific Arguments attribute - writer.AppendLine($"return ({concreteClassName})global::System.Activator.CreateInstance(typeof({concreteClassName}), args)!;"); + writer.AppendLine($"return new {concreteClassName}();"); } - } - else - { - writer.AppendLine($"return new {concreteClassName}();"); - } - writer.Unindent(); - writer.AppendLine("},"); + writer.Unindent(); + writer.AppendLine("},"); + } // Generate strongly-typed test invoker writer.AppendLine("InvokeTypedTest = static (instance, args, cancellationToken) =>"); @@ -5082,59 +5096,66 @@ private static void GenerateConcreteTestMetadataForNonGeneric( SourceInformationWriter.GenerateMethodInformation(writer, compilation, testMethod.TypeSymbol, testMethod.MethodSymbol, null, ','); // Generate instance factory - writer.AppendLine("InstanceFactory = static (typeArgs, args) =>"); - writer.AppendLine("{"); - writer.Indent(); - - // Check if the class has a constructor that requires arguments - var hasParameterizedConstructor = false; - var constructorParamCount = 0; - - // Find the primary constructor or first public constructor - var constructor = testMethod.TypeSymbol.Constructors - .Where(c => !c.IsStatic && c.DeclaredAccessibility == Accessibility.Public) - .OrderByDescending(c => c.Parameters.Length) - .FirstOrDefault(); - - if (constructor is { Parameters.Length: > 0 }) + if (InstanceFactoryGenerator.HasClassConstructorAttribute(testMethod.TypeSymbol)) { - hasParameterizedConstructor = true; - constructorParamCount = constructor.Parameters.Length; + InstanceFactoryGenerator.GenerateClassConstructorStub(writer); } - - if (hasParameterizedConstructor) + else { - // For classes with constructor parameters, check if we have Arguments attribute - var isArgumentsAttribute = classDataSourceAttribute?.AttributeClass?.Name == "ArgumentsAttribute"; + writer.AppendLine("InstanceFactory = static (typeArgs, args) =>"); + writer.AppendLine("{"); + writer.Indent(); - if (isArgumentsAttribute && classDataSourceAttribute is { ConstructorArguments.Length: > 0 } && - classDataSourceAttribute.ConstructorArguments[0].Kind == TypedConstantKind.Array) + // Check if the class has a constructor that requires arguments + var hasParameterizedConstructor = false; + var constructorParamCount = 0; + + // Find the primary constructor or first public constructor + var constructor = testMethod.TypeSymbol.Constructors + .Where(c => !c.IsStatic && c.DeclaredAccessibility == Accessibility.Public) + .OrderByDescending(c => c.Parameters.Length) + .FirstOrDefault(); + + if (constructor is { Parameters.Length: > 0 }) + { + hasParameterizedConstructor = true; + constructorParamCount = constructor.Parameters.Length; + } + + if (hasParameterizedConstructor) { - var argumentValues = classDataSourceAttribute.ConstructorArguments[0].Values; - var constructorArgs = string.Join(", ", argumentValues.Select(arg => TypedConstantParser.GetRawTypedConstantValue(arg))); + // For classes with constructor parameters, check if we have Arguments attribute + var isArgumentsAttribute = classDataSourceAttribute?.AttributeClass?.Name == "ArgumentsAttribute"; - writer.AppendLine($"return new {className}({constructorArgs});"); + if (isArgumentsAttribute && classDataSourceAttribute is { ConstructorArguments.Length: > 0 } && + classDataSourceAttribute.ConstructorArguments[0].Kind == TypedConstantKind.Array) + { + var argumentValues = classDataSourceAttribute.ConstructorArguments[0].Values; + var constructorArgs = string.Join(", ", argumentValues.Select(arg => TypedConstantParser.GetRawTypedConstantValue(arg))); + + writer.AppendLine($"return new {className}({constructorArgs});"); + } + else + { + // Use the args parameter if no specific arguments are provided + writer.AppendLine($"if (args.Length >= {constructorParamCount})"); + writer.AppendLine("{"); + writer.Indent(); + writer.AppendLine($"return new {className}({string.Join(", ", Enumerable.Range(0, constructorParamCount).Select(i => $"args[{i}]"))});"); + writer.Unindent(); + writer.AppendLine("}"); + writer.AppendLine("throw new global::System.InvalidOperationException(\"Not enough arguments provided for class constructor\");"); + } } else { - // Use the args parameter if no specific arguments are provided - writer.AppendLine($"if (args.Length >= {constructorParamCount})"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine($"return new {className}({string.Join(", ", Enumerable.Range(0, constructorParamCount).Select(i => $"args[{i}]"))});"); - writer.Unindent(); - writer.AppendLine("}"); - writer.AppendLine("throw new global::System.InvalidOperationException(\"Not enough arguments provided for class constructor\");"); + // No constructor parameters needed + writer.AppendLine($"return new {className}();"); } - } - else - { - // No constructor parameters needed - writer.AppendLine($"return new {className}();"); - } - writer.Unindent(); - writer.AppendLine("},"); + writer.Unindent(); + writer.AppendLine("},"); + } // Generate typed invoker GenerateTypedInvokers(writer, testMethod, className); diff --git a/TUnit.Engine/Building/TestBuilder.cs b/TUnit.Engine/Building/TestBuilder.cs index 7d2d0d22ea..6d308f8d8f 100644 --- a/TUnit.Engine/Building/TestBuilder.cs +++ b/TUnit.Engine/Building/TestBuilder.cs @@ -221,70 +221,35 @@ public async Task> BuildTestsFromMetadataAsy if (needsInstanceForMethodDataSources) { - try - { - // Try to resolve class generic types using class data for early instance creation - if (metadata.TestClassType.IsGenericTypeDefinition) - { - var tempTestData = new TestData - { - TestClassInstanceFactory = () => Task.FromResult(null!), - ClassDataSourceAttributeIndex = classDataAttributeIndex, - ClassDataLoopIndex = classDataLoopIndex, - ClassData = classData, - MethodDataSourceAttributeIndex = 0, - MethodDataLoopIndex = 0, - MethodData = [], - RepeatIndex = 0, - InheritanceDepth = metadata.InheritanceDepth - }; - - try - { - var resolution = TestGenericTypeResolver.Resolve(metadata, tempTestData); - instanceForMethodDataSources = metadata.InstanceFactory(resolution.ResolvedClassGenericArguments, classData); - } - catch (GenericTypeResolutionException) when (classData.Length == 0) - { - // If we can't resolve from constructor args, try to infer from data sources - var resolvedTypes = TryInferClassGenericsFromDataSources(metadata); - instanceForMethodDataSources = metadata.InstanceFactory(resolvedTypes, classData); - } - } - else - { - // Non-generic class - instanceForMethodDataSources = metadata.InstanceFactory([], classData); - } + var instanceResult = await CreateInstanceForMethodDataSources( + metadata, classDataAttributeIndex, classDataLoopIndex, classData, testBuilderContext); - // Initialize property data sources on the early instance so that - // method data sources can access fully-initialized properties. - // This is critical for scenarios like: - // [ClassDataSource>] public required ErrFixture Fixture { get; init; } - // public IEnumerable> TestExecutions => [() => Fixture.Value]; - // [MethodDataSource("TestExecutions")] [Test] public void MyTest(T value) { } - if (instanceForMethodDataSources != null) - { - var tempObjectBag = new ConcurrentDictionary(); - var tempEvents = new TestContextEvents(); - - await _objectLifecycleService.RegisterObjectAsync( - instanceForMethodDataSources, - tempObjectBag, - metadata.MethodMetadata, - tempEvents, - cancellationToken); - - // Discovery: only IAsyncDiscoveryInitializer is initialized - await ObjectInitializer.InitializeForDiscoveryAsync(instanceForMethodDataSources); - } - } - catch (Exception ex) + if (!instanceResult.Success) { - var failedTest = CreateFailedTestForInstanceDataSourceError(metadata, ex); + var failedTest = CreateFailedTestForInstanceDataSourceError(metadata, instanceResult.Exception!); tests.Add(failedTest); continue; } + + instanceForMethodDataSources = instanceResult.Instance; + + // Initialize property data sources on the early instance so that + // method data sources can access fully-initialized properties. + if (instanceForMethodDataSources != null) + { + var tempObjectBag = new ConcurrentDictionary(); + var tempEvents = new TestContextEvents(); + + await _objectLifecycleService.RegisterObjectAsync( + instanceForMethodDataSources, + tempObjectBag, + metadata.MethodMetadata, + tempEvents, + cancellationToken); + + // Discovery: only IAsyncDiscoveryInitializer is initialized + await ObjectInitializer.InitializeForDiscoveryAsync(instanceForMethodDataSources); + } } var methodDataAttributeIndex = 0; @@ -1607,7 +1572,7 @@ public async IAsyncEnumerable BuildTestsStreamingAsync( if (needsInstanceForMethodDataSources) { var instanceResult = await CreateInstanceForMethodDataSources( - metadata, classDataAttributeIndex, classDataLoopIndex, classData); + metadata, classDataAttributeIndex, classDataLoopIndex, classData, contextAccessor.Current); if (!instanceResult.Success) { @@ -1683,11 +1648,25 @@ await _objectLifecycleService.RegisterObjectAsync( #if NET6_0_OR_GREATER [RequiresUnreferencedCode("Generic type resolution for instance creation uses reflection")] #endif - private Task CreateInstanceForMethodDataSources( - TestMetadata metadata, int classDataAttributeIndex, int classDataLoopIndex, object?[] classData) + private async Task CreateInstanceForMethodDataSources( + TestMetadata metadata, int classDataAttributeIndex, int classDataLoopIndex, object?[] classData, TestBuilderContext testBuilderContext) { try { + // Try ClassConstructor first - if one is configured, it handles instance creation + var attributes = testBuilderContext.InitializedAttributes ?? metadata.GetOrCreateAttributes(); + var instance = await ClassConstructorHelper.TryCreateInstanceWithClassConstructor( + attributes, + metadata.TestClassType, + testBuilderContext, + metadata.TestSessionId); + + if (instance != null) + { + return InstanceCreationResult.CreateSuccess(instance); + } + + // Fall back to InstanceFactory if (metadata.TestClassType.IsGenericTypeDefinition) { var tempTestData = new TestData @@ -1706,22 +1685,22 @@ private Task CreateInstanceForMethodDataSources( try { var resolution = TestGenericTypeResolver.Resolve(metadata, tempTestData); - return Task.FromResult(InstanceCreationResult.CreateSuccess(metadata.InstanceFactory(resolution.ResolvedClassGenericArguments, classData))); + return InstanceCreationResult.CreateSuccess(metadata.InstanceFactory(resolution.ResolvedClassGenericArguments, classData)); } catch (GenericTypeResolutionException) when (classData.Length == 0) { var resolvedTypes = TryInferClassGenericsFromDataSources(metadata); - return Task.FromResult(InstanceCreationResult.CreateSuccess(metadata.InstanceFactory(resolvedTypes, classData))); + return InstanceCreationResult.CreateSuccess(metadata.InstanceFactory(resolvedTypes, classData)); } } else { - return Task.FromResult(InstanceCreationResult.CreateSuccess(metadata.InstanceFactory([], classData))); + return InstanceCreationResult.CreateSuccess(metadata.InstanceFactory([], classData)); } } catch (Exception ex) { - return Task.FromResult(InstanceCreationResult.CreateFailure(ex)); + return InstanceCreationResult.CreateFailure(ex); } }