From f2ddc94fad5665130ffd09448ccc4882e6acacee Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Fri, 20 May 2022 13:51:07 -0700 Subject: [PATCH] Query: Use GetDiscriminatorValue for TPT and TPC Add model validation when values are not unique Resolves #28054 --- .../RelationalModelValidator.cs | 20 ++++++++- .../RelationalEntityTypeExtensions.cs | 9 ---- .../Query/EntityProjectionExpression.cs | 2 +- ...lationalSqlTranslatingExpressionVisitor.cs | 2 +- src/EFCore/Metadata/IReadOnlyEntityType.cs | 3 +- .../RelationalModelValidatorTest.cs | 44 ++++++++++++++++--- 6 files changed, 61 insertions(+), 19 deletions(-) diff --git a/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs b/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs index 53c17d48f14..9c913997cf6 100644 --- a/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs +++ b/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs @@ -1391,6 +1391,24 @@ protected override void ValidateInheritanceMapping( ValidateNonTphMapping(entityType, forTables: false); ValidateNonTphMapping(entityType, forTables: true); + + var derivedTypes = entityType.GetDerivedTypesInclusive().ToList(); + var discriminatorValues = new Dictionary(); + foreach (var derivedType in derivedTypes) + { + if (!derivedType.ClrType.IsInstantiable()) + { + continue; + } + var discriminatorValue = derivedType.ShortName(); + + if (discriminatorValues.TryGetValue(discriminatorValue, out var duplicateEntityType)) + { + throw new InvalidOperationException("TBD"); + } + + discriminatorValues[discriminatorValue] = derivedType; + } } } @@ -1469,7 +1487,7 @@ private static void ValidateNonTphMapping(IEntityType rootEntityType, bool forTa { return; } - + var internalForeignKey = rootEntityType.FindRowInternalForeignKeys(storeObject.Value).FirstOrDefault(); if (internalForeignKey != null && derivedTypes.Count > 1 diff --git a/src/EFCore.Relational/Metadata/Internal/RelationalEntityTypeExtensions.cs b/src/EFCore.Relational/Metadata/Internal/RelationalEntityTypeExtensions.cs index a47423326c0..8bab3bf5e09 100644 --- a/src/EFCore.Relational/Metadata/Internal/RelationalEntityTypeExtensions.cs +++ b/src/EFCore.Relational/Metadata/Internal/RelationalEntityTypeExtensions.cs @@ -31,15 +31,6 @@ public static IEnumerable GetViewOrTableMappings(this IEntity ?? entityType.FindRuntimeAnnotationValue(RelationalAnnotationNames.TableMappings)) ?? Enumerable.Empty(); - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public static IReadOnlyList GetTptDiscriminatorValues(this IReadOnlyEntityType entityType) - => entityType.GetConcreteDerivedTypesInclusive().Select(et => et.ShortName()).ToList(); - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in diff --git a/src/EFCore.Relational/Query/EntityProjectionExpression.cs b/src/EFCore.Relational/Query/EntityProjectionExpression.cs index 15e8efdd055..bab17066aae 100644 --- a/src/EFCore.Relational/Query/EntityProjectionExpression.cs +++ b/src/EFCore.Relational/Query/EntityProjectionExpression.cs @@ -123,7 +123,7 @@ public virtual EntityProjectionExpression UpdateEntityType(IEntityType derivedTy var discriminatorExpression = DiscriminatorExpression; if (DiscriminatorExpression is CaseExpression caseExpression) { - var entityTypesToSelect = derivedType.GetTptDiscriminatorValues(); + var entityTypesToSelect = derivedType.GetConcreteDerivedTypesInclusive().Select(e => e.GetDiscriminatorValue()).ToList(); var whenClauses = caseExpression.WhenClauses .Where(wc => entityTypesToSelect.Contains((string)((SqlConstantExpression)wc.Result).Value!)) .ToList(); diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index dd6922d2749..aff2c0ca956 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -1066,7 +1066,7 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp } // TPT or TPC - var discriminatorValues = derivedType.GetTptDiscriminatorValues(); + var discriminatorValues = derivedType.GetConcreteDerivedTypesInclusive().Select(e => e.GetDiscriminatorValue()).ToList(); if (entityReferenceExpression.SubqueryEntity != null) { var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression; diff --git a/src/EFCore/Metadata/IReadOnlyEntityType.cs b/src/EFCore/Metadata/IReadOnlyEntityType.cs index f3e53bfee95..610b21828d0 100644 --- a/src/EFCore/Metadata/IReadOnlyEntityType.cs +++ b/src/EFCore/Metadata/IReadOnlyEntityType.cs @@ -70,7 +70,8 @@ bool GetIsDiscriminatorMappingComplete() /// /// The discriminator value for this entity type. object? GetDiscriminatorValue() - => this[CoreAnnotationNames.DiscriminatorValue]; + => this[CoreAnnotationNames.DiscriminatorValue] + ?? (GetDiscriminatorPropertyName() == null ? (object)ShortName() : null); /// /// Gets all types in the model from which a given entity type derives, starting with the root. diff --git a/test/EFCore.Relational.Tests/Infrastructure/RelationalModelValidatorTest.cs b/test/EFCore.Relational.Tests/Infrastructure/RelationalModelValidatorTest.cs index 5d619bab8b6..0b56eeeacc5 100644 --- a/test/EFCore.Relational.Tests/Infrastructure/RelationalModelValidatorTest.cs +++ b/test/EFCore.Relational.Tests/Infrastructure/RelationalModelValidatorTest.cs @@ -485,9 +485,9 @@ public virtual void Passes_for_incompatible_uniquified_check_constraints_with_sh var modelBuilder = CreateConventionalModelBuilder(); modelBuilder.Entity().HasOne().WithOne(b => b.A).HasForeignKey(a => a.Id).HasPrincipalKey(b => b.Id).IsRequired(); - modelBuilder.Entity().HasCheckConstraint("CK_Table_SomeCK", "Id > 0"); + modelBuilder.Entity().HasCheckConstraint("CK_Table_SomeCK", "Id > 0"); modelBuilder.Entity().ToTable("Table"); - modelBuilder.Entity().HasCheckConstraint("CK_Table_SomeCK", "Id > 10"); + modelBuilder.Entity().HasCheckConstraint("CK_Table_SomeCK", "Id > 10"); modelBuilder.Entity().ToTable("Table"); var model = Validate(modelBuilder); @@ -502,9 +502,9 @@ public virtual void Passes_for_compatible_shared_check_constraints_with_shared_t var modelBuilder = CreateConventionalModelBuilder(); modelBuilder.Entity().HasOne().WithOne(b => b.A).HasForeignKey(a => a.Id).HasPrincipalKey(b => b.Id).IsRequired(); - modelBuilder.Entity().HasCheckConstraint("CK_Table_SomeCK", "Id > 0"); + modelBuilder.Entity().HasCheckConstraint("CK_Table_SomeCK", "Id > 0"); modelBuilder.Entity().ToTable("Table"); - modelBuilder.Entity().HasCheckConstraint("CK_Table_SomeCK", "Id > 0"); + modelBuilder.Entity().HasCheckConstraint("CK_Table_SomeCK", "Id > 0"); modelBuilder.Entity().ToTable("Table"); var model = Validate(modelBuilder); @@ -1853,7 +1853,7 @@ public virtual void Detects_clashing_entity_types_in_views_TPC() modelBuilder.Entity().UseTpcMappingStrategy(); modelBuilder.Entity().ToTable("Cat").ToView("Cat"); modelBuilder.Entity().ToTable("Dog").ToView("Cat"); - + VerifyError( RelationalStrings.NonTphViewClash(nameof(Dog), nameof(Cat), "Cat"), modelBuilder); @@ -1945,7 +1945,7 @@ public virtual void Passes_on_valid_view_sharing_with_TPC() Validate(modelBuilder); } - + [ConditionalFact] public virtual void Detects_view_sharing_on_base_with_TPC() { @@ -2511,6 +2511,38 @@ public virtual void Detects_triggers_on_unmapped_entity_types() VerifyError(RelationalStrings.TriggerOnUnmappedEntityType("Animal_Trigger", "Animal"), modelBuilder); } + [ConditionalFact] + public virtual void Throws_when_non_tph_entity_type_short_names_are_not_unique() + { + var modelBuilder = CreateConventionalModelBuilder(); + modelBuilder.Entity().UseTpcMappingStrategy(); + modelBuilder.Entity().ToTable("TpcDerived1"); + modelBuilder.Entity().ToTable("TpcDerived2"); + + VerifyError("TBD", modelBuilder); + } + + private class TpcBase + { + public int Id { get; set; } + } + + private class Outer + { + public class TpcDerived : TpcBase + { + public string Value { get; set; } + } + } + + private class Outer2 + { + public class TpcDerived : TpcBase + { + public string Value { get; set; } + } + } + protected override void SetBaseType(IMutableEntityType entityType, IMutableEntityType baseEntityType) { base.SetBaseType(entityType, baseEntityType);