Skip to content

Commit

Permalink
Query: Use GetDiscriminatorValue for TPT and TPC
Browse files Browse the repository at this point in the history
Add model validation when values are not unique
Resolves #28054
  • Loading branch information
smitpatel committed May 20, 2022
1 parent 49328ec commit f2ddc94
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 19 deletions.
20 changes: 19 additions & 1 deletion src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,24 @@ protected override void ValidateInheritanceMapping(

ValidateNonTphMapping(entityType, forTables: false);
ValidateNonTphMapping(entityType, forTables: true);

var derivedTypes = entityType.GetDerivedTypesInclusive().ToList();
var discriminatorValues = new Dictionary<object, IEntityType>();
foreach (var derivedType in derivedTypes)
{
if (!derivedType.ClrType.IsInstantiable())
{
continue;
}
var discriminatorValue = derivedType.ShortName();

if (discriminatorValues.TryGetValue(discriminatorValue, out var duplicateEntityType))
{
throw new InvalidOperationException("TBD");
}

discriminatorValues[discriminatorValue] = derivedType;
}
}
}

Expand Down Expand Up @@ -1469,7 +1487,7 @@ private static void ValidateNonTphMapping(IEntityType rootEntityType, bool forTa
{
return;
}

var internalForeignKey = rootEntityType.FindRowInternalForeignKeys(storeObject.Value).FirstOrDefault();
if (internalForeignKey != null
&& derivedTypes.Count > 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,6 @@ public static IEnumerable<ITableMappingBase> GetViewOrTableMappings(this IEntity
?? entityType.FindRuntimeAnnotationValue(RelationalAnnotationNames.TableMappings))
?? Enumerable.Empty<ITableMappingBase>();

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static IReadOnlyList<string> GetTptDiscriminatorValues(this IReadOnlyEntityType entityType)
=> entityType.GetConcreteDerivedTypesInclusive().Select(et => et.ShortName()).ToList();

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore.Relational/Query/EntityProjectionExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public virtual EntityProjectionExpression UpdateEntityType(IEntityType derivedTy
var discriminatorExpression = DiscriminatorExpression;
if (DiscriminatorExpression is CaseExpression caseExpression)
{
var entityTypesToSelect = derivedType.GetTptDiscriminatorValues();
var entityTypesToSelect = derivedType.GetConcreteDerivedTypesInclusive().Select(e => e.GetDiscriminatorValue()).ToList();
var whenClauses = caseExpression.WhenClauses
.Where(wc => entityTypesToSelect.Contains((string)((SqlConstantExpression)wc.Result).Value!))
.ToList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
}

// TPT or TPC
var discriminatorValues = derivedType.GetTptDiscriminatorValues();
var discriminatorValues = derivedType.GetConcreteDerivedTypesInclusive().Select(e => e.GetDiscriminatorValue()).ToList();
if (entityReferenceExpression.SubqueryEntity != null)
{
var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression;
Expand Down
3 changes: 2 additions & 1 deletion src/EFCore/Metadata/IReadOnlyEntityType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ bool GetIsDiscriminatorMappingComplete()
/// </summary>
/// <returns>The discriminator value for this entity type.</returns>
object? GetDiscriminatorValue()
=> this[CoreAnnotationNames.DiscriminatorValue];
=> this[CoreAnnotationNames.DiscriminatorValue]
?? (GetDiscriminatorPropertyName() == null ? (object)ShortName() : null);

/// <summary>
/// Gets all types in the model from which a given entity type derives, starting with the root.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,9 @@ public virtual void Passes_for_incompatible_uniquified_check_constraints_with_sh
var modelBuilder = CreateConventionalModelBuilder();

modelBuilder.Entity<A>().HasOne<B>().WithOne(b => b.A).HasForeignKey<A>(a => a.Id).HasPrincipalKey<B>(b => b.Id).IsRequired();
modelBuilder.Entity<A>().HasCheckConstraint("CK_Table_SomeCK", "Id > 0");
modelBuilder.Entity<A>().HasCheckConstraint("CK_Table_SomeCK", "Id > 0");
modelBuilder.Entity<A>().ToTable("Table");
modelBuilder.Entity<B>().HasCheckConstraint("CK_Table_SomeCK", "Id > 10");
modelBuilder.Entity<B>().HasCheckConstraint("CK_Table_SomeCK", "Id > 10");
modelBuilder.Entity<B>().ToTable("Table");

var model = Validate(modelBuilder);
Expand All @@ -502,9 +502,9 @@ public virtual void Passes_for_compatible_shared_check_constraints_with_shared_t
var modelBuilder = CreateConventionalModelBuilder();

modelBuilder.Entity<A>().HasOne<B>().WithOne(b => b.A).HasForeignKey<A>(a => a.Id).HasPrincipalKey<B>(b => b.Id).IsRequired();
modelBuilder.Entity<A>().HasCheckConstraint("CK_Table_SomeCK", "Id > 0");
modelBuilder.Entity<A>().HasCheckConstraint("CK_Table_SomeCK", "Id > 0");
modelBuilder.Entity<A>().ToTable("Table");
modelBuilder.Entity<B>().HasCheckConstraint("CK_Table_SomeCK", "Id > 0");
modelBuilder.Entity<B>().HasCheckConstraint("CK_Table_SomeCK", "Id > 0");
modelBuilder.Entity<B>().ToTable("Table");

var model = Validate(modelBuilder);
Expand Down Expand Up @@ -1853,7 +1853,7 @@ public virtual void Detects_clashing_entity_types_in_views_TPC()
modelBuilder.Entity<Animal>().UseTpcMappingStrategy();
modelBuilder.Entity<Cat>().ToTable("Cat").ToView("Cat");
modelBuilder.Entity<Dog>().ToTable("Dog").ToView("Cat");

VerifyError(
RelationalStrings.NonTphViewClash(nameof(Dog), nameof(Cat), "Cat"),
modelBuilder);
Expand Down Expand Up @@ -1945,7 +1945,7 @@ public virtual void Passes_on_valid_view_sharing_with_TPC()

Validate(modelBuilder);
}

[ConditionalFact]
public virtual void Detects_view_sharing_on_base_with_TPC()
{
Expand Down Expand Up @@ -2511,6 +2511,38 @@ public virtual void Detects_triggers_on_unmapped_entity_types()
VerifyError(RelationalStrings.TriggerOnUnmappedEntityType("Animal_Trigger", "Animal"), modelBuilder);
}

[ConditionalFact]
public virtual void Throws_when_non_tph_entity_type_short_names_are_not_unique()
{
var modelBuilder = CreateConventionalModelBuilder();
modelBuilder.Entity<TpcBase>().UseTpcMappingStrategy();
modelBuilder.Entity<Outer.TpcDerived>().ToTable("TpcDerived1");
modelBuilder.Entity<Outer2.TpcDerived>().ToTable("TpcDerived2");

VerifyError("TBD", modelBuilder);
}

private class TpcBase
{
public int Id { get; set; }
}

private class Outer
{
public class TpcDerived : TpcBase
{
public string Value { get; set; }
}
}

private class Outer2
{
public class TpcDerived : TpcBase
{
public string Value { get; set; }
}
}

protected override void SetBaseType(IMutableEntityType entityType, IMutableEntityType baseEntityType)
{
base.SetBaseType(entityType, baseEntityType);
Expand Down

0 comments on commit f2ddc94

Please sign in to comment.