diff --git a/src/DependencyInjection.Tests/DependencyInjection.Tests.csproj b/src/DependencyInjection.Tests/DependencyInjection.Tests.csproj index 390b88f..d5c0b32 100644 --- a/src/DependencyInjection.Tests/DependencyInjection.Tests.csproj +++ b/src/DependencyInjection.Tests/DependencyInjection.Tests.csproj @@ -10,6 +10,8 @@ + + diff --git a/src/DependencyInjection.Tests/Regressions.cs b/src/DependencyInjection.Tests/Regressions.cs new file mode 100644 index 0000000..8bb0114 --- /dev/null +++ b/src/DependencyInjection.Tests/Regressions.cs @@ -0,0 +1,55 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Moq; +using Spectre.Console.Cli; + +namespace Tests.Regressions; + +public class Regressions +{ + [Fact] + public void CovariantRegistrationSatisfiesIntefaceConstraints() + { + var collection = new ServiceCollection(); + collection.AddServices(typeof(ICommand)); + + var provider = collection.BuildServiceProvider(); + + var command = provider.GetRequiredService(); + + Assert.Equal(0, command.Execute(new CommandContext([], Mock.Of(), "my", null), + new MySetting { Base = "", Name = "" })); + } +} + +public interface ISetting +{ + string Name { get; set; } +} + +public class BaseSetting : CommandSettings +{ + [CommandArgument(0, "")] + public required string Base { get; init; } +} + +public class MySetting : BaseSetting, ISetting +{ + [CommandOption("--name")] + public required string Name { get; set; } +} + +public class MyCommand : BaseCommand { } + +public abstract class BaseCommand : Command where TSettings : BaseSetting, ISetting +{ + public override int Execute(CommandContext context, TSettings settings) + { + Console.WriteLine($"Base: {settings.Base}, Name: {settings.Name}"); + return 0; + } +} \ No newline at end of file diff --git a/src/DependencyInjection/ConstraintsChecker.cs b/src/DependencyInjection/ConstraintsChecker.cs new file mode 100644 index 0000000..1285fe8 --- /dev/null +++ b/src/DependencyInjection/ConstraintsChecker.cs @@ -0,0 +1,60 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection.Metadata; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Devlooped.Extensions.DependencyInjection; + +static class ConstraintsChecker +{ + public static bool SatisfiesConstraints(this ITypeSymbol typeArgument, ITypeParameterSymbol typeParameter) + { + // Check reference type constraint + if (typeParameter.HasReferenceTypeConstraint && !typeArgument.IsReferenceType) + return false; + + // Check value type constraint + if (typeParameter.HasValueTypeConstraint && !typeArgument.IsValueType) + return false; + + // Check base class and interface constraints + foreach (var constraint in typeParameter.ConstraintTypes) + { + if (constraint.TypeKind == TypeKind.Class) + { + if (!typeArgument.GetBaseTypes().Any(baseType => SymbolEqualityComparer.Default.Equals(baseType, constraint))) + return false; + } + else if (constraint.TypeKind == TypeKind.Interface) + { + if (!typeArgument.AllInterfaces.Any(interfaceSymbol => SymbolEqualityComparer.Default.Equals(interfaceSymbol, constraint))) + return false; + } + } + + // Constructor constraint (optional, not typically needed here) + if (typeParameter.HasConstructorConstraint) + { + // Check for parameterless constructor (simplified) + var hasParameterlessConstructor = typeArgument.GetMembers(".ctor") + .OfType() + .Any(ctor => ctor.Parameters.Length == 0); + if (!hasParameterlessConstructor) + return false; + } + + return true; + } + + static IEnumerable GetBaseTypes(this ITypeSymbol typeSymbol) + { + var currentType = typeSymbol.BaseType; + while (currentType != null && currentType.SpecialType != SpecialType.System_Object) + { + yield return currentType; + currentType = currentType.BaseType; + } + } +} diff --git a/src/DependencyInjection/IncrementalGenerator.cs b/src/DependencyInjection/IncrementalGenerator.cs index 92e2d7b..b8b4b94 100644 --- a/src/DependencyInjection/IncrementalGenerator.cs +++ b/src/DependencyInjection/IncrementalGenerator.cs @@ -452,6 +452,9 @@ void AddServices(IEnumerable services, Compilation compilation foreach (var iface in type.AllInterfaces) { + if (!compilation.HasImplicitConversion(type, iface)) + continue; + var ifaceName = iface.ToFullName(compilation); if (!registered.Contains(ifaceName)) { @@ -476,9 +479,11 @@ void AddServices(IEnumerable services, Compilation compilation baseType = baseType.BaseType; } - foreach (var candidate in candidates.Select(x => iface.ConstructedFrom.Construct(x)) + foreach (var candidate in candidates + .Where(x => x.SatisfiesConstraints(iface.TypeParameters[0])) + .Select(x => iface.ConstructedFrom.Construct(x)) + .Where(x => x != null && compilation.HasImplicitConversion(type, x)) .ToImmutableHashSet(SymbolEqualityComparer.Default) - .Where(x => x != null) .Select(x => x!.ToFullName(compilation))) { if (!registered.Contains(candidate))