From 3579d257a8b52d4500404ffb84341342f63f1d54 Mon Sep 17 00:00:00 2001 From: Arthur Vickers Date: Sat, 19 Nov 2022 10:27:12 +0000 Subject: [PATCH] Fix value generation with converter between different numeric types (#29601) --- .../TemporaryNumberValueGeneratorFactory.cs | 131 ++++++++------- .../StoreGeneratedTestBase.cs | 149 +++++++++++++++++ .../StoreGeneratedSqlServerTest.cs | 152 ++++++++++++++++++ 3 files changed, 377 insertions(+), 55 deletions(-) diff --git a/src/EFCore/ValueGeneration/TemporaryNumberValueGeneratorFactory.cs b/src/EFCore/ValueGeneration/TemporaryNumberValueGeneratorFactory.cs index a4624fa88f9..18550c75e18 100644 --- a/src/EFCore/ValueGeneration/TemporaryNumberValueGeneratorFactory.cs +++ b/src/EFCore/ValueGeneration/TemporaryNumberValueGeneratorFactory.cs @@ -27,70 +27,91 @@ public class TemporaryNumberValueGeneratorFactory : ValueGeneratorFactory /// The newly created value generator. public override ValueGenerator Create(IProperty property, IEntityType entityType) { - var type = (property.GetValueConverter()?.ProviderClrType ?? property.GetTypeMapping().ClrType).UnwrapEnumType(); + var type = property.GetTypeMapping().ClrType.UnwrapEnumType(); - if (type == typeof(int)) + var generator = TryCreate(); + if (generator != null) { - return new TemporaryIntValueGenerator(); + return generator; } - if (type == typeof(long)) + type = property.GetValueConverter()?.ProviderClrType.UnwrapEnumType(); + if (type != null) { - return new TemporaryLongValueGenerator(); - } - - if (type == typeof(short)) - { - return new TemporaryShortValueGenerator(); - } - - if (type == typeof(byte)) - { - return new TemporaryByteValueGenerator(); - } - - if (type == typeof(char)) - { - return new TemporaryCharValueGenerator(); - } - - if (type == typeof(ulong)) - { - return new TemporaryULongValueGenerator(); - } - - if (type == typeof(uint)) - { - return new TemporaryUIntValueGenerator(); - } - - if (type == typeof(ushort)) - { - return new TemporaryUShortValueGenerator(); - } - - if (type == typeof(sbyte)) - { - return new TemporarySByteValueGenerator(); - } - - if (type == typeof(decimal)) - { - return new TemporaryDecimalValueGenerator(); - } - - if (type == typeof(float)) - { - return new TemporaryFloatValueGenerator(); - } - - if (type == typeof(double)) - { - return new TemporaryDoubleValueGenerator(); + generator = TryCreate(); + if (generator != null) + { + return generator; + } } throw new ArgumentException( CoreStrings.InvalidValueGeneratorFactoryProperty( nameof(TemporaryNumberValueGeneratorFactory), property.Name, property.DeclaringEntityType.DisplayName())); + + ValueGenerator? TryCreate() + { + if (type == typeof(int)) + { + return new TemporaryIntValueGenerator(); + } + + if (type == typeof(long)) + { + return new TemporaryLongValueGenerator(); + } + + if (type == typeof(short)) + { + return new TemporaryShortValueGenerator(); + } + + if (type == typeof(byte)) + { + return new TemporaryByteValueGenerator(); + } + + if (type == typeof(char)) + { + return new TemporaryCharValueGenerator(); + } + + if (type == typeof(ulong)) + { + return new TemporaryULongValueGenerator(); + } + + if (type == typeof(uint)) + { + return new TemporaryUIntValueGenerator(); + } + + if (type == typeof(ushort)) + { + return new TemporaryUShortValueGenerator(); + } + + if (type == typeof(sbyte)) + { + return new TemporarySByteValueGenerator(); + } + + if (type == typeof(decimal)) + { + return new TemporaryDecimalValueGenerator(); + } + + if (type == typeof(float)) + { + return new TemporaryFloatValueGenerator(); + } + + if (type == typeof(double)) + { + return new TemporaryDoubleValueGenerator(); + } + + return null; + } } } diff --git a/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs b/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs index 1f791426274..249eca7f6fb 100644 --- a/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs +++ b/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs @@ -2198,6 +2198,145 @@ public virtual void Insert_update_and_delete_with_wrapped_int_key() }); } + protected class LongToIntPrincipal + { + [DatabaseGenerated(DatabaseGeneratedOption.Identity)] + public long Id { get; set; } + + public ICollection Dependents { get; } = new List(); + public ICollection RequiredDependents { get; } = new List(); + public ICollection OptionalDependents { get; } = new List(); + } + + protected class LongToIntDependentShadow + { + [DatabaseGenerated(DatabaseGeneratedOption.Identity)] + public long Id { get; set; } + + public LongToIntPrincipal? Principal { get; set; } + } + + protected class LongToIntDependentRequired + { + [DatabaseGenerated(DatabaseGeneratedOption.Identity)] + public long Id { get; set; } + + public long PrincipalId { get; set; } + public LongToIntPrincipal Principal { get; set; } = null!; + } + + protected class LongToIntDependentOptional + { + [DatabaseGenerated(DatabaseGeneratedOption.Identity)] + public long Id { get; set; } + + public long? PrincipalId { get; set; } + public LongToIntPrincipal? Principal { get; set; } + } + + [ConditionalFact] + public virtual void Insert_update_and_delete_with_long_to_int_conversion() + { + var id1 = 0L; + ExecuteWithStrategyInTransaction( + context => + { + var principal1 = context.Add( + new LongToIntPrincipal + { + Dependents = { new LongToIntDependentShadow(), new LongToIntDependentShadow() }, + OptionalDependents = { new LongToIntDependentOptional(), new LongToIntDependentOptional() }, + RequiredDependents = { new LongToIntDependentRequired(), new LongToIntDependentRequired() } + }).Entity; + + context.SaveChanges(); + + id1 = principal1.Id; + Assert.NotEqual(0L, id1); + foreach (var dependent in principal1.Dependents) + { + Assert.NotEqual(0L, dependent.Id); + Assert.Same(principal1, dependent.Principal); + Assert.Equal(id1, context.Entry(dependent).Property("PrincipalId").CurrentValue!.Value); + } + + foreach (var dependent in principal1.OptionalDependents) + { + Assert.NotEqual(0L, dependent.Id); + Assert.Same(principal1, dependent.Principal); + Assert.Equal(id1, dependent.PrincipalId); + } + + foreach (var dependent in principal1.RequiredDependents) + { + Assert.NotEqual(0L, dependent.Id); + Assert.Same(principal1, dependent.Principal); + Assert.Equal(id1, dependent.PrincipalId); + } + }, + context => + { + var principal1 = context.Set() + .Include(e => e.Dependents) + .Include(e => e.OptionalDependents) + .Include(e => e.RequiredDependents) + .Single(); + + Assert.Equal(principal1.Id, id1); + foreach (var dependent in principal1.Dependents) + { + Assert.Same(principal1, dependent.Principal); + Assert.Equal(id1, context.Entry(dependent).Property("PrincipalId").CurrentValue!.Value); + } + + foreach (var dependent in principal1.OptionalDependents) + { + Assert.Same(principal1, dependent.Principal); + Assert.Equal(id1, dependent.PrincipalId!.Value); + } + + foreach (var dependent in principal1.RequiredDependents) + { + Assert.Same(principal1, dependent.Principal); + Assert.Equal(id1, dependent.PrincipalId); + } + + principal1.Dependents.Remove(principal1.Dependents.First()); + principal1.OptionalDependents.Remove(principal1.OptionalDependents.First()); + principal1.RequiredDependents.Remove(principal1.RequiredDependents.First()); + + context.SaveChanges(); + }, + context => + { + var dependents1 = context.Set().Include(e => e.Principal).ToList(); + Assert.Equal(2, dependents1.Count); + Assert.Null( + context.Entry(dependents1.Single(e => e.Principal == null)) + .Property("PrincipalId").CurrentValue); + + var optionalDependents1 = context.Set().Include(e => e.Principal).ToList(); + Assert.Equal(2, optionalDependents1.Count); + Assert.Null(optionalDependents1.Single(e => e.Principal == null).PrincipalId); + + var requiredDependents1 = context.Set().Include(e => e.Principal).ToList(); + Assert.Single(requiredDependents1); + + context.Remove(dependents1.Single(e => e.Principal != null)); + context.Remove(optionalDependents1.Single(e => e.Principal != null)); + context.Remove(requiredDependents1.Single()); + context.Remove(requiredDependents1.Single().Principal); + + context.SaveChanges(); + }, + context => + { + Assert.Equal(1, context.Set().Count()); + Assert.Equal(1, context.Set().Count()); + Assert.Equal(0, context.Set().Count()); + }); + } + protected class WrappedStringClass { public string? Value { get; set; } @@ -4579,6 +4718,16 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con entity.Property(e => e.NonKey).HasValueGenerator(); }); + modelBuilder.Entity( + entity => + { + var keyConverter = new ValueConverter( + v => (int)v, + v => (long)v); + + entity.Property(e => e.Id).HasConversion(keyConverter); + }); + modelBuilder.Entity( entity => { diff --git a/test/EFCore.SqlServer.FunctionalTests/StoreGeneratedSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/StoreGeneratedSqlServerTest.cs index 7512cb4947f..706c8d519db 100644 --- a/test/EFCore.SqlServer.FunctionalTests/StoreGeneratedSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/StoreGeneratedSqlServerTest.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.ComponentModel.DataAnnotations.Schema; using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; // ReSharper disable InconsistentNaming @@ -262,6 +263,145 @@ protected class WrappedIntHiLoRecordDependentRequired public WrappedIntHiLoRecordPrincipal Principal { get; set; } = null!; } + protected class LongToDecimalPrincipal + { + [DatabaseGenerated(DatabaseGeneratedOption.Identity)] + public long Id { get; set; } + + public ICollection Dependents { get; } = new List(); + public ICollection RequiredDependents { get; } = new List(); + public ICollection OptionalDependents { get; } = new List(); + } + + protected class LongToDecimalDependentShadow + { + [DatabaseGenerated(DatabaseGeneratedOption.Identity)] + public long Id { get; set; } + + public LongToDecimalPrincipal? Principal { get; set; } + } + + protected class LongToDecimalDependentRequired + { + [DatabaseGenerated(DatabaseGeneratedOption.Identity)] + public long Id { get; set; } + + public long PrincipalId { get; set; } + public LongToDecimalPrincipal Principal { get; set; } = null!; + } + + protected class LongToDecimalDependentOptional + { + [DatabaseGenerated(DatabaseGeneratedOption.Identity)] + public long Id { get; set; } + + public long? PrincipalId { get; set; } + public LongToDecimalPrincipal? Principal { get; set; } + } + + [ConditionalFact] + public virtual void Insert_update_and_delete_with_long_to_decimal_conversion() + { + var id1 = 0L; + ExecuteWithStrategyInTransaction( + context => + { + var principal1 = context.Add( + new LongToDecimalPrincipal + { + Dependents = { new LongToDecimalDependentShadow(), new LongToDecimalDependentShadow() }, + OptionalDependents = { new LongToDecimalDependentOptional(), new LongToDecimalDependentOptional() }, + RequiredDependents = { new LongToDecimalDependentRequired(), new LongToDecimalDependentRequired() } + }).Entity; + + context.SaveChanges(); + + id1 = principal1.Id; + Assert.NotEqual(0L, id1); + foreach (var dependent in principal1.Dependents) + { + Assert.NotEqual(0L, dependent.Id); + Assert.Same(principal1, dependent.Principal); + Assert.Equal(id1, context.Entry(dependent).Property("PrincipalId").CurrentValue!.Value); + } + + foreach (var dependent in principal1.OptionalDependents) + { + Assert.NotEqual(0L, dependent.Id); + Assert.Same(principal1, dependent.Principal); + Assert.Equal(id1, dependent.PrincipalId); + } + + foreach (var dependent in principal1.RequiredDependents) + { + Assert.NotEqual(0L, dependent.Id); + Assert.Same(principal1, dependent.Principal); + Assert.Equal(id1, dependent.PrincipalId); + } + }, + context => + { + var principal1 = context.Set() + .Include(e => e.Dependents) + .Include(e => e.OptionalDependents) + .Include(e => e.RequiredDependents) + .Single(); + + Assert.Equal(principal1.Id, id1); + foreach (var dependent in principal1.Dependents) + { + Assert.Same(principal1, dependent.Principal); + Assert.Equal(id1, context.Entry(dependent).Property("PrincipalId").CurrentValue!.Value); + } + + foreach (var dependent in principal1.OptionalDependents) + { + Assert.Same(principal1, dependent.Principal); + Assert.Equal(id1, dependent.PrincipalId!.Value); + } + + foreach (var dependent in principal1.RequiredDependents) + { + Assert.Same(principal1, dependent.Principal); + Assert.Equal(id1, dependent.PrincipalId); + } + + principal1.Dependents.Remove(principal1.Dependents.First()); + principal1.OptionalDependents.Remove(principal1.OptionalDependents.First()); + principal1.RequiredDependents.Remove(principal1.RequiredDependents.First()); + + context.SaveChanges(); + }, + context => + { + var dependents1 = context.Set().Include(e => e.Principal).ToList(); + Assert.Equal(2, dependents1.Count); + Assert.Null( + context.Entry(dependents1.Single(e => e.Principal == null)) + .Property("PrincipalId").CurrentValue); + + var optionalDependents1 = context.Set().Include(e => e.Principal).ToList(); + Assert.Equal(2, optionalDependents1.Count); + Assert.Null(optionalDependents1.Single(e => e.Principal == null).PrincipalId); + + var requiredDependents1 = context.Set().Include(e => e.Principal).ToList(); + Assert.Single(requiredDependents1); + + context.Remove(dependents1.Single(e => e.Principal != null)); + context.Remove(optionalDependents1.Single(e => e.Principal != null)); + context.Remove(requiredDependents1.Single()); + context.Remove(requiredDependents1.Single().Principal); + + context.SaveChanges(); + }, + context => + { + Assert.Equal(1, context.Set().Count()); + Assert.Equal(1, context.Set().Count()); + Assert.Equal(0, context.Set().Count()); + }); + } + [ConditionalFact] public virtual void Insert_update_and_delete_with_wrapped_int_key_using_hi_lo() { @@ -767,6 +907,18 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con modelBuilder.Entity().Property(e => e.Id).UseHiLo(); modelBuilder.Entity().Property(e => e.Id).UseHiLo(); + modelBuilder.Entity( + entity => + { + var keyConverter = new ValueConverter( + v => new decimal(v), + v => decimal.ToInt64(v)); + + entity.Property(e => e.Id) + .HasPrecision(18, 0) + .HasConversion(keyConverter); + }); + base.OnModelCreating(modelBuilder, context); }