diff --git a/src/EFCore.Cosmos/Extensions/CosmosEntityTypeExtensions.cs b/src/EFCore.Cosmos/Extensions/CosmosEntityTypeExtensions.cs index e9125995a27..f207cbed36d 100644 --- a/src/EFCore.Cosmos/Extensions/CosmosEntityTypeExtensions.cs +++ b/src/EFCore.Cosmos/Extensions/CosmosEntityTypeExtensions.cs @@ -155,6 +155,58 @@ public static void SetPartitionKeyPropertyName( => entityType.FindAnnotation(CosmosAnnotationNames.PartitionKeyName) ?.GetConfigurationSource(); + /// + /// Returns the property that is used to store the partition key. + /// + /// The entity type to get the partition key property for. + /// The name of the partition key property. + public static IReadOnlyProperty? GetPartitionKeyProperty(this IReadOnlyEntityType entityType) + { + var partitionKeyPropertyName = entityType.GetPartitionKeyPropertyName(); + return partitionKeyPropertyName == null + ? null + : entityType.FindProperty(partitionKeyPropertyName); + } + + /// + /// Returns the property that is used to store the partition key. + /// + /// The entity type to get the partition key property for. + /// The name of the partition key property. + public static IMutableProperty? GetPartitionKeyProperty(this IMutableEntityType entityType) + { + var partitionKeyPropertyName = entityType.GetPartitionKeyPropertyName(); + return partitionKeyPropertyName == null + ? null + : entityType.FindProperty(partitionKeyPropertyName); + } + + /// + /// Returns the property that is used to store the partition key. + /// + /// The entity type to get the partition key property for. + /// The name of the partition key property. + public static IConventionProperty? GetPartitionKeyProperty(this IConventionEntityType entityType) + { + var partitionKeyPropertyName = entityType.GetPartitionKeyPropertyName(); + return partitionKeyPropertyName == null + ? null + : entityType.FindProperty(partitionKeyPropertyName); + } + + /// + /// Returns the property that is used to store the partition key. + /// + /// The entity type to get the partition key property for. + /// The name of the partition key property. + public static IProperty? GetPartitionKeyProperty(this IEntityType entityType) + { + var partitionKeyPropertyName = entityType.GetPartitionKeyPropertyName(); + return partitionKeyPropertyName == null + ? null + : entityType.FindProperty(partitionKeyPropertyName); + } + /// /// Returns the name of the property that is used to store the ETag. /// diff --git a/src/EFCore.Cosmos/Metadata/Conventions/CosmosKeyDiscoveryConvention.cs b/src/EFCore.Cosmos/Metadata/Conventions/CosmosKeyDiscoveryConvention.cs index ee2e81c9a9a..2c51be86d70 100644 --- a/src/EFCore.Cosmos/Metadata/Conventions/CosmosKeyDiscoveryConvention.cs +++ b/src/EFCore.Cosmos/Metadata/Conventions/CosmosKeyDiscoveryConvention.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Collections.Generic; using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; using Microsoft.EntityFrameworkCore.Metadata.Builders; using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure; @@ -18,7 +17,7 @@ public class CosmosKeyDiscoveryConvention : IEntityTypeAnnotationChangedConvention { /// - /// Creates a new instance of . + /// Creates a new instance of . /// /// Parameter object containing dependencies for this convention. public CosmosKeyDiscoveryConvention(ProviderConventionSetBuilderDependencies dependencies) diff --git a/src/EFCore.Cosmos/Metadata/Conventions/CosmosManyToManyJoinEntityTypeConvention.cs b/src/EFCore.Cosmos/Metadata/Conventions/CosmosManyToManyJoinEntityTypeConvention.cs new file mode 100644 index 00000000000..f62c738ba28 --- /dev/null +++ b/src/EFCore.Cosmos/Metadata/Conventions/CosmosManyToManyJoinEntityTypeConvention.cs @@ -0,0 +1,195 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; +using Microsoft.EntityFrameworkCore.Metadata.Builders; +using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata.Internal; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.Metadata.Conventions +{ + /// + /// A convention that creates a join entity type for a many-to-many relationship + /// and adds a partition key to it if the related types share one. + /// + public class CosmosManyToManyJoinEntityTypeConvention : + ManyToManyJoinEntityTypeConvention, + IEntityTypeAnnotationChangedConvention + { + /// + /// Creates a new instance of . + /// + /// Parameter object containing dependencies for this convention. + public CosmosManyToManyJoinEntityTypeConvention(ProviderConventionSetBuilderDependencies dependencies) + : base(dependencies) + { + } + + /// + /// Called after an annotation is changed on an entity type. + /// + /// The builder for the entity type. + /// The annotation name. + /// The new annotation. + /// The old annotation. + /// Additional information associated with convention execution. + public virtual void ProcessEntityTypeAnnotationChanged( + IConventionEntityTypeBuilder entityTypeBuilder, + string name, + IConventionAnnotation? annotation, + IConventionAnnotation? oldAnnotation, + IConventionContext context) + { + Check.NotNull(entityTypeBuilder, nameof(entityTypeBuilder)); + Check.NotEmpty(name, nameof(name)); + Check.NotNull(context, nameof(context)); + + if (name == CosmosAnnotationNames.PartitionKeyName + || name == CosmosAnnotationNames.ContainerName) + { + foreach (var skipNavigation in entityTypeBuilder.Metadata.GetSkipNavigations()) + { + ProcessJoinPartitionKey(skipNavigation); + } + } + } + + /// + public override void ProcessSkipNavigationForeignKeyChanged( + IConventionSkipNavigationBuilder skipNavigationBuilder, + IConventionForeignKey? foreignKey, + IConventionForeignKey? oldForeignKey, + IConventionContext context) + { + base.ProcessSkipNavigationForeignKeyChanged(skipNavigationBuilder, foreignKey, oldForeignKey, context); + + if (oldForeignKey != null) + { + ProcessJoinPartitionKey(skipNavigationBuilder.Metadata); + } + } + + /// + protected override void CreateJoinEntityType(string joinEntityTypeName, IConventionSkipNavigation skipNavigation) + { + if (ShouldSharePartitionKey(skipNavigation)) + { + var model = skipNavigation.DeclaringEntityType.Model; + var joinEntityTypeBuilder = model.Builder.SharedTypeEntity(joinEntityTypeName, typeof(Dictionary))!; + ConfigurePartitionKeyJoinEntityType(skipNavigation, joinEntityTypeBuilder); + } + else + { + base.CreateJoinEntityType(joinEntityTypeName, skipNavigation); + } + } + + private void ConfigurePartitionKeyJoinEntityType(IConventionSkipNavigation skipNavigation, IConventionEntityTypeBuilder joinEntityTypeBuilder) + { + var principalPartitionKey = skipNavigation.DeclaringEntityType.GetPartitionKeyProperty()!; + var partitionKey = joinEntityTypeBuilder.Property(principalPartitionKey.ClrType, principalPartitionKey.Name)!.Metadata; + joinEntityTypeBuilder.HasPartitionKey(partitionKey.Name); + + CreateSkipNavigationForeignKey(skipNavigation, joinEntityTypeBuilder, partitionKey); + CreateSkipNavigationForeignKey(skipNavigation.Inverse!, joinEntityTypeBuilder, partitionKey); + } + + private IConventionForeignKey CreateSkipNavigationForeignKey( + IConventionSkipNavigation skipNavigation, + IConventionEntityTypeBuilder joinEntityTypeBuilder, + IConventionProperty partitionKeyProperty) + { + if (skipNavigation.ForeignKey != null + && !skipNavigation.Builder.CanSetForeignKey(null)) + { + return skipNavigation.ForeignKey; + } + + var principalKey = skipNavigation.DeclaringEntityType.FindPrimaryKey(); + if (principalKey == null + || principalKey.Properties.All(p => p.Name != partitionKeyProperty.Name)) + { + return CreateSkipNavigationForeignKey(skipNavigation, joinEntityTypeBuilder); + } + + if (skipNavigation.ForeignKey?.Properties.Contains(partitionKeyProperty) == true) + { + return skipNavigation.ForeignKey; + } + + var dependentProperties = new IConventionProperty[principalKey.Properties.Count]; + for (var i = 0; i < principalKey.Properties.Count; i++) + { + var principalProperty = principalKey.Properties[i]; + if (principalProperty.Name == partitionKeyProperty.Name) + { + dependentProperties[i] = partitionKeyProperty; + } + else + { + dependentProperties[i] = joinEntityTypeBuilder.CreateUniqueProperty( + principalProperty.ClrType, principalProperty.Name, required: true)!.Metadata; + } + } + + var foreignKey = joinEntityTypeBuilder.HasRelationship(skipNavigation.DeclaringEntityType, dependentProperties, principalKey)! + .IsUnique(false)! + .Metadata; + + skipNavigation.Builder.HasForeignKey(foreignKey); + + return foreignKey; + } + + private void ProcessJoinPartitionKey(IConventionSkipNavigation skipNavigation) + { + var inverseSkipNavigation = skipNavigation.Inverse; + if (skipNavigation.JoinEntityType != null + && skipNavigation.IsCollection + && inverseSkipNavigation != null + && inverseSkipNavigation.IsCollection + && inverseSkipNavigation.JoinEntityType == skipNavigation.JoinEntityType) + { + var joinEntityType = skipNavigation.JoinEntityType; + var joinEntityTypeBuilder = joinEntityType.Builder; + if (ShouldSharePartitionKey(skipNavigation)) + { + var principalPartitionKey = skipNavigation.DeclaringEntityType.GetPartitionKeyProperty()!; + var partitionKey = joinEntityType.GetPartitionKeyProperty(); + if ((partitionKey != null + && (!joinEntityTypeBuilder.CanSetPartitionKey(principalPartitionKey.Name) + || (skipNavigation.ForeignKey!.Properties.Contains(partitionKey) + && inverseSkipNavigation.ForeignKey!.Properties.Contains(partitionKey)))) + || !skipNavigation.Builder.CanSetForeignKey(null) + || !inverseSkipNavigation.Builder.CanSetForeignKey(null)) + { + return; + } + + ConfigurePartitionKeyJoinEntityType(skipNavigation, joinEntityTypeBuilder); + } + else + { + var partitionKey = joinEntityType.GetPartitionKeyProperty(); + if (partitionKey != null + && joinEntityTypeBuilder.HasPartitionKey(null) != null + && ((skipNavigation.ForeignKey!.Properties.Contains(partitionKey) + && skipNavigation.Builder.CanSetForeignKey(null)) + || (inverseSkipNavigation.ForeignKey!.Properties.Contains(partitionKey) + && inverseSkipNavigation.Builder.CanSetForeignKey(null)))) + { + CreateSkipNavigationForeignKey(skipNavigation, joinEntityTypeBuilder); + CreateSkipNavigationForeignKey(inverseSkipNavigation, joinEntityTypeBuilder); + } + } + } + } + + private bool ShouldSharePartitionKey(IConventionSkipNavigation skipNavigation) + => skipNavigation.DeclaringEntityType.GetContainer() == skipNavigation.TargetEntityType.GetContainer() + && skipNavigation.DeclaringEntityType.GetPartitionKeyPropertyName() != null + && skipNavigation.Inverse?.DeclaringEntityType.GetPartitionKeyPropertyName() + == skipNavigation.DeclaringEntityType.GetPartitionKeyPropertyName(); + } +} diff --git a/src/EFCore.Cosmos/Metadata/Conventions/Internal/CosmosConventionSetBuilder.cs b/src/EFCore.Cosmos/Metadata/Conventions/Internal/CosmosConventionSetBuilder.cs index 5371cfbb05a..5873eb983af 100644 --- a/src/EFCore.Cosmos/Metadata/Conventions/Internal/CosmosConventionSetBuilder.cs +++ b/src/EFCore.Cosmos/Metadata/Conventions/Internal/CosmosConventionSetBuilder.cs @@ -100,9 +100,19 @@ public override ConventionSet CreateConventionSet() ReplaceConvention(conventionSet.NavigationRemovedConventions, relationshipDiscoveryConvention); + ManyToManyJoinEntityTypeConvention manyToManyJoinEntityTypeConvention = new CosmosManyToManyJoinEntityTypeConvention(Dependencies); + ReplaceConvention(conventionSet.SkipNavigationAddedConventions, manyToManyJoinEntityTypeConvention); + + ReplaceConvention(conventionSet.SkipNavigationRemovedConventions, manyToManyJoinEntityTypeConvention); + + ReplaceConvention(conventionSet.SkipNavigationInverseChangedConventions, manyToManyJoinEntityTypeConvention); + + ReplaceConvention(conventionSet.SkipNavigationForeignKeyChangedConventions, manyToManyJoinEntityTypeConvention); + conventionSet.EntityTypeAnnotationChangedConventions.Add(discriminatorConvention); conventionSet.EntityTypeAnnotationChangedConventions.Add(storeKeyConvention); conventionSet.EntityTypeAnnotationChangedConventions.Add((CosmosKeyDiscoveryConvention)keyDiscoveryConvention); + conventionSet.EntityTypeAnnotationChangedConventions.Add((CosmosManyToManyJoinEntityTypeConvention)manyToManyJoinEntityTypeConvention); ReplaceConvention(conventionSet.PropertyAddedConventions, keyDiscoveryConvention); diff --git a/src/EFCore/Metadata/Conventions/ManyToManyJoinEntityTypeConvention.cs b/src/EFCore/Metadata/Conventions/ManyToManyJoinEntityTypeConvention.cs index 72934753e46..41e6b0370d6 100644 --- a/src/EFCore/Metadata/Conventions/ManyToManyJoinEntityTypeConvention.cs +++ b/src/EFCore/Metadata/Conventions/ManyToManyJoinEntityTypeConvention.cs @@ -1,8 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Linq; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata.Builders; using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure; @@ -41,7 +39,7 @@ public virtual void ProcessSkipNavigationAdded( IConventionSkipNavigationBuilder skipNavigationBuilder, IConventionContext context) { - CreateJoinEntityType(skipNavigationBuilder); + TryCreateJoinEntityType(skipNavigationBuilder); } /// @@ -51,7 +49,7 @@ public virtual void ProcessSkipNavigationInverseChanged( IConventionSkipNavigation? oldInverse, IConventionContext context) { - CreateJoinEntityType(skipNavigationBuilder); + TryCreateJoinEntityType(skipNavigationBuilder); } /// @@ -87,27 +85,44 @@ public virtual void ProcessSkipNavigationRemoved( } } - private void CreateJoinEntityType(IConventionSkipNavigationBuilder skipNavigationBuilder) + private void TryCreateJoinEntityType(IConventionSkipNavigationBuilder skipNavigationBuilder) { - var skipNavigation = (SkipNavigation)skipNavigationBuilder.Metadata; - var inverseSkipNavigation = skipNavigation.Inverse; - if (skipNavigation.ForeignKey != null - || !skipNavigation.IsCollection - || inverseSkipNavigation == null - || inverseSkipNavigation.ForeignKey != null - || !inverseSkipNavigation.IsCollection) + var skipNavigation = skipNavigationBuilder.Metadata; + if (ShouldCreateJoinType(skipNavigation)) { - return; + CreateJoinEntityType(GenerateJoinTypeName(skipNavigation), skipNavigation); } + } - Check.DebugAssert( - inverseSkipNavigation.Inverse == skipNavigation, + /// + /// Checks whether a new join antity type is needed. + /// + /// The target skip navigation. + /// A value indicating whether a new join antity type is needed. + protected virtual bool ShouldCreateJoinType(IConventionSkipNavigation skipNavigation) + { + var inverseSkipNavigation = skipNavigation.Inverse; + return skipNavigation.ForeignKey == null + && skipNavigation.IsCollection + && inverseSkipNavigation != null + && inverseSkipNavigation.ForeignKey == null + && inverseSkipNavigation.IsCollection; + } + + /// + /// Generates a unique name for the new joint entity type. + /// + /// The target skip navigation. + /// A unique entity type name. + protected virtual string GenerateJoinTypeName(IConventionSkipNavigation skipNavigation) + { + var inverseSkipNavigation = skipNavigation.Inverse; + Check.DebugAssert(inverseSkipNavigation?.Inverse == skipNavigation, "Inverse's inverse should be the original skip navigation"); var declaringEntityType = skipNavigation.DeclaringEntityType; var inverseEntityType = inverseSkipNavigation.DeclaringEntityType; var model = declaringEntityType.Model; - var joinEntityTypeName = declaringEntityType.ShortName(); var inverseName = inverseEntityType.ShortName(); joinEntityTypeName = StringComparer.Ordinal.Compare(joinEntityTypeName, inverseName) < 0 @@ -123,26 +138,49 @@ private void CreateJoinEntityType(IConventionSkipNavigationBuilder skipNavigatio int.MaxValue); } - var joinEntityTypeBuilder = model.Builder.SharedTypeEntity( - joinEntityTypeName, Model.DefaultPropertyBagType, ConfigurationSource.Convention)!; + return joinEntityTypeName; + } - var leftForeignKey = CreateSkipNavigationForeignKey(skipNavigation, joinEntityTypeBuilder); - var rightForeignKey = CreateSkipNavigationForeignKey(inverseSkipNavigation, joinEntityTypeBuilder); + /// + /// Create a join entity type and configures the corresponding foreign keys. + /// + /// The name for the new entity type. + /// The target skip navigation. + protected virtual void CreateJoinEntityType( + string joinEntityTypeName, + IConventionSkipNavigation skipNavigation) + { + var model = skipNavigation.DeclaringEntityType.Model; - skipNavigation.Builder.HasForeignKey(leftForeignKey, ConfigurationSource.Convention); - inverseSkipNavigation.Builder.HasForeignKey(rightForeignKey, ConfigurationSource.Convention); + var joinEntityTypeBuilder = model.Builder.SharedTypeEntity(joinEntityTypeName, Model.DefaultPropertyBagType)!; + + var inverseSkipNavigation = skipNavigation.Inverse!; + CreateSkipNavigationForeignKey(skipNavigation, joinEntityTypeBuilder); + CreateSkipNavigationForeignKey(inverseSkipNavigation, joinEntityTypeBuilder); } - private static ForeignKey CreateSkipNavigationForeignKey( - SkipNavigation skipNavigation, - InternalEntityTypeBuilder joinEntityTypeBuilder) - => joinEntityTypeBuilder - .HasRelationship( - skipNavigation.DeclaringEntityType, - ConfigurationSource.Convention, - required: true, - skipNavigation.Inverse!.Name)! - .IsUnique(false, ConfigurationSource.Convention)! - .Metadata; + /// + /// Creates a foreign key on the given entity type to be used by the given skip navigation. + /// + /// The target skip navigation. + /// The join entity type. + /// The created foreign key. + protected virtual IConventionForeignKey CreateSkipNavigationForeignKey( + IConventionSkipNavigation skipNavigation, + IConventionEntityTypeBuilder joinEntityTypeBuilder) + { + var foreignKey = ((InternalEntityTypeBuilder)joinEntityTypeBuilder) + .HasRelationship( + (EntityType)skipNavigation.DeclaringEntityType, + ConfigurationSource.Convention, + required: true, + skipNavigation.Inverse!.Name)! + .IsUnique(false, ConfigurationSource.Convention)! + .Metadata; + + skipNavigation.Builder.HasForeignKey(foreignKey); + + return foreignKey; + } } } diff --git a/test/EFCore.Cosmos.Tests/ModelBuilding/CosmosModelBuilderGenericTest.cs b/test/EFCore.Cosmos.Tests/ModelBuilding/CosmosModelBuilderGenericTest.cs index e19682fec6b..d93401b2221 100644 --- a/test/EFCore.Cosmos.Tests/ModelBuilding/CosmosModelBuilderGenericTest.cs +++ b/test/EFCore.Cosmos.Tests/ModelBuilding/CosmosModelBuilderGenericTest.cs @@ -1,9 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Collections.Generic; -using System.Linq; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Metadata.Conventions; using Microsoft.EntityFrameworkCore.TestUtilities; @@ -366,10 +363,92 @@ public virtual void Can_use_shared_type_as_join_entity_with_partition_keys() Assert.Equal(2, joinType.GetForeignKeys().Count()); Assert.Equal(3, joinType.FindPrimaryKey().Properties.Count); Assert.Equal(6, joinType.GetProperties().Count()); + Assert.Equal("DbContext", joinType.GetContainer()); Assert.Equal("PartitionId", joinType.GetPartitionKeyPropertyName()); Assert.Equal("PartitionId", joinType.FindPrimaryKey().Properties.Last().Name); } + [ConditionalFact] + public virtual void Can_use_implicit_join_entity_with_partition_keys() + { + var modelBuilder = CreateModelBuilder(); + + modelBuilder.Ignore(); + modelBuilder.Ignore(); + + modelBuilder.Entity(mb => + { + mb.Ignore(e => e.Dependents); + mb.Property("PartitionId"); + mb.HasPartitionKey("PartitionId"); + }); + + modelBuilder.Entity(mb => + { + mb.Property("PartitionId"); + mb.HasPartitionKey("PartitionId"); + }); + + modelBuilder.Entity() + .HasMany(e => e.Dependents) + .WithMany(e => e.ManyToManyPrincipals); + + var model = modelBuilder.FinalizeModel(); + + var joinType = model.FindEntityType("ManyToManyNavPrincipalNavDependent"); + Assert.NotNull(joinType); + Assert.Equal(2, joinType.GetForeignKeys().Count()); + Assert.Equal(3, joinType.FindPrimaryKey().Properties.Count); + Assert.Equal(6, joinType.GetProperties().Count()); + Assert.Equal("DbContext", joinType.GetContainer()); + Assert.Equal("PartitionId", joinType.GetPartitionKeyPropertyName()); + Assert.Equal("PartitionId", joinType.FindPrimaryKey().Properties.Last().Name); + } + + [ConditionalFact] + public virtual void Can_use_implicit_join_entity_with_partition_keys_changed() + { + var modelBuilder = CreateModelBuilder(); + + modelBuilder.Ignore(); + modelBuilder.Ignore(); + + modelBuilder.Entity(mb => + { + mb.Property("PartitionId"); + mb.HasPartitionKey("PartitionId"); + }); + + modelBuilder.Entity(mb => + { + mb.Property("PartitionId"); + mb.HasPartitionKey("PartitionId"); + }); + + modelBuilder.Entity(mb => + { + mb.Property("Partition2Id"); + mb.HasPartitionKey("Partition2Id"); + }); + + modelBuilder.Entity(mb => + { + mb.Property("Partition2Id"); + mb.HasPartitionKey("Partition2Id"); + }); + + var model = modelBuilder.FinalizeModel(); + + var joinType = model.FindEntityType("ManyToManyNavPrincipalNavDependent"); + Assert.NotNull(joinType); + Assert.Equal(2, joinType.GetForeignKeys().Count()); + Assert.Equal(3, joinType.FindPrimaryKey().Properties.Count); + Assert.Equal(6, joinType.GetProperties().Count()); + Assert.Equal("DbContext", joinType.GetContainer()); + Assert.Equal("Partition2Id", joinType.GetPartitionKeyPropertyName()); + Assert.Equal("Partition2Id", joinType.FindPrimaryKey().Properties.Last().Name); + } + public override void Join_type_is_automatically_configured_by_convention() { // Many-to-many not configured by convention on Cosmos