diff --git a/src/documentation/articles/intro.md b/src/documentation/articles/intro.md index d42d8188..f766f766 100644 --- a/src/documentation/articles/intro.md +++ b/src/documentation/articles/intro.md @@ -48,4 +48,4 @@ Do you want to help us? - open [issues](https://github.com/masesgroup/KEFCore/issues) to request features or report bugs :bug: - improves the project with Pull Requests -This project adheres to the Contributor [Covenant code of conduct](CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. Please report unacceptable behavior to coc_reporting@masesgroup.com. +This project adheres to the Contributor [Covenant code of conduct](https://github.com/masesgroup/KEFCore/blob/master/CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. Please report unacceptable behavior to coc_reporting@masesgroup.com. diff --git a/src/documentation/index.md b/src/documentation/index.md index 76dd7bf5..736138a8 100644 --- a/src/documentation/index.md +++ b/src/documentation/index.md @@ -43,7 +43,7 @@ Do you want to help us? - open [issues](https://github.com/masesgroup/KEFCore/issues) to request features or report bugs :bug: - improves the project with Pull Requests -This project adheres to the Contributor [Covenant code of conduct](CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. Please report unacceptable behavior to coc_reporting@masesgroup.com. +This project adheres to the Contributor [Covenant code of conduct](https://github.com/masesgroup/KEFCore/blob/master/CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. Please report unacceptable behavior to coc_reporting@masesgroup.com. --- ## Summary diff --git a/src/net/Common/Common.props b/src/net/Common/Common.props index 55a0edf3..69e4fa37 100644 --- a/src/net/Common/Common.props +++ b/src/net/Common/Common.props @@ -4,8 +4,8 @@ MASES s.r.l. MASES s.r.l. MASES s.r.l. - 1.0.0.0 - net6.0;net7.0 + 1.1.0.0 + net6.0;net7.0;net8.0 latest true true diff --git a/src/net/KEFCore.SerDes.Avro.Compiler/KEFCore.SerDes.Avro.Compiler.csproj b/src/net/KEFCore.SerDes.Avro.Compiler/KEFCore.SerDes.Avro.Compiler.csproj index 0dedc155..16e141b4 100644 --- a/src/net/KEFCore.SerDes.Avro.Compiler/KEFCore.SerDes.Avro.Compiler.csproj +++ b/src/net/KEFCore.SerDes.Avro.Compiler/KEFCore.SerDes.Avro.Compiler.csproj @@ -23,8 +23,8 @@ - - + + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/src/net/KEFCore.SerDes.Avro/KEFCore.SerDes.Avro.csproj b/src/net/KEFCore.SerDes.Avro/KEFCore.SerDes.Avro.csproj index 31427fe6..8cdd415e 100644 --- a/src/net/KEFCore.SerDes.Avro/KEFCore.SerDes.Avro.csproj +++ b/src/net/KEFCore.SerDes.Avro/KEFCore.SerDes.Avro.csproj @@ -43,7 +43,7 @@ - + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/src/net/KEFCore.SerDes.Protobuf/Generated/GenericValue.cs b/src/net/KEFCore.SerDes.Protobuf/Generated/GenericValue.cs index ebe8b6ab..eb526c6f 100644 --- a/src/net/KEFCore.SerDes.Protobuf/Generated/GenericValue.cs +++ b/src/net/KEFCore.SerDes.Protobuf/Generated/GenericValue.cs @@ -52,6 +52,7 @@ static GenericValueReflection() { /// [START messages] /// Our address book file is just one of these. /// + [global::System.Diagnostics.DebuggerDisplayAttribute("{ToString(),nq}")] public sealed partial class GenericValue : pb::IMessage #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE , pb::IBufferMessage diff --git a/src/net/KEFCore.SerDes.Protobuf/Generated/KeyContainer.cs b/src/net/KEFCore.SerDes.Protobuf/Generated/KeyContainer.cs index 3384d1de..39fca702 100644 --- a/src/net/KEFCore.SerDes.Protobuf/Generated/KeyContainer.cs +++ b/src/net/KEFCore.SerDes.Protobuf/Generated/KeyContainer.cs @@ -45,6 +45,7 @@ static KeyContainerReflection() { /// /// [START messages] /// + [global::System.Diagnostics.DebuggerDisplayAttribute("{ToString(),nq}")] public sealed partial class PrimaryKeyType : pb::IMessage #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE , pb::IBufferMessage @@ -226,6 +227,7 @@ public void MergeFrom(pb::CodedInputStream input) { /// /// Our address book file is just one of these. /// + [global::System.Diagnostics.DebuggerDisplayAttribute("{ToString(),nq}")] public sealed partial class KeyContainer : pb::IMessage #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE , pb::IBufferMessage diff --git a/src/net/KEFCore.SerDes.Protobuf/Generated/ValueContainer.cs b/src/net/KEFCore.SerDes.Protobuf/Generated/ValueContainer.cs index 0a3799b9..48698379 100644 --- a/src/net/KEFCore.SerDes.Protobuf/Generated/ValueContainer.cs +++ b/src/net/KEFCore.SerDes.Protobuf/Generated/ValueContainer.cs @@ -48,6 +48,7 @@ static ValueContainerReflection() { /// /// [START messages] /// + [global::System.Diagnostics.DebuggerDisplayAttribute("{ToString(),nq}")] public sealed partial class PropertyDataRecord : pb::IMessage #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE , pb::IBufferMessage @@ -360,6 +361,7 @@ public void MergeFrom(pb::CodedInputStream input) { /// /// Our address book file is just one of these. /// + [global::System.Diagnostics.DebuggerDisplayAttribute("{ToString(),nq}")] public sealed partial class ValueContainer : pb::IMessage #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE , pb::IBufferMessage diff --git a/src/net/KEFCore.SerDes.Protobuf/KEFCore.SerDes.Protobuf.csproj b/src/net/KEFCore.SerDes.Protobuf/KEFCore.SerDes.Protobuf.csproj index 977be155..7944e4d9 100644 --- a/src/net/KEFCore.SerDes.Protobuf/KEFCore.SerDes.Protobuf.csproj +++ b/src/net/KEFCore.SerDes.Protobuf/KEFCore.SerDes.Protobuf.csproj @@ -42,21 +42,21 @@ - - + + all runtime; build; native; contentfiles; analyzers; buildtransitive - + all runtime; build; native; contentfiles; analyzers; buildtransitive - - - - + + + + diff --git a/src/net/KEFCore.SerDes/KEFCore.SerDes.csproj b/src/net/KEFCore.SerDes/KEFCore.SerDes.csproj index 1f7ad704..e4167a51 100644 --- a/src/net/KEFCore.SerDes/KEFCore.SerDes.csproj +++ b/src/net/KEFCore.SerDes/KEFCore.SerDes.csproj @@ -41,13 +41,14 @@ - - - + + + + All None - + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/src/net/KEFCore/Extensions/KafkaServiceCollectionExtensions.cs b/src/net/KEFCore/Extensions/KafkaServiceCollectionExtensions.cs index 68aed60e..4bac84d2 100644 --- a/src/net/KEFCore/Extensions/KafkaServiceCollectionExtensions.cs +++ b/src/net/KEFCore/Extensions/KafkaServiceCollectionExtensions.cs @@ -62,6 +62,9 @@ public static IServiceCollection AddEntityFrameworkKafkaDatabase(this IServiceCo .TryAdd() .TryAdd() .TryAdd() +#if NET8_0 + .TryAdd() +#endif .TryAdd(p => p.GetRequiredService()) .TryAddProviderSpecificServices( b => b diff --git a/src/net/KEFCore/KEFCore.csproj b/src/net/KEFCore/KEFCore.csproj index 7160be21..80f00cc9 100644 --- a/src/net/KEFCore/KEFCore.csproj +++ b/src/net/KEFCore/KEFCore.csproj @@ -16,6 +16,22 @@ True False + + + + + + + + + + + + + + + + @@ -66,7 +82,7 @@ - + diff --git a/src/net/KEFCore/Properties/KafkaStrings.Designer.cs b/src/net/KEFCore/Properties/KafkaStrings.Designer.cs index bb45ce69..e13bd4af 100644 --- a/src/net/KEFCore/Properties/KafkaStrings.Designer.cs +++ b/src/net/KEFCore/Properties/KafkaStrings.Designer.cs @@ -43,6 +43,12 @@ public static string InvalidDerivedTypeInEntityProjection(object? derivedType, o GetString("InvalidDerivedTypeInEntityProjection", nameof(derivedType), nameof(entityType)), derivedType, entityType); + /// + /// A 'GroupBy' operation which is not composed into aggregate or projection of elements is not supported. + /// + public static string NonComposedGroupByNotSupported + => GetString("NonComposedGroupByNotSupported"); + /// /// There is no query string because the Kafka provider does not use a string-based query language. /// diff --git a/src/net/KEFCore/Properties/KafkaStrings.resx b/src/net/KEFCore/Properties/KafkaStrings.resx index 44584806..63fd4aac 100644 --- a/src/net/KEFCore/Properties/KafkaStrings.resx +++ b/src/net/KEFCore/Properties/KafkaStrings.resx @@ -134,6 +134,9 @@ Transactions are not supported by the Kafka store. See http://go.microsoft.com/fwlink/?LinkId=800142 Warning KafkaEventId.TransactionIgnoredWarning + + A 'GroupBy' operation which is not composed into aggregate or projection of elements is not supported. + There is no query string because the Kafka provider does not use a string-based query language. diff --git a/src/net/KEFCore/Query/Internal/KafkaQueryContext.cs b/src/net/KEFCore/Query/Internal/KafkaQueryContext.cs index 80d7ddb5..a53856e2 100644 --- a/src/net/KEFCore/Query/Internal/KafkaQueryContext.cs +++ b/src/net/KEFCore/Query/Internal/KafkaQueryContext.cs @@ -17,7 +17,6 @@ */ using MASES.EntityFrameworkCore.KNet.Storage.Internal; -using System.Collections.Concurrent; namespace MASES.EntityFrameworkCore.KNet.Query.Internal; /// diff --git a/src/net/KEFCore/Query/Internal/KafkaQueryableMethodTranslatingExpressionVisitor.cs b/src/net/KEFCore/Query/Internal/KafkaQueryableMethodTranslatingExpressionVisitor.cs index 634a3aba..98c1aae3 100644 --- a/src/net/KEFCore/Query/Internal/KafkaQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/net/KEFCore/Query/Internal/KafkaQueryableMethodTranslatingExpressionVisitor.cs @@ -91,6 +91,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } #if NET6_0 /// + [Obsolete] protected override ShapedQueryExpression CreateShapedQueryExpression(Type elementType) { throw new NotImplementedException(); diff --git a/src/net/KEFCore/Query/Internal8/AnonymousObject.cs b/src/net/KEFCore/Query/Internal8/AnonymousObject.cs new file mode 100644 index 00000000..2eafcefe --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/AnonymousObject.cs @@ -0,0 +1,86 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using JetBrains.Annotations; + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public readonly struct AnonymousObject +{ + private readonly object[] _values; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public static readonly ConstructorInfo AnonymousObjectCtor + = typeof(AnonymousObject).GetTypeInfo() + .DeclaredConstructors + .Single(c => c.GetParameters().Length == 1); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [UsedImplicitly] + public AnonymousObject(object[] values) + { + _values = values; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public static bool operator ==(AnonymousObject x, AnonymousObject y) + => x.Equals(y); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public static bool operator !=(AnonymousObject x, AnonymousObject y) + => !x.Equals(y); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override bool Equals(object? obj) + => obj is not null + && (obj is AnonymousObject anonymousObject + && _values.SequenceEqual(anonymousObject._values)); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override int GetHashCode() + { + var hash = new HashCode(); + foreach (var value in _values) + { + hash.Add(value); + } + + return hash.ToHashCode(); + } +} diff --git a/src/net/KEFCore/Query/Internal8/CollectionResultShaperExpression.cs b/src/net/KEFCore/Query/Internal8/CollectionResultShaperExpression.cs new file mode 100644 index 00000000..a9e9ec92 --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/CollectionResultShaperExpression.cs @@ -0,0 +1,107 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class CollectionResultShaperExpression : Expression, IPrintableExpression +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public CollectionResultShaperExpression( + Expression projection, + Expression innerShaper, + INavigationBase? navigation, + Type elementType) + { + Projection = projection; + InnerShaper = innerShaper; + Navigation = navigation; + ElementType = elementType; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression Projection { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression InnerShaper { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual INavigationBase? Navigation { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Type ElementType { get; } + + /// + public sealed override ExpressionType NodeType + => ExpressionType.Extension; + + /// + public override Type Type + => Navigation?.ClrType ?? typeof(List<>).MakeGenericType(ElementType); + + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + var projection = visitor.Visit(Projection); + var innerShaper = visitor.Visit(InnerShaper); + + return Update(projection, innerShaper); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual CollectionResultShaperExpression Update( + Expression projection, + Expression innerShaper) + => projection != Projection || innerShaper != InnerShaper + ? new CollectionResultShaperExpression(projection, innerShaper, Navigation, ElementType) + : this; + + /// + void IPrintableExpression.Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.AppendLine("CollectionResultShaperExpression:"); + using (expressionPrinter.Indent()) + { + expressionPrinter.Append("("); + expressionPrinter.Visit(Projection); + expressionPrinter.Append(", "); + expressionPrinter.Visit(InnerShaper); + expressionPrinter.AppendLine($", {Navigation?.Name}, {ElementType.ShortDisplayName()})"); + } + } +} diff --git a/src/net/KEFCore/Query/Internal8/EntityProjectionExpression.cs b/src/net/KEFCore/Query/Internal8/EntityProjectionExpression.cs new file mode 100644 index 00000000..53821020 --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/EntityProjectionExpression.cs @@ -0,0 +1,191 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using MASES.EntityFrameworkCore.KNet.Internal; + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class EntityProjectionExpression : Expression, IPrintableExpression +{ + private readonly IReadOnlyDictionary _readExpressionMap; + private readonly Dictionary _navigationExpressionsCache = new(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public EntityProjectionExpression( + IEntityType entityType, + IReadOnlyDictionary readExpressionMap) + { + EntityType = entityType; + _readExpressionMap = readExpressionMap; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual IEntityType EntityType { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Type Type + => EntityType.ClrType; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public sealed override ExpressionType NodeType + => ExpressionType.Extension; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual EntityProjectionExpression UpdateEntityType(IEntityType derivedType) + { + if (!derivedType.GetAllBaseTypes().Contains(EntityType)) + { + throw new InvalidOperationException( + KafkaStrings.InvalidDerivedTypeInEntityProjection( + derivedType.DisplayName(), EntityType.DisplayName())); + } + + var readExpressionMap = new Dictionary(); + foreach (var (property, methodCallExpression) in _readExpressionMap) + { + if (derivedType.IsAssignableFrom(property.DeclaringType) + || property.DeclaringType.IsAssignableFrom(derivedType)) + { + readExpressionMap[property] = methodCallExpression; + } + } + + return new EntityProjectionExpression(derivedType, readExpressionMap); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual MethodCallExpression BindProperty(IProperty property) + { + if (property.DeclaringType is not IEntityType entityType) + { + if (EntityType != property.DeclaringType) + { + throw new InvalidOperationException( + KafkaStrings.UnableToBindMemberToEntityProjection("property", property.Name, EntityType.DisplayName())); + } + } + else if (!EntityType.IsAssignableFrom(entityType) + && !entityType.IsAssignableFrom(EntityType)) + { + throw new InvalidOperationException( + KafkaStrings.UnableToBindMemberToEntityProjection("property", property.Name, EntityType.DisplayName())); + } + + return _readExpressionMap[property]; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void AddNavigationBinding(INavigation navigation, StructuralTypeShaperExpression shaper) + { + if (!EntityType.IsAssignableFrom(navigation.DeclaringEntityType) + && !navigation.DeclaringEntityType.IsAssignableFrom(EntityType)) + { + throw new InvalidOperationException( + KafkaStrings.UnableToBindMemberToEntityProjection("navigation", navigation.Name, EntityType.DisplayName())); + } + + _navigationExpressionsCache[navigation] = shaper; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual StructuralTypeShaperExpression? BindNavigation(INavigation navigation) + { + if (!EntityType.IsAssignableFrom(navigation.DeclaringEntityType) + && !navigation.DeclaringEntityType.IsAssignableFrom(EntityType)) + { + throw new InvalidOperationException( + KafkaStrings.UnableToBindMemberToEntityProjection("navigation", navigation.Name, EntityType.DisplayName())); + } + + return _navigationExpressionsCache.TryGetValue(navigation, out var expression) + ? expression + : null; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual EntityProjectionExpression Clone() + { + var readExpressionMap = new Dictionary(_readExpressionMap); + var entityProjectionExpression = new EntityProjectionExpression(EntityType, readExpressionMap); + foreach (var (navigation, entityShaperExpression) in _navigationExpressionsCache) + { + entityProjectionExpression._navigationExpressionsCache[navigation] = new StructuralTypeShaperExpression( + entityShaperExpression.StructuralType, + ((EntityProjectionExpression)entityShaperExpression.ValueBufferExpression).Clone(), + entityShaperExpression.IsNullable); + } + + return entityProjectionExpression; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + void IPrintableExpression.Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.AppendLine(nameof(EntityProjectionExpression) + ":"); + using (expressionPrinter.Indent()) + { + foreach (var (property, methodCallExpression) in _readExpressionMap) + { + expressionPrinter.Append(property + " -> "); + expressionPrinter.Visit(methodCallExpression); + expressionPrinter.AppendLine(); + } + } + } +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaExpressionTranslatingExpressionVisitor.cs b/src/net/KEFCore/Query/Internal8/KafkaExpressionTranslatingExpressionVisitor.cs new file mode 100644 index 00000000..a3760ac3 --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaExpressionTranslatingExpressionVisitor.cs @@ -0,0 +1,1725 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using System.Text.RegularExpressions; +using JetBrains.Annotations; +using ExpressionExtensions = Microsoft.EntityFrameworkCore.Infrastructure.ExpressionExtensions; + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class KafkaExpressionTranslatingExpressionVisitor : ExpressionVisitor +{ + private const string RuntimeParameterPrefix = QueryCompilationContext.QueryParameterPrefix + "entity_equality_"; + + private static readonly List SingleResultMethodInfos = new() + { + QueryableMethods.FirstWithPredicate, + QueryableMethods.FirstWithoutPredicate, + QueryableMethods.FirstOrDefaultWithPredicate, + QueryableMethods.FirstOrDefaultWithoutPredicate, + QueryableMethods.SingleWithPredicate, + QueryableMethods.SingleWithoutPredicate, + QueryableMethods.SingleOrDefaultWithPredicate, + QueryableMethods.SingleOrDefaultWithoutPredicate, + QueryableMethods.LastWithPredicate, + QueryableMethods.LastWithoutPredicate, + QueryableMethods.LastOrDefaultWithPredicate, + QueryableMethods.LastOrDefaultWithoutPredicate + //QueryableMethodProvider.ElementAtMethodInfo, + //QueryableMethodProvider.ElementAtOrDefaultMethodInfo + }; + + private static readonly MemberInfo ValueBufferIsEmpty = typeof(ValueBuffer).GetMember(nameof(ValueBuffer.IsEmpty))[0]; + + private static readonly MethodInfo ParameterValueExtractorMethod = + typeof(KafkaExpressionTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterValueExtractor))!; + + private static readonly MethodInfo ParameterListValueExtractorMethod = + typeof(KafkaExpressionTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterListValueExtractor))!; + + private static readonly MethodInfo GetParameterValueMethodInfo = + typeof(KafkaExpressionTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(GetParameterValue))!; + + private static readonly MethodInfo LikeMethodInfo = typeof(DbFunctionsExtensions).GetRuntimeMethod( + nameof(DbFunctionsExtensions.Like), new[] { typeof(DbFunctions), typeof(string), typeof(string) })!; + + private static readonly MethodInfo LikeMethodInfoWithEscape = typeof(DbFunctionsExtensions).GetRuntimeMethod( + nameof(DbFunctionsExtensions.Like), new[] { typeof(DbFunctions), typeof(string), typeof(string), typeof(string) })!; + + private static readonly MethodInfo RandomMethodInfo = typeof(DbFunctionsExtensions).GetRuntimeMethod( + nameof(DbFunctionsExtensions.Random), new[] { typeof(DbFunctions) })!; + + private static readonly MethodInfo RandomNextDoubleMethodInfo = typeof(Random).GetRuntimeMethod( + nameof(Random.NextDouble), Type.EmptyTypes)!; + + private static readonly MethodInfo KafkaLikeMethodInfo = + typeof(KafkaExpressionTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(KafkaLike))!; + + private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(GetType))!; + + // Regex special chars defined here: + // https://msdn.microsoft.com/en-us/library/4edbef7e(v=vs.110).aspx + private static readonly char[] RegexSpecialChars + = { '.', '$', '^', '{', '[', '(', '|', ')', '*', '+', '?', '\\' }; + + private static readonly string DefaultEscapeRegexCharsPattern = BuildEscapeRegexCharsPattern(RegexSpecialChars); + + private static readonly TimeSpan RegexTimeout = TimeSpan.FromMilliseconds(value: 1000.0); + + private static string BuildEscapeRegexCharsPattern(IEnumerable regexSpecialChars) + => string.Join("|", regexSpecialChars.Select(c => @"\" + c)); + + private readonly QueryCompilationContext _queryCompilationContext; + private readonly QueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor; + private readonly EntityReferenceFindingExpressionVisitor _entityReferenceFindingExpressionVisitor; + private readonly IModel _model; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public KafkaExpressionTranslatingExpressionVisitor( + QueryCompilationContext queryCompilationContext, + QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor) + { + _queryCompilationContext = queryCompilationContext; + _queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor; + _entityReferenceFindingExpressionVisitor = new EntityReferenceFindingExpressionVisitor(); + _model = queryCompilationContext.Model; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual string? TranslationErrorDetails { get; private set; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected virtual void AddTranslationErrorDetails(string details) + { + if (TranslationErrorDetails == null) + { + TranslationErrorDetails = details; + } + else + { + TranslationErrorDetails += Environment.NewLine + details; + } + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression? Translate(Expression expression) + { + TranslationErrorDetails = null; + + return TranslateInternal(expression); + } + + private Expression? TranslateInternal(Expression expression) + { + var result = Visit(expression); + + return result == QueryCompilationContext.NotTranslatedExpression + || _entityReferenceFindingExpressionVisitor.Find(result) + ? null + : result; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitBinary(BinaryExpression binaryExpression) + { + if (binaryExpression.Left.Type == typeof(object[]) + && binaryExpression is { Left: NewArrayExpression, NodeType: ExpressionType.Equal }) + { + return Visit(ConvertObjectArrayEqualityComparison(binaryExpression.Left, binaryExpression.Right)); + } + + if (binaryExpression.NodeType is ExpressionType.Equal or ExpressionType.NotEqual + && (binaryExpression.Left.IsNullConstantExpression() || binaryExpression.Right.IsNullConstantExpression())) + { + var nonNullExpression = binaryExpression.Left.IsNullConstantExpression() ? binaryExpression.Right : binaryExpression.Left; + if (nonNullExpression is MethodCallExpression nonNullMethodCallExpression + && nonNullMethodCallExpression.Method.DeclaringType == typeof(Queryable) + && nonNullMethodCallExpression.Method.IsGenericMethod + && SingleResultMethodInfos.Contains(nonNullMethodCallExpression.Method.GetGenericMethodDefinition())) + { + var source = nonNullMethodCallExpression.Arguments[0]; + if (nonNullMethodCallExpression.Arguments.Count == 2) + { + source = Expression.Call( + QueryableMethods.Where.MakeGenericMethod(source.Type.GetSequenceType()), + source, + nonNullMethodCallExpression.Arguments[1]); + } + + var translatedSubquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(source); + if (translatedSubquery != null) + { + var projection = translatedSubquery.ShaperExpression; + if (projection is NewExpression + || RemoveConvert(projection) is StructuralTypeShaperExpression { IsNullable: false } + || RemoveConvert(projection) is CollectionResultShaperExpression) + { + var anySubquery = Expression.Call( + QueryableMethods.AnyWithoutPredicate.MakeGenericMethod(translatedSubquery.Type.GetSequenceType()), + translatedSubquery); + + return Visit( + binaryExpression.NodeType == ExpressionType.Equal + ? Expression.Not(anySubquery) + : anySubquery); + } + + static Expression RemoveConvert(Expression e) + => e is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unary + ? RemoveConvert(unary.Operand) + : e; + } + } + } + + if (binaryExpression.NodeType == ExpressionType.Equal + || binaryExpression.NodeType == ExpressionType.NotEqual + && binaryExpression.Left.Type == typeof(Type)) + { + if (IsGetTypeMethodCall(binaryExpression.Left, out var entityReference1) + && IsTypeConstant(binaryExpression.Right, out var type1)) + { + return ProcessGetType(entityReference1!, type1!, binaryExpression.NodeType == ExpressionType.Equal); + } + + if (IsGetTypeMethodCall(binaryExpression.Right, out var entityReference2) + && IsTypeConstant(binaryExpression.Left, out var type2)) + { + return ProcessGetType(entityReference2!, type2!, binaryExpression.NodeType == ExpressionType.Equal); + } + } + + var newLeft = Visit(binaryExpression.Left); + var newRight = Visit(binaryExpression.Right); + + if (newLeft == QueryCompilationContext.NotTranslatedExpression + || newRight == QueryCompilationContext.NotTranslatedExpression) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + if (binaryExpression.NodeType is ExpressionType.Equal or ExpressionType.NotEqual + // Visited expression could be null, We need to pass MemberInitExpression + && TryRewriteEntityEquality( + binaryExpression.NodeType, + newLeft, + newRight, + equalsMethod: false, + out var result)) + { + return result; + } + + if (IsConvertedToNullable(newLeft, binaryExpression.Left) + || IsConvertedToNullable(newRight, binaryExpression.Right)) + { + newLeft = ConvertToNullable(newLeft); + newRight = ConvertToNullable(newRight); + } + + if (binaryExpression.NodeType is ExpressionType.Equal or ExpressionType.NotEqual + && TryUseComparer(newLeft, newRight, out var updatedExpression)) + { + if (binaryExpression.NodeType == ExpressionType.NotEqual) + { + updatedExpression = Expression.IsFalse(updatedExpression!); + } + + return updatedExpression!; + } + + return Expression.MakeBinary( + binaryExpression.NodeType, + newLeft, + newRight, + binaryExpression.IsLiftedToNull, + binaryExpression.Method, + binaryExpression.Conversion); + + Expression ProcessGetType(StructuralTypeReferenceExpression typeReference, Type comparisonType, bool match) + { + if (typeReference.StructuralType is not IEntityType entityType + || (entityType.BaseType == null + && !entityType.GetDirectlyDerivedTypes().Any())) + { + // No hierarchy + return Expression.Constant((typeReference.StructuralType.ClrType == comparisonType) == match); + } + + if (entityType.GetAllBaseTypes().Any(e => e.ClrType == comparisonType)) + { + // EntitySet will never contain a type of base type + return Expression.Constant(!match); + } + + var derivedType = entityType.GetDerivedTypesInclusive().SingleOrDefault(et => et.ClrType == comparisonType); + // If no derived type matches then fail the translation + if (derivedType != null) + { + // If the derived type is abstract type then predicate will always be false + if (derivedType.IsAbstract()) + { + return Expression.Constant(!match); + } + + // Or add predicate for matching that particular type discriminator value + // All hierarchies have discriminator property + var discriminatorProperty = entityType.FindDiscriminatorProperty()!; + var boundProperty = BindProperty(typeReference, discriminatorProperty, discriminatorProperty.ClrType); + // KeyValueComparer is not null at runtime + var valueComparer = discriminatorProperty.GetKeyValueComparer(); + + var result = valueComparer.ExtractEqualsBody( + boundProperty!, + Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)); + + return match ? result : Expression.Not(result); + } + + return QueryCompilationContext.NotTranslatedExpression; + } + + bool IsGetTypeMethodCall(Expression expression, out StructuralTypeReferenceExpression? typeReference) + { + typeReference = null; + if (expression is not MethodCallExpression methodCallExpression + || methodCallExpression.Method != GetTypeMethodInfo) + { + return false; + } + + typeReference = Visit(methodCallExpression.Object) as StructuralTypeReferenceExpression; + return typeReference != null; + } + + static bool IsTypeConstant(Expression expression, out Type? type) + { + type = null; + if (expression is not UnaryExpression + { + NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked, + Operand: ConstantExpression constantExpression + }) + { + return false; + } + + type = constantExpression.Value as Type; + return type != null; + } + } + + private static bool TryUseComparer( + Expression? newLeft, + Expression? newRight, + out Expression? updatedExpression) + { + updatedExpression = null; + + if (newLeft == null + || newRight == null) + { + return false; + } + + var property = FindProperty(newLeft) ?? FindProperty(newRight); + var comparer = property?.GetValueComparer(); + + if (comparer == null) + { + return false; + } + + MethodInfo? objectEquals = null; + MethodInfo? exactMatch = null; + + var converter = property?.GetValueConverter(); + foreach (var candidate in comparer + .GetType() + .GetMethods(BindingFlags.Public | BindingFlags.Instance) + .Where( + m => m.Name == "Equals" && m.GetParameters().Length == 2) + .ToList()) + { + var parameters = candidate.GetParameters(); + var leftType = parameters[0].ParameterType; + var rightType = parameters[1].ParameterType; + + if (leftType == typeof(object) + && rightType == typeof(object)) + { + objectEquals = candidate; + continue; + } + + var matchingLeft = leftType.IsAssignableFrom(newLeft.Type) + ? newLeft + : converter != null + && leftType.IsAssignableFrom(converter.ModelClrType) + && converter.ProviderClrType.IsAssignableFrom(newLeft.Type) + ? ReplacingExpressionVisitor.Replace( + converter.ConvertFromProviderExpression.Parameters.Single(), + newLeft, + converter.ConvertFromProviderExpression.Body) + : null; + + var matchingRight = rightType.IsAssignableFrom(newRight.Type) + ? newRight + : converter != null + && rightType.IsAssignableFrom(converter.ModelClrType) + && converter.ProviderClrType.IsAssignableFrom(newRight.Type) + ? ReplacingExpressionVisitor.Replace( + converter.ConvertFromProviderExpression.Parameters.Single(), + newRight, + converter.ConvertFromProviderExpression.Body) + : null; + + if (matchingLeft != null && matchingRight != null) + { + exactMatch = candidate; + newLeft = matchingLeft; + newRight = matchingRight; + break; + } + } + + if (exactMatch == null + && (!property!.ClrType.IsAssignableFrom(newLeft.Type)) + || !property!.ClrType.IsAssignableFrom(newRight.Type)) + { + return false; + } + + updatedExpression = + exactMatch != null + ? Expression.Call( + Expression.Constant(comparer, comparer.GetType()), + exactMatch, + newLeft, + newRight) + : Expression.Call( + Expression.Constant(comparer, comparer.GetType()), + objectEquals!, + Expression.Convert(newLeft, typeof(object)), + Expression.Convert(newRight, typeof(object))); + + return true; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitConditional(ConditionalExpression conditionalExpression) + { + var test = Visit(conditionalExpression.Test); + var ifTrue = Visit(conditionalExpression.IfTrue); + var ifFalse = Visit(conditionalExpression.IfFalse); + + if (test == QueryCompilationContext.NotTranslatedExpression + || ifTrue == QueryCompilationContext.NotTranslatedExpression + || ifFalse == QueryCompilationContext.NotTranslatedExpression) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + if (test.Type == typeof(bool?)) + { + test = Expression.Equal(test, Expression.Constant(true, typeof(bool?))); + } + + if (IsConvertedToNullable(ifTrue, conditionalExpression.IfTrue) + || IsConvertedToNullable(ifFalse, conditionalExpression.IfFalse)) + { + ifTrue = ConvertToNullable(ifTrue); + ifFalse = ConvertToNullable(ifFalse); + } + + return Expression.Condition(test, ifTrue, ifFalse); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitExtension(Expression extensionExpression) + { + switch (extensionExpression) + { + case EntityProjectionExpression: + case StructuralTypeReferenceExpression: + return extensionExpression; + + case StructuralTypeShaperExpression shaper: + return new StructuralTypeReferenceExpression(shaper); + + case ProjectionBindingExpression projectionBindingExpression: + return ((KafkaQueryExpression)projectionBindingExpression.QueryExpression) + .GetProjection(projectionBindingExpression); + + default: + return QueryCompilationContext.NotTranslatedExpression; + } + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitInvocation(InvocationExpression invocationExpression) + => QueryCompilationContext.NotTranslatedExpression; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitLambda(Expression lambdaExpression) + => throw new InvalidOperationException(CoreStrings.TranslationFailed(lambdaExpression.Print())); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitListInit(ListInitExpression listInitExpression) + => QueryCompilationContext.NotTranslatedExpression; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitMember(MemberExpression memberExpression) + { + var innerExpression = Visit(memberExpression.Expression); + + // when visiting unary we remove converts from nullable to non-nullable + // however if this happens for memberExpression.Expression we are unable to bind + if (innerExpression != null + && memberExpression.Expression != null + && innerExpression.Type != memberExpression.Expression.Type + && innerExpression.Type.IsNullableType() + && innerExpression.Type.UnwrapNullableType() == memberExpression.Expression.Type) + { + innerExpression = Expression.Convert(innerExpression, memberExpression.Expression.Type); + } + + if (memberExpression.Expression != null + && innerExpression == QueryCompilationContext.NotTranslatedExpression) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + if (TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member), memberExpression.Type) is Expression result) + { + return result; + } + + var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression); + if (innerExpression != null + && innerExpression.Type.IsNullableType() + && ShouldApplyNullProtectionForMemberAccess(innerExpression.Type, memberExpression.Member.Name)) + { + updatedMemberExpression = ConvertToNullable(updatedMemberExpression); + + return Expression.Condition( + // Since inner is nullable type this is fine. + Expression.Equal(innerExpression, Expression.Default(innerExpression.Type)), + Expression.Default(updatedMemberExpression.Type), + updatedMemberExpression); + } + + return updatedMemberExpression; + + static bool ShouldApplyNullProtectionForMemberAccess(Type callerType, string memberName) + => !(callerType.IsGenericType + && callerType.GetGenericTypeDefinition() == typeof(Nullable<>) + && memberName is nameof(Nullable.Value) or nameof(Nullable.HasValue)); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment) + { + var expression = Visit(memberAssignment.Expression); + if (expression == QueryCompilationContext.NotTranslatedExpression) + { + return memberAssignment.Update(Expression.Convert(expression, memberAssignment.Expression.Type)); + } + + if (IsConvertedToNullable(expression, memberAssignment.Expression)) + { + expression = ConvertToNonNullable(expression); + } + + return memberAssignment.Update(expression); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) + { + var newExpression = Visit(memberInitExpression.NewExpression); + if (newExpression == QueryCompilationContext.NotTranslatedExpression) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + var newBindings = new MemberBinding[memberInitExpression.Bindings.Count]; + for (var i = 0; i < newBindings.Length; i++) + { + if (memberInitExpression.Bindings[i].BindingType != MemberBindingType.Assignment) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + newBindings[i] = VisitMemberBinding(memberInitExpression.Bindings[i]); + if (((MemberAssignment)newBindings[i]).Expression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression + && unaryExpression.Operand == QueryCompilationContext.NotTranslatedExpression) + { + return QueryCompilationContext.NotTranslatedExpression; + } + } + + return memberInitExpression.Update((NewExpression)newExpression, newBindings); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + if (methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == ExpressionExtensions.ValueBufferTryReadValueMethod) + { + return methodCallExpression; + } + + // EF.Property case + if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName)) + { + return TryBindMember(Visit(source), MemberIdentity.Create(propertyName), methodCallExpression.Type) + ?? throw new InvalidOperationException(CoreStrings.QueryUnableToTranslateEFProperty(methodCallExpression.Print())); + } + + // EF Indexer property + if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) + { + return TryBindMember(Visit(source), MemberIdentity.Create(propertyName), methodCallExpression.Type) + ?? QueryCompilationContext.NotTranslatedExpression; + } + + // Subquery case + var subqueryTranslation = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression); + if (subqueryTranslation != null) + { + var subquery = (KafkaQueryExpression)subqueryTranslation.QueryExpression; + if (subqueryTranslation.ResultCardinality == ResultCardinality.Enumerable) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + var shaperExpression = subqueryTranslation.ShaperExpression; + var innerExpression = shaperExpression; + Type? convertedType = null; + if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression) + { + convertedType = unaryExpression.Type; + innerExpression = unaryExpression.Operand; + } + + if (innerExpression is StructuralTypeShaperExpression shaper + && (convertedType == null + || convertedType.IsAssignableFrom(shaper.Type))) + { + return new StructuralTypeReferenceExpression(subqueryTranslation.UpdateShaperExpression(innerExpression)); + } + + if (!(innerExpression is ProjectionBindingExpression projectionBindingExpression + && (convertedType == null + || convertedType.MakeNullable() == innerExpression.Type))) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + if (projectionBindingExpression.ProjectionMember == null) + { + // We don't lift scalar subquery with client eval + return QueryCompilationContext.NotTranslatedExpression; + } + + return ProcessSingleResultScalar( + subquery, + subquery.GetProjection(projectionBindingExpression), + methodCallExpression.Type); + } + + if (methodCallExpression.Method == LikeMethodInfo + || methodCallExpression.Method == LikeMethodInfoWithEscape) + { + // EF.Functions.Like + var visitedArguments = new Expression[3]; + visitedArguments[2] = Expression.Constant(null, typeof(string)); + // Skip first DbFunctions argument + for (var i = 1; i < methodCallExpression.Arguments.Count; i++) + { + var argument = Visit(methodCallExpression.Arguments[i]); + if (TranslationFailed(methodCallExpression.Arguments[i], argument)) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + visitedArguments[i - 1] = argument; + } + + return Expression.Call(KafkaLikeMethodInfo, visitedArguments); + } + + if (methodCallExpression.Method == RandomMethodInfo) + { + return Expression.Call(Expression.New(typeof(Random)), RandomNextDoubleMethodInfo); + } + + Expression? @object = null; + Expression[] arguments; + var method = methodCallExpression.Method; + + if (method.Name == nameof(object.Equals) + && methodCallExpression is { Object: not null, Arguments.Count: 1 }) + { + var left = Visit(methodCallExpression.Object); + var right = Visit(methodCallExpression.Arguments[0]); + + if (TryRewriteEntityEquality( + ExpressionType.Equal, + left == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Object : left, + right == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[0] : right, + equalsMethod: true, + out var result)) + { + return result; + } + + if (TranslationFailed(methodCallExpression.Object, left) + || TranslationFailed(methodCallExpression.Arguments[0], right)) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + @object = left; + arguments = new[] { right }; + } + else if (method.Name == nameof(object.Equals) + && methodCallExpression.Object == null + && methodCallExpression.Arguments.Count == 2) + { + if (methodCallExpression.Arguments[0].Type == typeof(object[]) + && methodCallExpression.Arguments[0] is NewArrayExpression) + { + return Visit( + ConvertObjectArrayEqualityComparison( + methodCallExpression.Arguments[0], methodCallExpression.Arguments[1])); + } + + var left = Visit(methodCallExpression.Arguments[0]); + var right = Visit(methodCallExpression.Arguments[1]); + + if (TryUseComparer(left, right, out var updatedExpression)) + { + return updatedExpression!; + } + + if (TryRewriteEntityEquality( + ExpressionType.Equal, + left == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[0] : left, + right == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[1] : right, + equalsMethod: true, + out var result)) + { + return result; + } + + if (TranslationFailed(methodCallExpression.Arguments[0], left) + || TranslationFailed(methodCallExpression.Arguments[1], right)) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + arguments = new[] { left, right }; + } + else if (method.IsGenericMethod + && method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)) + { + var enumerable = Visit(methodCallExpression.Arguments[0]); + var item = Visit(methodCallExpression.Arguments[1]); + + if (TryRewriteContainsEntity( + enumerable, + item == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[1] : item, + out var result)) + { + return result; + } + + if (TranslationFailed(methodCallExpression.Arguments[0], enumerable) + || TranslationFailed(methodCallExpression.Arguments[1], item)) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + arguments = new[] { enumerable, item }; + } + else if (methodCallExpression.Arguments.Count == 1 + && method.IsContainsMethod()) + { + var enumerable = Visit(methodCallExpression.Object); + var item = Visit(methodCallExpression.Arguments[0]); + + if (TryRewriteContainsEntity( + enumerable, + item == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[0] : item, + out var result)) + { + return result; + } + + if (TranslationFailed(methodCallExpression.Object, enumerable) + || TranslationFailed(methodCallExpression.Arguments[0], item)) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + @object = enumerable; + arguments = new[] { item }; + } + else + { + @object = Visit(methodCallExpression.Object); + if (TranslationFailed(methodCallExpression.Object, @object)) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + arguments = new Expression[methodCallExpression.Arguments.Count]; + for (var i = 0; i < arguments.Length; i++) + { + var argument = Visit(methodCallExpression.Arguments[i]); + if (TranslationFailed(methodCallExpression.Arguments[i], argument)) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + arguments[i] = argument; + } + } + + // if the nullability of arguments change, we have no easy/reliable way to adjust the actual methodInfo to match the new type, + // so we are forced to cast back to the original type + var parameterTypes = methodCallExpression.Method.GetParameters().Select(p => p.ParameterType).ToArray(); + for (var i = 0; i < arguments.Length; i++) + { + var argument = arguments[i]; + if (IsConvertedToNullable(argument, methodCallExpression.Arguments[i]) + && !parameterTypes[i].IsAssignableFrom(argument.Type)) + { + argument = ConvertToNonNullable(argument); + } + + arguments[i] = argument; + } + + // if object is nullable, add null safeguard before calling the function + // we special-case Nullable<>.GetValueOrDefault, which doesn't need the safeguard + if (methodCallExpression.Object != null + && @object!.Type.IsNullableType() + && methodCallExpression.Method.Name != nameof(Nullable.GetValueOrDefault)) + { + var result = (Expression)methodCallExpression.Update( + Expression.Convert(@object, methodCallExpression.Object.Type), + arguments); + + result = ConvertToNullable(result); + var objectNullCheck = Expression.Equal(@object, Expression.Constant(null, @object.Type)); + // instance.Equals(argument) should translate to + // instance == null ? argument == null : instance.Equals(argument) + if (method.Name == nameof(object.Equals)) + { + var argument = arguments[0]; + if (argument.NodeType == ExpressionType.Convert + && argument is UnaryExpression unaryExpression + && argument.Type == unaryExpression.Operand.Type.UnwrapNullableType()) + { + argument = unaryExpression.Operand; + } + + if (!argument.Type.IsNullableType()) + { + argument = Expression.Convert(argument, argument.Type.MakeNullable()); + } + + return Expression.Condition( + objectNullCheck, + ConvertToNullable(Expression.Equal(argument, Expression.Constant(null, argument.Type))), + result); + } + + return Expression.Condition(objectNullCheck, Expression.Constant(null, result.Type), result); + } + + return methodCallExpression.Update(@object, arguments); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitNew(NewExpression newExpression) + { + var newArguments = new List(); + foreach (var argument in newExpression.Arguments) + { + var newArgument = Visit(argument); + if (newArgument == QueryCompilationContext.NotTranslatedExpression) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + if (IsConvertedToNullable(newArgument, argument)) + { + newArgument = ConvertToNonNullable(newArgument); + } + + newArguments.Add(newArgument); + } + + return newExpression.Update(newArguments); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitNewArray(NewArrayExpression newArrayExpression) + { + var newExpressions = new List(); + foreach (var expression in newArrayExpression.Expressions) + { + var newExpression = Visit(expression); + if (newExpression == QueryCompilationContext.NotTranslatedExpression) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + if (IsConvertedToNullable(newExpression, expression)) + { + newExpression = ConvertToNonNullable(newExpression); + } + + newExpressions.Add(newExpression); + } + + return newArrayExpression.Update(newExpressions); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitParameter(ParameterExpression parameterExpression) + { + if (parameterExpression.Name?.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal) == true) + { + return Expression.Call( + GetParameterValueMethodInfo.MakeGenericMethod(parameterExpression.Type), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(parameterExpression.Name)); + } + + throw new InvalidOperationException(CoreStrings.TranslationFailed(parameterExpression.Print())); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) + { + if (typeBinaryExpression.NodeType == ExpressionType.TypeIs + && Visit(typeBinaryExpression.Expression) is StructuralTypeReferenceExpression typeReference) + { + if (typeReference.StructuralType is not IEntityType entityType) + { + return Expression.Constant(typeReference.StructuralType.ClrType == typeBinaryExpression.TypeOperand); + } + + if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand)) + { + return Expression.Constant(true); + } + + var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand); + if (derivedType != null) + { + // All hierarchies have discriminator property + var discriminatorProperty = entityType.FindDiscriminatorProperty()!; + var boundProperty = BindProperty(typeReference, discriminatorProperty, discriminatorProperty.ClrType); + // KeyValueComparer is not null at runtime + var valueComparer = discriminatorProperty.GetKeyValueComparer(); + + var equals = valueComparer.ExtractEqualsBody( + boundProperty!, + Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)); + + foreach (var derivedDerivedType in derivedType.GetDerivedTypes()) + { + equals = Expression.OrElse( + equals, + valueComparer.ExtractEqualsBody( + boundProperty!, + Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType))); + } + + return equals; + } + } + + return QueryCompilationContext.NotTranslatedExpression; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitUnary(UnaryExpression unaryExpression) + { + var newOperand = Visit(unaryExpression.Operand); + if (newOperand == QueryCompilationContext.NotTranslatedExpression) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + if (newOperand is StructuralTypeReferenceExpression typeReference + && unaryExpression.NodeType is ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs) + { + return typeReference.Convert(unaryExpression.Type); + } + + if (unaryExpression.NodeType == ExpressionType.Convert + && newOperand.Type == unaryExpression.Type) + { + return newOperand; + } + + if (unaryExpression.NodeType == ExpressionType.Convert + && IsConvertedToNullable(newOperand, unaryExpression)) + { + return newOperand; + } + + var result = (Expression)Expression.MakeUnary(unaryExpression.NodeType, newOperand, unaryExpression.Type); + if (result is UnaryExpression + { + NodeType: ExpressionType.Convert, + Operand: UnaryExpression { NodeType: ExpressionType.Convert } innerUnary + } outerUnary) + { + var innerMostType = innerUnary.Operand.Type; + var intermediateType = innerUnary.Type; + var outerMostType = outerUnary.Type; + + if (outerMostType == innerMostType + && intermediateType == innerMostType.UnwrapNullableType()) + { + result = innerUnary.Operand; + } + else if (outerMostType == typeof(object) + && intermediateType == innerMostType.UnwrapNullableType()) + { + result = Expression.Convert(innerUnary.Operand, typeof(object)); + } + } + + return result; + } + + private Expression? TryBindMember(Expression? source, MemberIdentity member, Type type) + { + if (source is not StructuralTypeReferenceExpression typeReference) + { + return null; + } + + var entityType = typeReference.StructuralType; + + var property = member.MemberInfo != null + ? entityType.FindProperty(member.MemberInfo) + : entityType.FindProperty(member.Name!); + + if (property != null) + { + return BindProperty(typeReference, property, type); + } + + AddTranslationErrorDetails( + CoreStrings.QueryUnableToTranslateMember( + member.Name, + typeReference.StructuralType.DisplayName())); + + return null; + } + + private Expression? BindProperty(StructuralTypeReferenceExpression typeReference, IProperty property, Type type) + { + if (typeReference.Parameter != null) + { + var valueBufferExpression = Visit(typeReference.Parameter.ValueBufferExpression); + if (valueBufferExpression == QueryCompilationContext.NotTranslatedExpression) + { + return null; + } + + var result = ((EntityProjectionExpression)valueBufferExpression).BindProperty(property); + + // if the result type change was just nullability change e.g from int to int? + // we want to preserve the new type for null propagation + return result.Type != type + && !(result.Type.IsNullableType() + && !type.IsNullableType() + && result.Type.UnwrapNullableType() == type) + ? Expression.Convert(result, type) + : result; + } + + if (typeReference.Subquery != null) + { + var entityShaper = (StructuralTypeShaperExpression)typeReference.Subquery.ShaperExpression; + var kafkaQueryExpression = (KafkaQueryExpression)typeReference.Subquery.QueryExpression; + + var projectionBindingExpression = (ProjectionBindingExpression)entityShaper.ValueBufferExpression; + var entityProjectionExpression = (EntityProjectionExpression)kafkaQueryExpression.GetProjection( + projectionBindingExpression); + var readValueExpression = entityProjectionExpression.BindProperty(property); + + return ProcessSingleResultScalar( + kafkaQueryExpression, + readValueExpression, + type); + } + + return null; + } + + private static Expression ProcessSingleResultScalar( + KafkaQueryExpression kafkaQueryExpression, + Expression readValueExpression, + Type type) + { + if (kafkaQueryExpression.ServerQueryExpression is not NewExpression) + { + // The terminating operator is not applied + // It is of FirstOrDefault kind + // So we change to single column projection and then apply it. + kafkaQueryExpression.ReplaceProjection( + new Dictionary { { new ProjectionMember(), readValueExpression } }); + kafkaQueryExpression.ApplyProjection(); + } + + var serverQuery = kafkaQueryExpression.ServerQueryExpression; + serverQuery = ((LambdaExpression)((NewExpression)serverQuery).Arguments[0]).Body; + if (serverQuery is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression + && unaryExpression.Type == typeof(object)) + { + serverQuery = unaryExpression.Operand; + } + + var valueBufferVariable = Expression.Variable(typeof(ValueBuffer)); + var readExpression = valueBufferVariable.CreateValueBufferReadValueExpression(type, index: 0, property: null); + return Expression.Block( + variables: new[] { valueBufferVariable }, + Expression.Assign(valueBufferVariable, serverQuery), + Expression.Condition( + Expression.MakeMemberAccess(valueBufferVariable, ValueBufferIsEmpty), + Expression.Default(type), + readExpression)); + } + + [UsedImplicitly] + private static T GetParameterValue(QueryContext queryContext, string parameterName) + => (T)queryContext.ParameterValues[parameterName]!; + + private static bool IsConvertedToNullable(Expression result, Expression original) + => result.Type.IsNullableType() + && !original.Type.IsNullableType() + && result.Type.UnwrapNullableType() == original.Type; + + private static Expression ConvertToNullable(Expression expression) + => !expression.Type.IsNullableType() + ? Expression.Convert(expression, expression.Type.MakeNullable()) + : expression; + + private static Expression ConvertToNonNullable(Expression expression) + => expression.Type.IsNullableType() + ? Expression.Convert(expression, expression.Type.UnwrapNullableType()) + : expression; + + private static IProperty? FindProperty(Expression? expression) + { + if (expression?.NodeType == ExpressionType.Convert + && expression.Type == typeof(object)) + { + expression = ((UnaryExpression)expression).Operand; + } + + if (expression?.NodeType == ExpressionType.Convert + && expression.Type.IsNullableType() + && expression is UnaryExpression unaryExpression + && (expression.Type.UnwrapNullableType() == unaryExpression.Type + || expression.Type == unaryExpression.Type)) + { + expression = unaryExpression.Operand; + } + + if (expression is MethodCallExpression { Method.IsGenericMethod: true } readValueMethodCall + && readValueMethodCall.Method.GetGenericMethodDefinition() == ExpressionExtensions.ValueBufferTryReadValueMethod) + { + return readValueMethodCall.Arguments[2].GetConstantValue(); + } + + return null; + } + + private bool TryRewriteContainsEntity(Expression? source, Expression item, [NotNullWhen(true)] out Expression? result) + { + result = null; + + if (item is not StructuralTypeReferenceExpression { StructuralType: IEntityType entityType }) + { + return false; + } + + var primaryKeyProperties = entityType.FindPrimaryKey()?.Properties; + if (primaryKeyProperties == null) + { + throw new InvalidOperationException( + CoreStrings.EntityEqualityOnKeylessEntityNotSupported( + nameof(Queryable.Contains), entityType.DisplayName())); + } + + if (primaryKeyProperties.Count > 1) + { + throw new InvalidOperationException( + CoreStrings.EntityEqualityOnCompositeKeyEntitySubqueryNotSupported( + nameof(Queryable.Contains), entityType.DisplayName())); + } + + var property = primaryKeyProperties[0]; + Expression rewrittenSource; + switch (source) + { + case ConstantExpression constantExpression: + var values = constantExpression.GetConstantValue(); + var propertyValueList = + (IList)Activator.CreateInstance(typeof(List<>).MakeGenericType(property.ClrType.MakeNullable()))!; + var propertyGetter = property.GetGetter(); + foreach (var value in values) + { + propertyValueList.Add(propertyGetter.GetClrValue(value)); + } + + rewrittenSource = Expression.Constant(propertyValueList); + break; + + case MethodCallExpression { Method.IsGenericMethod: true } methodCallExpression + when methodCallExpression.Method.GetGenericMethodDefinition() == GetParameterValueMethodInfo: + var parameterName = methodCallExpression.Arguments[1].GetConstantValue(); + var lambda = Expression.Lambda( + Expression.Call( + ParameterListValueExtractorMethod.MakeGenericMethod(entityType.ClrType, property.ClrType.MakeNullable()), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(parameterName, typeof(string)), + Expression.Constant(property, typeof(IProperty))), + QueryCompilationContext.QueryContextParameter + ); + + var newParameterName = + $"{RuntimeParameterPrefix}" + + $"{parameterName[QueryCompilationContext.QueryParameterPrefix.Length..]}_{property.Name}"; + + rewrittenSource = _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); + break; + + default: + return false; + } + + result = Visit( + Expression.Call( + EnumerableMethods.Contains.MakeGenericMethod(property.ClrType.MakeNullable()), + rewrittenSource, + CreatePropertyAccessExpression(item, property))); + + return true; + } + + private bool TryRewriteEntityEquality( + ExpressionType nodeType, + Expression left, + Expression right, + bool equalsMethod, + [NotNullWhen(true)] out Expression? result) + { + var leftEntityReference = left is StructuralTypeReferenceExpression { StructuralType: IEntityType } l ? l : null; + var rightEntityReference = right is StructuralTypeReferenceExpression { StructuralType: IEntityType } r ? r : null; + + if (leftEntityReference == null + && rightEntityReference == null) + { + result = null; + return false; + } + + if (IsNullConstantExpression(left) + || IsNullConstantExpression(right)) + { + var nonNullEntityReference = (IsNullConstantExpression(left) ? rightEntityReference : leftEntityReference)!; + var entityType1 = (IEntityType)nonNullEntityReference.StructuralType; + var primaryKeyProperties1 = entityType1.FindPrimaryKey()?.Properties; + if (primaryKeyProperties1 == null) + { + throw new InvalidOperationException( + CoreStrings.EntityEqualityOnKeylessEntityNotSupported( + nodeType == ExpressionType.Equal + ? equalsMethod ? nameof(object.Equals) : "==" + : equalsMethod + ? "!" + nameof(object.Equals) + : "!=", + entityType1.DisplayName())); + } + + result = Visit( + primaryKeyProperties1.Select( + p => + Expression.MakeBinary( + nodeType, CreatePropertyAccessExpression(nonNullEntityReference, p), + Expression.Constant(null, p.ClrType.MakeNullable()))) + .Aggregate((l, r) => nodeType == ExpressionType.Equal ? Expression.OrElse(l, r) : Expression.AndAlso(l, r))); + + return true; + } + + var leftEntityType = (IEntityType?)leftEntityReference?.StructuralType; + var rightEntityType = (IEntityType?)rightEntityReference?.StructuralType; + var entityType = leftEntityType ?? rightEntityType; + + Check.DebugAssert(entityType != null, "At least either side should be entityReference so entityType should be non-null."); + + if (leftEntityType != null + && rightEntityType != null + && leftEntityType.GetRootType() != rightEntityType.GetRootType()) + { + result = Expression.Constant(false); + return true; + } + + var primaryKeyProperties = entityType.FindPrimaryKey()?.Properties; + if (primaryKeyProperties == null) + { + throw new InvalidOperationException( + CoreStrings.EntityEqualityOnKeylessEntityNotSupported( + nodeType == ExpressionType.Equal + ? equalsMethod ? nameof(object.Equals) : "==" + : equalsMethod + ? "!" + nameof(object.Equals) + : "!=", + entityType.DisplayName())); + } + + if (primaryKeyProperties.Count > 1 + && (leftEntityReference?.Subquery != null + || rightEntityReference?.Subquery != null)) + { + throw new InvalidOperationException( + CoreStrings.EntityEqualityOnCompositeKeyEntitySubqueryNotSupported( + nodeType == ExpressionType.Equal + ? equalsMethod ? nameof(object.Equals) : "==" + : equalsMethod + ? "!" + nameof(object.Equals) + : "!=", + entityType.DisplayName())); + } + + result = Visit( + primaryKeyProperties.Select( + p => + Expression.MakeBinary( + nodeType, + CreatePropertyAccessExpression(left, p), + CreatePropertyAccessExpression(right, p))) + .Aggregate( + (l, r) => nodeType == ExpressionType.Equal + ? Expression.AndAlso(l, r) + : Expression.OrElse(l, r))); + + return true; + } + + private Expression CreatePropertyAccessExpression(Expression target, IProperty property) + { + switch (target) + { + case ConstantExpression constantExpression: + return Expression.Constant( + constantExpression.Value is null + ? null + : property.GetGetter().GetClrValue(constantExpression.Value), + property.ClrType.MakeNullable()); + + case MethodCallExpression { Method.IsGenericMethod: true } methodCallExpression + when methodCallExpression.Method.GetGenericMethodDefinition() == GetParameterValueMethodInfo: + var parameterName = methodCallExpression.Arguments[1].GetConstantValue(); + var lambda = Expression.Lambda( + Expression.Call( + ParameterValueExtractorMethod.MakeGenericMethod(property.ClrType.MakeNullable()), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(parameterName, typeof(string)), + Expression.Constant(property, typeof(IProperty))), + QueryCompilationContext.QueryContextParameter); + + var newParameterName = + $"{RuntimeParameterPrefix}" + + $"{parameterName[QueryCompilationContext.QueryParameterPrefix.Length..]}_{property.Name}"; + + return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); + + case MemberInitExpression memberInitExpression + when memberInitExpression.Bindings.SingleOrDefault( + mb => mb.Member.Name == property.Name) is MemberAssignment memberAssignment: + return memberAssignment.Expression.Type.IsNullableType() + ? memberAssignment.Expression + : Expression.Convert(memberAssignment.Expression, property.ClrType.MakeNullable()); + + case NewExpression newExpression + when CanEvaluate(newExpression): + return CreatePropertyAccessExpression(GetValue(newExpression), property); + + case MemberInitExpression memberInitExpression + when CanEvaluate(memberInitExpression): + return CreatePropertyAccessExpression(GetValue(memberInitExpression), property); + + default: + return target.CreateEFPropertyExpression(property); + } + } + + private static T? ParameterValueExtractor(QueryContext context, string baseParameterName, IProperty property) + { + var baseParameter = context.ParameterValues[baseParameterName]; + return baseParameter == null ? (T?)(object?)null : (T?)property.GetGetter().GetClrValue(baseParameter); + } + + private static List? ParameterListValueExtractor( + QueryContext context, + string baseParameterName, + IProperty property) + { + if (!(context.ParameterValues[baseParameterName] is IEnumerable baseListParameter)) + { + return null; + } + + var getter = property.GetGetter(); + return baseListParameter.Select(e => e != null ? (TProperty?)getter.GetClrValue(e) : (TProperty?)(object?)null).ToList(); + } + + private static ConstantExpression GetValue(Expression expression) + => Expression.Constant( + Expression.Lambda>(Expression.Convert(expression, typeof(object))) + .Compile(preferInterpretation: true) + .Invoke(), + expression.Type); + + private static bool CanEvaluate(Expression expression) + { +#pragma warning disable IDE0066 // Convert switch statement to expression + switch (expression) +#pragma warning restore IDE0066 // Convert switch statement to expression + { + case ConstantExpression: + return true; + + case NewExpression newExpression: + return newExpression.Arguments.All(CanEvaluate); + + case MemberInitExpression memberInitExpression: + return CanEvaluate(memberInitExpression.NewExpression) + && memberInitExpression.Bindings.All( + mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression)); + + default: + return false; + } + } + + private static Expression ConvertObjectArrayEqualityComparison(Expression left, Expression right) + { + var leftExpressions = ((NewArrayExpression)left).Expressions; + var rightExpressions = ((NewArrayExpression)right).Expressions; + + return leftExpressions.Zip( + rightExpressions, + (l, r) => + { + l = RemoveObjectConvert(l); + r = RemoveObjectConvert(r); + if (l.Type.IsNullableType()) + { + r = r.Type.IsNullableType() ? r : Expression.Convert(r, l.Type); + } + else if (r.Type.IsNullableType()) + { + l = l.Type.IsNullableType() ? l : Expression.Convert(l, r.Type); + } + + return ExpressionExtensions.CreateEqualsExpression(l, r); + }) + .Aggregate((a, b) => Expression.AndAlso(a, b)); + + static Expression RemoveObjectConvert(Expression expression) + => expression is UnaryExpression unaryExpression + && expression.Type == typeof(object) + && expression.NodeType == ExpressionType.Convert + ? unaryExpression.Operand + : expression; + } + + private static bool IsNullConstantExpression(Expression expression) + => expression is ConstantExpression { Value: null }; + + [DebuggerStepThrough] + private static bool TranslationFailed(Expression? original, Expression? translation) + => original != null + && (translation == QueryCompilationContext.NotTranslatedExpression || translation is StructuralTypeReferenceExpression); + + private static bool KafkaLike(string matchExpression, string pattern, string escapeCharacter) + { + //TODO: this fixes https://github.com/aspnet/EntityFramework/issues/8656 by insisting that + // the "escape character" is a string but just using the first character of that string, + // but we may later want to allow the complete string as the "escape character" + // in which case we need to change the way we construct the regex below. + var singleEscapeCharacter = + (escapeCharacter == null || escapeCharacter.Length == 0) + ? (char?)null + : escapeCharacter.First(); + + if (matchExpression == null + || pattern == null) + { + return false; + } + + if (matchExpression.Equals(pattern, StringComparison.OrdinalIgnoreCase)) + { + return true; + } + + if (matchExpression.Length == 0 + || pattern.Length == 0) + { + return false; + } + + var escapeRegexCharsPattern + = singleEscapeCharacter == null + ? DefaultEscapeRegexCharsPattern + : BuildEscapeRegexCharsPattern(RegexSpecialChars.Where(c => c != singleEscapeCharacter)); + + var regexPattern + = Regex.Replace( + pattern, + escapeRegexCharsPattern, + c => @"\" + c, + default, + RegexTimeout); + + var stringBuilder = new StringBuilder(); + + for (var i = 0; i < regexPattern.Length; i++) + { + var c = regexPattern[i]; + var escaped = i > 0 && regexPattern[i - 1] == singleEscapeCharacter; + + switch (c) + { + case '_': + { + stringBuilder.Append(escaped ? '_' : '.'); + break; + } + case '%': + { + stringBuilder.Append(escaped ? "%" : ".*"); + break; + } + default: + { + if (c != singleEscapeCharacter) + { + stringBuilder.Append(c); + } + + break; + } + } + } + + regexPattern = stringBuilder.ToString(); + + return Regex.IsMatch( + matchExpression, + @"\A" + regexPattern + @"\s*\z", + RegexOptions.IgnoreCase | RegexOptions.Singleline, + RegexTimeout); + } + + private sealed class EntityReferenceFindingExpressionVisitor : ExpressionVisitor + { + private bool _found; + + public bool Find(Expression expression) + { + _found = false; + + Visit(expression); + + return _found; + } + + [return: NotNullIfNotNull("expression")] + public override Expression? Visit(Expression? expression) + { + if (_found) + { + return expression; + } + + if (expression is StructuralTypeReferenceExpression) + { + _found = true; + return expression; + } + + return base.Visit(expression); + } + } + + private sealed class StructuralTypeReferenceExpression : Expression + { + public StructuralTypeReferenceExpression(StructuralTypeShaperExpression parameter) + { + Parameter = parameter; + StructuralType = parameter.StructuralType; + } + + public StructuralTypeReferenceExpression(ShapedQueryExpression subquery) + { + Subquery = subquery; + StructuralType = ((StructuralTypeShaperExpression)subquery.ShaperExpression).StructuralType; + } + + private StructuralTypeReferenceExpression(StructuralTypeReferenceExpression typeReference, IEntityType type) + { + Parameter = typeReference.Parameter; + Subquery = typeReference.Subquery; + StructuralType = type; + } + + public new StructuralTypeShaperExpression? Parameter { get; } + public ShapedQueryExpression? Subquery { get; } + public ITypeBase StructuralType { get; } + + public override Type Type + => StructuralType.ClrType; + + public override ExpressionType NodeType + => ExpressionType.Extension; + + public Expression Convert(Type type) + { + if (type == typeof(object) // Ignore object conversion + || type.IsAssignableFrom(Type)) // Ignore casting to base type/interface + { + return this; + } + + return StructuralType is IEntityType entityType + && entityType.GetDerivedTypes().FirstOrDefault(et => et.ClrType == type) is IEntityType derivedEntityType + ? new StructuralTypeReferenceExpression(this, derivedEntityType) + : QueryCompilationContext.NotTranslatedExpression; + } + } +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaProjectionBindingExpressionVisitor.cs b/src/net/KEFCore/Query/Internal8/KafkaProjectionBindingExpressionVisitor.cs new file mode 100644 index 00000000..a1ec6cbc --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaProjectionBindingExpressionVisitor.cs @@ -0,0 +1,532 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class KafkaProjectionBindingExpressionVisitor : ExpressionVisitor +{ + private readonly KafkaQueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor; + private readonly KafkaExpressionTranslatingExpressionVisitor _expressionTranslatingExpressionVisitor; + + private KafkaQueryExpression _queryExpression; + private bool _indexBasedBinding; + + private Dictionary? _entityProjectionCache; + + private readonly Dictionary _projectionMapping = new(); + private List? _clientProjections; + private readonly Stack _projectionMembers = new(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public KafkaProjectionBindingExpressionVisitor( + KafkaQueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor, + KafkaExpressionTranslatingExpressionVisitor expressionTranslatingExpressionVisitor) + { + _queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor; + _expressionTranslatingExpressionVisitor = expressionTranslatingExpressionVisitor; + _queryExpression = null!; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression Translate(KafkaQueryExpression queryExpression, Expression expression) + { + _queryExpression = queryExpression; + _indexBasedBinding = false; + + _projectionMembers.Push(new ProjectionMember()); + var result = Visit(expression); + + if (result == QueryCompilationContext.NotTranslatedExpression) + { + _indexBasedBinding = true; + _projectionMapping.Clear(); + _entityProjectionCache = new Dictionary(); + _clientProjections = new List(); + + result = Visit(expression); + + _queryExpression.ReplaceProjection(_clientProjections); + _clientProjections = null; + } + else + { + _queryExpression.ReplaceProjection(_projectionMapping); + _projectionMapping.Clear(); + } + + _queryExpression = null!; + _projectionMembers.Clear(); + result = MatchTypes(result, expression.Type); + + return result; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [return: NotNullIfNotNull("expression")] + public override Expression? Visit(Expression? expression) + { + if (expression == null) + { + return null; + } + + if (expression is not (NewExpression or MemberInitExpression or StructuralTypeShaperExpression or IncludeExpression)) + { + if (_indexBasedBinding) + { + switch (expression) + { + case ConstantExpression: + return expression; + + case ProjectionBindingExpression projectionBindingExpression: + var mappedProjection = _queryExpression.GetProjection(projectionBindingExpression); + if (mappedProjection is EntityProjectionExpression entityProjection) + { + return AddClientProjection(entityProjection, typeof(ValueBuffer)); + } + + if (mappedProjection is not KafkaQueryExpression) + { + return AddClientProjection(mappedProjection, expression.Type.MakeNullable()); + } + + throw new InvalidOperationException(CoreStrings.TranslationFailed(projectionBindingExpression.Print())); + + case MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression: + { + var subquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery( + materializeCollectionNavigationExpression.Subquery)!; + _clientProjections!.Add(subquery.QueryExpression); + return new CollectionResultShaperExpression( + new ProjectionBindingExpression( + _queryExpression, _clientProjections.Count - 1, typeof(IEnumerable)), + subquery.ShaperExpression, + materializeCollectionNavigationExpression.Navigation, + materializeCollectionNavigationExpression.Navigation.ClrType.GetSequenceType()); + } + + case MethodCallExpression methodCallExpression: + if (methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.DeclaringType == typeof(Enumerable) + && methodCallExpression.Method.Name == nameof(Enumerable.ToList) + && methodCallExpression.Arguments.Count == 1 + && methodCallExpression.Arguments[0].Type.TryGetElementType(typeof(IQueryable<>)) != null) + { + var subquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery( + methodCallExpression.Arguments[0]); + if (subquery != null) + { + _clientProjections!.Add(subquery.QueryExpression); + return new CollectionResultShaperExpression( + new ProjectionBindingExpression( + _queryExpression, _clientProjections.Count - 1, typeof(IEnumerable)), + subquery.ShaperExpression, + null, + methodCallExpression.Method.GetGenericArguments()[0]); + } + } + else + { + var subquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression); + if (subquery != null) + { + // This simplifies the check when subquery is translated and can be lifted as scalar. + var scalarTranslation = _expressionTranslatingExpressionVisitor.Translate(subquery); + if (scalarTranslation != null) + { + return AddClientProjection(scalarTranslation, expression.Type.MakeNullable()); + } + + if (subquery.ResultCardinality == ResultCardinality.Enumerable) + { + _clientProjections!.Add(subquery.QueryExpression); + var projectionBindingExpression = new ProjectionBindingExpression( + _queryExpression, _clientProjections.Count - 1, typeof(IEnumerable)); + return new CollectionResultShaperExpression( + projectionBindingExpression, subquery.ShaperExpression, navigation: null, + subquery.ShaperExpression.Type); + } + else + { + _clientProjections!.Add(subquery.QueryExpression); + var projectionBindingExpression = new ProjectionBindingExpression( + _queryExpression, _clientProjections.Count - 1, typeof(ValueBuffer)); + return new SingleResultShaperExpression(projectionBindingExpression, subquery.ShaperExpression); + } + } + } + + break; + } + + var translation = _expressionTranslatingExpressionVisitor.Translate(expression); + return translation != null + ? AddClientProjection(translation, expression.Type.MakeNullable()) + : base.Visit(expression); + } + else + { + var translation = _expressionTranslatingExpressionVisitor.Translate(expression); + if (translation == null) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + _projectionMapping[_projectionMembers.Peek()] = translation; + + return new ProjectionBindingExpression(_queryExpression, _projectionMembers.Peek(), expression.Type.MakeNullable()); + } + } + + return base.Visit(expression); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitBinary(BinaryExpression binaryExpression) + { + var left = MatchTypes(Visit(binaryExpression.Left), binaryExpression.Left.Type); + var right = MatchTypes(Visit(binaryExpression.Right), binaryExpression.Right.Type); + + return binaryExpression.Update(left, VisitAndConvert(binaryExpression.Conversion, "VisitBinary"), right); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitConditional(ConditionalExpression conditionalExpression) + { + var test = Visit(conditionalExpression.Test); + var ifTrue = Visit(conditionalExpression.IfTrue); + var ifFalse = Visit(conditionalExpression.IfFalse); + + if (test.Type == typeof(bool?)) + { + test = Expression.Equal(test, Expression.Constant(true, typeof(bool?))); + } + + ifTrue = MatchTypes(ifTrue, conditionalExpression.IfTrue.Type); + ifFalse = MatchTypes(ifFalse, conditionalExpression.IfFalse.Type); + + return conditionalExpression.Update(test, ifTrue, ifFalse); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitExtension(Expression extensionExpression) + { + if (extensionExpression is StructuralTypeShaperExpression shaper) + { + EntityProjectionExpression entityProjectionExpression; + if (shaper.ValueBufferExpression is ProjectionBindingExpression projectionBindingExpression) + { + entityProjectionExpression = + (EntityProjectionExpression)((KafkaQueryExpression)projectionBindingExpression.QueryExpression) + .GetProjection(projectionBindingExpression); + } + else + { + entityProjectionExpression = (EntityProjectionExpression)shaper.ValueBufferExpression; + } + + if (_indexBasedBinding) + { + if (!_entityProjectionCache!.TryGetValue(entityProjectionExpression, out var entityProjectionBinding)) + { + entityProjectionBinding = AddClientProjection(entityProjectionExpression, typeof(ValueBuffer)); + _entityProjectionCache[entityProjectionExpression] = entityProjectionBinding; + } + + return shaper.Update(entityProjectionBinding); + } + + _projectionMapping[_projectionMembers.Peek()] = entityProjectionExpression; + + return shaper.Update( + new ProjectionBindingExpression(_queryExpression, _projectionMembers.Peek(), typeof(ValueBuffer))); + } + + if (extensionExpression is IncludeExpression includeExpression) + { + return _indexBasedBinding + ? base.VisitExtension(includeExpression) + : QueryCompilationContext.NotTranslatedExpression; + } + + throw new InvalidOperationException(CoreStrings.TranslationFailed(extensionExpression.Print())); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ElementInit VisitElementInit(ElementInit elementInit) + => elementInit.Update(elementInit.Arguments.Select(e => MatchTypes(Visit(e), e.Type))); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitMember(MemberExpression memberExpression) + { + var expression = Visit(memberExpression.Expression); + Expression updatedMemberExpression = memberExpression.Update( + expression != null ? MatchTypes(expression, memberExpression.Expression!.Type) : expression); + + if (expression?.Type.IsNullableValueType() == true) + { + var nullableReturnType = memberExpression.Type.MakeNullable(); + if (!memberExpression.Type.IsNullableType()) + { + updatedMemberExpression = Expression.Convert(updatedMemberExpression, nullableReturnType); + } + + updatedMemberExpression = Expression.Condition( + Expression.Equal(expression, Expression.Default(expression.Type)), + Expression.Constant(null, nullableReturnType), + updatedMemberExpression); + } + + return updatedMemberExpression; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment) + { + var expression = memberAssignment.Expression; + Expression? visitedExpression; + if (_indexBasedBinding) + { + visitedExpression = Visit(memberAssignment.Expression); + } + else + { + var projectionMember = _projectionMembers.Peek().Append(memberAssignment.Member); + _projectionMembers.Push(projectionMember); + + visitedExpression = Visit(memberAssignment.Expression); + if (visitedExpression == QueryCompilationContext.NotTranslatedExpression) + { + return memberAssignment.Update(Expression.Convert(visitedExpression, memberAssignment.Expression.Type)); + } + + _projectionMembers.Pop(); + } + + visitedExpression = MatchTypes(visitedExpression, expression.Type); + + return memberAssignment.Update(visitedExpression); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) + { + var newExpression = Visit(memberInitExpression.NewExpression); + if (newExpression == QueryCompilationContext.NotTranslatedExpression) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + var newBindings = new MemberBinding[memberInitExpression.Bindings.Count]; + for (var i = 0; i < newBindings.Length; i++) + { + if (memberInitExpression.Bindings[i].BindingType != MemberBindingType.Assignment) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + newBindings[i] = VisitMemberBinding(memberInitExpression.Bindings[i]); + if (((MemberAssignment)newBindings[i]).Expression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression + && unaryExpression.Operand == QueryCompilationContext.NotTranslatedExpression) + { + return QueryCompilationContext.NotTranslatedExpression; + } + } + + return memberInitExpression.Update((NewExpression)newExpression, newBindings); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + var @object = Visit(methodCallExpression.Object); + var arguments = new Expression[methodCallExpression.Arguments.Count]; + for (var i = 0; i < methodCallExpression.Arguments.Count; i++) + { + var argument = methodCallExpression.Arguments[i]; + arguments[i] = MatchTypes(Visit(argument), argument.Type); + } + + Expression updatedMethodCallExpression = methodCallExpression.Update( + @object != null ? MatchTypes(@object, methodCallExpression.Object!.Type) : @object!, + arguments); + + if (@object?.Type.IsNullableType() == true + && !methodCallExpression.Object!.Type.IsNullableType()) + { + var nullableReturnType = methodCallExpression.Type.MakeNullable(); + if (!methodCallExpression.Type.IsNullableType()) + { + updatedMethodCallExpression = Expression.Convert(updatedMethodCallExpression, nullableReturnType); + } + + return Expression.Condition( + Expression.Equal(@object, Expression.Default(@object.Type)), + Expression.Constant(null, nullableReturnType), + updatedMethodCallExpression); + } + + return updatedMethodCallExpression; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitNew(NewExpression newExpression) + { + if (newExpression.Arguments.Count == 0) + { + return newExpression; + } + + if (!_indexBasedBinding + && newExpression.Members == null) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + var newArguments = new Expression[newExpression.Arguments.Count]; + for (var i = 0; i < newArguments.Length; i++) + { + var argument = newExpression.Arguments[i]; + Expression? visitedArgument; + if (_indexBasedBinding) + { + visitedArgument = Visit(argument); + } + else + { + var projectionMember = _projectionMembers.Peek().Append(newExpression.Members![i]); + _projectionMembers.Push(projectionMember); + visitedArgument = Visit(argument); + if (visitedArgument == QueryCompilationContext.NotTranslatedExpression) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + _projectionMembers.Pop(); + } + + newArguments[i] = MatchTypes(visitedArgument, argument.Type); + } + + return newExpression.Update(newArguments); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitNewArray(NewArrayExpression newArrayExpression) + => newArrayExpression.Update(newArrayExpression.Expressions.Select(e => MatchTypes(Visit(e), e.Type))); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitUnary(UnaryExpression unaryExpression) + { + var operand = Visit(unaryExpression.Operand); + + return unaryExpression.NodeType is ExpressionType.Convert or ExpressionType.ConvertChecked + && unaryExpression.Type == operand.Type + ? operand + : unaryExpression.Update(MatchTypes(operand, unaryExpression.Operand.Type)); + } + + private static Expression MatchTypes(Expression expression, Type targetType) + { + if (targetType != expression.Type + && targetType.TryGetElementType(typeof(IQueryable<>)) == null) + { + Check.DebugAssert(targetType.MakeNullable() == expression.Type, "Not a nullable to non-nullable conversion"); + + expression = Expression.Convert(expression, targetType); + } + + return expression; + } + + private ProjectionBindingExpression AddClientProjection(Expression expression, Type type) + { + var existingIndex = _clientProjections!.FindIndex(e => e.Equals(expression)); + if (existingIndex == -1) + { + _clientProjections.Add(expression); + existingIndex = _clientProjections.Count - 1; + } + + return new ProjectionBindingExpression(_queryExpression, existingIndex, type); + } +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaQueryContext.cs b/src/net/KEFCore/Query/Internal8/KafkaQueryContext.cs new file mode 100644 index 00000000..a53856e2 --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaQueryContext.cs @@ -0,0 +1,48 @@ +/* +* Copyright 2023 MASES s.r.l. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +* Refer to LICENSE for more information. +*/ + +using MASES.EntityFrameworkCore.KNet.Storage.Internal; + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class KafkaQueryContext : QueryContext +{ + private readonly IKafkaCluster _cluster; + /// + /// Retrieve for the specified + /// + /// + /// + public virtual IEnumerable GetValueBuffers(IEntityType entityType) + { + return _cluster.GetValueBuffers(entityType); + } + /// + /// Default initializer + /// + public KafkaQueryContext(QueryContextDependencies dependencies, IKafkaCluster cluster) + : base(dependencies) + { + _cluster = cluster; + } +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaQueryContextFactory.cs b/src/net/KEFCore/Query/Internal8/KafkaQueryContextFactory.cs new file mode 100644 index 00000000..20ce8d0b --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaQueryContextFactory.cs @@ -0,0 +1,49 @@ +/* +* Copyright 2023 MASES s.r.l. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +* Refer to LICENSE for more information. +*/ + +using MASES.EntityFrameworkCore.KNet.Storage.Internal; + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class KafkaQueryContextFactory : IQueryContextFactory +{ + private readonly IKafkaCluster _cluster; + /// + /// Default initializer + /// + public KafkaQueryContextFactory( + QueryContextDependencies dependencies, + IKafkaClusterCache clusterCache, + IDbContextOptions contextOptions) + { + _cluster = clusterCache.GetCluster(contextOptions); + Dependencies = dependencies; + } + + /// + /// Dependencies for this service. + /// + protected virtual QueryContextDependencies Dependencies { get; } + /// + public virtual QueryContext Create() => new KafkaQueryContext(Dependencies, _cluster); +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaQueryExpression.Helper.cs b/src/net/KEFCore/Query/Internal8/KafkaQueryExpression.Helper.cs new file mode 100644 index 00000000..66a371a3 --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaQueryExpression.Helper.cs @@ -0,0 +1,230 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Diagnostics.CodeAnalysis; + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +public partial class KafkaQueryExpression +{ + private sealed class ResultEnumerable : IEnumerable + { + private readonly Func _getElement; + + public ResultEnumerable(Func getElement) + { + _getElement = getElement; + } + + public IEnumerator GetEnumerator() + => new ResultEnumerator(_getElement()); + + IEnumerator IEnumerable.GetEnumerator() + => GetEnumerator(); + + private sealed class ResultEnumerator : IEnumerator + { + private readonly ValueBuffer _value; + private bool _moved; + + public ResultEnumerator(ValueBuffer value) + { + _value = value; + _moved = _value.IsEmpty; + } + + public bool MoveNext() + { + if (!_moved) + { + _moved = true; + + return _moved; + } + + return false; + } + + public void Reset() + => _moved = false; + + object IEnumerator.Current + => Current; + + public ValueBuffer Current + => !_moved ? ValueBuffer.Empty : _value; + + void IDisposable.Dispose() + { + } + } + } + + private sealed class ProjectionMemberRemappingExpressionVisitor : ExpressionVisitor + { + private readonly Expression _queryExpression; + private readonly Dictionary _projectionMemberMappings; + + public ProjectionMemberRemappingExpressionVisitor( + Expression queryExpression, + Dictionary projectionMemberMappings) + { + _queryExpression = queryExpression; + _projectionMemberMappings = projectionMemberMappings; + } + + [return: NotNullIfNotNull("expression")] + public override Expression? Visit(Expression? expression) + { + if (expression is ProjectionBindingExpression projectionBindingExpression) + { + Check.DebugAssert( + projectionBindingExpression.ProjectionMember != null, + "ProjectionBindingExpression must have projection member."); + + return new ProjectionBindingExpression( + _queryExpression, + _projectionMemberMappings[projectionBindingExpression.ProjectionMember], + projectionBindingExpression.Type); + } + + return base.Visit(expression); + } + } + + private sealed class ProjectionMemberToIndexConvertingExpressionVisitor : ExpressionVisitor + { + private readonly Expression _queryExpression; + private readonly Dictionary _projectionMemberMappings; + + public ProjectionMemberToIndexConvertingExpressionVisitor( + Expression queryExpression, + Dictionary projectionMemberMappings) + { + _queryExpression = queryExpression; + _projectionMemberMappings = projectionMemberMappings; + } + + [return: NotNullIfNotNull("expression")] + public override Expression? Visit(Expression? expression) + { + if (expression is ProjectionBindingExpression projectionBindingExpression) + { + Check.DebugAssert( + projectionBindingExpression.ProjectionMember != null, + "ProjectionBindingExpression must have projection member."); + + return new ProjectionBindingExpression( + _queryExpression, + _projectionMemberMappings[projectionBindingExpression.ProjectionMember], + projectionBindingExpression.Type); + } + + return base.Visit(expression); + } + } + + private sealed class ProjectionIndexRemappingExpressionVisitor : ExpressionVisitor + { + private readonly Expression _oldExpression; + private readonly Expression _newExpression; + private readonly int[] _indexMap; + + public ProjectionIndexRemappingExpressionVisitor( + Expression oldExpression, + Expression newExpression, + int[] indexMap) + { + _oldExpression = oldExpression; + _newExpression = newExpression; + _indexMap = indexMap; + } + + [return: NotNullIfNotNull("expression")] + public override Expression? Visit(Expression? expression) + { + if (expression is ProjectionBindingExpression projectionBindingExpression + && ReferenceEquals(projectionBindingExpression.QueryExpression, _oldExpression)) + { + Check.DebugAssert( + projectionBindingExpression.Index != null, + "ProjectionBindingExpression must have index."); + + return new ProjectionBindingExpression( + _newExpression, + _indexMap[projectionBindingExpression.Index.Value], + projectionBindingExpression.Type); + } + + return base.Visit(expression); + } + } + + private sealed class EntityShaperNullableMarkingExpressionVisitor : ExpressionVisitor + { + protected override Expression VisitExtension(Expression extensionExpression) + => extensionExpression is StructuralTypeShaperExpression shaper + ? shaper.MakeNullable() + : base.VisitExtension(extensionExpression); + } + + private sealed class QueryExpressionReplacingExpressionVisitor : ExpressionVisitor + { + private readonly Expression _oldQuery; + private readonly Expression _newQuery; + + public QueryExpressionReplacingExpressionVisitor(Expression oldQuery, Expression newQuery) + { + _oldQuery = oldQuery; + _newQuery = newQuery; + } + + [return: NotNullIfNotNull("expression")] + public override Expression? Visit(Expression? expression) + => expression is ProjectionBindingExpression projectionBindingExpression + && ReferenceEquals(projectionBindingExpression.QueryExpression, _oldQuery) + ? projectionBindingExpression.ProjectionMember != null + ? new ProjectionBindingExpression( + _newQuery, projectionBindingExpression.ProjectionMember!, projectionBindingExpression.Type) + : new ProjectionBindingExpression( + _newQuery, projectionBindingExpression.Index!.Value, projectionBindingExpression.Type) + : base.Visit(expression); + } + + private sealed class CloningExpressionVisitor : ExpressionVisitor + { + [return: NotNullIfNotNull("expression")] + public override Expression? Visit(Expression? expression) + { + if (expression is KafkaQueryExpression kafkaQueryExpression) + { + var clonedKafkaQueryExpression = new KafkaQueryExpression( + kafkaQueryExpression.ServerQueryExpression, kafkaQueryExpression._valueBufferParameter) + { + _groupingParameter = kafkaQueryExpression._groupingParameter, + _singleResultMethodInfo = kafkaQueryExpression._singleResultMethodInfo, + _scalarServerQuery = kafkaQueryExpression._scalarServerQuery + }; + + clonedKafkaQueryExpression._clientProjections.AddRange( + kafkaQueryExpression._clientProjections.Select(e => Visit(e))); + clonedKafkaQueryExpression._projectionMappingExpressions.AddRange( + kafkaQueryExpression._projectionMappingExpressions); + foreach (var (projectionMember, value) in kafkaQueryExpression._projectionMapping) + { + clonedKafkaQueryExpression._projectionMapping[projectionMember] = Visit(value); + } + + return clonedKafkaQueryExpression; + } + + if (expression is EntityProjectionExpression entityProjectionExpression) + { + return entityProjectionExpression.Clone(); + } + + return base.Visit(expression); + } + } +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaQueryExpression.cs b/src/net/KEFCore/Query/Internal8/KafkaQueryExpression.cs new file mode 100644 index 00000000..f4042e62 --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaQueryExpression.cs @@ -0,0 +1,1312 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using MASES.EntityFrameworkCore.KNet.Internal; +using ExpressionExtensions = Microsoft.EntityFrameworkCore.Infrastructure.ExpressionExtensions; + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public partial class KafkaQueryExpression : Expression, IPrintableExpression +{ + private static readonly ConstructorInfo ValueBufferConstructor + = typeof(ValueBuffer).GetConstructors().Single(ci => ci.GetParameters().Length == 1); + + private static readonly PropertyInfo ValueBufferCountMemberInfo + = typeof(ValueBuffer).GetTypeInfo().GetProperty(nameof(ValueBuffer.Count))!; + + private static readonly MethodInfo LeftJoinMethodInfo = typeof(KafkaQueryExpression).GetTypeInfo() + .GetDeclaredMethods(nameof(LeftJoin)).Single(mi => mi.GetParameters().Length == 7); + + private static readonly ConstructorInfo ResultEnumerableConstructor + = typeof(ResultEnumerable).GetConstructors().Single(); + + private readonly ParameterExpression _valueBufferParameter; + private ParameterExpression? _groupingParameter; + private MethodInfo? _singleResultMethodInfo; + private bool _scalarServerQuery; + + private CloningExpressionVisitor? _cloningExpressionVisitor; + + private Dictionary _projectionMapping = new(); + private readonly List _clientProjections = new(); + private readonly List _projectionMappingExpressions = new(); + + private KafkaQueryExpression( + Expression serverQueryExpression, + ParameterExpression valueBufferParameter) + { + ServerQueryExpression = serverQueryExpression; + _valueBufferParameter = valueBufferParameter; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public KafkaQueryExpression(IEntityType entityType) + { + _valueBufferParameter = Parameter(typeof(ValueBuffer), "valueBuffer"); + ServerQueryExpression = new KafkaTableExpression(entityType); + var propertyExpressionsMap = new Dictionary(); + var selectorExpressions = new List(); + foreach (var property in entityType.GetAllBaseTypesInclusive().SelectMany(et => et.GetDeclaredProperties())) + { + var propertyExpression = CreateReadValueExpression(property.ClrType, property.GetIndex(), property); + selectorExpressions.Add(propertyExpression); + + Check.DebugAssert( + property.GetIndex() == selectorExpressions.Count - 1, + "Properties should be ordered in same order as their indexes."); + propertyExpressionsMap[property] = propertyExpression; + _projectionMappingExpressions.Add(propertyExpression); + } + + var discriminatorProperty = entityType.FindDiscriminatorProperty(); + if (discriminatorProperty != null) + { + var keyValueComparer = discriminatorProperty.GetKeyValueComparer(); + foreach (var derivedEntityType in entityType.GetDerivedTypes()) + { + var entityCheck = derivedEntityType.GetConcreteDerivedTypesInclusive() + .Select( + e => keyValueComparer.ExtractEqualsBody( + propertyExpressionsMap[discriminatorProperty], + Constant(e.GetDiscriminatorValue(), discriminatorProperty.ClrType))) + .Aggregate((l, r) => OrElse(l, r)); + + foreach (var property in derivedEntityType.GetDeclaredProperties()) + { + // We read nullable value from property of derived type since it could be null. + var typeToRead = property.ClrType.MakeNullable(); + var propertyExpression = Condition( + entityCheck, + CreateReadValueExpression(typeToRead, property.GetIndex(), property), + Default(typeToRead)); + + selectorExpressions.Add(propertyExpression); + var readExpression = CreateReadValueExpression(propertyExpression.Type, selectorExpressions.Count - 1, property); + propertyExpressionsMap[property] = readExpression; + _projectionMappingExpressions.Add(readExpression); + } + } + + // Force a selector if entity projection has complex expressions. + var selectorLambda = Lambda( + New( + ValueBufferConstructor, + NewArrayInit( + typeof(object), + selectorExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e))), + CurrentParameter); + + ServerQueryExpression = Call( + EnumerableMethods.Select.MakeGenericMethod(typeof(ValueBuffer), typeof(ValueBuffer)), + ServerQueryExpression, + selectorLambda); + } + + var entityProjection = new EntityProjectionExpression(entityType, propertyExpressionsMap); + _projectionMapping[new ProjectionMember()] = entityProjection; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression ServerQueryExpression { get; private set; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual ParameterExpression CurrentParameter + => _groupingParameter ?? _valueBufferParameter; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void ReplaceProjection(IReadOnlyList clientProjections) + { + _projectionMapping.Clear(); + _projectionMappingExpressions.Clear(); + _clientProjections.Clear(); + _clientProjections.AddRange(clientProjections); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void ReplaceProjection(IReadOnlyDictionary projectionMapping) + { + _projectionMapping.Clear(); + _projectionMappingExpressions.Clear(); + _clientProjections.Clear(); + var selectorExpressions = new List(); + foreach (var (projectionMember, expression) in projectionMapping) + { + if (expression is EntityProjectionExpression entityProjectionExpression) + { + _projectionMapping[projectionMember] = AddEntityProjection(entityProjectionExpression); + } + else + { + selectorExpressions.Add(expression); + var readExpression = CreateReadValueExpression( + expression.Type, selectorExpressions.Count - 1, InferPropertyFromInner(expression)); + _projectionMapping[projectionMember] = readExpression; + _projectionMappingExpressions.Add(readExpression); + } + } + + if (selectorExpressions.Count == 0) + { + // No server correlated term in projection so add dummy 1. + selectorExpressions.Add(Constant(1)); + } + + var selectorLambda = Lambda( + New( + ValueBufferConstructor, + NewArrayInit( + typeof(object), + selectorExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e).ToArray())), + CurrentParameter); + + ServerQueryExpression = Call( + EnumerableMethods.Select.MakeGenericMethod(CurrentParameter.Type, typeof(ValueBuffer)), + ServerQueryExpression, + selectorLambda); + + _groupingParameter = null; + + EntityProjectionExpression AddEntityProjection(EntityProjectionExpression entityProjectionExpression) + { + var readExpressionMap = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType)) + { + var expression = entityProjectionExpression.BindProperty(property); + selectorExpressions.Add(expression); + var newExpression = CreateReadValueExpression(expression.Type, selectorExpressions.Count - 1, property); + readExpressionMap[property] = newExpression; + _projectionMappingExpressions.Add(newExpression); + } + + var result = new EntityProjectionExpression(entityProjectionExpression.EntityType, readExpressionMap); + + // Also compute nested entity projections + foreach (var navigation in entityProjectionExpression.EntityType.GetAllBaseTypes() + .Concat(entityProjectionExpression.EntityType.GetDerivedTypesInclusive()) + .SelectMany(t => t.GetDeclaredNavigations())) + { + var boundEntityShaperExpression = entityProjectionExpression.BindNavigation(navigation); + if (boundEntityShaperExpression != null) + { + var innerEntityProjection = (EntityProjectionExpression)boundEntityShaperExpression.ValueBufferExpression; + var newInnerEntityProjection = AddEntityProjection(innerEntityProjection); + boundEntityShaperExpression = boundEntityShaperExpression.Update(newInnerEntityProjection); + result.AddNavigationBinding(navigation, boundEntityShaperExpression); + } + } + + return result; + } + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression GetProjection(ProjectionBindingExpression projectionBindingExpression) + => projectionBindingExpression.ProjectionMember != null + ? _projectionMapping[projectionBindingExpression.ProjectionMember] + : _clientProjections[projectionBindingExpression.Index!.Value]; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void ApplyProjection() + { + if (_scalarServerQuery) + { + _projectionMapping[new ProjectionMember()] = Constant(0); + return; + } + + var selectorExpressions = new List(); + if (_clientProjections.Count > 0) + { + for (var i = 0; i < _clientProjections.Count; i++) + { + var projection = _clientProjections[i]; + switch (projection) + { + case EntityProjectionExpression entityProjectionExpression: + { + var indexMap = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType)) + { + selectorExpressions.Add(entityProjectionExpression.BindProperty(property)); + indexMap[property] = selectorExpressions.Count - 1; + } + + _clientProjections[i] = Constant(indexMap); + break; + } + + case KafkaQueryExpression kafkaQueryExpression: + { + var singleResult = kafkaQueryExpression._scalarServerQuery + || kafkaQueryExpression._singleResultMethodInfo != null; + kafkaQueryExpression.ApplyProjection(); + var serverQuery = kafkaQueryExpression.ServerQueryExpression; + if (singleResult) + { + serverQuery = ((LambdaExpression)((NewExpression)serverQuery).Arguments[0]).Body; + } + + selectorExpressions.Add(serverQuery); + _clientProjections[i] = Constant(selectorExpressions.Count - 1); + break; + } + + default: + selectorExpressions.Add(projection); + _clientProjections[i] = Constant(selectorExpressions.Count - 1); + break; + } + } + } + else + { + var newProjectionMapping = new Dictionary(); + foreach (var (projectionMember, expression) in _projectionMapping) + { + if (expression is EntityProjectionExpression entityProjectionExpression) + { + var indexMap = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType)) + { + selectorExpressions.Add(entityProjectionExpression.BindProperty(property)); + indexMap[property] = selectorExpressions.Count - 1; + } + + newProjectionMapping[projectionMember] = Constant(indexMap); + } + else + { + selectorExpressions.Add(expression); + newProjectionMapping[projectionMember] = Constant(selectorExpressions.Count - 1); + } + } + + _projectionMapping = newProjectionMapping; + _projectionMappingExpressions.Clear(); + } + + var selectorLambda = Lambda( + New( + ValueBufferConstructor, + NewArrayInit( + typeof(object), + selectorExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e).ToArray())), + CurrentParameter); + + ServerQueryExpression = Call( + EnumerableMethods.Select.MakeGenericMethod(CurrentParameter.Type, typeof(ValueBuffer)), + ServerQueryExpression, + selectorLambda); + + _groupingParameter = null; + + if (_singleResultMethodInfo != null) + { + ServerQueryExpression = Call( + _singleResultMethodInfo.MakeGenericMethod(CurrentParameter.Type), + ServerQueryExpression); + + ConvertToEnumerable(); + + _singleResultMethodInfo = null; + } + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void UpdateServerQueryExpression(Expression serverQueryExpression) + => ServerQueryExpression = serverQueryExpression; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void ApplySetOperation(MethodInfo setOperationMethodInfo, KafkaQueryExpression source2) + { + Check.DebugAssert(_groupingParameter == null, "Cannot apply set operation after GroupBy without flattening."); + if (_clientProjections.Count == 0) + { + var projectionMapping = new Dictionary(); + var source1SelectorExpressions = new List(); + var source2SelectorExpressions = new List(); + foreach (var (key, value1, value2) in _projectionMapping.Join( + source2._projectionMapping, kv => kv.Key, kv => kv.Key, + (kv1, kv2) => (kv1.Key, Value1: kv1.Value, Value2: kv2.Value))) + { + if (value1 is EntityProjectionExpression entityProjection1 + && value2 is EntityProjectionExpression entityProjection2) + { + var map = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(entityProjection1.EntityType)) + { + var expressionToAdd1 = entityProjection1.BindProperty(property); + var expressionToAdd2 = entityProjection2.BindProperty(property); + source1SelectorExpressions.Add(expressionToAdd1); + source2SelectorExpressions.Add(expressionToAdd2); + var type = expressionToAdd1.Type; + if (!type.IsNullableType() + && expressionToAdd2.Type.IsNullableType()) + { + type = expressionToAdd2.Type; + } + + map[property] = CreateReadValueExpression(type, source1SelectorExpressions.Count - 1, property); + } + + projectionMapping[key] = new EntityProjectionExpression(entityProjection1.EntityType, map); + } + else + { + source1SelectorExpressions.Add(value1); + source2SelectorExpressions.Add(value2); + var type = value1.Type; + if (!type.IsNullableType() + && value2.Type.IsNullableType()) + { + type = value2.Type; + } + + projectionMapping[key] = CreateReadValueExpression( + type, source1SelectorExpressions.Count - 1, InferPropertyFromInner(value1)); + } + } + + _projectionMapping = projectionMapping; + + ServerQueryExpression = Call( + EnumerableMethods.Select.MakeGenericMethod(ServerQueryExpression.Type.GetSequenceType(), typeof(ValueBuffer)), + ServerQueryExpression, + Lambda( + New( + ValueBufferConstructor, + NewArrayInit( + typeof(object), + source1SelectorExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e))), + CurrentParameter)); + + source2.ServerQueryExpression = Call( + EnumerableMethods.Select.MakeGenericMethod(source2.ServerQueryExpression.Type.GetSequenceType(), typeof(ValueBuffer)), + source2.ServerQueryExpression, + Lambda( + New( + ValueBufferConstructor, + NewArrayInit( + typeof(object), + source2SelectorExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e))), + source2.CurrentParameter)); + } + else + { + throw new InvalidOperationException(KafkaStrings.SetOperationsNotAllowedAfterClientEvaluation); + } + + ServerQueryExpression = Call( + setOperationMethodInfo.MakeGenericMethod(typeof(ValueBuffer)), ServerQueryExpression, source2.ServerQueryExpression); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void ApplyDefaultIfEmpty() + { + if (_clientProjections.Count != 0) + { + throw new InvalidOperationException(KafkaStrings.DefaultIfEmptyAppliedAfterProjection); + } + + var projectionMapping = new Dictionary(); + foreach (var (projectionMember, expression) in _projectionMapping) + { + projectionMapping[projectionMember] = expression is EntityProjectionExpression entityProjectionExpression + ? MakeEntityProjectionNullable(entityProjectionExpression) + : MakeReadValueNullable(expression); + } + + _projectionMapping = projectionMapping; + var projectionMappingExpressions = _projectionMappingExpressions.Select(e => MakeReadValueNullable(e)).ToList(); + _projectionMappingExpressions.Clear(); + _projectionMappingExpressions.AddRange(projectionMappingExpressions); + _groupingParameter = null; + + ServerQueryExpression = Call( + EnumerableMethods.DefaultIfEmptyWithArgument.MakeGenericMethod(typeof(ValueBuffer)), + ServerQueryExpression, + Constant(new ValueBuffer(Enumerable.Repeat((object?)null, _projectionMappingExpressions.Count).ToArray()))); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void ApplyDistinct() + { + Check.DebugAssert(!_scalarServerQuery && _singleResultMethodInfo == null, "Cannot apply distinct on single result query"); + Check.DebugAssert(_groupingParameter == null, "Cannot apply distinct after GroupBy before flattening."); + + var selectorExpressions = new List(); + if (_clientProjections.Count == 0) + { + selectorExpressions.AddRange(_projectionMappingExpressions); + if (selectorExpressions.Count == 0) + { + // No server correlated term in projection so add dummy 1. + selectorExpressions.Add(Constant(1)); + } + } + else + { + for (var i = 0; i < _clientProjections.Count; i++) + { + var projection = _clientProjections[i]; + if (projection is KafkaQueryExpression) + { + throw new InvalidOperationException(KafkaStrings.DistinctOnSubqueryNotSupported); + } + + if (projection is EntityProjectionExpression entityProjectionExpression) + { + _clientProjections[i] = TraverseEntityProjection( + selectorExpressions, entityProjectionExpression, makeNullable: false); + } + else + { + selectorExpressions.Add(projection); + _clientProjections[i] = CreateReadValueExpression( + projection.Type, selectorExpressions.Count - 1, InferPropertyFromInner(projection)); + } + } + } + + var selectorLambda = Lambda( + New( + ValueBufferConstructor, + NewArrayInit( + typeof(object), + selectorExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e).ToArray())), + CurrentParameter); + + ServerQueryExpression = Call( + EnumerableMethods.Distinct.MakeGenericMethod(typeof(ValueBuffer)), + Call( + EnumerableMethods.Select.MakeGenericMethod(CurrentParameter.Type, typeof(ValueBuffer)), + ServerQueryExpression, + selectorLambda)); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual GroupByShaperExpression ApplyGrouping( + Expression groupingKey, + Expression shaperExpression, + bool defaultElementSelector) + { + var source = ServerQueryExpression; + Expression? selector; + if (defaultElementSelector) + { + selector = Lambda( + New( + ValueBufferConstructor, + NewArrayInit( + typeof(object), + _projectionMappingExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e))), + _valueBufferParameter); + } + else + { + var selectMethodCall = (MethodCallExpression)ServerQueryExpression; + source = selectMethodCall.Arguments[0]; + selector = selectMethodCall.Arguments[1]; + } + + _groupingParameter = Parameter(typeof(IGrouping), "grouping"); + var groupingKeyAccessExpression = PropertyOrField(_groupingParameter, nameof(IGrouping.Key)); + var groupingKeyExpressions = new List(); + groupingKey = GetGroupingKey(groupingKey, groupingKeyExpressions, groupingKeyAccessExpression); + var keySelector = Lambda( + New( + ValueBufferConstructor, + NewArrayInit( + typeof(object), + groupingKeyExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e))), + _valueBufferParameter); + + ServerQueryExpression = Call( + EnumerableMethods.GroupByWithKeyElementSelector.MakeGenericMethod( + typeof(ValueBuffer), typeof(ValueBuffer), typeof(ValueBuffer)), + source, + keySelector, + selector); + + var clonedKafkaQueryExpression = Clone(); + clonedKafkaQueryExpression.UpdateServerQueryExpression(_groupingParameter); + clonedKafkaQueryExpression._groupingParameter = null; + + return new GroupByShaperExpression( + groupingKey, + new ShapedQueryExpression( + clonedKafkaQueryExpression, + new QueryExpressionReplacingExpressionVisitor(this, clonedKafkaQueryExpression).Visit(shaperExpression))); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression AddInnerJoin( + KafkaQueryExpression innerQueryExpression, + LambdaExpression outerKeySelector, + LambdaExpression innerKeySelector, + Expression outerShaperExpression, + Expression innerShaperExpression) + => AddJoin( + innerQueryExpression, outerKeySelector, innerKeySelector, outerShaperExpression, innerShaperExpression, + innerNullable: false); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression AddLeftJoin( + KafkaQueryExpression innerQueryExpression, + LambdaExpression outerKeySelector, + LambdaExpression innerKeySelector, + Expression outerShaperExpression, + Expression innerShaperExpression) + => AddJoin( + innerQueryExpression, outerKeySelector, innerKeySelector, outerShaperExpression, innerShaperExpression, + innerNullable: true); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression AddSelectMany( + KafkaQueryExpression innerQueryExpression, + Expression outerShaperExpression, + Expression innerShaperExpression, + bool innerNullable) + => AddJoin(innerQueryExpression, null, null, outerShaperExpression, innerShaperExpression, innerNullable); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual StructuralTypeShaperExpression AddNavigationToWeakEntityType( + EntityProjectionExpression entityProjectionExpression, + INavigation navigation, + KafkaQueryExpression innerQueryExpression, + LambdaExpression outerKeySelector, + LambdaExpression innerKeySelector) + { + Check.DebugAssert(_clientProjections.Count == 0, "Cannot expand weak entity navigation after client projection yet."); + var outerParameter = Parameter(typeof(ValueBuffer), "outer"); + var innerParameter = Parameter(typeof(ValueBuffer), "inner"); + var replacingVisitor = new ReplacingExpressionVisitor( + new Expression[] { CurrentParameter, innerQueryExpression.CurrentParameter }, + new Expression[] { outerParameter, innerParameter }); + + var selectorExpressions = _projectionMappingExpressions.Select(e => replacingVisitor.Visit(e)).ToList(); + var outerIndex = selectorExpressions.Count; + var innerEntityProjection = (EntityProjectionExpression)innerQueryExpression._projectionMapping[new ProjectionMember()]; + var innerReadExpressionMap = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(innerEntityProjection.EntityType)) + { + var propertyExpression = innerEntityProjection.BindProperty(property); + propertyExpression = MakeReadValueNullable(propertyExpression); + + selectorExpressions.Add(propertyExpression); + var readValueExpression = CreateReadValueExpression(propertyExpression.Type, selectorExpressions.Count - 1, property); + innerReadExpressionMap[property] = readValueExpression; + _projectionMappingExpressions.Add(readValueExpression); + } + + innerEntityProjection = new EntityProjectionExpression(innerEntityProjection.EntityType, innerReadExpressionMap); + + var resultSelector = Lambda( + New( + ValueBufferConstructor, + NewArrayInit( + typeof(object), + selectorExpressions + .Select(e => replacingVisitor.Visit(e)) + .Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e))), + outerParameter, + innerParameter); + + ServerQueryExpression = Call( + LeftJoinMethodInfo.MakeGenericMethod( + typeof(ValueBuffer), typeof(ValueBuffer), outerKeySelector.ReturnType, typeof(ValueBuffer)), + ServerQueryExpression, + innerQueryExpression.ServerQueryExpression, + outerKeySelector, + innerKeySelector, + resultSelector, + Constant(new ValueBuffer(Enumerable.Repeat((object?)null, selectorExpressions.Count - outerIndex).ToArray())), + Constant(null, typeof(IEqualityComparer<>).MakeGenericType(outerKeySelector.ReturnType))); + + var entityShaper = new StructuralTypeShaperExpression(innerEntityProjection.EntityType, innerEntityProjection, nullable: true); + entityProjectionExpression.AddNavigationBinding(navigation, entityShaper); + + return entityShaper; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual ShapedQueryExpression Clone(Expression shaperExpression) + { + var clonedKafkaQueryExpression = Clone(); + + return new ShapedQueryExpression( + clonedKafkaQueryExpression, + new QueryExpressionReplacingExpressionVisitor(this, clonedKafkaQueryExpression).Visit(shaperExpression)); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression GetSingleScalarProjection() + { + var expression = CreateReadValueExpression(ServerQueryExpression.Type, 0, null); + _projectionMapping.Clear(); + _projectionMappingExpressions.Clear(); + _clientProjections.Clear(); + _projectionMapping[new ProjectionMember()] = expression; + _projectionMappingExpressions.Add(expression); + _groupingParameter = null; + + _scalarServerQuery = true; + ConvertToEnumerable(); + + return new ProjectionBindingExpression(this, new ProjectionMember(), expression.Type.MakeNullable()); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void ConvertToSingleResult(MethodInfo methodInfo) + => _singleResultMethodInfo = methodInfo; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Type Type + => typeof(IEnumerable); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public sealed override ExpressionType NodeType + => ExpressionType.Extension; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + void IPrintableExpression.Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.AppendLine(nameof(KafkaQueryExpression) + ": "); + using (expressionPrinter.Indent()) + { + expressionPrinter.AppendLine(nameof(ServerQueryExpression) + ": "); + using (expressionPrinter.Indent()) + { + expressionPrinter.Visit(ServerQueryExpression); + } + + expressionPrinter.AppendLine(); + if (_clientProjections.Count > 0) + { + expressionPrinter.AppendLine("ClientProjections:"); + using (expressionPrinter.Indent()) + { + for (var i = 0; i < _clientProjections.Count; i++) + { + expressionPrinter.AppendLine(); + expressionPrinter.Append(i.ToString()).Append(" -> "); + expressionPrinter.Visit(_clientProjections[i]); + } + } + } + else + { + expressionPrinter.AppendLine("ProjectionMapping:"); + using (expressionPrinter.Indent()) + { + foreach (var (projectionMember, expression) in _projectionMapping) + { + expressionPrinter.Append("Member: " + projectionMember + " Projection: "); + expressionPrinter.Visit(expression); + expressionPrinter.AppendLine(","); + } + } + } + + expressionPrinter.AppendLine(); + } + } + + private KafkaQueryExpression Clone() + { + _cloningExpressionVisitor ??= new CloningExpressionVisitor(); + + return (KafkaQueryExpression)_cloningExpressionVisitor.Visit(this); + } + + private static Expression GetGroupingKey(Expression key, List groupingExpressions, Expression groupingKeyAccessExpression) + { + switch (key) + { + case NewExpression newExpression: + var arguments = new Expression[newExpression.Arguments.Count]; + for (var i = 0; i < arguments.Length; i++) + { + arguments[i] = GetGroupingKey(newExpression.Arguments[i], groupingExpressions, groupingKeyAccessExpression); + } + + return newExpression.Update(arguments); + + case MemberInitExpression memberInitExpression: + if (memberInitExpression.Bindings.Any(mb => mb is not MemberAssignment)) + { + goto default; + } + + var updatedNewExpression = (NewExpression)GetGroupingKey( + memberInitExpression.NewExpression, groupingExpressions, groupingKeyAccessExpression); + var memberBindings = new MemberAssignment[memberInitExpression.Bindings.Count]; + for (var i = 0; i < memberBindings.Length; i++) + { + var memberAssignment = (MemberAssignment)memberInitExpression.Bindings[i]; + memberBindings[i] = memberAssignment.Update( + GetGroupingKey( + memberAssignment.Expression, + groupingExpressions, + groupingKeyAccessExpression)); + } + + return memberInitExpression.Update(updatedNewExpression, memberBindings); + + case StructuralTypeShaperExpression { ValueBufferExpression: ProjectionBindingExpression projectionBindingExpression } shaper: + var entityProjectionExpression = + (EntityProjectionExpression)((KafkaQueryExpression)projectionBindingExpression.QueryExpression) + .GetProjection(projectionBindingExpression); + var readExpressions = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType)) + { + readExpressions[property] = (MethodCallExpression)GetGroupingKey( + entityProjectionExpression.BindProperty(property), + groupingExpressions, + groupingKeyAccessExpression); + } + + return shaper.Update( + new EntityProjectionExpression(entityProjectionExpression.EntityType, readExpressions)); + + default: + var index = groupingExpressions.Count; + groupingExpressions.Add(key); + return groupingKeyAccessExpression.CreateValueBufferReadValueExpression( + key.Type, + index, + InferPropertyFromInner(key)); + } + } + + private Expression AddJoin( + KafkaQueryExpression innerQueryExpression, + LambdaExpression? outerKeySelector, + LambdaExpression? innerKeySelector, + Expression outerShaperExpression, + Expression innerShaperExpression, + bool innerNullable) + { + var transparentIdentifierType = TransparentIdentifierFactory.Create(outerShaperExpression.Type, innerShaperExpression.Type); + var outerMemberInfo = transparentIdentifierType.GetTypeInfo().GetDeclaredField("Outer")!; + var innerMemberInfo = transparentIdentifierType.GetTypeInfo().GetDeclaredField("Inner")!; + var outerClientEval = _clientProjections.Count > 0; + var innerClientEval = innerQueryExpression._clientProjections.Count > 0; + var resultSelectorExpressions = new List(); + var outerParameter = Parameter(typeof(ValueBuffer), "outer"); + var innerParameter = Parameter(typeof(ValueBuffer), "inner"); + var replacingVisitor = new ReplacingExpressionVisitor( + new Expression[] { CurrentParameter, innerQueryExpression.CurrentParameter }, + new Expression[] { outerParameter, innerParameter }); + int outerIndex; + + if (outerClientEval) + { + // Outer projection are already populated + if (innerClientEval) + { + // Add inner to projection and update indexes + var indexMap = new int[innerQueryExpression._clientProjections.Count]; + for (var i = 0; i < innerQueryExpression._clientProjections.Count; i++) + { + var projectionToAdd = innerQueryExpression._clientProjections[i]; + projectionToAdd = MakeNullable(projectionToAdd, innerNullable); + _clientProjections.Add(projectionToAdd); + indexMap[i] = _clientProjections.Count - 1; + } + + innerQueryExpression._clientProjections.Clear(); + + innerShaperExpression = + new ProjectionIndexRemappingExpressionVisitor(innerQueryExpression, this, indexMap).Visit(innerShaperExpression); + } + else + { + // Apply inner projection mapping and convert projection member binding to indexes + var mapping = ConvertProjectionMappingToClientProjections(innerQueryExpression._projectionMapping, innerNullable); + innerShaperExpression = + new ProjectionMemberToIndexConvertingExpressionVisitor(this, mapping).Visit(innerShaperExpression); + } + + // TODO: We still need to populate and generate result selector + // Further for a subquery in projection we may need to update correlation terms used inside it. + throw new NotImplementedException(); + } + + if (innerClientEval) + { + // Since inner projections are populated, we need to populate outer also + var mapping = ConvertProjectionMappingToClientProjections(_projectionMapping); + outerShaperExpression = new ProjectionMemberToIndexConvertingExpressionVisitor(this, mapping).Visit(outerShaperExpression); + + var indexMap = new int[innerQueryExpression._clientProjections.Count]; + for (var i = 0; i < innerQueryExpression._clientProjections.Count; i++) + { + var projectionToAdd = innerQueryExpression._clientProjections[i]; + projectionToAdd = MakeNullable(projectionToAdd, innerNullable); + _clientProjections.Add(projectionToAdd); + indexMap[i] = _clientProjections.Count - 1; + } + + innerQueryExpression._clientProjections.Clear(); + + innerShaperExpression = + new ProjectionIndexRemappingExpressionVisitor(innerQueryExpression, this, indexMap).Visit(innerShaperExpression); + // TODO: We still need to populate and generate result selector + // Further for a subquery in projection we may need to update correlation terms used inside it. + throw new NotImplementedException(); + } + else + { + var projectionMapping = new Dictionary(); + var mapping = new Dictionary(); + foreach (var (projectionMember, expression) in _projectionMapping) + { + var newProjectionMember = projectionMember.Prepend(outerMemberInfo); + mapping[projectionMember] = newProjectionMember; + if (expression is EntityProjectionExpression entityProjectionExpression) + { + projectionMapping[newProjectionMember] = TraverseEntityProjection( + resultSelectorExpressions, entityProjectionExpression, makeNullable: false); + } + else + { + resultSelectorExpressions.Add(expression); + projectionMapping[newProjectionMember] = CreateReadValueExpression( + expression.Type, resultSelectorExpressions.Count - 1, InferPropertyFromInner(expression)); + } + } + + outerShaperExpression = new ProjectionMemberRemappingExpressionVisitor(this, mapping).Visit(outerShaperExpression); + mapping.Clear(); + + outerIndex = resultSelectorExpressions.Count; + foreach (var projection in innerQueryExpression._projectionMapping) + { + var newProjectionMember = projection.Key.Prepend(innerMemberInfo); + mapping[projection.Key] = newProjectionMember; + if (projection.Value is EntityProjectionExpression entityProjectionExpression) + { + projectionMapping[newProjectionMember] = TraverseEntityProjection( + resultSelectorExpressions, entityProjectionExpression, innerNullable); + } + else + { + var expression = projection.Value; + if (innerNullable) + { + expression = MakeReadValueNullable(expression); + } + + resultSelectorExpressions.Add(expression); + projectionMapping[newProjectionMember] = CreateReadValueExpression( + expression.Type, resultSelectorExpressions.Count - 1, InferPropertyFromInner(projection.Value)); + } + } + + innerShaperExpression = new ProjectionMemberRemappingExpressionVisitor(this, mapping).Visit(innerShaperExpression); + mapping.Clear(); + + _projectionMapping = projectionMapping; + } + + var resultSelector = Lambda( + New( + ValueBufferConstructor, NewArrayInit( + typeof(object), + resultSelectorExpressions.Select( + (e, i) => + { + var expression = replacingVisitor.Visit(e); + if (innerNullable + && i > outerIndex) + { + expression = MakeReadValueNullable(expression); + } + + if (expression.Type.IsValueType) + { + expression = Convert(expression, typeof(object)); + } + + return expression; + }))), + outerParameter, + innerParameter); + + if (outerKeySelector != null + && innerKeySelector != null) + { + var comparer = ((InferPropertyFromInner(outerKeySelector.Body) + ?? InferPropertyFromInner(outerKeySelector.Body)) + as IProperty)?.GetValueComparer(); + + if (comparer?.Type != outerKeySelector.ReturnType) + { + comparer = null; + } + + if (innerNullable) + { + ServerQueryExpression = Call( + LeftJoinMethodInfo.MakeGenericMethod( + typeof(ValueBuffer), typeof(ValueBuffer), outerKeySelector.ReturnType, typeof(ValueBuffer)), + ServerQueryExpression, + innerQueryExpression.ServerQueryExpression, + outerKeySelector, + innerKeySelector, + resultSelector, + Constant(new ValueBuffer(Enumerable.Repeat((object?)null, resultSelectorExpressions.Count - outerIndex).ToArray())), + Constant(comparer, typeof(IEqualityComparer<>).MakeGenericType(outerKeySelector.ReturnType))); + } + else + { + ServerQueryExpression = comparer == null + ? Call( + EnumerableMethods.Join.MakeGenericMethod( + typeof(ValueBuffer), typeof(ValueBuffer), outerKeySelector.ReturnType, typeof(ValueBuffer)), + ServerQueryExpression, + innerQueryExpression.ServerQueryExpression, + outerKeySelector, + innerKeySelector, + resultSelector) + : Call( + EnumerableMethods.JoinWithComparer.MakeGenericMethod( + typeof(ValueBuffer), typeof(ValueBuffer), outerKeySelector.ReturnType, typeof(ValueBuffer)), + ServerQueryExpression, + innerQueryExpression.ServerQueryExpression, + outerKeySelector, + innerKeySelector, + resultSelector, + Constant(comparer, typeof(IEqualityComparer<>).MakeGenericType(outerKeySelector.ReturnType))); + } + } + else + { + // inner nullable should do something different here + // Issue#17536 + ServerQueryExpression = Call( + EnumerableMethods.SelectManyWithCollectionSelector.MakeGenericMethod( + typeof(ValueBuffer), typeof(ValueBuffer), typeof(ValueBuffer)), + ServerQueryExpression, + Lambda(innerQueryExpression.ServerQueryExpression, CurrentParameter), + resultSelector); + } + + if (innerNullable) + { + innerShaperExpression = new EntityShaperNullableMarkingExpressionVisitor().Visit(innerShaperExpression); + } + + return New( + transparentIdentifierType.GetTypeInfo().DeclaredConstructors.Single(), + new[] { outerShaperExpression, innerShaperExpression }, outerMemberInfo, innerMemberInfo); + + static Expression MakeNullable(Expression expression, bool nullable) + => nullable + ? expression is EntityProjectionExpression entityProjection + ? MakeEntityProjectionNullable(entityProjection) + : MakeReadValueNullable(expression) + : expression; + } + + private void ConvertToEnumerable() + { + if (_scalarServerQuery || _singleResultMethodInfo != null) + { + if (ServerQueryExpression.Type != typeof(ValueBuffer)) + { + if (ServerQueryExpression.Type.IsValueType) + { + ServerQueryExpression = Convert(ServerQueryExpression, typeof(object)); + } + + ServerQueryExpression = New( + ResultEnumerableConstructor, + Lambda>( + New( + ValueBufferConstructor, + NewArrayInit(typeof(object), ServerQueryExpression)))); + } + else + { + ServerQueryExpression = New( + ResultEnumerableConstructor, + Lambda>(ServerQueryExpression)); + } + } + } + + private MethodCallExpression CreateReadValueExpression(Type type, int index, IPropertyBase? property) + => (MethodCallExpression)_valueBufferParameter.CreateValueBufferReadValueExpression(type, index, property); + + private static IEnumerable GetAllPropertiesInHierarchy(IEntityType entityType) + => entityType.GetAllBaseTypes().Concat(entityType.GetDerivedTypesInclusive()) + .SelectMany(t => t.GetDeclaredProperties()); + + private static IPropertyBase? InferPropertyFromInner(Expression expression) + => expression is MethodCallExpression { Method.IsGenericMethod: true } methodCallExpression + && methodCallExpression.Method.GetGenericMethodDefinition() == ExpressionExtensions.ValueBufferTryReadValueMethod + ? methodCallExpression.Arguments[2].GetConstantValue() + : null; + + private static EntityProjectionExpression MakeEntityProjectionNullable(EntityProjectionExpression entityProjectionExpression) + { + var readExpressionMap = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType)) + { + readExpressionMap[property] = MakeReadValueNullable(entityProjectionExpression.BindProperty(property)); + } + + var result = new EntityProjectionExpression(entityProjectionExpression.EntityType, readExpressionMap); + + // Also compute nested entity projections + foreach (var navigation in entityProjectionExpression.EntityType.GetAllBaseTypes() + .Concat(entityProjectionExpression.EntityType.GetDerivedTypesInclusive()) + .SelectMany(t => t.GetDeclaredNavigations())) + { + var boundEntityShaperExpression = entityProjectionExpression.BindNavigation(navigation); + if (boundEntityShaperExpression != null) + { + var innerEntityProjection = (EntityProjectionExpression)boundEntityShaperExpression.ValueBufferExpression; + var newInnerEntityProjection = MakeEntityProjectionNullable(innerEntityProjection); + boundEntityShaperExpression = boundEntityShaperExpression.Update(newInnerEntityProjection); + result.AddNavigationBinding(navigation, boundEntityShaperExpression); + } + } + + return result; + } + + private Dictionary ConvertProjectionMappingToClientProjections( + Dictionary projectionMapping, + bool makeNullable = false) + { + var mapping = new Dictionary(); + var entityProjectionCache = new Dictionary(ReferenceEqualityComparer.Instance); + foreach (var projection in projectionMapping) + { + var projectionMember = projection.Key; + var projectionToAdd = projection.Value; + + if (projectionToAdd is EntityProjectionExpression entityProjection) + { + if (!entityProjectionCache.TryGetValue(entityProjection, out var value)) + { + var entityProjectionToCache = entityProjection; + if (makeNullable) + { + entityProjection = MakeEntityProjectionNullable(entityProjection); + } + + _clientProjections.Add(entityProjection); + value = _clientProjections.Count - 1; + entityProjectionCache[entityProjectionToCache] = value; + } + + mapping[projectionMember] = value; + } + else + { + if (makeNullable) + { + projectionToAdd = MakeReadValueNullable(projectionToAdd); + } + + var existingIndex = _clientProjections.FindIndex(e => e.Equals(projectionToAdd)); + if (existingIndex == -1) + { + _clientProjections.Add(projectionToAdd); + existingIndex = _clientProjections.Count - 1; + } + + mapping[projectionMember] = existingIndex; + } + } + + projectionMapping.Clear(); + + return mapping; + } + + private static IEnumerable LeftJoin( + IEnumerable outer, + IEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func resultSelector, + TInner defaultValue, + IEqualityComparer? comparer) + => (comparer == null + ? outer.GroupJoin(inner, outerKeySelector, innerKeySelector, (oe, ies) => new { oe, ies }) + : outer.GroupJoin(inner, outerKeySelector, innerKeySelector, (oe, ies) => new { oe, ies }, comparer)) + .SelectMany(t => t.ies.DefaultIfEmpty(defaultValue), (t, i) => resultSelector(t.oe, i)); + + private static MethodCallExpression MakeReadValueNullable(Expression expression) + { + Check.DebugAssert(expression is MethodCallExpression, "Expression must be method call expression."); + + var methodCallExpression = (MethodCallExpression)expression; + + return methodCallExpression.Type.IsNullableType() + ? methodCallExpression + : Call( + ExpressionExtensions.ValueBufferTryReadValueMethod.MakeGenericMethod(methodCallExpression.Type.MakeNullable()), + methodCallExpression.Arguments); + } + + private EntityProjectionExpression TraverseEntityProjection( + List selectorExpressions, + EntityProjectionExpression entityProjectionExpression, + bool makeNullable) + { + var readExpressionMap = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType)) + { + var expression = entityProjectionExpression.BindProperty(property); + if (makeNullable) + { + expression = MakeReadValueNullable(expression); + } + + selectorExpressions.Add(expression); + var newExpression = CreateReadValueExpression(expression.Type, selectorExpressions.Count - 1, property); + readExpressionMap[property] = newExpression; + } + + var result = new EntityProjectionExpression(entityProjectionExpression.EntityType, readExpressionMap); + + // Also compute nested entity projections + foreach (var navigation in entityProjectionExpression.EntityType.GetAllBaseTypes() + .Concat(entityProjectionExpression.EntityType.GetDerivedTypesInclusive()) + .SelectMany(t => t.GetDeclaredNavigations())) + { + var boundEntityShaperExpression = entityProjectionExpression.BindNavigation(navigation); + if (boundEntityShaperExpression != null) + { + var innerEntityProjection = (EntityProjectionExpression)boundEntityShaperExpression.ValueBufferExpression; + var newInnerEntityProjection = TraverseEntityProjection(selectorExpressions, innerEntityProjection, makeNullable); + boundEntityShaperExpression = boundEntityShaperExpression.Update(newInnerEntityProjection); + result.AddNavigationBinding(navigation, boundEntityShaperExpression); + } + } + + return result; + } +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaQueryTranslationPreprocessor.cs b/src/net/KEFCore/Query/Internal8/KafkaQueryTranslationPreprocessor.cs new file mode 100644 index 00000000..7646f922 --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaQueryTranslationPreprocessor.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using MASES.EntityFrameworkCore.KNet.Internal; + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class KafkaQueryTranslationPreprocessor : QueryTranslationPreprocessor +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public KafkaQueryTranslationPreprocessor( + QueryTranslationPreprocessorDependencies dependencies, + QueryCompilationContext queryCompilationContext) + : base(dependencies, queryCompilationContext) + { + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression Process(Expression query) + { + var result = base.Process(query); + + if (result is MethodCallExpression { Method.IsGenericMethod: true } methodCallExpression + && (methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.GroupByWithKeySelector + || methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.GroupByWithKeyElementSelector)) + { + throw new InvalidOperationException( + CoreStrings.TranslationFailedWithDetails(methodCallExpression.Print(), KafkaStrings.NonComposedGroupByNotSupported)); + } + + return result; + } +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaQueryTranslationPreprocessorFactory.cs b/src/net/KEFCore/Query/Internal8/KafkaQueryTranslationPreprocessorFactory.cs new file mode 100644 index 00000000..c3da12f8 --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaQueryTranslationPreprocessorFactory.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class KafkaQueryTranslationPreprocessorFactory : IQueryTranslationPreprocessorFactory +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public KafkaQueryTranslationPreprocessorFactory( + QueryTranslationPreprocessorDependencies dependencies) + { + Dependencies = dependencies; + } + + /// + /// Dependencies for this service. + /// + protected virtual QueryTranslationPreprocessorDependencies Dependencies { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual QueryTranslationPreprocessor Create(QueryCompilationContext queryCompilationContext) + => new KafkaQueryTranslationPreprocessor(Dependencies, queryCompilationContext); +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaQueryableMethodTranslatingExpressionVisitor.cs b/src/net/KEFCore/Query/Internal8/KafkaQueryableMethodTranslatingExpressionVisitor.cs new file mode 100644 index 00000000..758b96df --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaQueryableMethodTranslatingExpressionVisitor.cs @@ -0,0 +1,1476 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using ExpressionExtensions = Microsoft.EntityFrameworkCore.Infrastructure.ExpressionExtensions; + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class KafkaQueryableMethodTranslatingExpressionVisitor : QueryableMethodTranslatingExpressionVisitor +{ + private readonly KafkaExpressionTranslatingExpressionVisitor _expressionTranslator; + private readonly SharedTypeEntityExpandingExpressionVisitor _weakEntityExpandingExpressionVisitor; + private readonly KafkaProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor; + private readonly IModel _model; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public KafkaQueryableMethodTranslatingExpressionVisitor( + QueryableMethodTranslatingExpressionVisitorDependencies dependencies, + QueryCompilationContext queryCompilationContext) + : base(dependencies, queryCompilationContext, subquery: false) + { + _expressionTranslator = new KafkaExpressionTranslatingExpressionVisitor(queryCompilationContext, this); + _weakEntityExpandingExpressionVisitor = new SharedTypeEntityExpandingExpressionVisitor(_expressionTranslator); + _projectionBindingExpressionVisitor = new KafkaProjectionBindingExpressionVisitor(this, _expressionTranslator); + _model = queryCompilationContext.Model; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected KafkaQueryableMethodTranslatingExpressionVisitor( + KafkaQueryableMethodTranslatingExpressionVisitor parentVisitor) + : base(parentVisitor.Dependencies, parentVisitor.QueryCompilationContext, subquery: true) + { + _expressionTranslator = new KafkaExpressionTranslatingExpressionVisitor(QueryCompilationContext, parentVisitor); + _weakEntityExpandingExpressionVisitor = new SharedTypeEntityExpandingExpressionVisitor(_expressionTranslator); + _projectionBindingExpressionVisitor = new KafkaProjectionBindingExpressionVisitor(this, _expressionTranslator); + _model = parentVisitor._model; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override QueryableMethodTranslatingExpressionVisitor CreateSubqueryVisitor() + => new KafkaQueryableMethodTranslatingExpressionVisitor(this); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitExtension(Expression extensionExpression) + { + switch (extensionExpression) + { + case GroupByShaperExpression groupByShaperExpression: + var groupShapedQueryExpression = groupByShaperExpression.GroupingEnumerable; + + return ((KafkaQueryExpression)groupShapedQueryExpression.QueryExpression) + .Clone(groupShapedQueryExpression.ShaperExpression); + + case ShapedQueryExpression shapedQueryExpression: + return ((KafkaQueryExpression)shapedQueryExpression.QueryExpression) + .Clone(shapedQueryExpression.ShaperExpression); + + default: + return base.VisitExtension(extensionExpression); + } + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + if (methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Arguments.Count == 1 + && methodCallExpression.Arguments[0].Type.TryGetSequenceType() != null + && (string.Equals(methodCallExpression.Method.Name, "AsSplitQuery", StringComparison.Ordinal) + || string.Equals(methodCallExpression.Method.Name, "AsSingleQuery", StringComparison.Ordinal))) + { + return Visit(methodCallExpression.Arguments[0]); + } + + return base.VisitMethodCall(methodCallExpression); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType) + => CreateShapedQueryExpressionStatic(entityType); + + private static ShapedQueryExpression CreateShapedQueryExpressionStatic(IEntityType entityType) + { + var queryExpression = new KafkaQueryExpression(entityType); + + return new ShapedQueryExpression( + queryExpression, + new StructuralTypeShaperExpression( + entityType, + new ProjectionBindingExpression( + queryExpression, + new ProjectionMember(), + typeof(ValueBuffer)), + false)); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateAll(ShapedQueryExpression source, LambdaExpression predicate) + { + predicate = Expression.Lambda(Expression.Not(predicate.Body), predicate.Parameters); + var newSource = TranslateWhere(source, predicate); + if (newSource == null) + { + return null; + } + + source = newSource; + + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + + if (source.ShaperExpression is GroupByShaperExpression) + { + kafkaQueryExpression.ReplaceProjection(new Dictionary()); + } + + kafkaQueryExpression.UpdateServerQueryExpression( + Expression.Not( + Expression.Call( + EnumerableMethods.AnyWithoutPredicate.MakeGenericMethod(kafkaQueryExpression.CurrentParameter.Type), + kafkaQueryExpression.ServerQueryExpression))); + + return source.UpdateShaperExpression(Expression.Convert(kafkaQueryExpression.GetSingleScalarProjection(), typeof(bool))); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateAny(ShapedQueryExpression source, LambdaExpression? predicate) + { + if (predicate != null) + { + var newSource = TranslateWhere(source, predicate); + if (newSource == null) + { + return null; + } + + source = newSource; + } + + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + + if (source.ShaperExpression is GroupByShaperExpression) + { + kafkaQueryExpression.ReplaceProjection(new Dictionary()); + } + + kafkaQueryExpression.UpdateServerQueryExpression( + Expression.Call( + EnumerableMethods.AnyWithoutPredicate.MakeGenericMethod(kafkaQueryExpression.CurrentParameter.Type), + kafkaQueryExpression.ServerQueryExpression)); + + return source.UpdateShaperExpression(Expression.Convert(kafkaQueryExpression.GetSingleScalarProjection(), typeof(bool))); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateAverage( + ShapedQueryExpression source, + LambdaExpression? selector, + Type resultType) + => TranslateScalarAggregate(source, selector, nameof(Enumerable.Average), resultType); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateCast(ShapedQueryExpression source, Type resultType) + => source.ShaperExpression.Type != resultType + ? source.UpdateShaperExpression(Expression.Convert(source.ShaperExpression, resultType)) + : source; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateConcat(ShapedQueryExpression source1, ShapedQueryExpression source2) + => TranslateSetOperation(EnumerableMethods.Concat, source1, source2); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateContains(ShapedQueryExpression source, Expression item) + { + var anyLambdaParameter = Expression.Parameter(item.Type, "p"); + var anyLambda = Expression.Lambda( + ExpressionExtensions.CreateEqualsExpression(anyLambdaParameter, item), + anyLambdaParameter); + + return TranslateAny(source, anyLambda); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateCount(ShapedQueryExpression source, LambdaExpression? predicate) + { + if (predicate != null) + { + var newSource = TranslateWhere(source, predicate); + if (newSource == null) + { + return null; + } + + source = newSource; + } + + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + + if (source.ShaperExpression is GroupByShaperExpression) + { + kafkaQueryExpression.ReplaceProjection(new Dictionary()); + } + + kafkaQueryExpression.UpdateServerQueryExpression( + Expression.Call( + EnumerableMethods.CountWithoutPredicate.MakeGenericMethod(kafkaQueryExpression.CurrentParameter.Type), + kafkaQueryExpression.ServerQueryExpression)); + + return source.UpdateShaperExpression(Expression.Convert(kafkaQueryExpression.GetSingleScalarProjection(), typeof(int))); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateDefaultIfEmpty(ShapedQueryExpression source, Expression? defaultValue) + { + if (defaultValue == null) + { + ((KafkaQueryExpression)source.QueryExpression).ApplyDefaultIfEmpty(); + return source.UpdateShaperExpression(MarkShaperNullable(source.ShaperExpression)); + } + + return null; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateDistinct(ShapedQueryExpression source) + { + ((KafkaQueryExpression)source.QueryExpression).ApplyDistinct(); + + return source; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateElementAtOrDefault( + ShapedQueryExpression source, + Expression index, + bool returnDefault) + => null; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateExcept(ShapedQueryExpression source1, ShapedQueryExpression source2) + => TranslateSetOperation(EnumerableMethods.Except, source1, source2); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateFirstOrDefault( + ShapedQueryExpression source, + LambdaExpression? predicate, + Type returnType, + bool returnDefault) + => TranslateSingleResultOperator( + source, + predicate, + returnType, + returnDefault + ? EnumerableMethods.FirstOrDefaultWithoutPredicate + : EnumerableMethods.FirstWithoutPredicate); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateGroupBy( + ShapedQueryExpression source, + LambdaExpression keySelector, + LambdaExpression? elementSelector, + LambdaExpression? resultSelector) + { + var remappedKeySelector = RemapLambdaBody(source, keySelector); + + var translatedKey = TranslateGroupingKey(remappedKeySelector); + if (translatedKey != null) + { + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + var defaultElementSelector = elementSelector == null || elementSelector.Body == elementSelector.Parameters[0]; + if (!defaultElementSelector) + { + source = TranslateSelect(source, elementSelector!); + } + + var groupByShaper = kafkaQueryExpression.ApplyGrouping(translatedKey, source.ShaperExpression, defaultElementSelector); + + if (resultSelector == null) + { + return source.UpdateShaperExpression(groupByShaper); + } + + var original1 = resultSelector.Parameters[0]; + var original2 = resultSelector.Parameters[1]; + + var newResultSelectorBody = new ReplacingExpressionVisitor( + new Expression[] { original1, original2 }, + new[] { groupByShaper.KeySelector, groupByShaper }).Visit(resultSelector.Body); + + newResultSelectorBody = ExpandSharedTypeEntities(kafkaQueryExpression, newResultSelectorBody); + var newShaper = _projectionBindingExpressionVisitor.Translate(kafkaQueryExpression, newResultSelectorBody); + + return source.UpdateShaperExpression(newShaper); + } + + return null; + } + + private Expression? TranslateGroupingKey(Expression expression) + { + switch (expression) + { + case NewExpression newExpression: + if (newExpression.Arguments.Count == 0) + { + return newExpression; + } + + var newArguments = new Expression[newExpression.Arguments.Count]; + for (var i = 0; i < newArguments.Length; i++) + { + var key = TranslateGroupingKey(newExpression.Arguments[i]); + if (key == null) + { + return null; + } + + newArguments[i] = key; + } + + return newExpression.Update(newArguments); + + case MemberInitExpression memberInitExpression: + var updatedNewExpression = (NewExpression?)TranslateGroupingKey(memberInitExpression.NewExpression); + if (updatedNewExpression == null) + { + return null; + } + + var newBindings = new MemberAssignment[memberInitExpression.Bindings.Count]; + for (var i = 0; i < newBindings.Length; i++) + { + var memberAssignment = (MemberAssignment)memberInitExpression.Bindings[i]; + var visitedExpression = TranslateGroupingKey(memberAssignment.Expression); + if (visitedExpression == null) + { + return null; + } + + newBindings[i] = memberAssignment.Update(visitedExpression); + } + + return memberInitExpression.Update(updatedNewExpression, newBindings); + + case StructuralTypeShaperExpression { ValueBufferExpression: ProjectionBindingExpression } shaper: + return shaper; + + default: + var translation = TranslateExpression(expression); + if (translation == null) + { + return null; + } + + return translation.Type == expression.Type + ? translation + : Expression.Convert(translation, expression.Type); + } + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateGroupJoin( + ShapedQueryExpression outer, + ShapedQueryExpression inner, + LambdaExpression outerKeySelector, + LambdaExpression innerKeySelector, + LambdaExpression resultSelector) + => null; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateIntersect(ShapedQueryExpression source1, ShapedQueryExpression source2) + => TranslateSetOperation(EnumerableMethods.Intersect, source1, source2); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateJoin( + ShapedQueryExpression outer, + ShapedQueryExpression inner, + LambdaExpression outerKeySelector, + LambdaExpression innerKeySelector, + LambdaExpression resultSelector) + { + var (newOuterKeySelector, newInnerKeySelector) = ProcessJoinKeySelector(outer, inner, outerKeySelector, innerKeySelector); + + if (newOuterKeySelector == null + || newInnerKeySelector == null) + { + return null; + } + + (outerKeySelector, innerKeySelector) = (newOuterKeySelector, newInnerKeySelector); + + var outerShaperExpression = ((KafkaQueryExpression)outer.QueryExpression).AddInnerJoin( + (KafkaQueryExpression)inner.QueryExpression, + outerKeySelector, + innerKeySelector, + outer.ShaperExpression, + inner.ShaperExpression); + + outer = outer.UpdateShaperExpression(outerShaperExpression); + + return TranslateTwoParameterSelector(outer, resultSelector); + } + + private (LambdaExpression? OuterKeySelector, LambdaExpression? InnerKeySelector) ProcessJoinKeySelector( + ShapedQueryExpression outer, + ShapedQueryExpression inner, + LambdaExpression outerKeySelector, + LambdaExpression innerKeySelector) + { + var left = RemapLambdaBody(outer, outerKeySelector); + var right = RemapLambdaBody(inner, innerKeySelector); + + var joinCondition = TranslateExpression(ExpressionExtensions.CreateEqualsExpression(left, right)); + + var (outerKeyBody, innerKeyBody) = DecomposeJoinCondition(joinCondition); + + if (outerKeyBody == null + || innerKeyBody == null) + { + return (null, null); + } + + outerKeySelector = Expression.Lambda(outerKeyBody, ((KafkaQueryExpression)outer.QueryExpression).CurrentParameter); + innerKeySelector = Expression.Lambda(innerKeyBody, ((KafkaQueryExpression)inner.QueryExpression).CurrentParameter); + + return AlignKeySelectorTypes(outerKeySelector, innerKeySelector); + } + + private static (Expression?, Expression?) DecomposeJoinCondition(Expression? joinCondition) + { + var leftExpressions = new List(); + var rightExpressions = new List(); + + return ProcessJoinCondition(joinCondition, leftExpressions, rightExpressions) + ? leftExpressions.Count == 1 + ? (leftExpressions[0], rightExpressions[0]) + : (CreateAnonymousObject(leftExpressions), CreateAnonymousObject(rightExpressions)) + : (null, null); + + // Kafka joins need to use AnonymousObject to perform correct key comparison for server side joins + static Expression CreateAnonymousObject(List expressions) + => Expression.New( + AnonymousObject.AnonymousObjectCtor, + Expression.NewArrayInit( + typeof(object), + expressions.Select(e => Expression.Convert(e, typeof(object))))); + } + + private static bool ProcessJoinCondition( + Expression? joinCondition, + List leftExpressions, + List rightExpressions) + { + if (joinCondition is BinaryExpression binaryExpression) + { + if (binaryExpression.NodeType == ExpressionType.Equal) + { + leftExpressions.Add(binaryExpression.Left); + rightExpressions.Add(binaryExpression.Right); + + return true; + } + + if (binaryExpression.NodeType == ExpressionType.AndAlso) + { + return ProcessJoinCondition(binaryExpression.Left, leftExpressions, rightExpressions) + && ProcessJoinCondition(binaryExpression.Right, leftExpressions, rightExpressions); + } + } + + if (joinCondition is MethodCallExpression { Method.Name: nameof(object.Equals), Arguments.Count: 2 } methodCallExpression + && ((methodCallExpression.Method.IsStatic + && methodCallExpression.Method.DeclaringType == typeof(object)) + || typeof(ValueComparer).IsAssignableFrom(methodCallExpression.Method.DeclaringType))) + { + leftExpressions.Add(methodCallExpression.Arguments[0]); + rightExpressions.Add(methodCallExpression.Arguments[1]); + + return true; + } + + return false; + } + + private static (LambdaExpression OuterKeySelector, LambdaExpression InnerKeySelector) + AlignKeySelectorTypes(LambdaExpression outerKeySelector, LambdaExpression innerKeySelector) + { + if (outerKeySelector.Body.Type != innerKeySelector.Body.Type) + { + if (IsConvertedToNullable(outerKeySelector.Body, innerKeySelector.Body)) + { + innerKeySelector = Expression.Lambda( + Expression.Convert(innerKeySelector.Body, outerKeySelector.Body.Type), innerKeySelector.Parameters); + } + else if (IsConvertedToNullable(innerKeySelector.Body, outerKeySelector.Body)) + { + outerKeySelector = Expression.Lambda( + Expression.Convert(outerKeySelector.Body, innerKeySelector.Body.Type), outerKeySelector.Parameters); + } + } + + return (outerKeySelector, innerKeySelector); + + static bool IsConvertedToNullable(Expression outer, Expression inner) + => outer.Type.IsNullableType() + && !inner.Type.IsNullableType() + && outer.Type.UnwrapNullableType() == inner.Type; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateLastOrDefault( + ShapedQueryExpression source, + LambdaExpression? predicate, + Type returnType, + bool returnDefault) + => TranslateSingleResultOperator( + source, + predicate, + returnType, + returnDefault + ? EnumerableMethods.LastOrDefaultWithoutPredicate + : EnumerableMethods.LastWithoutPredicate); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateLeftJoin( + ShapedQueryExpression outer, + ShapedQueryExpression inner, + LambdaExpression outerKeySelector, + LambdaExpression innerKeySelector, + LambdaExpression resultSelector) + { + var (newOuterKeySelector, newInnerKeySelector) = ProcessJoinKeySelector(outer, inner, outerKeySelector, innerKeySelector); + + if (newOuterKeySelector == null + || newInnerKeySelector == null) + { + return null; + } + + (outerKeySelector, innerKeySelector) = (newOuterKeySelector, newInnerKeySelector); + + var outerShaperExpression = ((KafkaQueryExpression)outer.QueryExpression).AddLeftJoin( + (KafkaQueryExpression)inner.QueryExpression, + outerKeySelector, + innerKeySelector, + outer.ShaperExpression, + inner.ShaperExpression); + + outer = outer.UpdateShaperExpression(outerShaperExpression); + + return TranslateTwoParameterSelector(outer, resultSelector); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateLongCount(ShapedQueryExpression source, LambdaExpression? predicate) + { + if (predicate != null) + { + var newSource = TranslateWhere(source, predicate); + if (newSource == null) + { + return null; + } + + source = newSource; + } + + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + + if (source.ShaperExpression is GroupByShaperExpression) + { + kafkaQueryExpression.ReplaceProjection(new Dictionary()); + } + + kafkaQueryExpression.UpdateServerQueryExpression( + Expression.Call( + EnumerableMethods.LongCountWithoutPredicate.MakeGenericMethod( + kafkaQueryExpression.CurrentParameter.Type), + kafkaQueryExpression.ServerQueryExpression)); + + return source.UpdateShaperExpression(Expression.Convert(kafkaQueryExpression.GetSingleScalarProjection(), typeof(long))); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateMax( + ShapedQueryExpression source, + LambdaExpression? selector, + Type resultType) + => TranslateScalarAggregate(source, selector, nameof(Enumerable.Max), resultType); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) + => TranslateScalarAggregate(source, selector, nameof(Enumerable.Min), resultType); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateOfType(ShapedQueryExpression source, Type resultType) + { + if (source.ShaperExpression is StructuralTypeShaperExpression { StructuralType: IEntityType entityType } shaper) + { + if (entityType.ClrType == resultType) + { + return source; + } + + var parameterExpression = Expression.Parameter(shaper.Type); + var predicate = Expression.Lambda(Expression.TypeIs(parameterExpression, resultType), parameterExpression); + var newSource = TranslateWhere(source, predicate); + if (newSource == null) + { + // EntityType is not part of hierarchy + return null; + } + + source = newSource; + + var baseType = entityType.GetAllBaseTypes().SingleOrDefault(et => et.ClrType == resultType); + if (baseType != null) + { + return source.UpdateShaperExpression(shaper.WithType(baseType)); + } + + var derivedType = entityType.GetDerivedTypes().Single(et => et.ClrType == resultType); + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + + var projectionBindingExpression = (ProjectionBindingExpression)shaper.ValueBufferExpression; + var projectionMember = projectionBindingExpression.ProjectionMember; + Check.DebugAssert(new ProjectionMember().Equals(projectionMember), "Invalid ProjectionMember when processing OfType"); + + var entityProjectionExpression = + (EntityProjectionExpression)kafkaQueryExpression.GetProjection(projectionBindingExpression); + kafkaQueryExpression.ReplaceProjection( + new Dictionary + { + { projectionMember, entityProjectionExpression.UpdateEntityType(derivedType) } + }); + + return source.UpdateShaperExpression(shaper.WithType(derivedType)); + } + + return null; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateOrderBy( + ShapedQueryExpression source, + LambdaExpression keySelector, + bool ascending) + { + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + + var newKeySelector = TranslateLambdaExpression(source, keySelector); + if (newKeySelector == null) + { + return null; + } + + keySelector = newKeySelector; + + var orderBy = ascending ? EnumerableMethods.OrderBy : EnumerableMethods.OrderByDescending; + kafkaQueryExpression.UpdateServerQueryExpression( + Expression.Call( + orderBy.MakeGenericMethod(kafkaQueryExpression.CurrentParameter.Type, keySelector.ReturnType), + kafkaQueryExpression.ServerQueryExpression, + keySelector)); + + return source; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateReverse(ShapedQueryExpression source) + { + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + + kafkaQueryExpression.UpdateServerQueryExpression( + Expression.Call( + EnumerableMethods.Reverse.MakeGenericMethod(kafkaQueryExpression.CurrentParameter.Type), + kafkaQueryExpression.ServerQueryExpression)); + + return source; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression source, LambdaExpression selector) + { + if (selector.Body == selector.Parameters[0]) + { + return source; + } + + var newSelectorBody = RemapLambdaBody(source, selector); + var queryExpression = (KafkaQueryExpression)source.QueryExpression; + var newShaper = _projectionBindingExpressionVisitor.Translate(queryExpression, newSelectorBody); + + return source.UpdateShaperExpression(newShaper); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateSelectMany( + ShapedQueryExpression source, + LambdaExpression collectionSelector, + LambdaExpression resultSelector) + { + var defaultIfEmpty = new DefaultIfEmptyFindingExpressionVisitor().IsOptional(collectionSelector); + var collectionSelectorBody = RemapLambdaBody(source, collectionSelector); + + if (Visit(collectionSelectorBody) is ShapedQueryExpression inner) + { + var outerShaperExpression = ((KafkaQueryExpression)source.QueryExpression).AddSelectMany( + (KafkaQueryExpression)inner.QueryExpression, source.ShaperExpression, inner.ShaperExpression, defaultIfEmpty); + + source = source.UpdateShaperExpression(outerShaperExpression); + + return TranslateTwoParameterSelector(source, resultSelector); + } + + return null; + } + + private sealed class DefaultIfEmptyFindingExpressionVisitor : ExpressionVisitor + { + private bool _defaultIfEmpty; + + public bool IsOptional(LambdaExpression lambdaExpression) + { + _defaultIfEmpty = false; + + Visit(lambdaExpression.Body); + + return _defaultIfEmpty; + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + if (methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.DefaultIfEmptyWithoutArgument) + { + _defaultIfEmpty = true; + } + + return base.VisitMethodCall(methodCallExpression); + } + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateSelectMany(ShapedQueryExpression source, LambdaExpression selector) + { + var innerParameter = Expression.Parameter(selector.ReturnType.GetSequenceType(), "i"); + var resultSelector = Expression.Lambda( + innerParameter, Expression.Parameter(source.Type.GetSequenceType()), innerParameter); + + return TranslateSelectMany(source, selector, resultSelector); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateSingleOrDefault( + ShapedQueryExpression source, + LambdaExpression? predicate, + Type returnType, + bool returnDefault) + => TranslateSingleResultOperator( + source, + predicate, + returnType, + returnDefault + ? EnumerableMethods.SingleOrDefaultWithoutPredicate + : EnumerableMethods.SingleWithoutPredicate); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateSkip(ShapedQueryExpression source, Expression count) + { + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + var newCount = TranslateExpression(count); + if (newCount == null) + { + return null; + } + + count = newCount; + + kafkaQueryExpression.UpdateServerQueryExpression( + Expression.Call( + EnumerableMethods.Skip.MakeGenericMethod(kafkaQueryExpression.CurrentParameter.Type), + kafkaQueryExpression.ServerQueryExpression, + count)); + + return source; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateSkipWhile(ShapedQueryExpression source, LambdaExpression predicate) + => null; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) + => TranslateScalarAggregate(source, selector, nameof(Enumerable.Sum), resultType); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateTake(ShapedQueryExpression source, Expression count) + { + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + var newCount = TranslateExpression(count); + if (newCount == null) + { + return null; + } + + count = newCount; + + kafkaQueryExpression.UpdateServerQueryExpression( + Expression.Call( + EnumerableMethods.Take.MakeGenericMethod(kafkaQueryExpression.CurrentParameter.Type), + kafkaQueryExpression.ServerQueryExpression, + count)); + + return source; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateTakeWhile(ShapedQueryExpression source, LambdaExpression predicate) + => null; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateThenBy( + ShapedQueryExpression source, + LambdaExpression keySelector, + bool ascending) + { + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + var newKeySelector = TranslateLambdaExpression(source, keySelector); + if (newKeySelector == null) + { + return null; + } + + keySelector = newKeySelector; + + kafkaQueryExpression.UpdateServerQueryExpression( + Expression.Call( + (ascending ? EnumerableMethods.ThenBy : EnumerableMethods.ThenByDescending) + .MakeGenericMethod(kafkaQueryExpression.CurrentParameter.Type, keySelector.ReturnType), + kafkaQueryExpression.ServerQueryExpression, + keySelector)); + + return source; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateUnion(ShapedQueryExpression source1, ShapedQueryExpression source2) + => TranslateSetOperation(EnumerableMethods.Union, source1, source2); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateWhere(ShapedQueryExpression source, LambdaExpression predicate) + { + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + var newPredicate = TranslateLambdaExpression(source, predicate, preserveType: true); + if (newPredicate == null) + { + return null; + } + + predicate = newPredicate; + + kafkaQueryExpression.UpdateServerQueryExpression( + Expression.Call( + EnumerableMethods.Where.MakeGenericMethod(kafkaQueryExpression.CurrentParameter.Type), + kafkaQueryExpression.ServerQueryExpression, + predicate)); + + return source; + } + + private Expression? TranslateExpression(Expression expression, bool preserveType = false) + { + var translation = _expressionTranslator.Translate(expression); + if (translation == null && _expressionTranslator.TranslationErrorDetails != null) + { + AddTranslationErrorDetails(_expressionTranslator.TranslationErrorDetails); + } + + if (expression != null + && translation != null + && preserveType + && expression.Type != translation.Type) + { + translation = expression.Type == typeof(bool) + ? Expression.Equal(translation, Expression.Constant(true, translation.Type)) + : Expression.Convert(translation, expression.Type); + } + + return translation; + } + + private LambdaExpression? TranslateLambdaExpression( + ShapedQueryExpression shapedQueryExpression, + LambdaExpression lambdaExpression, + bool preserveType = false) + { + var lambdaBody = TranslateExpression(RemapLambdaBody(shapedQueryExpression, lambdaExpression), preserveType); + + return lambdaBody != null + ? Expression.Lambda( + lambdaBody, + ((KafkaQueryExpression)shapedQueryExpression.QueryExpression).CurrentParameter) + : null; + } + + private Expression RemapLambdaBody(ShapedQueryExpression shapedQueryExpression, LambdaExpression lambdaExpression) + { + var lambdaBody = ReplacingExpressionVisitor.Replace( + lambdaExpression.Parameters.Single(), shapedQueryExpression.ShaperExpression, lambdaExpression.Body); + + return ExpandSharedTypeEntities((KafkaQueryExpression)shapedQueryExpression.QueryExpression, lambdaBody); + } + + private Expression ExpandSharedTypeEntities(KafkaQueryExpression queryExpression, Expression lambdaBody) + => _weakEntityExpandingExpressionVisitor.Expand(queryExpression, lambdaBody); + + private sealed class SharedTypeEntityExpandingExpressionVisitor : ExpressionVisitor + { + private readonly KafkaExpressionTranslatingExpressionVisitor _expressionTranslator; + + private KafkaQueryExpression _queryExpression; + + public SharedTypeEntityExpandingExpressionVisitor(KafkaExpressionTranslatingExpressionVisitor expressionTranslator) + { + _expressionTranslator = expressionTranslator; + _queryExpression = null!; + } + + public string? TranslationErrorDetails + => _expressionTranslator.TranslationErrorDetails; + + public Expression Expand(KafkaQueryExpression queryExpression, Expression lambdaBody) + { + _queryExpression = queryExpression; + + return Visit(lambdaBody); + } + + protected override Expression VisitMember(MemberExpression memberExpression) + { + var innerExpression = Visit(memberExpression.Expression); + + return TryExpand(innerExpression, MemberIdentity.Create(memberExpression.Member)) + ?? memberExpression.Update(innerExpression); + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var navigationName)) + { + source = Visit(source); + + return TryExpand(source, MemberIdentity.Create(navigationName)) + ?? methodCallExpression.Update(null!, new[] { source, methodCallExpression.Arguments[1] }); + } + + return base.VisitMethodCall(methodCallExpression); + } + + protected override Expression VisitExtension(Expression extensionExpression) + => extensionExpression is StructuralTypeShaperExpression or ShapedQueryExpression or GroupByShaperExpression + ? extensionExpression + : base.VisitExtension(extensionExpression); + + private Expression? TryExpand(Expression? source, MemberIdentity member) + { + source = source.UnwrapTypeConversion(out var convertedType); + if (source is not StructuralTypeShaperExpression shaper) + { + return null; + } + + if (shaper.StructuralType is not IEntityType) + { + return null; + } + + var entityType = (IEntityType)shaper.StructuralType; + + if (convertedType != null) + { + entityType = entityType.GetRootType().GetDerivedTypesInclusive() + .FirstOrDefault(et => et.ClrType == convertedType); + + if (entityType == null) + { + return null; + } + } + + var navigation = member.MemberInfo != null + ? entityType.FindNavigation(member.MemberInfo) + : entityType.FindNavigation(member.Name!); + + if (navigation == null) + { + return null; + } + + var targetEntityType = navigation.TargetEntityType; + if (targetEntityType == null + || !targetEntityType.IsOwned()) + { + return null; + } + + var foreignKey = navigation.ForeignKey; + if (navigation.IsCollection) + { + var innerShapedQuery = CreateShapedQueryExpressionStatic(targetEntityType); + var innerQueryExpression = (KafkaQueryExpression)innerShapedQuery.QueryExpression; + + var makeNullable = foreignKey.PrincipalKey.Properties + .Concat(foreignKey.Properties) + .Select(p => p.ClrType) + .Any(t => t.IsNullableType()); + + var outerKey = shaper.CreateKeyValuesExpression( + navigation.IsOnDependent + ? foreignKey.Properties + : foreignKey.PrincipalKey.Properties, + makeNullable); + var innerKey = innerShapedQuery.ShaperExpression.CreateKeyValuesExpression( + navigation.IsOnDependent + ? foreignKey.PrincipalKey.Properties + : foreignKey.Properties, + makeNullable); + + var keyComparison = ExpressionExtensions.CreateEqualsExpression(outerKey, innerKey); + + var predicate = makeNullable + ? Expression.AndAlso( + outerKey is NewArrayExpression newArrayExpression + ? newArrayExpression.Expressions + .Select( + e => + { + var left = (e as UnaryExpression)?.Operand ?? e; + + return Expression.NotEqual(left, Expression.Constant(null, left.Type)); + }) + .Aggregate((l, r) => Expression.AndAlso(l, r)) + : Expression.NotEqual(outerKey, Expression.Constant(null, outerKey.Type)), + keyComparison) + : keyComparison; + + var correlationPredicate = _expressionTranslator.Translate(predicate)!; + innerQueryExpression.UpdateServerQueryExpression( + Expression.Call( + EnumerableMethods.Where.MakeGenericMethod(innerQueryExpression.CurrentParameter.Type), + innerQueryExpression.ServerQueryExpression, + Expression.Lambda(correlationPredicate, innerQueryExpression.CurrentParameter))); + + return innerShapedQuery; + } + + var entityProjectionExpression = + shaper.ValueBufferExpression is ProjectionBindingExpression projectionBindingExpression + ? (EntityProjectionExpression)_queryExpression.GetProjection(projectionBindingExpression) + : (EntityProjectionExpression)shaper.ValueBufferExpression; + var innerShaper = entityProjectionExpression.BindNavigation(navigation); + if (innerShaper == null) + { + var innerShapedQuery = CreateShapedQueryExpressionStatic(targetEntityType); + var innerQueryExpression = (KafkaQueryExpression)innerShapedQuery.QueryExpression; + + var makeNullable = foreignKey.PrincipalKey.Properties + .Concat(foreignKey.Properties) + .Select(p => p.ClrType) + .Any(t => t.IsNullableType()); + + var outerKey = shaper.CreateKeyValuesExpression( + navigation.IsOnDependent + ? foreignKey.Properties + : foreignKey.PrincipalKey.Properties, + makeNullable); + var innerKey = innerShapedQuery.ShaperExpression.CreateKeyValuesExpression( + navigation.IsOnDependent + ? foreignKey.PrincipalKey.Properties + : foreignKey.Properties, + makeNullable); + + if (foreignKey.Properties.Count > 1) + { + outerKey = Expression.New(AnonymousObject.AnonymousObjectCtor, outerKey); + innerKey = Expression.New(AnonymousObject.AnonymousObjectCtor, innerKey); + } + + var outerKeySelector = Expression.Lambda(_expressionTranslator.Translate(outerKey)!, _queryExpression.CurrentParameter); + var innerKeySelector = Expression.Lambda( + _expressionTranslator.Translate(innerKey)!, innerQueryExpression.CurrentParameter); + (outerKeySelector, innerKeySelector) = AlignKeySelectorTypes(outerKeySelector, innerKeySelector); + innerShaper = _queryExpression.AddNavigationToWeakEntityType( + entityProjectionExpression, navigation, innerQueryExpression, outerKeySelector, innerKeySelector); + } + + return innerShaper; + } + } + + private ShapedQueryExpression TranslateTwoParameterSelector(ShapedQueryExpression source, LambdaExpression resultSelector) + { + var transparentIdentifierType = source.ShaperExpression.Type; + var transparentIdentifierParameter = Expression.Parameter(transparentIdentifierType); + + Expression original1 = resultSelector.Parameters[0]; + var replacement1 = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Outer"); + Expression original2 = resultSelector.Parameters[1]; + var replacement2 = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Inner"); + var newResultSelector = Expression.Lambda( + new ReplacingExpressionVisitor( + new[] { original1, original2 }, new[] { replacement1, replacement2 }) + .Visit(resultSelector.Body), + transparentIdentifierParameter); + + return TranslateSelect(source, newResultSelector); + } + + private static Expression AccessField( + Type transparentIdentifierType, + Expression targetExpression, + string fieldName) + => Expression.Field(targetExpression, transparentIdentifierType.GetTypeInfo().GetDeclaredField(fieldName)!); + + private ShapedQueryExpression? TranslateScalarAggregate( + ShapedQueryExpression source, + LambdaExpression? selector, + string methodName, + Type returnType) + { + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + + selector = selector == null + || selector.Body == selector.Parameters[0] + ? Expression.Lambda( + kafkaQueryExpression.GetProjection( + new ProjectionBindingExpression( + kafkaQueryExpression, new ProjectionMember(), returnType)), + kafkaQueryExpression.CurrentParameter) + : TranslateLambdaExpression(source, selector, preserveType: true); + + if (selector == null + || selector.Body is EntityProjectionExpression) + { + return null; + } + + var method = GetMethod(); + method = method.GetGenericArguments().Length == 2 + ? method.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType) + : method.MakeGenericMethod(typeof(ValueBuffer)); + + kafkaQueryExpression.UpdateServerQueryExpression( + Expression.Call(method, kafkaQueryExpression.ServerQueryExpression, selector)); + + return source.UpdateShaperExpression(Expression.Convert(kafkaQueryExpression.GetSingleScalarProjection(), returnType)); + + MethodInfo GetMethod() + => methodName switch + { + nameof(Enumerable.Average) => EnumerableMethods.GetAverageWithSelector(selector.ReturnType), + nameof(Enumerable.Max) => EnumerableMethods.GetMaxWithSelector(selector.ReturnType), + nameof(Enumerable.Min) => EnumerableMethods.GetMinWithSelector(selector.ReturnType), + nameof(Enumerable.Sum) => EnumerableMethods.GetSumWithSelector(selector.ReturnType), + _ => throw new InvalidOperationException(CoreStrings.UnknownEntity("Aggregate Operator")) + }; + } + + private ShapedQueryExpression? TranslateSingleResultOperator( + ShapedQueryExpression source, + LambdaExpression? predicate, + Type returnType, + MethodInfo method) + { + var kafkaQueryExpression = (KafkaQueryExpression)source.QueryExpression; + + if (predicate != null) + { + var newSource = TranslateWhere(source, predicate); + if (newSource == null) + { + return null; + } + + source = newSource; + } + + kafkaQueryExpression.ConvertToSingleResult(method); + + return source.ShaperExpression.Type != returnType + ? source.UpdateShaperExpression(Expression.Convert(source.ShaperExpression, returnType)) + : source; + } + + private static ShapedQueryExpression TranslateSetOperation( + MethodInfo setOperationMethodInfo, + ShapedQueryExpression source1, + ShapedQueryExpression source2) + { + var kafkaQueryExpression1 = (KafkaQueryExpression)source1.QueryExpression; + var kafkaQueryExpression2 = (KafkaQueryExpression)source2.QueryExpression; + + kafkaQueryExpression1.ApplySetOperation(setOperationMethodInfo, kafkaQueryExpression2); + + if (setOperationMethodInfo.Equals(EnumerableMethods.Except)) + { + return source1; + } + + var makeNullable = setOperationMethodInfo != EnumerableMethods.Intersect; + + return source1.UpdateShaperExpression( + MatchShaperNullabilityForSetOperation( + source1.ShaperExpression, source2.ShaperExpression, makeNullable)); + } + + private static Expression MatchShaperNullabilityForSetOperation(Expression shaper1, Expression shaper2, bool makeNullable) + { + switch (shaper1) + { + case StructuralTypeShaperExpression entityShaperExpression1 + when shaper2 is StructuralTypeShaperExpression entityShaperExpression2: + return entityShaperExpression1.IsNullable != entityShaperExpression2.IsNullable + ? entityShaperExpression1.MakeNullable(makeNullable) + : entityShaperExpression1; + + case NewExpression newExpression1 + when shaper2 is NewExpression newExpression2: + var newArguments = new Expression[newExpression1.Arguments.Count]; + for (var i = 0; i < newArguments.Length; i++) + { + newArguments[i] = MatchShaperNullabilityForSetOperation( + newExpression1.Arguments[i], newExpression2.Arguments[i], makeNullable); + } + + return newExpression1.Update(newArguments); + + case MemberInitExpression memberInitExpression1 + when shaper2 is MemberInitExpression memberInitExpression2: + var newExpression = (NewExpression)MatchShaperNullabilityForSetOperation( + memberInitExpression1.NewExpression, memberInitExpression2.NewExpression, makeNullable); + + var memberBindings = new MemberBinding[memberInitExpression1.Bindings.Count]; + for (var i = 0; i < memberBindings.Length; i++) + { + var memberAssignment = memberInitExpression1.Bindings[i] as MemberAssignment; + Check.DebugAssert(memberAssignment != null, "Only member assignment bindings are supported"); + + memberBindings[i] = memberAssignment.Update( + MatchShaperNullabilityForSetOperation( + memberAssignment.Expression, ((MemberAssignment)memberInitExpression2.Bindings[i]).Expression, + makeNullable)); + } + + return memberInitExpression1.Update(newExpression, memberBindings); + + default: + return shaper1; + } + } +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaQueryableMethodTranslatingExpressionVisitorFactory.cs b/src/net/KEFCore/Query/Internal8/KafkaQueryableMethodTranslatingExpressionVisitorFactory.cs new file mode 100644 index 00000000..8ae12099 --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaQueryableMethodTranslatingExpressionVisitorFactory.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class KafkaQueryableMethodTranslatingExpressionVisitorFactory : IQueryableMethodTranslatingExpressionVisitorFactory +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public KafkaQueryableMethodTranslatingExpressionVisitorFactory( + QueryableMethodTranslatingExpressionVisitorDependencies dependencies) + { + Dependencies = dependencies; + } + + /// + /// Dependencies for this service. + /// + protected virtual QueryableMethodTranslatingExpressionVisitorDependencies Dependencies { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual QueryableMethodTranslatingExpressionVisitor Create(QueryCompilationContext queryCompilationContext) + => new KafkaQueryableMethodTranslatingExpressionVisitor(Dependencies, queryCompilationContext); +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs b/src/net/KEFCore/Query/Internal8/KafkaShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs new file mode 100644 index 00000000..b2a53832 --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs @@ -0,0 +1,191 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using MASES.EntityFrameworkCore.KNet.Internal; +using System.Collections; + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public partial class KafkaShapedQueryCompilingExpressionVisitor +{ + private sealed class QueryingEnumerable : IAsyncEnumerable, IEnumerable, IQueryingEnumerable + { + private readonly QueryContext _queryContext; + private readonly IEnumerable _innerEnumerable; + private readonly Func _shaper; + private readonly Type _contextType; + private readonly IDiagnosticsLogger _queryLogger; + private readonly bool _standAloneStateManager; + private readonly bool _threadSafetyChecksEnabled; + + public QueryingEnumerable( + QueryContext queryContext, + IEnumerable innerEnumerable, + Func shaper, + Type contextType, + bool standAloneStateManager, + bool threadSafetyChecksEnabled) + { + _queryContext = queryContext; + _innerEnumerable = innerEnumerable; + _shaper = shaper; + _contextType = contextType; + _queryLogger = queryContext.QueryLogger; + _standAloneStateManager = standAloneStateManager; + _threadSafetyChecksEnabled = threadSafetyChecksEnabled; + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + => new Enumerator(this, cancellationToken); + + public IEnumerator GetEnumerator() + => new Enumerator(this); + + IEnumerator IEnumerable.GetEnumerator() + => GetEnumerator(); + + public string ToQueryString() + => KafkaStrings.NoQueryStrings; + + private sealed class Enumerator : IEnumerator, IAsyncEnumerator + { + private readonly QueryContext _queryContext; + private readonly IEnumerable _innerEnumerable; + private readonly Func _shaper; + private readonly Type _contextType; + private readonly IDiagnosticsLogger _queryLogger; + private readonly bool _standAloneStateManager; + private readonly CancellationToken _cancellationToken; + private readonly IConcurrencyDetector? _concurrencyDetector; + private readonly IExceptionDetector _exceptionDetector; + + private IEnumerator? _enumerator; + + public Enumerator(QueryingEnumerable queryingEnumerable, CancellationToken cancellationToken = default) + { + _queryContext = queryingEnumerable._queryContext; + _innerEnumerable = queryingEnumerable._innerEnumerable; + _shaper = queryingEnumerable._shaper; + _contextType = queryingEnumerable._contextType; + _queryLogger = queryingEnumerable._queryLogger; + _standAloneStateManager = queryingEnumerable._standAloneStateManager; + _cancellationToken = cancellationToken; + _exceptionDetector = _queryContext.ExceptionDetector; + Current = default!; + + _concurrencyDetector = queryingEnumerable._threadSafetyChecksEnabled + ? _queryContext.ConcurrencyDetector + : null; + } + + public T Current { get; private set; } + + object IEnumerator.Current + => Current!; + + public bool MoveNext() + { + try + { + _concurrencyDetector?.EnterCriticalSection(); + + try + { + return MoveNextHelper(); + } + finally + { + _concurrencyDetector?.ExitCriticalSection(); + } + } + catch (Exception exception) + { + if (_exceptionDetector.IsCancellation(exception)) + { + _queryLogger.QueryCanceled(_contextType); + } + else + { + _queryLogger.QueryIterationFailed(_contextType, exception); + } + + throw; + } + } + + public ValueTask MoveNextAsync() + { + try + { + _concurrencyDetector?.EnterCriticalSection(); + + try + { + _cancellationToken.ThrowIfCancellationRequested(); + + return ValueTask.FromResult(MoveNextHelper()); + } + finally + { + _concurrencyDetector?.ExitCriticalSection(); + } + } + catch (Exception exception) + { + if (_exceptionDetector.IsCancellation(exception, _cancellationToken)) + { + _queryLogger.QueryCanceled(_contextType); + } + else + { + _queryLogger.QueryIterationFailed(_contextType, exception); + } + + throw; + } + } + + private bool MoveNextHelper() + { + if (_enumerator == null) + { + EntityFrameworkEventSource.Log.QueryExecuting(); + + _enumerator = _innerEnumerable.GetEnumerator(); + _queryContext.InitializeStateManager(_standAloneStateManager); + } + + var hasNext = _enumerator.MoveNext(); + + Current = hasNext + ? _shaper(_queryContext, _enumerator.Current) + : default!; + + return hasNext; + } + + public void Dispose() + { + _enumerator?.Dispose(); + _enumerator = null; + } + + public ValueTask DisposeAsync() + { + var enumerator = _enumerator; + _enumerator = null; + + return enumerator.DisposeAsyncIfAvailable(); + } + + public void Reset() + => throw new NotSupportedException(CoreStrings.EnumerableResetNotSupported); + } + } +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaShapedQueryCompilingExpressionVisitor.ShaperExpressionProcessingExpressionVisitor.cs b/src/net/KEFCore/Query/Internal8/KafkaShapedQueryCompilingExpressionVisitor.ShaperExpressionProcessingExpressionVisitor.cs new file mode 100644 index 00000000..a86416a3 --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaShapedQueryCompilingExpressionVisitor.ShaperExpressionProcessingExpressionVisitor.cs @@ -0,0 +1,411 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using static System.Linq.Expressions.Expression; +using ExpressionExtensions = Microsoft.EntityFrameworkCore.Infrastructure.ExpressionExtensions; + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +public partial class KafkaShapedQueryCompilingExpressionVisitor +{ + private sealed class ShaperExpressionProcessingExpressionVisitor : ExpressionVisitor + { + private static readonly MethodInfo IncludeReferenceMethodInfo + = typeof(ShaperExpressionProcessingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(IncludeReference))!; + + private static readonly MethodInfo IncludeCollectionMethodInfo + = typeof(ShaperExpressionProcessingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(IncludeCollection))!; + + private static readonly MethodInfo MaterializeCollectionMethodInfo + = typeof(ShaperExpressionProcessingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(MaterializeCollection))!; + + private static readonly MethodInfo MaterializeSingleResultMethodInfo + = typeof(ShaperExpressionProcessingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(MaterializeSingleResult))!; + + private static readonly MethodInfo CollectionAccessorAddMethodInfo + = typeof(IClrCollectionAccessor).GetTypeInfo().GetDeclaredMethod(nameof(IClrCollectionAccessor.Add))!; + + private readonly KafkaShapedQueryCompilingExpressionVisitor _kafkaShapedQueryCompilingExpressionVisitor; + private readonly bool _tracking; + private ParameterExpression? _valueBufferParameter; + + private readonly Dictionary _mapping = new(); + private readonly List _variables = new(); + private readonly List _expressions = new(); + private readonly Dictionary> _materializationContextBindings = new(); + + public ShaperExpressionProcessingExpressionVisitor( + KafkaShapedQueryCompilingExpressionVisitor kafkaShapedQueryCompilingExpressionVisitor, + KafkaQueryExpression kafkaQueryExpression, + bool tracking) + { + _kafkaShapedQueryCompilingExpressionVisitor = kafkaShapedQueryCompilingExpressionVisitor; + _valueBufferParameter = kafkaQueryExpression.CurrentParameter; + _tracking = tracking; + } + + private ShaperExpressionProcessingExpressionVisitor( + KafkaShapedQueryCompilingExpressionVisitor kafkaShapedQueryCompilingExpressionVisitor, + bool tracking) + { + _kafkaShapedQueryCompilingExpressionVisitor = kafkaShapedQueryCompilingExpressionVisitor; + _tracking = tracking; + } + + public LambdaExpression ProcessShaper(Expression shaperExpression) + { + var result = Visit(shaperExpression); + _expressions.Add(result); + result = Block(_variables, _expressions); + + // If parameter is null then the projection is not really server correlated so we can just put anything. + _valueBufferParameter ??= Parameter(typeof(ValueBuffer)); + + return Lambda(result, QueryCompilationContext.QueryContextParameter, _valueBufferParameter); + } + + protected override Expression VisitExtension(Expression extensionExpression) + { + switch (extensionExpression) + { + case StructuralTypeShaperExpression shaper: + { + var key = shaper.ValueBufferExpression; + if (!_mapping.TryGetValue(key, out var variable)) + { + variable = Parameter(shaper.StructuralType.ClrType); + _variables.Add(variable); + var innerShaper = + _kafkaShapedQueryCompilingExpressionVisitor.InjectEntityMaterializers(shaper); + innerShaper = Visit(innerShaper); + _expressions.Add(Assign(variable, innerShaper)); + _mapping[key] = variable; + } + + return variable; + } + + case ProjectionBindingExpression projectionBindingExpression: + { + var key = projectionBindingExpression; + if (!_mapping.TryGetValue(key, out var variable)) + { + variable = Parameter(projectionBindingExpression.Type); + _variables.Add(variable); + var queryExpression = (KafkaQueryExpression)projectionBindingExpression.QueryExpression; + _valueBufferParameter ??= queryExpression.CurrentParameter; + + var projectionIndex = queryExpression.GetProjection(projectionBindingExpression).GetConstantValue(); + + // We don't need to pass property when reading at top-level + _expressions.Add( + Assign( + variable, queryExpression.CurrentParameter.CreateValueBufferReadValueExpression( + projectionBindingExpression.Type, projectionIndex, property: null))); + _mapping[key] = variable; + } + + return variable; + } + + case IncludeExpression includeExpression: + { + var entity = Visit(includeExpression.EntityExpression); + var entityClrType = includeExpression.EntityExpression.Type; + var includingClrType = includeExpression.Navigation.DeclaringEntityType.ClrType; + var inverseNavigation = includeExpression.Navigation.Inverse; + var relatedEntityClrType = includeExpression.Navigation.TargetEntityType.ClrType; + if (includingClrType != entityClrType + && includingClrType.IsAssignableFrom(entityClrType)) + { + includingClrType = entityClrType; + } + + if (includeExpression.Navigation.IsCollection) + { + var collectionResultShaperExpression = (CollectionResultShaperExpression)includeExpression.NavigationExpression; + var shaperLambda = new ShaperExpressionProcessingExpressionVisitor( + _kafkaShapedQueryCompilingExpressionVisitor, _tracking) + .ProcessShaper(collectionResultShaperExpression.InnerShaper); + _expressions.Add( + Call( + IncludeCollectionMethodInfo.MakeGenericMethod(entityClrType, includingClrType, relatedEntityClrType), + QueryCompilationContext.QueryContextParameter, + Visit(collectionResultShaperExpression.Projection), + Constant(shaperLambda.Compile()), + entity, + Constant(includeExpression.Navigation), + Constant(inverseNavigation, typeof(INavigationBase)), + Constant( + GenerateFixup( + includingClrType, relatedEntityClrType, includeExpression.Navigation, inverseNavigation) + .Compile()), + Constant(_tracking), +#pragma warning disable EF1001 // Internal EF Core API usage. + Constant(includeExpression.SetLoaded))); +#pragma warning restore EF1001 // Internal EF Core API usage. + } + else + { + _expressions.Add( + Call( + IncludeReferenceMethodInfo.MakeGenericMethod(entityClrType, includingClrType, relatedEntityClrType), + QueryCompilationContext.QueryContextParameter, + entity, + Visit(includeExpression.NavigationExpression), + Constant(includeExpression.Navigation), + Constant(inverseNavigation, typeof(INavigationBase)), + Constant( + GenerateFixup( + includingClrType, relatedEntityClrType, includeExpression.Navigation, inverseNavigation) + .Compile()), + Constant(_tracking))); + } + + return entity; + } + + case CollectionResultShaperExpression collectionResultShaperExpression: + { + var navigation = collectionResultShaperExpression.Navigation; + var collectionAccessor = navigation?.GetCollectionAccessor(); + var collectionType = collectionAccessor?.CollectionType ?? collectionResultShaperExpression.Type; + var elementType = collectionResultShaperExpression.ElementType; + var shaperLambda = new ShaperExpressionProcessingExpressionVisitor( + _kafkaShapedQueryCompilingExpressionVisitor, _tracking) + .ProcessShaper(collectionResultShaperExpression.InnerShaper); + + return Call( + MaterializeCollectionMethodInfo.MakeGenericMethod(elementType, collectionType), + QueryCompilationContext.QueryContextParameter, + Visit(collectionResultShaperExpression.Projection), + Constant(shaperLambda.Compile()), + Constant(collectionAccessor, typeof(IClrCollectionAccessor))); + } + + case SingleResultShaperExpression singleResultShaperExpression: + { + var shaperLambda = new ShaperExpressionProcessingExpressionVisitor( + _kafkaShapedQueryCompilingExpressionVisitor, _tracking) + .ProcessShaper(singleResultShaperExpression.InnerShaper); + + return Call( + MaterializeSingleResultMethodInfo.MakeGenericMethod(singleResultShaperExpression.Type), + QueryCompilationContext.QueryContextParameter, + Visit(singleResultShaperExpression.Projection), + Constant(shaperLambda.Compile())); + } + } + + return base.VisitExtension(extensionExpression); + } + + protected override Expression VisitBinary(BinaryExpression binaryExpression) + { + if (binaryExpression is { NodeType: ExpressionType.Assign, Left: ParameterExpression parameterExpression } + && parameterExpression.Type == typeof(MaterializationContext)) + { + var newExpression = (NewExpression)binaryExpression.Right; + + var projectionBindingExpression = (ProjectionBindingExpression)newExpression.Arguments[0]; + var queryExpression = (KafkaQueryExpression)projectionBindingExpression.QueryExpression; + _valueBufferParameter ??= queryExpression.CurrentParameter; + + _materializationContextBindings[parameterExpression] + = queryExpression.GetProjection(projectionBindingExpression).GetConstantValue>(); + + var updatedExpression = newExpression.Update( + new[] { Constant(ValueBuffer.Empty), newExpression.Arguments[1] }); + + return MakeBinary(ExpressionType.Assign, binaryExpression.Left, updatedExpression); + } + + if (binaryExpression is + { NodeType: ExpressionType.Assign, Left: MemberExpression { Member: FieldInfo { IsInitOnly: true } } memberExpression }) + { + return memberExpression.Assign(Visit(binaryExpression.Right)); + } + + return base.VisitBinary(binaryExpression); + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + if (methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == ExpressionExtensions.ValueBufferTryReadValueMethod) + { + var property = methodCallExpression.Arguments[2].GetConstantValue(); + var indexMap = _materializationContextBindings[ + (ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object!]; + + Check.DebugAssert( + property != null || methodCallExpression.Type.IsNullableType(), "Must read nullable value without property"); + + return Call( + methodCallExpression.Method, + _valueBufferParameter!, + Constant(indexMap[property!]), + methodCallExpression.Arguments[2]); + } + + return base.VisitMethodCall(methodCallExpression); + } + + private static void IncludeReference( + QueryContext queryContext, + TEntity entity, + TIncludedEntity? relatedEntity, + INavigationBase navigation, + INavigationBase? inverseNavigation, + Action fixup, + bool trackingQuery) + where TIncludingEntity : class, TEntity + where TEntity : class + where TIncludedEntity : class + { + if (entity is TIncludingEntity includingEntity) + { + if (trackingQuery + && navigation.DeclaringEntityType.FindPrimaryKey() != null) + { + // For non-null relatedEntity StateManager will set the flag + if (relatedEntity == null) + { + queryContext.SetNavigationIsLoaded(includingEntity, navigation); + } + } + else + { + navigation.SetIsLoadedWhenNoTracking(includingEntity); + if (relatedEntity != null) + { + fixup(includingEntity, relatedEntity); + if (inverseNavigation is { IsCollection: false }) + { + inverseNavigation.SetIsLoadedWhenNoTracking(relatedEntity); + } + } + } + } + } + + private static void IncludeCollection( + QueryContext queryContext, + IEnumerable innerValueBuffers, + Func innerShaper, + TEntity entity, + INavigationBase navigation, + INavigationBase? inverseNavigation, + Action fixup, + bool trackingQuery, + bool setLoaded) + where TIncludingEntity : class, TEntity + where TEntity : class + where TIncludedEntity : class + { + if (entity is TIncludingEntity includingEntity) + { + if (!navigation.IsShadowProperty()) + { + navigation.GetCollectionAccessor()!.GetOrCreate(includingEntity, forMaterialization: true); + } + + if (setLoaded) + { + if (trackingQuery) + { + queryContext.SetNavigationIsLoaded(entity, navigation); + } + else + { + navigation.SetIsLoadedWhenNoTracking(entity); + } + } + + foreach (var valueBuffer in innerValueBuffers) + { + var relatedEntity = innerShaper(queryContext, valueBuffer); + + if (!trackingQuery) + { + fixup(includingEntity, relatedEntity); + inverseNavigation?.SetIsLoadedWhenNoTracking(relatedEntity); + } + } + } + } + + private static TCollection MaterializeCollection( + QueryContext queryContext, + IEnumerable innerValueBuffers, + Func innerShaper, + IClrCollectionAccessor? clrCollectionAccessor) + where TCollection : class, ICollection + { + var collection = (TCollection)(clrCollectionAccessor?.Create() ?? new List()); + + foreach (var valueBuffer in innerValueBuffers) + { + var element = innerShaper(queryContext, valueBuffer); + collection.Add(element); + } + + return collection; + } + + private static TResult? MaterializeSingleResult( + QueryContext queryContext, + ValueBuffer valueBuffer, + Func innerShaper) + => valueBuffer.IsEmpty + ? default + : innerShaper(queryContext, valueBuffer); + + private static LambdaExpression GenerateFixup( + Type entityType, + Type relatedEntityType, + INavigationBase navigation, + INavigationBase? inverseNavigation) + { + var entityParameter = Parameter(entityType); + var relatedEntityParameter = Parameter(relatedEntityType); + var expressions = new List(); + + if (!navigation.IsShadowProperty()) + { + expressions.Add( + navigation.IsCollection + ? AddToCollectionNavigation(entityParameter, relatedEntityParameter, navigation) + : AssignReferenceNavigation(entityParameter, relatedEntityParameter, navigation)); + } + + if (inverseNavigation != null + && !inverseNavigation.IsShadowProperty()) + { + expressions.Add( + inverseNavigation.IsCollection + ? AddToCollectionNavigation(relatedEntityParameter, entityParameter, inverseNavigation) + : AssignReferenceNavigation(relatedEntityParameter, entityParameter, inverseNavigation)); + } + + return Lambda(Block(typeof(void), expressions), entityParameter, relatedEntityParameter); + } + + private static Expression AssignReferenceNavigation( + ParameterExpression entity, + ParameterExpression relatedEntity, + INavigationBase navigation) + => entity.MakeMemberAccess(navigation.GetMemberInfo(forMaterialization: true, forSet: true)).Assign(relatedEntity); + + private static Expression AddToCollectionNavigation( + ParameterExpression entity, + ParameterExpression relatedEntity, + INavigationBase navigation) + => Call( + Constant(navigation.GetCollectionAccessor()), + CollectionAccessorAddMethodInfo, + entity, + relatedEntity, + Constant(true)); + } +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaShapedQueryCompilingExpressionVisitor.cs b/src/net/KEFCore/Query/Internal8/KafkaShapedQueryCompilingExpressionVisitor.cs new file mode 100644 index 00000000..1470fe98 --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaShapedQueryCompilingExpressionVisitor.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +using static Expression; + +public partial class KafkaShapedQueryCompilingExpressionVisitor : ShapedQueryCompilingExpressionVisitor +{ + private readonly Type _contextType; + private readonly bool _threadSafetyChecksEnabled; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public KafkaShapedQueryCompilingExpressionVisitor( + ShapedQueryCompilingExpressionVisitorDependencies dependencies, + QueryCompilationContext queryCompilationContext) + : base(dependencies, queryCompilationContext) + { + _contextType = queryCompilationContext.ContextType; + _threadSafetyChecksEnabled = dependencies.CoreSingletonOptions.AreThreadSafetyChecksEnabled; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitExtension(Expression extensionExpression) + { + switch (extensionExpression) + { + case KafkaTableExpression kafkaTableExpression: + return Call( + TableMethodInfo, + QueryCompilationContext.QueryContextParameter, + Constant(kafkaTableExpression.EntityType)); + } + + return base.VisitExtension(extensionExpression); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQueryExpression) + { + var kafkaQueryExpression = (KafkaQueryExpression)shapedQueryExpression.QueryExpression; + kafkaQueryExpression.ApplyProjection(); + + var shaperExpression = new ShaperExpressionProcessingExpressionVisitor( + this, kafkaQueryExpression, QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll) + .ProcessShaper(shapedQueryExpression.ShaperExpression); + var innerEnumerable = Visit(kafkaQueryExpression.ServerQueryExpression); + + return New( + typeof(QueryingEnumerable<>).MakeGenericType(shaperExpression.ReturnType).GetConstructors()[0], + QueryCompilationContext.QueryContextParameter, + innerEnumerable, + Constant(shaperExpression.Compile()), + Constant(_contextType), + Constant( + QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.NoTrackingWithIdentityResolution), + Constant(_threadSafetyChecksEnabled)); + } + + private static readonly MethodInfo TableMethodInfo + = typeof(KafkaShapedQueryCompilingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(Table))!; + + private static IEnumerable Table( + QueryContext queryContext, + IEntityType entityType) + => ((KafkaQueryContext)queryContext).GetValueBuffers(entityType); +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaShapedQueryExpressionVisitorFactory.cs b/src/net/KEFCore/Query/Internal8/KafkaShapedQueryExpressionVisitorFactory.cs new file mode 100644 index 00000000..1728a65d --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaShapedQueryExpressionVisitorFactory.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class KafkaShapedQueryCompilingExpressionVisitorFactory : IShapedQueryCompilingExpressionVisitorFactory +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public KafkaShapedQueryCompilingExpressionVisitorFactory( + ShapedQueryCompilingExpressionVisitorDependencies dependencies) + { + Dependencies = dependencies; + } + + /// + /// Dependencies for this service. + /// + protected virtual ShapedQueryCompilingExpressionVisitorDependencies Dependencies { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual ShapedQueryCompilingExpressionVisitor Create(QueryCompilationContext queryCompilationContext) + => new KafkaShapedQueryCompilingExpressionVisitor(Dependencies, queryCompilationContext); +} diff --git a/src/net/KEFCore/Query/Internal8/KafkaTableExpression.cs b/src/net/KEFCore/Query/Internal8/KafkaTableExpression.cs new file mode 100644 index 00000000..524ebeaf --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/KafkaTableExpression.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class KafkaTableExpression : Expression, IPrintableExpression +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public KafkaTableExpression(IEntityType entityType) + { + EntityType = entityType; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Type Type + => typeof(IEnumerable); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual IEntityType EntityType { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public sealed override ExpressionType NodeType + => ExpressionType.Extension; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + => this; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + void IPrintableExpression.Print(ExpressionPrinter expressionPrinter) + => expressionPrinter.Append(nameof(KafkaTableExpression) + ": Entity: " + EntityType.DisplayName()); +} diff --git a/src/net/KEFCore/Query/Internal8/SingleResultShaperExpression.cs b/src/net/KEFCore/Query/Internal8/SingleResultShaperExpression.cs new file mode 100644 index 00000000..db581cb1 --- /dev/null +++ b/src/net/KEFCore/Query/Internal8/SingleResultShaperExpression.cs @@ -0,0 +1,105 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace MASES.EntityFrameworkCore.KNet.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class SingleResultShaperExpression : Expression, IPrintableExpression +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SingleResultShaperExpression( + Expression projection, + Expression innerShaper) + { + Projection = projection; + InnerShaper = innerShaper; + Type = innerShaper.Type; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + var projection = visitor.Visit(Projection); + var innerShaper = visitor.Visit(InnerShaper); + + return Update(projection, innerShaper); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual SingleResultShaperExpression Update(Expression projection, Expression innerShaper) + => projection != Projection || innerShaper != InnerShaper + ? new SingleResultShaperExpression(projection, innerShaper) + : this; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public sealed override ExpressionType NodeType + => ExpressionType.Extension; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Type Type { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression Projection { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression InnerShaper { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + void IPrintableExpression.Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.AppendLine($"{nameof(SingleResultShaperExpression)}:"); + using (expressionPrinter.Indent()) + { + expressionPrinter.Append("("); + expressionPrinter.Visit(Projection); + expressionPrinter.Append(", "); + expressionPrinter.Visit(InnerShaper); + expressionPrinter.AppendLine(")"); + } + } +} diff --git a/src/net/KEFCore/Shared8/Check.cs b/src/net/KEFCore/Shared8/Check.cs new file mode 100644 index 00000000..2ebf5ec1 --- /dev/null +++ b/src/net/KEFCore/Shared8/Check.cs @@ -0,0 +1,123 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +using System.Diagnostics.CodeAnalysis; +using JetBrains.Annotations; + +namespace Microsoft.EntityFrameworkCore.Utilities; + +[DebuggerStepThrough] +internal static class Check +{ + [ContractAnnotation("value:null => halt")] + [return: NotNull] + public static T NotNull([NoEnumeration] [AllowNull] [NotNull] T value, [InvokerParameterName] string parameterName) + { + if (value is null) + { + NotEmpty(parameterName, nameof(parameterName)); + + throw new ArgumentNullException(parameterName); + } + + return value; + } + + [ContractAnnotation("value:null => halt")] + public static IReadOnlyList NotEmpty( + [NotNull] IReadOnlyList? value, + [InvokerParameterName] string parameterName) + { + NotNull(value, parameterName); + + if (value.Count == 0) + { + NotEmpty(parameterName, nameof(parameterName)); + + throw new ArgumentException(AbstractionsStrings.CollectionArgumentIsEmpty(parameterName)); + } + + return value; + } + + [ContractAnnotation("value:null => halt")] + public static string NotEmpty([NotNull] string? value, [InvokerParameterName] string parameterName) + { + if (value is null) + { + NotEmpty(parameterName, nameof(parameterName)); + + throw new ArgumentNullException(parameterName); + } + + if (value.Trim().Length == 0) + { + NotEmpty(parameterName, nameof(parameterName)); + + throw new ArgumentException(AbstractionsStrings.ArgumentIsEmpty(parameterName)); + } + + return value; + } + + public static string? NullButNotEmpty(string? value, [InvokerParameterName] string parameterName) + { + if (value is not null && value.Length == 0) + { + NotEmpty(parameterName, nameof(parameterName)); + + throw new ArgumentException(AbstractionsStrings.ArgumentIsEmpty(parameterName)); + } + + return value; + } + + public static IReadOnlyList HasNoNulls( + [NotNull] IReadOnlyList? value, + [InvokerParameterName] string parameterName) + where T : class + { + NotNull(value, parameterName); + + if (value.Any(e => e == null)) + { + NotEmpty(parameterName, nameof(parameterName)); + + throw new ArgumentException(parameterName); + } + + return value; + } + + public static IReadOnlyList HasNoEmptyElements( + [NotNull] IReadOnlyList? value, + [InvokerParameterName] string parameterName) + { + NotNull(value, parameterName); + + if (value.Any(s => string.IsNullOrWhiteSpace(s))) + { + NotEmpty(parameterName, nameof(parameterName)); + + throw new ArgumentException(AbstractionsStrings.CollectionArgumentHasEmptyElements(parameterName)); + } + + return value; + } + + [Conditional("DEBUG")] + public static void DebugAssert([DoesNotReturnIf(false)] bool condition, string message) + { + if (!condition) + { + throw new UnreachableException($"Check.DebugAssert failed: {message}"); + } + } + + [Conditional("DEBUG")] + [DoesNotReturn] + public static void DebugFail(string message) + => throw new UnreachableException($"Check.DebugFail failed: {message}"); +} diff --git a/src/net/KEFCore/Shared8/CodeAnnotations.cs b/src/net/KEFCore/Shared8/CodeAnnotations.cs new file mode 100644 index 00000000..d81a2c26 --- /dev/null +++ b/src/net/KEFCore/Shared8/CodeAnnotations.cs @@ -0,0 +1,97 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +#nullable enable + +namespace JetBrains.Annotations; + +[AttributeUsage(AttributeTargets.Parameter)] +internal sealed class InvokerParameterNameAttribute : Attribute +{ +} + +[AttributeUsage(AttributeTargets.Parameter)] +internal sealed class NoEnumerationAttribute : Attribute +{ +} + +[AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] +internal sealed class ContractAnnotationAttribute : Attribute +{ + public string Contract { get; } + + public bool ForceFullStates { get; } + + public ContractAnnotationAttribute(string contract) + : this(contract, false) + { + } + + public ContractAnnotationAttribute(string contract, bool forceFullStates) + { + Contract = contract; + ForceFullStates = forceFullStates; + } +} + +[AttributeUsage(AttributeTargets.All)] +internal sealed class UsedImplicitlyAttribute : Attribute +{ + public UsedImplicitlyAttribute() + : this(ImplicitUseKindFlags.Default, ImplicitUseTargetFlags.Default) + { + } + + public UsedImplicitlyAttribute(ImplicitUseKindFlags useKindFlags) + : this(useKindFlags, ImplicitUseTargetFlags.Default) + { + } + + public UsedImplicitlyAttribute(ImplicitUseTargetFlags targetFlags) + : this(ImplicitUseKindFlags.Default, targetFlags) + { + } + + public UsedImplicitlyAttribute( + ImplicitUseKindFlags useKindFlags, + ImplicitUseTargetFlags targetFlags) + { + UseKindFlags = useKindFlags; + TargetFlags = targetFlags; + } + + public ImplicitUseKindFlags UseKindFlags { get; } + public ImplicitUseTargetFlags TargetFlags { get; } +} + +[AttributeUsage(AttributeTargets.Constructor | AttributeTargets.Method | AttributeTargets.Property | AttributeTargets.Delegate)] +internal sealed class StringFormatMethodAttribute : Attribute +{ + public StringFormatMethodAttribute(string formatParameterName) + { + FormatParameterName = formatParameterName; + } + + public string FormatParameterName { get; } +} + +[Flags] +internal enum ImplicitUseKindFlags +{ + Default = Access | Assign | InstantiatedWithFixedConstructorSignature, + Access = 1, + Assign = 2, + InstantiatedWithFixedConstructorSignature = 4, + InstantiatedNoFixedConstructorSignature = 8 +} + +[Flags] +internal enum ImplicitUseTargetFlags +{ + Default = Itself, + Itself = 1, + Members = 2, + WithMembers = Itself | Members +} diff --git a/src/net/KEFCore/Shared8/DictionaryExtensions.cs b/src/net/KEFCore/Shared8/DictionaryExtensions.cs new file mode 100644 index 00000000..e39f1f90 --- /dev/null +++ b/src/net/KEFCore/Shared8/DictionaryExtensions.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.EntityFrameworkCore.Utilities; + +[DebuggerStepThrough] +internal static class DictionaryExtensions +{ + public static TValue GetOrAddNew( + this IDictionary source, + TKey key) + where TValue : new() + { + if (!source.TryGetValue(key, out var value)) + { + value = new TValue(); + source.Add(key, value); + } + + return value; + } + + public static TValue? Find( + this IReadOnlyDictionary source, + TKey key) + => !source.TryGetValue(key, out var value) ? default : value; + + public static bool TryGetAndRemove( + this IDictionary source, + TKey key, + [NotNullWhen(true)] out TReturn value) + { + if (source.TryGetValue(key, out var item) + && item != null) + { + source.Remove(key); + value = (TReturn)(object)item; + return true; + } + + value = default!; + return false; + } + + public static void Remove( + this IDictionary source, + Func predicate) + => source.Remove((k, v, p) => p!(k, v), predicate); + + public static void Remove( + this IDictionary source, + Func predicate, + TState? state) + { + var found = false; + var firstRemovedKey = default(TKey); + List>? pairsRemainder = null; + foreach (var pair in source) + { + if (found) + { + pairsRemainder ??= new List>(); + + pairsRemainder.Add(pair); + continue; + } + + if (!predicate(pair.Key, pair.Value, state)) + { + continue; + } + + if (!found) + { + found = true; + firstRemovedKey = pair.Key; + } + } + + if (found) + { + source.Remove(firstRemovedKey!); + if (pairsRemainder == null) + { + return; + } + + foreach (var (key, value) in pairsRemainder) + { + if (predicate(key, value, state)) + { + source.Remove(key); + } + } + } + } +} diff --git a/src/net/KEFCore/Shared8/DisposableExtensions.cs b/src/net/KEFCore/Shared8/DisposableExtensions.cs new file mode 100644 index 00000000..e43c0187 --- /dev/null +++ b/src/net/KEFCore/Shared8/DisposableExtensions.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +namespace Microsoft.EntityFrameworkCore.Utilities; + +internal static class DisposableExtensions +{ + public static ValueTask DisposeAsyncIfAvailable(this IDisposable? disposable) + { + if (disposable != null) + { + if (disposable is IAsyncDisposable asyncDisposable) + { + return asyncDisposable.DisposeAsync(); + } + + disposable.Dispose(); + } + + return default; + } +} diff --git a/src/net/KEFCore/Shared8/EnumerableExtensions.cs b/src/net/KEFCore/Shared8/EnumerableExtensions.cs new file mode 100644 index 00000000..eeeceea2 --- /dev/null +++ b/src/net/KEFCore/Shared8/EnumerableExtensions.cs @@ -0,0 +1,145 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +using System.Collections; + +// ReSharper disable once CheckNamespace +namespace Microsoft.EntityFrameworkCore.Utilities; + +[DebuggerStepThrough] +internal static class EnumerableExtensions +{ + public static IOrderedEnumerable OrderByOrdinal( + this IEnumerable source, + Func keySelector) + => source.OrderBy(keySelector, StringComparer.Ordinal); + + public static IEnumerable Distinct( + this IEnumerable source, + Func comparer) + where T : class + => source.Distinct(new DynamicEqualityComparer(comparer)); + + private sealed class DynamicEqualityComparer : IEqualityComparer + where T : class + { + private readonly Func _func; + + public DynamicEqualityComparer(Func func) + { + _func = func; + } + + public bool Equals(T? x, T? y) + => _func(x, y); + + public int GetHashCode(T obj) + => 0; + } + + public static string Join( + this IEnumerable source, + string separator = ", ") + => string.Join(separator, source); + + public static bool StructuralSequenceEqual( + this IEnumerable first, + IEnumerable second) + { + if (ReferenceEquals(first, second)) + { + return true; + } + + using var firstEnumerator = first.GetEnumerator(); + using var secondEnumerator = second.GetEnumerator(); + while (firstEnumerator.MoveNext()) + { + if (!secondEnumerator.MoveNext() + || !StructuralComparisons.StructuralEqualityComparer + .Equals(firstEnumerator.Current, secondEnumerator.Current)) + { + return false; + } + } + + return !secondEnumerator.MoveNext(); + } + + public static bool StartsWith( + this IEnumerable first, + IEnumerable second) + { + if (ReferenceEquals(first, second)) + { + return true; + } + + using var firstEnumerator = first.GetEnumerator(); + using var secondEnumerator = second.GetEnumerator(); + + while (secondEnumerator.MoveNext()) + { + if (!firstEnumerator.MoveNext() + || !Equals(firstEnumerator.Current, secondEnumerator.Current)) + { + return false; + } + } + + return true; + } + + public static int IndexOf(this IEnumerable source, T item) + => IndexOf(source, item, EqualityComparer.Default); + + public static int IndexOf( + this IEnumerable source, + T item, + IEqualityComparer comparer) + => source.Select( + (x, index) => + comparer.Equals(item, x) ? index : -1) + .FirstOr(x => x != -1, -1); + + public static T FirstOr(this IEnumerable source, T alternate) + => source.DefaultIfEmpty(alternate).First(); + + public static T FirstOr(this IEnumerable source, Func predicate, T alternate) + => source.Where(predicate).FirstOr(alternate); + + public static bool Any(this IEnumerable source) + { + foreach (var _ in source) + { + return true; + } + + return false; + } + + public static async Task> ToListAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + var list = new List(); + await foreach (var element in source.WithCancellation(cancellationToken)) + { + list.Add(element); + } + + return list; + } + + public static List ToList(this IEnumerable source) + => source.OfType().ToList(); + + public static string Format(this IEnumerable strings) + => "{" + + string.Join( + ", ", + strings.Select(s => "'" + s + "'")) + + "}"; +} diff --git a/src/net/KEFCore/Shared8/EnumerableMethods.cs b/src/net/KEFCore/Shared8/EnumerableMethods.cs new file mode 100644 index 00000000..9b49f693 --- /dev/null +++ b/src/net/KEFCore/Shared8/EnumerableMethods.cs @@ -0,0 +1,615 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; + +namespace Microsoft.EntityFrameworkCore; + +internal static class EnumerableMethods +{ + //public static MethodInfo AggregateWithoutSeed { get; } + + //public static MethodInfo AggregateWithSeedWithoutSelector { get; } + + public static MethodInfo AggregateWithSeedSelector { get; } + + public static MethodInfo All { get; } + + public static MethodInfo AnyWithoutPredicate { get; } + + public static MethodInfo AnyWithPredicate { get; } + + //public static Append { get; } + + public static MethodInfo AsEnumerable { get; } + + public static MethodInfo Cast { get; } + + public static MethodInfo Concat { get; } + + public static MethodInfo Contains { get; } + + //public static MethodInfo ContainsWithComparer { get; } + + public static MethodInfo CountWithoutPredicate { get; } + + public static MethodInfo CountWithPredicate { get; } + + public static MethodInfo DefaultIfEmptyWithoutArgument { get; } + + public static MethodInfo DefaultIfEmptyWithArgument { get; } + + public static MethodInfo Distinct { get; } + + //public static MethodInfo DistinctWithComparer { get; } + + public static MethodInfo ElementAt { get; } + + public static MethodInfo ElementAtOrDefault { get; } + + //public static MethodInfo Empty { get; } + + public static MethodInfo Except { get; } + + //public static MethodInfo ExceptWithComparer { get; } + + public static MethodInfo FirstWithoutPredicate { get; } + + public static MethodInfo FirstWithPredicate { get; } + + public static MethodInfo FirstOrDefaultWithoutPredicate { get; } + + public static MethodInfo FirstOrDefaultWithPredicate { get; } + + public static MethodInfo GroupByWithKeySelector { get; } + + public static MethodInfo GroupByWithKeyElementSelector { get; } + + //public static MethodInfo GroupByWithKeySelectorAndComparer { get; } + + //public static MethodInfo GroupByWithKeyElementSelectorAndComparer { get; } + + public static MethodInfo GroupByWithKeyElementResultSelector { get; } + + public static MethodInfo GroupByWithKeyResultSelector { get; } + + //public static MethodInfo GroupByWithKeyResultSelectorAndComparer { get; } + + //public static MethodInfo GroupByWithKeyElementResultSelectorAndComparer { get; } + + public static MethodInfo GroupJoin { get; } + + //public static MethodInfo GroupJoinWithComparer { get; } + + public static MethodInfo Intersect { get; } + + //public static MethodInfo IntersectWithComparer { get; } + + public static MethodInfo Join { get; } + + public static MethodInfo JoinWithComparer { get; } + + public static MethodInfo LastWithoutPredicate { get; } + + public static MethodInfo LastWithPredicate { get; } + + public static MethodInfo LastOrDefaultWithoutPredicate { get; } + + public static MethodInfo LastOrDefaultWithPredicate { get; } + + public static MethodInfo LongCountWithoutPredicate { get; } + + public static MethodInfo LongCountWithPredicate { get; } + + public static MethodInfo MaxWithoutSelector { get; } + + public static MethodInfo MaxWithSelector { get; } + + public static MethodInfo MinWithoutSelector { get; } + + public static MethodInfo MinWithSelector { get; } + + public static MethodInfo OfType { get; } + + public static MethodInfo OrderBy { get; } + + //public static MethodInfo OrderByWithComparer { get; } + + public static MethodInfo OrderByDescending { get; } + + //public static MethodInfo OrderByDescendingWithComparer { get; } + + //public static MethodInfo Prepend { get; } + + //public static MethodInfo Range { get; } + + //public static MethodInfo Repeat { get; } + + public static MethodInfo Reverse { get; } + + public static MethodInfo Select { get; } + + public static MethodInfo SelectWithOrdinal { get; } + + public static MethodInfo SelectManyWithoutCollectionSelector { get; } + + //public static MethodInfo SelectManyWithoutCollectionSelectorOrdinal { get; } + + public static MethodInfo SelectManyWithCollectionSelector { get; } + + //public static MethodInfo SelectManyWithCollectionSelectorOrdinal { get; } + + public static MethodInfo SequenceEqual { get; } + + //public static MethodInfo SequenceEqualWithComparer { get; } + + public static MethodInfo SingleWithoutPredicate { get; } + + public static MethodInfo SingleWithPredicate { get; } + + public static MethodInfo SingleOrDefaultWithoutPredicate { get; } + + public static MethodInfo SingleOrDefaultWithPredicate { get; } + + public static MethodInfo Skip { get; } + + public static MethodInfo SkipWhile { get; } + + //public static MethodInfo SkipWhileOrdinal { get; } + + public static MethodInfo Take { get; } + + public static MethodInfo TakeWhile { get; } + + //public static MethodInfo TakeWhileOrdinal { get; } + + public static MethodInfo ThenBy { get; } + + //public static MethodInfo ThenByWithComparer { get; } + + public static MethodInfo ThenByDescending { get; } + + //public static MethodInfo ThenByDescendingWithComparer { get; } + + public static MethodInfo ToArray { get; } + + //public static MethodInfo ToDictionaryWithKeySelector { get; } + //public static MethodInfo ToDictionaryWithKeySelectorAndComparer { get; } + //public static MethodInfo ToDictionaryWithKeyElementSelector { get; } + //public static MethodInfo ToDictionaryWithKeyElementSelectorAndComparer { get; } + + //public static MethodInfo ToHashSet { get; } + //public static MethodInfo ToHashSetWithComparer { get; } + + public static MethodInfo ToList { get; } + + //public static MethodInfo ToLookupWithKeySelector { get; } + //public static MethodInfo ToLookupWithKeySelectorAndComparer { get; } + //public static MethodInfo ToLookupWithKeyElementSelector { get; } + //public static MethodInfo ToLookupWithKeyElementSelectorAndComparer { get; } + + public static MethodInfo Union { get; } + + //public static MethodInfo UnionWithComparer { get; } + + public static MethodInfo Where { get; } + + //public static MethodInfo WhereOrdinal { get; } + + public static MethodInfo ZipWithSelector { get; } + + // private static Dictionary SumWithoutSelectorMethods { get; } + private static Dictionary SumWithSelectorMethods { get; } + + // private static Dictionary AverageWithoutSelectorMethods { get; } + private static Dictionary AverageWithSelectorMethods { get; } + private static Dictionary MaxWithoutSelectorMethods { get; } + private static Dictionary MaxWithSelectorMethods { get; } + private static Dictionary MinWithoutSelectorMethods { get; } + private static Dictionary MinWithSelectorMethods { get; } + + // Not currently used + // + // public static bool IsSumWithoutSelector(MethodInfo methodInfo) + // => SumWithoutSelectorMethods.Values.Contains(methodInfo); + // + // public static bool IsSumWithSelector(MethodInfo methodInfo) + // => methodInfo.IsGenericMethod + // && SumWithSelectorMethods.Values.Contains(methodInfo.GetGenericMethodDefinition()); + // + // public static bool IsAverageWithoutSelector(MethodInfo methodInfo) + // => AverageWithoutSelectorMethods.Values.Contains(methodInfo); + // + // public static bool IsAverageWithSelector(MethodInfo methodInfo) + // => methodInfo.IsGenericMethod + // && AverageWithSelectorMethods.Values.Contains(methodInfo.GetGenericMethodDefinition()); + // + // public static MethodInfo GetSumWithoutSelector(Type type) + // => SumWithoutSelectorMethods[type]; + + public static MethodInfo GetSumWithSelector(Type type) + => SumWithSelectorMethods[type]; + + // public static MethodInfo GetAverageWithoutSelector(Type type) + // => AverageWithoutSelectorMethods[type]; + + public static MethodInfo GetAverageWithSelector(Type type) + => AverageWithSelectorMethods[type]; + + public static MethodInfo GetMaxWithoutSelector(Type type) + => MaxWithoutSelectorMethods.TryGetValue(type, out var method) + ? method + : MaxWithoutSelector; + + public static MethodInfo GetMaxWithSelector(Type type) + => MaxWithSelectorMethods.TryGetValue(type, out var method) + ? method + : MaxWithSelector; + + public static MethodInfo GetMinWithoutSelector(Type type) + => MinWithoutSelectorMethods.TryGetValue(type, out var method) + ? method + : MinWithoutSelector; + + public static MethodInfo GetMinWithSelector(Type type) + => MinWithSelectorMethods.TryGetValue(type, out var method) + ? method + : MinWithSelector; + + static EnumerableMethods() + { + var queryableMethodGroups = typeof(Enumerable) + .GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) + .GroupBy(mi => mi.Name) + .ToDictionary(e => e.Key, l => l.ToList()); + + AggregateWithSeedSelector = GetMethod( + nameof(Enumerable.Aggregate), 3, + types => new[] + { + typeof(IEnumerable<>).MakeGenericType(types[0]), + types[1], + typeof(Func<,,>).MakeGenericType(types[1], types[0], types[1]), + typeof(Func<,>).MakeGenericType(types[1], types[2]) + }); + + All = GetMethod( + nameof(Enumerable.All), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) }); + + AnyWithoutPredicate = GetMethod( + nameof(Enumerable.Any), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + AnyWithPredicate = GetMethod( + nameof(Enumerable.Any), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) }); + + AsEnumerable = GetMethod( + nameof(Enumerable.AsEnumerable), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + Cast = GetMethod(nameof(Enumerable.Cast), 1, _ => new[] { typeof(IEnumerable) }); + + Concat = GetMethod( + nameof(Enumerable.Concat), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + Contains = GetMethod( + nameof(Enumerable.Contains), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), types[0] }); + + CountWithoutPredicate = GetMethod( + nameof(Enumerable.Count), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + CountWithPredicate = GetMethod( + nameof(Enumerable.Count), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) }); + + DefaultIfEmptyWithoutArgument = GetMethod( + nameof(Enumerable.DefaultIfEmpty), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + DefaultIfEmptyWithArgument = GetMethod( + nameof(Enumerable.DefaultIfEmpty), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), types[0] }); + + Distinct = GetMethod(nameof(Enumerable.Distinct), 1, types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + ElementAt = GetMethod( + nameof(Enumerable.ElementAt), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(int) }); + + ElementAtOrDefault = GetMethod( + nameof(Enumerable.ElementAtOrDefault), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(int) }); + + Except = GetMethod( + nameof(Enumerable.Except), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + FirstWithoutPredicate = GetMethod( + nameof(Enumerable.First), 1, types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + FirstWithPredicate = GetMethod( + nameof(Enumerable.First), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) }); + + FirstOrDefaultWithoutPredicate = GetMethod( + nameof(Enumerable.FirstOrDefault), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + FirstOrDefaultWithPredicate = GetMethod( + nameof(Enumerable.FirstOrDefault), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) }); + + GroupByWithKeySelector = GetMethod( + nameof(Enumerable.GroupBy), 2, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1]) }); + + GroupByWithKeyElementSelector = GetMethod( + nameof(Enumerable.GroupBy), 3, + types => new[] + { + typeof(IEnumerable<>).MakeGenericType(types[0]), + typeof(Func<,>).MakeGenericType(types[0], types[1]), + typeof(Func<,>).MakeGenericType(types[0], types[2]) + }); + + GroupByWithKeyElementResultSelector = GetMethod( + nameof(Enumerable.GroupBy), 4, + types => new[] + { + typeof(IEnumerable<>).MakeGenericType(types[0]), + typeof(Func<,>).MakeGenericType(types[0], types[1]), + typeof(Func<,>).MakeGenericType(types[0], types[2]), + typeof(Func<,,>).MakeGenericType( + types[1], typeof(IEnumerable<>).MakeGenericType(types[2]), types[3]) + }); + + GroupByWithKeyResultSelector = GetMethod( + nameof(Enumerable.GroupBy), 3, + types => new[] + { + typeof(IEnumerable<>).MakeGenericType(types[0]), + typeof(Func<,>).MakeGenericType(types[0], types[1]), + typeof(Func<,,>).MakeGenericType( + types[1], typeof(IEnumerable<>).MakeGenericType(types[0]), types[2]) + }); + + GroupJoin = GetMethod( + nameof(Enumerable.GroupJoin), 4, + types => new[] + { + typeof(IEnumerable<>).MakeGenericType(types[0]), + typeof(IEnumerable<>).MakeGenericType(types[1]), + typeof(Func<,>).MakeGenericType(types[0], types[2]), + typeof(Func<,>).MakeGenericType(types[1], types[2]), + typeof(Func<,,>).MakeGenericType( + types[0], typeof(IEnumerable<>).MakeGenericType(types[1]), types[3]) + }); + + Intersect = GetMethod( + nameof(Enumerable.Intersect), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + Join = GetMethod( + nameof(Enumerable.Join), 4, + types => new[] + { + typeof(IEnumerable<>).MakeGenericType(types[0]), + typeof(IEnumerable<>).MakeGenericType(types[1]), + typeof(Func<,>).MakeGenericType(types[0], types[2]), + typeof(Func<,>).MakeGenericType(types[1], types[2]), + typeof(Func<,,>).MakeGenericType(types[0], types[1], types[3]) + }); + + JoinWithComparer = GetMethod( + nameof(Enumerable.Join), 4, + types => new[] + { + typeof(IEnumerable<>).MakeGenericType(types[0]), + typeof(IEnumerable<>).MakeGenericType(types[1]), + typeof(Func<,>).MakeGenericType(types[0], types[2]), + typeof(Func<,>).MakeGenericType(types[1], types[2]), + typeof(Func<,,>).MakeGenericType(types[0], types[1], types[3]), + typeof(IEqualityComparer<>).MakeGenericType(types[2]) + }); + + LastWithoutPredicate = GetMethod( + nameof(Enumerable.Last), 1, types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + LastWithPredicate = GetMethod( + nameof(Enumerable.Last), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) }); + + LastOrDefaultWithoutPredicate = GetMethod( + nameof(Enumerable.LastOrDefault), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + LastOrDefaultWithPredicate = GetMethod( + nameof(Enumerable.LastOrDefault), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) }); + + LongCountWithoutPredicate = GetMethod( + nameof(Enumerable.LongCount), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + LongCountWithPredicate = GetMethod( + nameof(Enumerable.LongCount), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) }); + + MaxWithoutSelector = GetMethod(nameof(Enumerable.Max), 1, types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + MaxWithSelector = GetMethod( + nameof(Enumerable.Max), 2, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1]) }); + + MinWithoutSelector = GetMethod(nameof(Enumerable.Min), 1, types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + MinWithSelector = GetMethod( + nameof(Enumerable.Min), 2, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1]) }); + + OfType = GetMethod(nameof(Enumerable.OfType), 1, _ => new[] { typeof(IEnumerable) }); + + OrderBy = GetMethod( + nameof(Enumerable.OrderBy), 2, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1]) }); + + OrderByDescending = GetMethod( + nameof(Enumerable.OrderByDescending), 2, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1]) }); + + Reverse = GetMethod(nameof(Enumerable.Reverse), 1, types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + Select = GetMethod( + nameof(Enumerable.Select), 2, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1]) }); + + SelectWithOrdinal = GetMethod( + nameof(Enumerable.Select), 2, + types => new[] + { + typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,,>).MakeGenericType(types[0], typeof(int), types[1]) + }); + + SelectManyWithoutCollectionSelector = GetMethod( + nameof(Enumerable.SelectMany), 2, + types => new[] + { + typeof(IEnumerable<>).MakeGenericType(types[0]), + typeof(Func<,>).MakeGenericType( + types[0], typeof(IEnumerable<>).MakeGenericType(types[1])) + }); + + SelectManyWithCollectionSelector = GetMethod( + nameof(Enumerable.SelectMany), 3, + types => new[] + { + typeof(IEnumerable<>).MakeGenericType(types[0]), + typeof(Func<,>).MakeGenericType( + types[0], typeof(IEnumerable<>).MakeGenericType(types[1])), + typeof(Func<,,>).MakeGenericType(types[0], types[1], types[2]) + }); + + SequenceEqual = GetMethod( + nameof(Enumerable.SequenceEqual), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + SingleWithoutPredicate = GetMethod( + nameof(Enumerable.Single), 1, types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + SingleWithPredicate = GetMethod( + nameof(Enumerable.Single), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) }); + + SingleOrDefaultWithoutPredicate = GetMethod( + nameof(Enumerable.SingleOrDefault), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + SingleOrDefaultWithPredicate = GetMethod( + nameof(Enumerable.SingleOrDefault), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) }); + + Skip = GetMethod( + nameof(Enumerable.Skip), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(int) }); + + SkipWhile = GetMethod( + nameof(Enumerable.SkipWhile), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) }); + + ToArray = GetMethod(nameof(Enumerable.ToArray), 1, types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + ToList = GetMethod(nameof(Enumerable.ToList), 1, types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + Take = GetMethod( + nameof(Enumerable.Take), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(int) }); + + TakeWhile = GetMethod( + nameof(Enumerable.TakeWhile), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) }); + + ThenBy = GetMethod( + nameof(Enumerable.ThenBy), 2, + types => new[] { typeof(IOrderedEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1]) }); + + ThenByDescending = GetMethod( + nameof(Enumerable.ThenByDescending), 2, + types => new[] { typeof(IOrderedEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], types[1]) }); + + Union = GetMethod( + nameof(Enumerable.Union), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(IEnumerable<>).MakeGenericType(types[0]) }); + + Where = GetMethod( + nameof(Enumerable.Where), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], typeof(bool)) }); + + ZipWithSelector = GetMethod( + nameof(Enumerable.Zip), 3, + types => new[] + { + typeof(IEnumerable<>).MakeGenericType(types[0]), + typeof(IEnumerable<>).MakeGenericType(types[1]), + typeof(Func<,,>).MakeGenericType(types[0], types[1], types[2]) + }); + + var numericTypes = new[] + { + typeof(int), + typeof(int?), + typeof(long), + typeof(long?), + typeof(float), + typeof(float?), + typeof(double), + typeof(double?), + typeof(decimal), + typeof(decimal?) + }; + + // AverageWithoutSelectorMethods = new Dictionary(); + AverageWithSelectorMethods = new Dictionary(); + MaxWithoutSelectorMethods = new Dictionary(); + MaxWithSelectorMethods = new Dictionary(); + MinWithoutSelectorMethods = new Dictionary(); + MinWithSelectorMethods = new Dictionary(); + // SumWithoutSelectorMethods = new Dictionary(); + SumWithSelectorMethods = new Dictionary(); + + foreach (var type in numericTypes) + { + // AverageWithoutSelectorMethods[type] = GetMethod( + // nameof(Enumerable.Average), 0, types => new[] { typeof(IEnumerable<>).MakeGenericType(type) }); + AverageWithSelectorMethods[type] = GetMethod( + nameof(Enumerable.Average), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], type) }); + MaxWithoutSelectorMethods[type] = GetMethod( + nameof(Enumerable.Max), 0, _ => new[] { typeof(IEnumerable<>).MakeGenericType(type) }); + MaxWithSelectorMethods[type] = GetMethod( + nameof(Enumerable.Max), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], type) }); + MinWithoutSelectorMethods[type] = GetMethod( + nameof(Enumerable.Min), 0, _ => new[] { typeof(IEnumerable<>).MakeGenericType(type) }); + MinWithSelectorMethods[type] = GetMethod( + nameof(Enumerable.Min), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], type) }); + // SumWithoutSelectorMethods[type] = GetMethod( + // nameof(Enumerable.Sum), 0, types => new[] { typeof(IEnumerable<>).MakeGenericType(type) }); + SumWithSelectorMethods[type] = GetMethod( + nameof(Enumerable.Sum), 1, + types => new[] { typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(Func<,>).MakeGenericType(types[0], type) }); + } + + MethodInfo GetMethod(string name, int genericParameterCount, Func parameterGenerator) + => queryableMethodGroups[name].Single( + mi => ((genericParameterCount == 0 && !mi.IsGenericMethod) + || (mi.IsGenericMethod && mi.GetGenericArguments().Length == genericParameterCount)) + && mi.GetParameters().Select(e => e.ParameterType).SequenceEqual( + parameterGenerator(mi.IsGenericMethod ? mi.GetGenericArguments() : Array.Empty()))); + } +} diff --git a/src/net/KEFCore/Shared8/ExpressionExtensions.cs b/src/net/KEFCore/Shared8/ExpressionExtensions.cs new file mode 100644 index 00000000..0835bd7a --- /dev/null +++ b/src/net/KEFCore/Shared8/ExpressionExtensions.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +using System.Diagnostics.CodeAnalysis; + +// ReSharper disable once CheckNamespace +namespace System.Linq.Expressions; + +[DebuggerStepThrough] +internal static class ExpressionExtensions +{ + public static bool IsNullConstantExpression(this Expression expression) + => RemoveConvert(expression) is ConstantExpression { Value: null }; + + public static LambdaExpression UnwrapLambdaFromQuote(this Expression expression) + => (LambdaExpression)(expression is UnaryExpression unary && expression.NodeType == ExpressionType.Quote + ? unary.Operand + : expression); + + [return: NotNullIfNotNull("expression")] + public static Expression? UnwrapTypeConversion(this Expression? expression, out Type? convertedType) + { + convertedType = null; + while (expression is UnaryExpression + { + NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs + } unaryExpression) + { + expression = unaryExpression.Operand; + if (unaryExpression.Type != typeof(object) // Ignore object conversion + && !unaryExpression.Type.IsAssignableFrom(expression.Type)) // Ignore casting to base type/interface + { + convertedType = unaryExpression.Type; + } + } + + return expression; + } + + private static Expression RemoveConvert(Expression expression) + => expression is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unaryExpression + ? RemoveConvert(unaryExpression.Operand) + : expression; + + public static T GetConstantValue(this Expression expression) + => expression is ConstantExpression constantExpression + ? (T)constantExpression.Value! + : throw new InvalidOperationException(); +} diff --git a/src/net/KEFCore/Shared8/ExpressionVisitorExtensions.cs b/src/net/KEFCore/Shared8/ExpressionVisitorExtensions.cs new file mode 100644 index 00000000..5ca301ad --- /dev/null +++ b/src/net/KEFCore/Shared8/ExpressionVisitorExtensions.cs @@ -0,0 +1,131 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// ReSharper disable once CheckNamespace + +using System.Runtime.CompilerServices; + +namespace System.Linq.Expressions; + +#nullable enable + +[DebuggerStepThrough] +internal static class ExpressionVisitorExtensions +{ + /// + /// Dispatches the list of expressions to one of the more specialized visit methods in this class. + /// + /// The expression visitor. + /// The expressions to visit. + /// + /// The modified expression list, if any of the elements were modified; otherwise, returns the original expression list. + /// + public static IReadOnlyList Visit(this ExpressionVisitor visitor, IReadOnlyList nodes) + { + Expression[]? newNodes = null; + for (int i = 0, n = nodes.Count; i < n; i++) + { + var node = visitor.Visit(nodes[i]); + + if (newNodes is not null) + { + newNodes[i] = node; + } + else if (!ReferenceEquals(node, nodes[i])) + { + newNodes = new Expression[n]; + for (var j = 0; j < i; j++) + { + newNodes[j] = nodes[j]; + } + + newNodes[i] = node; + } + } + + return newNodes ?? nodes; + } + + /// + /// Visits an expression, casting the result back to the original expression type. + /// + /// The type of the expression. + /// The expression visitor. + /// The expression to visit. + /// The name of the calling method; used to report to report a better error message. + /// + /// The modified expression, if it or any subexpression was modified; otherwise, returns the original expression. + /// + /// The visit method for this node returned a different type. + public static IReadOnlyList VisitAndConvert( + this ExpressionVisitor visitor, + IReadOnlyList nodes, + [CallerMemberName] string? callerName = null) + where T : Expression + { + T[]? newNodes = null; + for (int i = 0, n = nodes.Count; i < n; i++) + { + if (visitor.Visit(nodes[i]) is not T node) + { + throw new InvalidOperationException(CoreStrings.MustRewriteToSameNode(callerName, typeof(T).Name)); + } + + if (newNodes is not null) + { + newNodes[i] = node; + } + else if (!ReferenceEquals(node, nodes[i])) + { + newNodes = new T[n]; + for (var j = 0; j < i; j++) + { + newNodes[j] = nodes[j]; + } + + newNodes[i] = node; + } + } + + return newNodes ?? nodes; + } + + /// + /// Visits all nodes in the collection using a specified element visitor. + /// + /// The type of the nodes. + /// The expression visitor. + /// The nodes to visit. + /// + /// A delegate that visits a single element, + /// optionally replacing it with a new element. + /// + /// + /// The modified node list, if any of the elements were modified; + /// otherwise, returns the original node list. + /// + public static IReadOnlyList Visit(this ExpressionVisitor visitor, IReadOnlyList nodes, Func elementVisitor) + { + T[]? newNodes = null; + for (int i = 0, n = nodes.Count; i < n; i++) + { + var node = elementVisitor(nodes[i]); + if (newNodes is not null) + { + newNodes[i] = node; + } + else if (!ReferenceEquals(node, nodes[i])) + { + newNodes = new T[n]; + for (var j = 0; j < i; j++) + { + newNodes[j] = nodes[j]; + } + + newNodes[i] = node; + } + } + + return newNodes ?? nodes; + } +} diff --git a/src/net/KEFCore/Shared8/Graph.cs b/src/net/KEFCore/Shared8/Graph.cs new file mode 100644 index 00000000..0fc0d1ae --- /dev/null +++ b/src/net/KEFCore/Shared8/Graph.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Utilities; + +internal abstract class Graph +{ + public abstract IEnumerable Vertices { get; } + + public abstract void Clear(); + + public abstract IEnumerable GetOutgoingNeighbors(TVertex from); + + public abstract IEnumerable GetIncomingNeighbors(TVertex to); + + public ISet GetUnreachableVertices(IReadOnlyList roots) + { + var unreachableVertices = new HashSet(Vertices); + unreachableVertices.ExceptWith(roots); + var visitingQueue = new List(roots); + + var currentVertexIndex = 0; + while (currentVertexIndex < visitingQueue.Count) + { + var currentVertex = visitingQueue[currentVertexIndex]; + currentVertexIndex++; + // ReSharper disable once LoopCanBeConvertedToQuery + foreach (var neighbor in GetOutgoingNeighbors(currentVertex)) + { + if (unreachableVertices.Remove(neighbor)) + { + visitingQueue.Add(neighbor); + } + } + } + + return unreachableVertices; + } +} diff --git a/src/net/KEFCore/Shared8/HashHelpers.cs b/src/net/KEFCore/Shared8/HashHelpers.cs new file mode 100644 index 00000000..f9055d62 --- /dev/null +++ b/src/net/KEFCore/Shared8/HashHelpers.cs @@ -0,0 +1,117 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +namespace Microsoft.EntityFrameworkCore.Utilities +{ + internal static partial class HashHelpers + { + internal static int PowerOf2(int v) + { + if ((v & (v - 1)) == 0) + { + return v; + } + + var i = 2; + while (i < v) + { + i <<= 1; + } + + return i; + } + + // must never be written to + internal static readonly int[] SizeOneIntArray = new int[1]; + + public const int HashCollisionThreshold = 100; + + // This is the maximum prime smaller than Array.MaxArrayLength + public const int MaxPrimeArrayLength = 0x7FEFFFFD; + + public const int HashPrime = 101; + + // Table of prime numbers to use as hash table sizes. + // A typical resize algorithm would pick the smallest prime number in this array + // that is larger than twice the previous capacity. + // Suppose our Hashtable currently has capacity x and enough elements are added + // such that a resize needs to occur. Resizing first computes 2x then finds the + // first prime in the table greater than 2x, i.e. if primes are ordered + // p_1, p_2, ..., p_i, ..., it finds p_n such that p_n-1 < 2x < p_n. + // Doubling is important for preserving the asymptotic complexity of the + // hashtable operations such as add. Having a prime guarantees that double + // hashing does not lead to infinite loops. IE, your hash function will be + // h1(key) + i*h2(key), 0 <= i < size. h2 and the size must be relatively prime. + // We prefer the low computation costs of higher prime numbers over the increased + // memory allocation of a fixed prime number i.e. when right sizing a HashSet. + public static readonly int[] primes = { + 3, 7, 11, 17, 23, 29, 37, 47, 59, 71, 89, 107, 131, 163, 197, 239, 293, 353, 431, 521, 631, 761, 919, + 1103, 1327, 1597, 1931, 2333, 2801, 3371, 4049, 4861, 5839, 7013, 8419, 10103, 12143, 14591, + 17519, 21023, 25229, 30293, 36353, 43627, 52361, 62851, 75431, 90523, 108631, 130363, 156437, + 187751, 225307, 270371, 324449, 389357, 467237, 560689, 672827, 807403, 968897, 1162687, 1395263, + 1674319, 2009191, 2411033, 2893249, 3471899, 4166287, 4999559, 5999471, 7199369 }; + + public static bool IsPrime(int candidate) + { + if ((candidate & 1) != 0) + { + var limit = (int)Math.Sqrt(candidate); + for (var divisor = 3; divisor <= limit; divisor += 2) + { + if ((candidate % divisor) == 0) + { + return false; + } + } + return true; + } + return candidate == 2; + } + + public static int GetPrime(int min) + { + if (min < 0) + { + throw new ArgumentException("Hashtable's capacity overflowed and went negative. Check load factor, capacity and the current size of the table."); + } + + for (var i = 0; i < primes.Length; i++) + { + var prime = primes[i]; + if (prime >= min) + { + return prime; + } + } + + //outside of our predefined table. + //compute the hard way. + for (var i = (min | 1); i < int.MaxValue; i += 2) + { + if (IsPrime(i) && ((i - 1) % HashPrime != 0)) + { + return i; + } + } + return min; + } + + // Returns size of hashtable to grow to. + public static int ExpandPrime(int oldSize) + { + var newSize = 2 * oldSize; + + // Allow the hashtables to grow to maximum possible size (~2G elements) before encountering capacity overflow. + // Note that this check works even when _items.Length overflowed thanks to the (uint) cast + if ((uint)newSize > MaxPrimeArrayLength && MaxPrimeArrayLength > oldSize) + { + Debug.Assert(MaxPrimeArrayLength == GetPrime(MaxPrimeArrayLength), "Invalid MaxPrimeArrayLength"); + return MaxPrimeArrayLength; + } + + return GetPrime(newSize); + } + } +} diff --git a/src/net/KEFCore/Shared8/IDictionaryDebugView.cs b/src/net/KEFCore/Shared8/IDictionaryDebugView.cs new file mode 100644 index 00000000..4c995f84 --- /dev/null +++ b/src/net/KEFCore/Shared8/IDictionaryDebugView.cs @@ -0,0 +1,70 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +namespace Microsoft.EntityFrameworkCore.Utilities +{ + internal sealed class IDictionaryDebugView + { + private readonly IDictionary _dict; + + public IDictionaryDebugView(IDictionary dictionary) + { + _dict = dictionary ?? throw new ArgumentNullException(nameof(dictionary)); + } + + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public KeyValuePair[] Items + { + get + { + var items = new KeyValuePair[_dict.Count]; + _dict.CopyTo(items, 0); + return items; + } + } + } + + internal sealed class DictionaryKeyCollectionDebugView + { + private readonly ICollection _collection; + + public DictionaryKeyCollectionDebugView(ICollection collection) + { + _collection = collection ?? throw new ArgumentNullException(nameof(collection)); + } + + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public TKey[] Items + { + get + { + var items = new TKey[_collection.Count]; + _collection.CopyTo(items, 0); + return items; + } + } + } + + internal sealed class DictionaryValueCollectionDebugView + { + private readonly ICollection _collection; + + public DictionaryValueCollectionDebugView(ICollection collection) + { + _collection = collection ?? throw new ArgumentNullException(nameof(collection)); + } + + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public TValue[] Items + { + get + { + var items = new TValue[_collection.Count]; + _collection.CopyTo(items, 0); + return items; + } + } + } +} diff --git a/src/net/KEFCore/Shared8/MemberInfoExtensions.cs b/src/net/KEFCore/Shared8/MemberInfoExtensions.cs new file mode 100644 index 00000000..06564f5c --- /dev/null +++ b/src/net/KEFCore/Shared8/MemberInfoExtensions.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +namespace System.Reflection; + +internal static class EntityFrameworkMemberInfoExtensions +{ + public static Type GetMemberType(this MemberInfo memberInfo) + => (memberInfo as PropertyInfo)?.PropertyType ?? ((FieldInfo)memberInfo).FieldType; + + public static bool IsSameAs(this MemberInfo? propertyInfo, MemberInfo? otherPropertyInfo) + => propertyInfo == null + ? otherPropertyInfo == null + : (otherPropertyInfo != null + && (Equals(propertyInfo, otherPropertyInfo) + || (propertyInfo.Name == otherPropertyInfo.Name + && propertyInfo.DeclaringType != null + && otherPropertyInfo.DeclaringType != null + && (propertyInfo.DeclaringType == otherPropertyInfo.DeclaringType + || propertyInfo.DeclaringType.GetTypeInfo().IsSubclassOf(otherPropertyInfo.DeclaringType) + || otherPropertyInfo.DeclaringType.GetTypeInfo().IsSubclassOf(propertyInfo.DeclaringType) + || propertyInfo.DeclaringType.GetTypeInfo().ImplementedInterfaces.Contains(otherPropertyInfo.DeclaringType) + || otherPropertyInfo.DeclaringType.GetTypeInfo().ImplementedInterfaces + .Contains(propertyInfo.DeclaringType))))); + + public static bool IsOverriddenBy(this MemberInfo? propertyInfo, MemberInfo? otherPropertyInfo) + => propertyInfo == null + ? otherPropertyInfo == null + : (otherPropertyInfo != null + && (Equals(propertyInfo, otherPropertyInfo) + || (propertyInfo.Name == otherPropertyInfo.Name + && propertyInfo.DeclaringType != null + && otherPropertyInfo.DeclaringType != null + && (propertyInfo.DeclaringType == otherPropertyInfo.DeclaringType + || otherPropertyInfo.DeclaringType.GetTypeInfo().IsSubclassOf(propertyInfo.DeclaringType) + || otherPropertyInfo.DeclaringType.GetTypeInfo().ImplementedInterfaces + .Contains(propertyInfo.DeclaringType))))); + + public static string GetSimpleMemberName(this MemberInfo member) + { + var name = member.Name; + var index = name.LastIndexOf('.'); + return index >= 0 ? name[(index + 1)..] : name; + } + + public static bool IsReallyVirtual(this MethodInfo method) + => method is { IsVirtual: true, IsFinal: false }; +} diff --git a/src/net/KEFCore/Shared8/MethodInfoExtensions.cs b/src/net/KEFCore/Shared8/MethodInfoExtensions.cs new file mode 100644 index 00000000..1d86c668 --- /dev/null +++ b/src/net/KEFCore/Shared8/MethodInfoExtensions.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Collections.Immutable; + +namespace System.Reflection; + +internal static class MethodInfoExtensions +{ + public static bool IsContainsMethod(this MethodInfo method) + => method is { Name: nameof(IList.Contains), DeclaringType: not null } + && method.DeclaringType.GetInterfaces().Append(method.DeclaringType).Any( + t => t == typeof(IList) + || (t.IsGenericType + && t.GetGenericTypeDefinition() is Type genericType + && (genericType == typeof(ICollection<>) + || genericType == typeof(IReadOnlySet<>) + || genericType == typeof(IImmutableSet<>)))); +} diff --git a/src/net/KEFCore/Shared8/Multigraph.cs b/src/net/KEFCore/Shared8/Multigraph.cs new file mode 100644 index 00000000..ff803950 --- /dev/null +++ b/src/net/KEFCore/Shared8/Multigraph.cs @@ -0,0 +1,386 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +namespace Microsoft.EntityFrameworkCore.Utilities; + +internal class Multigraph : Graph + where TVertex : notnull +{ + private readonly IComparer? _secondarySortComparer; + private readonly HashSet _vertices = new(); + private readonly Dictionary> _successorMap = new(); + private readonly Dictionary> _predecessorMap = new(); + + public Multigraph() + { + } + + public Multigraph(IComparer secondarySortComparer) + { + _secondarySortComparer = secondarySortComparer; + } + + public Multigraph(Comparison secondarySortComparer) + : this(Comparer.Create(secondarySortComparer)) + { + } + + public IEnumerable GetEdges(TVertex from, TVertex to) + { + if (_successorMap.TryGetValue(from, out var successorSet)) + { + if (successorSet.TryGetValue(to, out var edges)) + { + return edges is IEnumerable edgeList ? edgeList.Select(e => e.Payload) : (new[] { ((Edge)edges!).Payload }); + } + } + + return Enumerable.Empty(); + } + + public void AddVertex(TVertex vertex) + => _vertices.Add(vertex); + + public void AddVertices(IEnumerable vertices) + => _vertices.UnionWith(vertices); + + public void AddEdge(TVertex from, TVertex to, TEdge payload, bool requiresBatchingBoundary = false) + { +#if DEBUG + if (!_vertices.Contains(from)) + { + throw new InvalidOperationException(CoreStrings.GraphDoesNotContainVertex(from)); + } + + if (!_vertices.Contains(to)) + { + throw new InvalidOperationException(CoreStrings.GraphDoesNotContainVertex(to)); + } +#endif + + var edge = new Edge(payload, requiresBatchingBoundary); + + if (!_successorMap.TryGetValue(from, out var successorEdges)) + { + successorEdges = new Dictionary(); + _successorMap.Add(from, successorEdges); + } + + if (successorEdges.TryGetValue(to, out var edges)) + { + if (edges is not List edgeList) + { + edgeList = new List { (Edge)edges! }; + successorEdges[to] = edgeList; + } + + edgeList.Add(edge); + } + else + { + successorEdges.Add(to, edge); + } + + if (!_predecessorMap.TryGetValue(to, out var predecessorEdges)) + { + predecessorEdges = new Dictionary(); + _predecessorMap.Add(to, predecessorEdges); + } + + if (predecessorEdges.TryGetValue(from, out edges)) + { + if (edges is not List edgeList) + { + edgeList = new List { (Edge)edges! }; + predecessorEdges[from] = edgeList; + } + + edgeList.Add(edge); + } + else + { + predecessorEdges.Add(from, edge); + } + } + + public override void Clear() + { + _vertices.Clear(); + _successorMap.Clear(); + _predecessorMap.Clear(); + } + + public IReadOnlyList TopologicalSort() + => TopologicalSort(null, null); + + public IReadOnlyList TopologicalSort( + Func, bool> tryBreakEdge) + => TopologicalSort(tryBreakEdge, null); + + public IReadOnlyList TopologicalSort( + Func>>, string> formatCycle) + => TopologicalSort(null, formatCycle); + + public IReadOnlyList TopologicalSort( + Func, bool>? tryBreakEdge, + Func>>, string>? formatCycle, + Func? formatException = null) + { + var batches = TopologicalSortCore(withBatching: false, tryBreakEdge, formatCycle, formatException); + + Check.DebugAssert(batches.Count < 2, "TopologicalSortCore did batching but withBatching was false"); + + return batches.Count == 1 + ? batches[0] + : Array.Empty(); + } + + protected virtual string? ToString(TVertex vertex) + => vertex.ToString(); + + public IReadOnlyList> BatchingTopologicalSort() + => BatchingTopologicalSort(null, null); + + public IReadOnlyList> BatchingTopologicalSort( + Func, bool>? canBreakEdges, + Func>>, string>? formatCycle, + Func? formatException = null) + => TopologicalSortCore(withBatching: true, canBreakEdges, formatCycle, formatException); + + private IReadOnlyList> TopologicalSortCore( + bool withBatching, + Func, bool>? canBreakEdges, + Func>>, string>? formatCycle, + Func? formatException = null) + { + // Performs a breadth-first topological sort (Kahn's algorithm) + var result = new List>(); + var currentRootsQueue = new List(); + var nextRootsQueue = new List(); + var vertexesProcessed = 0; + var batchBoundaryRequired = false; + var currentBatch = new List(); + var currentBatchSet = new HashSet(); + + var predecessorCounts = new Dictionary(_predecessorMap.Count); + foreach (var (vertex, vertices) in _predecessorMap) + { + predecessorCounts[vertex] = vertices.Count; + } + + // Bootstrap the topological sort by finding all vertexes which have no predecessors + foreach (var vertex in _vertices) + { + if (!predecessorCounts.ContainsKey(vertex)) + { + currentRootsQueue.Add(vertex); + } + } + + result.Add(currentBatch); + + while (vertexesProcessed < _vertices.Count) + { + while (currentRootsQueue.Count > 0) + { + // Secondary sorting: after the first topological sorting (according to dependencies between the commands as expressed in + // the graph), we apply an optional secondary sort. + // When sorting modification commands, this ensures a deterministic ordering and prevents deadlocks between concurrent + // transactions locking the same rows in different orders. + if (_secondarySortComparer is not null) + { + currentRootsQueue.Sort(_secondarySortComparer); + } + + // If we detected in the last roots pass that a batch boundary is required, close the current batch and start a new one. + if (batchBoundaryRequired) + { + currentBatch = new List(); + result.Add(currentBatch); + currentBatchSet.Clear(); + + batchBoundaryRequired = false; + } + + foreach (var currentRoot in currentRootsQueue) + { + currentBatch.Add(currentRoot); + currentBatchSet.Add(currentRoot); + vertexesProcessed++; + + foreach (var successor in GetOutgoingNeighbors(currentRoot)) + { + predecessorCounts[successor]--; + + // If the successor has no other predecessors, add it for processing in the next roots pass. + if (predecessorCounts[successor] == 0) + { + nextRootsQueue.Add(successor); + CheckBatchingBoundary(successor); + } + } + } + + // Finished passing over the current roots, move on to the next set. + (currentRootsQueue, nextRootsQueue) = (nextRootsQueue, currentRootsQueue); + nextRootsQueue.Clear(); + } + + // We have no more roots to process. That either means we're done, or that there's a cycle which we need to break + if (vertexesProcessed < _vertices.Count) + { + var broken = false; + + var candidateVertices = predecessorCounts.Keys.ToList(); + var candidateIndex = 0; + + while ((candidateIndex < candidateVertices.Count) + && !broken + && canBreakEdges != null) + { + var candidateVertex = candidateVertices[candidateIndex]; + if (predecessorCounts[candidateVertex] == 0) + { + candidateIndex++; + continue; + } + + // Find a vertex in the unsorted portion of the graph that has edges to the candidate + var incomingNeighbor = GetIncomingNeighbors(candidateVertex) + .First( + neighbor => predecessorCounts.TryGetValue(neighbor, out var neighborPredecessors) + && neighborPredecessors > 0); + + if (canBreakEdges(incomingNeighbor, candidateVertex, GetEdges(incomingNeighbor, candidateVertex))) + { + var removed = _successorMap[incomingNeighbor].Remove(candidateVertex); + Check.DebugAssert(removed, "Candidate vertex not found in successor map"); + removed = _predecessorMap[candidateVertex].Remove(incomingNeighbor); + Check.DebugAssert(removed, "Incoming neighbor not found in predecessor map"); + + predecessorCounts[candidateVertex]--; + if (predecessorCounts[candidateVertex] == 0) + { + currentRootsQueue.Add(candidateVertex); + CheckBatchingBoundary(candidateVertex); + broken = true; + } + + continue; + } + + candidateIndex++; + } + + if (broken) + { + continue; + } + + var currentCycleVertex = _vertices.First( + v => predecessorCounts.TryGetValue(v, out var predecessorCount) && predecessorCount != 0); + var cycle = new List { currentCycleVertex }; + var finished = false; + while (!finished) + { + foreach (var predecessor in GetIncomingNeighbors(currentCycleVertex)) + { + if (!predecessorCounts.TryGetValue(predecessor, out var predecessorCount) + || predecessorCount == 0) + { + continue; + } + + predecessorCounts[currentCycleVertex] = -1; + + currentCycleVertex = predecessor; + cycle.Add(currentCycleVertex); + finished = predecessorCounts[predecessor] == -1; + break; + } + } + + cycle.Reverse(); + + // Remove any tail that's not part of the cycle + var startingVertex = cycle[0]; + for (var i = cycle.Count - 1; i >= 0; i--) + { + if (cycle[i].Equals(startingVertex)) + { + break; + } + + cycle.RemoveAt(i); + } + + ThrowCycle(cycle, formatCycle, formatException); + } + } + + return result; + + // Detect batch boundary (if batching is enabled). + // If the successor has any predecessor where the edge requires a batching boundary, and that predecessor is + // already in the current batch, then the next batch will have to be executed in a separate batch. + // TODO: Optimization: Instead of currentBatchSet, store a batch counter on each vertex, and check if later + // vertexes have a boundary-requiring dependency on a vertex with the same batch counter. + void CheckBatchingBoundary(TVertex vertex) + { + if (withBatching + && _predecessorMap[vertex].Any( + kv => + (kv.Value is Edge { RequiresBatchingBoundary: true } + || kv.Value is IEnumerable edges && edges.Any(e => e.RequiresBatchingBoundary)) + && currentBatchSet.Contains(kv.Key))) + { + batchBoundaryRequired = true; + } + } + } + + private void ThrowCycle( + List cycle, + Func>>, string>? formatCycle, + Func? formatException = null) + { + string cycleString; + if (formatCycle == null) + { + cycleString = cycle.Select(e => ToString(e)!).Join(" ->" + Environment.NewLine); + } + else + { + var currentCycleVertex = cycle.First(); + var cycleData = new List>>(); + + foreach (var vertex in cycle.Skip(1)) + { + cycleData.Add(Tuple.Create(currentCycleVertex, vertex, GetEdges(currentCycleVertex, vertex))); + currentCycleVertex = vertex; + } + + cycleString = formatCycle(cycleData); + } + + var message = formatException == null ? CoreStrings.CircularDependency(cycleString) : formatException(cycleString); + throw new InvalidOperationException(message); + } + + public override IEnumerable Vertices + => _vertices; + + public override IEnumerable GetOutgoingNeighbors(TVertex from) + => _successorMap.TryGetValue(from, out var successorSet) + ? successorSet.Keys + : Enumerable.Empty(); + + public override IEnumerable GetIncomingNeighbors(TVertex to) + => _predecessorMap.TryGetValue(to, out var predecessors) + ? predecessors.Keys + : Enumerable.Empty(); + + private record struct Edge(TEdge Payload, bool RequiresBatchingBoundary); +} diff --git a/src/net/KEFCore/Shared8/NonCapturingLazyInitializer.cs b/src/net/KEFCore/Shared8/NonCapturingLazyInitializer.cs new file mode 100644 index 00000000..42412bda --- /dev/null +++ b/src/net/KEFCore/Shared8/NonCapturingLazyInitializer.cs @@ -0,0 +1,130 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +using System.Diagnostics.CodeAnalysis; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.Internal; + +internal static class NonCapturingLazyInitializer +{ + public static TValue EnsureInitialized( + [NotNull] ref TValue? target, + TParam param, + Func valueFactory) + where TValue : class + { + var tmp = Volatile.Read(ref target); + if (tmp != null) + { + Check.DebugAssert(target != null, $"target was null in {nameof(EnsureInitialized)} after check"); + return tmp; + } + + Interlocked.CompareExchange(ref target, valueFactory(param), null); + + return target; + } + + public static TValue EnsureInitialized( + [NotNull] ref TValue? target, + TParam1 param1, + TParam2 param2, + Func valueFactory) + where TValue : class + { + var tmp = Volatile.Read(ref target); + if (tmp != null) + { + Check.DebugAssert(target != null, $"target was null in {nameof(EnsureInitialized)} after check"); + return tmp; + } + + Interlocked.CompareExchange(ref target, valueFactory(param1, param2), null); + + return target; + } + + public static TValue EnsureInitialized( + [NotNull] ref TValue? target, + TParam1 param1, + TParam2 param2, + TParam3 param3, + Func valueFactory) + where TValue : class + { + var tmp = Volatile.Read(ref target); + if (tmp != null) + { + Check.DebugAssert(target != null, $"target was null in {nameof(EnsureInitialized)} after check"); + return tmp; + } + + Interlocked.CompareExchange(ref target, valueFactory(param1, param2, param3), null); + + return target; + } + + public static TValue EnsureInitialized( + ref TValue target, + ref bool initialized, + TParam param, + Func valueFactory) + where TValue : class? + { + var alreadyInitialized = Volatile.Read(ref initialized); + if (alreadyInitialized) + { + var value = Volatile.Read(ref target); + Check.DebugAssert(target != null, $"target was null in {nameof(EnsureInitialized)} after check"); + Check.DebugAssert(value != null, $"value was null in {nameof(EnsureInitialized)} after check"); + return value; + } + + Volatile.Write(ref target, valueFactory(param)); + Volatile.Write(ref initialized, true); + + return target; + } + + public static TValue EnsureInitialized( + [NotNull] ref TValue? target, + TValue value) + where TValue : class + { + var tmp = Volatile.Read(ref target); + if (tmp != null) + { + Check.DebugAssert(target != null, $"target was null in {nameof(EnsureInitialized)} after check"); + return tmp; + } + + Interlocked.CompareExchange(ref target, value, null); + + return target; + } + + public static TValue EnsureInitialized( + [NotNull] ref TValue? target, + TParam param, + Action valueFactory) + where TValue : class + { + var tmp = Volatile.Read(ref target); + if (tmp != null) + { + Check.DebugAssert(target != null, $"target was null in {nameof(EnsureInitialized)} after check"); + return tmp; + } + + valueFactory(param); + + var tmp2 = Volatile.Read(ref target); + Check.DebugAssert( + target != null && tmp2 != null, + $"{nameof(valueFactory)} did not initialize {nameof(target)} in {nameof(EnsureInitialized)}"); + return tmp2; + } +} diff --git a/src/net/KEFCore/Shared8/OrderedDictionary.KeyCollection.cs b/src/net/KEFCore/Shared8/OrderedDictionary.KeyCollection.cs new file mode 100644 index 00000000..401aec34 --- /dev/null +++ b/src/net/KEFCore/Shared8/OrderedDictionary.KeyCollection.cs @@ -0,0 +1,160 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +using System.Collections; + +namespace Microsoft.EntityFrameworkCore.Utilities +{ + internal partial class OrderedDictionary + { + /// + /// Represents the collection of keys in a . This class cannot be inherited. + /// + [DebuggerTypeProxy(typeof(DictionaryKeyCollectionDebugView<,>))] + [DebuggerDisplay("Count = {Count}")] + internal sealed class KeyCollection : IList, IReadOnlyList + { + private readonly OrderedDictionary _orderedDictionary; + + /// + /// Gets the number of elements contained in the . + /// + /// The number of elements contained in the . + public int Count => _orderedDictionary.Count; + + /// + /// Gets the key at the specified index as an O(1) operation. + /// + /// The zero-based index of the key to get. + /// The key at the specified index. + /// is less than 0.-or- is equal to or greater than . + public TKey this[int index] => ((IList>)_orderedDictionary)[index].Key; + + TKey IList.this[int index] + { + get => this[index]; + set => throw new NotSupportedException(); + } + + bool ICollection.IsReadOnly => true; + + internal KeyCollection(OrderedDictionary orderedDictionary) + { + _orderedDictionary = orderedDictionary; + } + + /// + /// Returns an enumerator that iterates through the . + /// + /// A for the . + public Enumerator GetEnumerator() => new Enumerator(_orderedDictionary); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + int IList.IndexOf(TKey item) => _orderedDictionary.IndexOf(item); + + void IList.Insert(int index, TKey item) => throw new NotSupportedException(); + + void IList.RemoveAt(int index) => throw new NotSupportedException(); + + void ICollection.Add(TKey item) => throw new NotSupportedException(); + + void ICollection.Clear() => throw new NotSupportedException(); + + bool ICollection.Contains(TKey item) => _orderedDictionary.ContainsKey(item); + + void ICollection.CopyTo(TKey[] array, int arrayIndex) + { + ArgumentNullException.ThrowIfNull(array); + if ((uint)arrayIndex > (uint)array.Length) + { + throw new ArgumentOutOfRangeException(nameof(arrayIndex)); + } + var count = Count; + if (array.Length - arrayIndex < count) + { + throw new ArgumentException(); + } + + var entries = _orderedDictionary._entries; + for (var i = 0; i < count; ++i) + { + array[i + arrayIndex] = entries[i].Key; + } + } + + bool ICollection.Remove(TKey item) => throw new NotSupportedException(); + + /// + /// Enumerates the elements of a . + /// + public struct Enumerator : IEnumerator + { + private readonly OrderedDictionary _orderedDictionary; + private readonly int _version; + private int _index; + private TKey _current; + + /// + /// Gets the element at the current position of the enumerator. + /// + /// The element in the at the current position of the enumerator. + public readonly TKey Current => _current; + + readonly object? IEnumerator.Current => _current; + + internal Enumerator(OrderedDictionary orderedDictionary) + { + _orderedDictionary = orderedDictionary; + _version = orderedDictionary._version; + _index = 0; + _current = default!; + } + + /// + /// Releases all resources used by the . + /// + public void Dispose() + { + } + + /// + /// Advances the enumerator to the next element of the . + /// + /// true if the enumerator was successfully advanced to the next element; false if the enumerator has passed the end of the collection. + /// The collection was modified after the enumerator was created. + public bool MoveNext() + { + if (_version != _orderedDictionary._version) + { + throw new InvalidOperationException(); + } + + if (_index < _orderedDictionary.Count) + { + _current = _orderedDictionary._entries[_index].Key; + ++_index; + return true; + } + _current = default!; + return false; + } + + void IEnumerator.Reset() + { + if (_version != _orderedDictionary._version) + { + throw new InvalidOperationException(); + } + + _index = 0; + _current = default!; + } + } + } + } +} diff --git a/src/net/KEFCore/Shared8/OrderedDictionary.ValueCollection.cs b/src/net/KEFCore/Shared8/OrderedDictionary.ValueCollection.cs new file mode 100644 index 00000000..71ead29c --- /dev/null +++ b/src/net/KEFCore/Shared8/OrderedDictionary.ValueCollection.cs @@ -0,0 +1,174 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +using System.Collections; + +namespace Microsoft.EntityFrameworkCore.Utilities +{ + internal partial class OrderedDictionary + { + /// + /// Represents the collection of values in a . This class cannot be inherited. + /// + [DebuggerTypeProxy(typeof(DictionaryValueCollectionDebugView<,>))] + [DebuggerDisplay("Count = {Count}")] + public sealed class ValueCollection : IList, IReadOnlyList + { + private readonly OrderedDictionary _orderedDictionary; + + /// + /// Gets the number of elements contained in the . + /// + /// The number of elements contained in the . + public int Count => _orderedDictionary.Count; + + /// + /// Gets the value at the specified index as an O(1) operation. + /// + /// The zero-based index of the value to get. + /// The value at the specified index. + /// is less than 0.-or- is equal to or greater than . + public TValue this[int index] => _orderedDictionary[index]; + + TValue IList.this[int index] + { + get => this[index]; + set => throw new NotSupportedException(); + } + + bool ICollection.IsReadOnly => true; + + internal ValueCollection(OrderedDictionary orderedDictionary) + { + _orderedDictionary = orderedDictionary; + } + + /// + /// Returns an enumerator that iterates through the . + /// + /// A for the . + public Enumerator GetEnumerator() => new Enumerator(_orderedDictionary); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + int IList.IndexOf(TValue item) + { + var comparer = EqualityComparer.Default; + var entries = _orderedDictionary._entries; + var count = Count; + for (var i = 0; i < count; ++i) + { + if (comparer.Equals(entries[i].Value, item)) + { + return i; + } + } + return -1; + } + + void IList.Insert(int index, TValue item) => throw new NotSupportedException(); + + void IList.RemoveAt(int index) => throw new NotSupportedException(); + + void ICollection.Add(TValue item) => throw new NotSupportedException(); + + void ICollection.Clear() => throw new NotSupportedException(); + + bool ICollection.Contains(TValue item) => ((IList)this).IndexOf(item) >= 0; + + void ICollection.CopyTo(TValue[] array, int arrayIndex) + { + ArgumentNullException.ThrowIfNull(array); + + if ((uint)arrayIndex > (uint)array.Length) + { + throw new ArgumentOutOfRangeException(nameof(arrayIndex)); + } + var count = Count; + if (array.Length - arrayIndex < count) + { + throw new ArgumentException(); + } + + var entries = _orderedDictionary._entries; + for (var i = 0; i < count; ++i) + { + array[i + arrayIndex] = entries[i].Value; + } + } + + bool ICollection.Remove(TValue item) => throw new NotSupportedException(); + + /// + /// Enumerates the elements of a . + /// + public struct Enumerator : IEnumerator + { + private readonly OrderedDictionary _orderedDictionary; + private readonly int _version; + private int _index; + private TValue _current; + + /// + /// Gets the element at the current position of the enumerator. + /// + /// The element in the at the current position of the enumerator. + public TValue Current => _current; + + object? IEnumerator.Current => _current; + + internal Enumerator(OrderedDictionary orderedDictionary) + { + _orderedDictionary = orderedDictionary; + _version = orderedDictionary._version; + _index = 0; + _current = default!; + } + + /// + /// Releases all resources used by the . + /// + public void Dispose() + { + } + + /// + /// Advances the enumerator to the next element of the . + /// + /// true if the enumerator was successfully advanced to the next element; false if the enumerator has passed the end of the collection. + /// The collection was modified after the enumerator was created. + public bool MoveNext() + { + if (_version != _orderedDictionary._version) + { + throw new InvalidOperationException(); + } + + if (_index < _orderedDictionary.Count) + { + _current = _orderedDictionary._entries[_index].Value; + ++_index; + return true; + } + _current = default!; + return false; + } + + void IEnumerator.Reset() + { + if (_version != _orderedDictionary._version) + { + throw new InvalidOperationException(); + } + + _index = 0; + _current = default!; + } + } + } + } +} diff --git a/src/net/KEFCore/Shared8/OrderedDictionary.cs b/src/net/KEFCore/Shared8/OrderedDictionary.cs new file mode 100644 index 00000000..67d462ef --- /dev/null +++ b/src/net/KEFCore/Shared8/OrderedDictionary.cs @@ -0,0 +1,901 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +using System.Collections; + +namespace Microsoft.EntityFrameworkCore.Utilities +{ + internal enum InsertionBehavior + { + None = 0, + OverwriteExisting = 1, + ThrowOnExisting = 2 + } + + /// + /// Represents an ordered collection of keys and values with the same performance as with O(1) lookups and adds but with O(n) inserts and removes. + /// + /// The type of the keys in the dictionary. + /// The type of the values in the dictionary. + [DebuggerTypeProxy(typeof(IDictionaryDebugView<,>))] + [DebuggerDisplay("Count = {Count}")] + internal sealed partial class OrderedDictionary : IDictionary, IReadOnlyDictionary, IList>, IReadOnlyList> + { + private struct Entry + { + public uint HashCode; + public TKey Key; + public TValue Value; + public int Next; // the index of the next item in the same bucket, -1 if last + } + + // We want to initialize without allocating arrays. We also want to avoid null checks. + // Array.Empty would give divide by zero in modulo operation. So we use static one element arrays. + // The first add will cause a resize replacing these with real arrays of three elements. + // Arrays are wrapped in a class to avoid being duplicated for each + private static readonly Entry[] InitialEntries = new Entry[1]; + // 1-based index into _entries; 0 means empty + private int[] _buckets = HashHelpers.SizeOneIntArray; + // remains contiguous and maintains order + private Entry[] _entries = InitialEntries; + private int _count; + private int _version; + // is null when comparer is EqualityComparer.Default so that the GetHashCode method is used explicitly on the object + private readonly IEqualityComparer? _comparer; + private KeyCollection? _keys; + private ValueCollection? _values; + + /// + /// Gets the number of key/value pairs contained in the . + /// + /// The number of key/value pairs contained in the . + public int Count => _count; + + /// + /// Gets the that is used to determine equality of keys for the dictionary. + /// + /// The generic interface implementation that is used to determine equality of keys for the current and to provide hash values for the keys. + public IEqualityComparer Comparer => _comparer ?? EqualityComparer.Default; + + /// + /// Gets a collection containing the keys in the . + /// + /// An containing the keys in the . + public KeyCollection Keys => _keys ??= new KeyCollection(this); + + /// + /// Gets a collection containing the values in the . + /// + /// An containing the values in the . + public ValueCollection Values => _values ??= new ValueCollection(this); + + /// + /// Gets or sets the value associated with the specified key as an O(1) operation. + /// + /// The key of the value to get or set. + /// The value associated with the specified key. If the specified key is not found, a get operation throws a , and a set operation creates a new element with the specified key. + /// is null. + /// The property is retrieved and does not exist in the collection. + public TValue this[TKey key] + { + get + { + var index = IndexOf(key); + return index < 0 + ? throw new KeyNotFoundException($"Key {key} not found in the dictionary") + : _entries[index].Value; + } + set => TryInsert(null, key, value, InsertionBehavior.OverwriteExisting); + } + + /// + /// Gets or sets the value at the specified index as an O(1) operation. + /// + /// The zero-based index of the element to get or set. + /// The value at the specified index. + /// is less than 0.-or- is equal to or greater than . + public TValue this[int index] + { + get + { + ArgumentOutOfRangeException.ThrowIfGreaterThanOrEqual(index, Count); + + return _entries[index].Value; + } + set + { + ArgumentOutOfRangeException.ThrowIfGreaterThanOrEqual(index, Count); + + _entries[index].Value = value; + } + } + + /// + /// Initializes a new instance of the class that is empty, has the default initial capacity, and uses the default equality comparer for the key type. + /// + public OrderedDictionary() + : this(0, null) + { + } + + /// + /// Initializes a new instance of the class that is empty, has the specified initial capacity, and uses the default equality comparer for the key type. + /// + /// The initial number of elements that the can contain. + /// is less than 0. + public OrderedDictionary(int capacity) + : this(capacity, null) + { + } + + /// + /// Initializes a new instance of the class that is empty, has the default initial capacity, and uses the specified . + /// + /// The implementation to use when comparing keys, or null to use the default for the type of the key. + public OrderedDictionary(IEqualityComparer comparer) + : this(0, comparer) + { + } + + /// + /// Initializes a new instance of the class that is empty, has the specified initial capacity, and uses the specified . + /// + /// The initial number of elements that the can contain. + /// The implementation to use when comparing keys, or null to use the default for the type of the key. + /// is less than 0. + public OrderedDictionary(int capacity, IEqualityComparer? comparer) + { + ArgumentOutOfRangeException.ThrowIfNegative(capacity); + + if (capacity > 0) + { + var newSize = HashHelpers.GetPrime(capacity); + _buckets = new int[newSize]; + _entries = new Entry[newSize]; + } + + if (comparer != EqualityComparer.Default) + { + _comparer = comparer; + } + } + + /// + /// Initializes a new instance of the class that contains elements copied from the specified and uses the default equality comparer for the key type. + /// + /// The whose elements are copied to the new . + /// is null. + /// contains one or more duplicate keys. + public OrderedDictionary(IEnumerable> collection) + : this(collection, null) + { + } + + /// + /// Initializes a new instance of the class that contains elements copied from the specified and uses the specified . + /// + /// The whose elements are copied to the new . + /// The implementation to use when comparing keys, or null to use the default for the type of the key. + /// is null. + /// contains one or more duplicate keys. + public OrderedDictionary(IEnumerable> collection, IEqualityComparer? comparer) + : this((collection as ICollection>)?.Count ?? 0, comparer) + { + ArgumentNullException.ThrowIfNull(collection); + + foreach (var pair in collection) + { + Add(pair.Key, pair.Value); + } + } + + /// + /// Adds the specified key and value to the dictionary as an O(1) operation. + /// + /// The key of the element to add. + /// The value of the element to add. The value can be null for reference types. + /// is null. + /// An element with the same key already exists in the . + public void Add(TKey key, TValue value) => TryInsert(null, key, value, InsertionBehavior.ThrowOnExisting); + + /// + /// Removes all keys and values from the . + /// + public void Clear() + { + if (_count > 0) + { + Array.Clear(_buckets, 0, _buckets.Length); + Array.Clear(_entries, 0, _count); + _count = 0; + ++_version; + } + } + + /// + /// Determines whether the contains the specified key as an O(1) operation. + /// + /// The key to locate in the . + /// true if the contains an element with the specified key; otherwise, false. + /// is null. + public bool ContainsKey(TKey key) => IndexOf(key) >= 0; + + /// + /// Resizes the internal data structure if necessary to ensure no additional resizing to support the specified capacity. + /// + /// The number of elements that the must be able to contain. + /// The capacity of the . + /// is less than 0. + public int EnsureCapacity(int capacity) + { + ArgumentOutOfRangeException.ThrowIfNegative(capacity); + + if (_entries.Length >= capacity) + { + return _entries.Length; + } + var newSize = HashHelpers.GetPrime(capacity); + Resize(newSize); + ++_version; + return newSize; + } + + /// + /// Returns an enumerator that iterates through the . + /// + /// An structure for the . + public Enumerator GetEnumerator() => new Enumerator(this); + + /// + /// Adds a key/value pair to the if the key does not already exist as an O(1) operation. + /// + /// The key of the element to add. + /// The value to be added, if the key does not already exist. + /// The value for the key. This will be either the existing value for the key if the key is already in the dictionary, or the new value if the key was not in the dictionary. + /// is null. + public TValue GetOrAdd(TKey key, TValue value) => GetOrAdd(key, () => value); + + /// + /// Adds a key/value pair to the by using the specified function, if the key does not already exist as an O(1) operation. + /// + /// The key of the element to add. + /// The function used to generate a value for the key. + /// The value for the key. This will be either the existing value for the key if the key is already in the dictionary, or the new value for the key as returned by valueFactory if the key was not in the dictionary. + /// is null.-or- is null. + public TValue GetOrAdd(TKey key, Func valueFactory) + { + ArgumentNullException.ThrowIfNull(valueFactory); + + var index = IndexOf(key, out var hashCode); + TValue value; + if (index < 0) + { + value = valueFactory(); + AddInternal(null, key, value, hashCode); + } + else + { + value = _entries[index].Value; + } + return value; + } + + /// + /// Returns the zero-based index of the element with the specified key within the as an O(1) operation. + /// + /// The key of the element to locate. + /// The zero-based index of the element with the specified key within the , if found; otherwise, -1. + /// is null. + public int IndexOf(TKey key) => IndexOf(key, out _); + + /// + /// Inserts the specified key/value pair into the at the specified index as an O(n) operation. + /// + /// The zero-based index of the key/value pair to insert. + /// The key of the element to insert. + /// The value of the element to insert. + /// is null. + /// An element with the same key already exists in the . + /// is less than 0.-or- is greater than . + public void Insert(int index, TKey key, TValue value) + { + ArgumentOutOfRangeException.ThrowIfGreaterThan(index, Count); + + TryInsert(index, key, value, InsertionBehavior.ThrowOnExisting); + } + + /// + /// Inserts the element in this sorted dictionary to the corresponding index using the default comparer. + /// + /// The key of the element to insert. + /// The value of the element to insert. + public void Insert(TKey key, TValue value) + => Insert(key, value, Comparer.Default); + + /// + /// Inserts the element in this sorted dictionary to the corresponding index using the default comparer. + /// + /// The key of the element to insert. + /// The value of the element to insert. + /// The comparer to use. + public void Insert(TKey key, TValue value, IComparer comparer) + { + var existingIndex = IndexOf(key, out var hashCode); + if (existingIndex >= 0) + { + throw new ArgumentException($"Key {key} is already present"); + } + + for (var i = _count - 1; i >= 0; i--) + { + if (comparer.Compare(key, _entries[i].Key) >= 0) + { + AddInternal(i + 1, key, value, hashCode); + return; + } + } + + AddInternal(0, key, value, hashCode); + } + + /// + /// Moves the element at the specified fromIndex to the specified toIndex while re-arranging the elements in between. + /// + /// The zero-based index of the element to move. + /// The zero-based index to move the element to. + /// + /// is less than 0. + /// -or- + /// is equal to or greater than + /// -or- + /// is less than 0. + /// -or- + /// is equal to or greater than + /// + public void Move(int fromIndex, int toIndex) + { + ArgumentOutOfRangeException.ThrowIfGreaterThanOrEqual(fromIndex, Count); + ArgumentOutOfRangeException.ThrowIfGreaterThanOrEqual(toIndex, Count); + + if (fromIndex == toIndex) + { + return; + } + + var entries = _entries; + var temp = entries[fromIndex]; + RemoveEntryFromBucket(fromIndex); + var direction = fromIndex < toIndex ? 1 : -1; + for (var i = fromIndex; i != toIndex; i += direction) + { + entries[i] = entries[i + direction]; + UpdateBucketIndex(i + direction, -direction); + } + AddEntryToBucket(ref temp, toIndex, _buckets); + entries[toIndex] = temp; + ++_version; + } + + /// + /// Moves the specified number of elements at the specified fromIndex to the specified toIndex while re-arranging the elements in between. + /// + /// The zero-based index of the elements to move. + /// The zero-based index to move the elements to. + /// The number of elements to move. + /// is less than 0. + /// -or- + /// is equal to or greater than . + /// -or- + /// is less than 0. + /// -or- + /// is equal to or greater than . + /// -or- + /// is less than 0. + /// + is greater than . + /// -or- + /// + is greater than . + public void MoveRange(int fromIndex, int toIndex, int count) + { + if (count == 1) + { + Move(fromIndex, toIndex); + return; + } + + ArgumentOutOfRangeException.ThrowIfGreaterThanOrEqual(fromIndex, Count); + ArgumentOutOfRangeException.ThrowIfGreaterThanOrEqual(toIndex, Count); + ArgumentOutOfRangeException.ThrowIfNegative(count); + ArgumentOutOfRangeException.ThrowIfGreaterThan(fromIndex + count, Count, nameof(fromIndex)); + ArgumentOutOfRangeException.ThrowIfGreaterThan(toIndex + count, Count, nameof(toIndex)); + + if (fromIndex == toIndex || count == 0) + { + return; + } + + var entries = _entries; + // Make a copy of the entries to move. Consider using ArrayPool instead to avoid allocations? + var entriesToMove = new Entry[count]; + for (var i = 0; i < count; ++i) + { + entriesToMove[i] = entries[fromIndex + i]; + RemoveEntryFromBucket(fromIndex + i); + } + + // Move entries in between + var direction = 1; + var amount = count; + var start = fromIndex; + var end = toIndex; + if (fromIndex > toIndex) + { + direction = -1; + amount = -count; + start = fromIndex + count - 1; + end = toIndex + count - 1; + } + for (var i = start; i != end; i += direction) + { + entries[i] = entries[i + amount]; + UpdateBucketIndex(i + amount, -amount); + } + + var buckets = _buckets; + // Copy entries to destination + for (var i = 0; i < count; ++i) + { + var temp = entriesToMove[i]; + AddEntryToBucket(ref temp, toIndex + i, buckets); + entries[toIndex + i] = temp; + } + ++_version; + } + + /// + /// Removes the value with the specified key from the as an O(n) operation. + /// + /// The key of the element to remove. + /// true if the element is successfully found and removed; otherwise, false. This method returns false if is not found in the . + /// is null. + public bool Remove(TKey key) => Remove(key, out _); + + /// + /// Removes the value with the specified key from the and returns the value as an O(n) operation. + /// + /// The key of the element to remove. + /// When this method returns, contains the value associated with the specified key, if the key is found; otherwise, the default value for the type of the parameter. This parameter is passed uninitialized. + /// true if the element is successfully found and removed; otherwise, false. This method returns false if is not found in the . + /// is null. + public bool Remove(TKey key, out TValue value) + { + var index = IndexOf(key); + if (index >= 0) + { + value = _entries[index].Value; + RemoveAt(index); + return true; + } + value = default!; + return false; + } + + /// + /// Removes the value at the specified index from the as an O(n) operation. + /// + /// The zero-based index of the element to remove. + /// is less than 0.-or- is equal to or greater than . + public void RemoveAt(int index) + { + var count = Count; + ArgumentOutOfRangeException.ThrowIfGreaterThanOrEqual(index, count); + + // Remove the entry from the bucket + RemoveEntryFromBucket(index); + + // Decrement the indices > index + var entries = _entries; + for (var i = index + 1; i < count; ++i) + { + entries[i - 1] = entries[i]; + UpdateBucketIndex(i, incrementAmount: -1); + } + --_count; + entries[_count] = default; + ++_version; + } + + /// + /// Sets the capacity of an object to the actual number of elements it contains, rounded up to a nearby, implementation-specific value. + /// + public void TrimExcess() => TrimExcess(Count); + + /// + /// Sets the capacity of an object to the specified capacity, rounded up to a nearby, implementation-specific value. + /// + /// The number of elements that the must be able to contain. + /// is less than . + public void TrimExcess(int capacity) + { + ArgumentOutOfRangeException.ThrowIfLessThan(capacity, Count); + + var newSize = HashHelpers.GetPrime(capacity); + if (newSize < _entries.Length) + { + Resize(newSize); + ++_version; + } + } + + /// + /// Tries to add the specified key and value to the dictionary as an O(1) operation. + /// + /// The key of the element to add. + /// The value of the element to add. The value can be null for reference types. + /// true if the element was added to the ; false if the already contained an element with the specified key. + /// is null. + public bool TryAdd(TKey key, TValue value) => TryInsert(null, key, value, InsertionBehavior.None); + + /// + /// Gets the value associated with the specified key as an O(1) operation. + /// + /// The key of the value to get. + /// When this method returns, contains the value associated with the specified key, if the key is found; otherwise, the default value for the type of the parameter. This parameter is passed uninitialized. + /// true if the contains an element with the specified key; otherwise, false. + /// is null. + public bool TryGetValue(TKey key, out TValue value) + { + var index = IndexOf(key); + if (index >= 0) + { + value = _entries[index].Value; + return true; + } + value = default!; + return false; + } + + #region Explicit Interface Implementation + KeyValuePair IList>.this[int index] + { + get + { + ArgumentOutOfRangeException.ThrowIfGreaterThanOrEqual(index, Count); + + var entry = _entries[index]; + return new KeyValuePair(entry.Key, entry.Value); + } + set + { + ArgumentOutOfRangeException.ThrowIfGreaterThanOrEqual(index, Count); + + var key = value.Key; + var foundIndex = IndexOf(key, out var hashCode); + // key does not exist in dictionary thus replace entry at index + if (foundIndex < 0) + { + RemoveEntryFromBucket(index); + var entry = new Entry { HashCode = hashCode, Key = key, Value = value.Value }; + AddEntryToBucket(ref entry, index, _buckets); + _entries[index] = entry; + ++_version; + } + // key already exists in dictionary at the specified index thus just replace the key and value as hashCode remains the same + else if (foundIndex == index) + { + ref var entry = ref _entries[index]; + entry.Key = key; + entry.Value = value.Value; + } + // key already exists in dictionary but not at the specified index thus throw exception as this method shouldn't affect the indices of other entries + else + { + throw new ArgumentException($"Key {key} already exists in dictionary but not at the specified index {index}"); + } + } + } + + KeyValuePair IReadOnlyList>.this[int index] => ((IList>)this)[index]; + + ICollection IDictionary.Keys => Keys; + + ICollection IDictionary.Values => Values; + + IEnumerable IReadOnlyDictionary.Keys => Keys; + + IEnumerable IReadOnlyDictionary.Values => Values; + + bool ICollection>.IsReadOnly => false; + + IEnumerator> IEnumerable>.GetEnumerator() => GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + void ICollection>.Add(KeyValuePair item) => Add(item.Key, item.Value); + + bool ICollection>.Contains(KeyValuePair item) => TryGetValue(item.Key, out var value) && EqualityComparer.Default.Equals(value, item.Value); + + void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) + { + ArgumentNullException.ThrowIfNull(array); + ArgumentOutOfRangeException.ThrowIfGreaterThan(arrayIndex, array.Length); + var count = Count; + ArgumentOutOfRangeException.ThrowIfLessThan(array.Length - arrayIndex, count); + + var entries = _entries; + for (var i = 0; i < count; ++i) + { + var entry = entries[i]; + array[i + arrayIndex] = new KeyValuePair(entry.Key, entry.Value); + } + } + + bool ICollection>.Remove(KeyValuePair item) + { + var index = IndexOf(item.Key); + if (index >= 0 && EqualityComparer.Default.Equals(_entries[index].Value, item.Value)) + { + RemoveAt(index); + return true; + } + return false; + } + + int IList>.IndexOf(KeyValuePair item) + { + var index = IndexOf(item.Key); + if (index >= 0 && !EqualityComparer.Default.Equals(_entries[index].Value, item.Value)) + { + index = -1; + } + return index; + } + + void IList>.Insert(int index, KeyValuePair item) => Insert(index, item.Key, item.Value); + #endregion + + private Entry[] Resize(int newSize) + { + var newBuckets = new int[newSize]; + var newEntries = new Entry[newSize]; + + var count = Count; + Array.Copy(_entries, newEntries, count); + for (var i = 0; i < count; ++i) + { + AddEntryToBucket(ref newEntries[i], i, newBuckets); + } + + _buckets = newBuckets; + _entries = newEntries; + return newEntries; + } + + private int IndexOf(TKey key, out uint hashCode) + { + ArgumentNullException.ThrowIfNull(key); + + var comparer = _comparer; + hashCode = (uint)(comparer?.GetHashCode(key) ?? key.GetHashCode()); + var index = _buckets[(int)(hashCode % (uint)_buckets.Length)] - 1; + if (index >= 0) + { + comparer ??= EqualityComparer.Default; + var entries = _entries; + var collisionCount = 0; + do + { + var entry = entries[index]; + if (entry.HashCode == hashCode && comparer.Equals(entry.Key, key)) + { + break; + } + index = entry.Next; + if (collisionCount >= entries.Length) + { + // The chain of entries forms a loop; which means a concurrent update has happened. + // Break out of the loop and throw, rather than looping forever. + throw new InvalidOperationException("Concurrent update detected"); + } + ++collisionCount; + } while (index >= 0); + } + return index; + } + + private bool TryInsert(int? index, TKey key, TValue value, InsertionBehavior behavior) + { + var i = IndexOf(key, out var hashCode); + if (i >= 0) + { + switch (behavior) + { + case InsertionBehavior.OverwriteExisting: + _entries[i].Value = value; + return true; + case InsertionBehavior.ThrowOnExisting: + throw new ArgumentException($"Key {key} is already present"); + default: + return false; + } + } + + AddInternal(index, key, value, hashCode); + return true; + } + + private int AddInternal(int? index, TKey key, TValue value, uint hashCode) + { + var entries = _entries; + // Check if resize is needed + var count = Count; + if (entries.Length == count || entries.Length == 1) + { + entries = Resize(HashHelpers.ExpandPrime(entries.Length)); + } + + // Increment indices >= index; + var actualIndex = index ?? count; + for (var i = count - 1; i >= actualIndex; --i) + { + entries[i + 1] = entries[i]; + UpdateBucketIndex(i, incrementAmount: 1); + } + + ref var entry = ref entries[actualIndex]; + entry.HashCode = hashCode; + entry.Key = key; + entry.Value = value; + AddEntryToBucket(ref entry, actualIndex, _buckets); + ++_count; + ++_version; + return actualIndex; + } + + // Returns the index of the next entry in the bucket + private void AddEntryToBucket(ref Entry entry, int entryIndex, int[] buckets) + { + ref var b = ref buckets[(int)(entry.HashCode % (uint)buckets.Length)]; + entry.Next = b - 1; + b = entryIndex + 1; + } + + private void RemoveEntryFromBucket(int entryIndex) + { + var entries = _entries; + var entry = entries[entryIndex]; + ref var b = ref _buckets[(int)(entry.HashCode % (uint)_buckets.Length)]; + // Bucket was pointing to removed entry. Update it to point to the next in the chain + if (b == entryIndex + 1) + { + b = entry.Next + 1; + } + else + { + // Start at the entry the bucket points to, and walk the chain until we find the entry with the index we want to remove, then fix the chain + var i = b - 1; + var collisionCount = 0; + while (true) + { + ref var e = ref entries[i]; + if (e.Next == entryIndex) + { + e.Next = entry.Next; + return; + } + i = e.Next; + if (collisionCount >= entries.Length) + { + // The chain of entries forms a loop; which means a concurrent update has happened. + // Break out of the loop and throw, rather than looping forever. + throw new InvalidOperationException("Concurrent update detected"); + } + ++collisionCount; + } + } + } + + private void UpdateBucketIndex(int entryIndex, int incrementAmount) + { + var entries = _entries; + var entry = entries[entryIndex]; + ref var b = ref _buckets[(int)(entry.HashCode % (uint)_buckets.Length)]; + // Bucket was pointing to entry. Increment the index by incrementAmount. + if (b == entryIndex + 1) + { + b += incrementAmount; + } + else + { + // Start at the entry the bucket points to, and walk the chain until we find the entry with the index we want to increment. + var i = b - 1; + var collisionCount = 0; + while (true) + { + ref var e = ref entries[i]; + if (e.Next == entryIndex) + { + e.Next += incrementAmount; + return; + } + i = e.Next; + if (collisionCount >= entries.Length) + { + // The chain of entries forms a loop; which means a concurrent update has happened. + // Break out of the loop and throw, rather than looping forever. + throw new InvalidOperationException("Concurrent update detected"); + } + ++collisionCount; + } + } + } + + /// + /// Enumerates the elements of a . + /// + public struct Enumerator : IEnumerator> + { + private readonly OrderedDictionary _orderedDictionary; + private readonly int _version; + private int _index; + private KeyValuePair _current; + + /// + /// Gets the element at the current position of the enumerator. + /// + /// The element in the at the current position of the enumerator. + public KeyValuePair Current => _current; + + object IEnumerator.Current => _current; + + internal Enumerator(OrderedDictionary orderedDictionary) + { + _orderedDictionary = orderedDictionary; + _version = orderedDictionary._version; + _index = 0; + } + + /// + /// Releases all resources used by the . + /// + public void Dispose() + { + } + + /// + /// Advances the enumerator to the next element of the . + /// + /// true if the enumerator was successfully advanced to the next element; false if the enumerator has passed the end of the collection. + /// The collection was modified after the enumerator was created. + public bool MoveNext() + { + if (_version != _orderedDictionary._version) + { + throw new InvalidOperationException("The dictionary has been modified during enumeration"); + } + + if (_index < _orderedDictionary.Count) + { + var entry = _orderedDictionary._entries[_index]; + _current = new KeyValuePair(entry.Key, entry.Value); + ++_index; + return true; + } + _current = default; + return false; + } + + void IEnumerator.Reset() + { + if (_version != _orderedDictionary._version) + { + throw new InvalidOperationException("The dictionary has been modified during enumeration"); + } + + _index = 0; + _current = default; + } + } + } +} diff --git a/src/net/KEFCore/Shared8/PropertyInfoExtensions.cs b/src/net/KEFCore/Shared8/PropertyInfoExtensions.cs new file mode 100644 index 00000000..e93d8c47 --- /dev/null +++ b/src/net/KEFCore/Shared8/PropertyInfoExtensions.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +// ReSharper disable once CheckNamespace + +namespace System.Reflection; + +[DebuggerStepThrough] +internal static class PropertyInfoExtensions +{ + public static bool IsStatic(this PropertyInfo property) + => (property.GetMethod ?? property.SetMethod)!.IsStatic; + + public static bool IsCandidateProperty(this MemberInfo memberInfo, bool needsWrite = true, bool publicOnly = true) + => memberInfo is PropertyInfo propertyInfo + ? !propertyInfo.IsStatic() + && propertyInfo.CanRead + && (!needsWrite || propertyInfo.FindSetterProperty() != null) + && propertyInfo.GetMethod != null + && (!publicOnly || propertyInfo.GetMethod.IsPublic) + && propertyInfo.GetIndexParameters().Length == 0 + : memberInfo is FieldInfo { IsStatic: false } fieldInfo + && (!publicOnly || fieldInfo.IsPublic); + + public static bool IsIndexerProperty(this PropertyInfo propertyInfo) + { + var indexParams = propertyInfo.GetIndexParameters(); + return indexParams.Length == 1 + && indexParams[0].ParameterType == typeof(string); + } + + public static PropertyInfo? FindGetterProperty(this PropertyInfo propertyInfo) + => propertyInfo.DeclaringType! + .GetPropertiesInHierarchy(propertyInfo.GetSimpleMemberName()) + .FirstOrDefault(p => p.GetMethod != null); + + public static PropertyInfo? FindSetterProperty(this PropertyInfo propertyInfo) + => propertyInfo.DeclaringType! + .GetPropertiesInHierarchy(propertyInfo.GetSimpleMemberName()) + .FirstOrDefault(p => p.SetMethod != null); +} diff --git a/src/net/KEFCore/Shared8/SharedStopwatch.cs b/src/net/KEFCore/Shared8/SharedStopwatch.cs new file mode 100644 index 00000000..1c012199 --- /dev/null +++ b/src/net/KEFCore/Shared8/SharedStopwatch.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Utilities; + +// Copied from https://github.com/dotnet/roslyn/blob/main/src/Compilers/Core/Portable/InternalUtilities/SharedStopwatch.cs +internal readonly struct SharedStopwatch +{ + private static readonly Stopwatch Stopwatch = Stopwatch.StartNew(); + + private readonly TimeSpan _started; + + private SharedStopwatch(TimeSpan started) + { + _started = started; + } + + public TimeSpan Elapsed + => Stopwatch.Elapsed - _started; + + public static SharedStopwatch StartNew() + { + // This call to StartNewCore isn't required, but is included to avoid measurement errors + // which can occur during periods of high allocation activity. In some cases, calls to Stopwatch + // operations can block at their return point on the completion of a background GC operation. When + // this occurs, the GC wait time ends up included in the measured time span. In the event the first + // call to StartNewCore blocked on a GC operation, the second call will most likely occur when the + // GC is no longer active. In practice, a substantial improvement to the consistency of analyzer + // timing data was observed. + // + // Note that the call to SharedStopwatch.Elapsed is not affected, because the GC wait will occur + // after the timer has already recorded its stop time. + _ = StartNewCore(); + return StartNewCore(); + } + + private static SharedStopwatch StartNewCore() + => new(Stopwatch.Elapsed); +} diff --git a/src/net/KEFCore/Shared8/SharedTypeExtensions.cs b/src/net/KEFCore/Shared8/SharedTypeExtensions.cs new file mode 100644 index 00000000..4a116ef6 --- /dev/null +++ b/src/net/KEFCore/Shared8/SharedTypeExtensions.cs @@ -0,0 +1,619 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Text; + +// ReSharper disable once CheckNamespace +namespace System; + +[DebuggerStepThrough] +internal static class SharedTypeExtensions +{ + private static readonly Dictionary BuiltInTypeNames = new() + { + { typeof(bool), "bool" }, + { typeof(byte), "byte" }, + { typeof(char), "char" }, + { typeof(decimal), "decimal" }, + { typeof(double), "double" }, + { typeof(float), "float" }, + { typeof(int), "int" }, + { typeof(long), "long" }, + { typeof(object), "object" }, + { typeof(sbyte), "sbyte" }, + { typeof(short), "short" }, + { typeof(string), "string" }, + { typeof(uint), "uint" }, + { typeof(ulong), "ulong" }, + { typeof(ushort), "ushort" }, + { typeof(void), "void" } + }; + + public static Type UnwrapNullableType(this Type type) + => Nullable.GetUnderlyingType(type) ?? type; + + public static bool IsNullableValueType(this Type type) + => type.IsConstructedGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>); + + public static bool IsNullableType(this Type type) + => !type.IsValueType || type.IsNullableValueType(); + + public static bool IsValidEntityType(this Type type) + => type is { IsClass: true, IsArray: false } + && type != typeof(string); + + public static bool IsValidComplexType(this Type type) + => !type.IsArray + && !type.IsInterface + && !IsScalarType(type); + + public static bool IsScalarType(this Type type) + => type == typeof(string) + || CommonTypeDictionary.ContainsKey(type); + + public static bool IsPropertyBagType([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] this Type type) + { + if (type.IsGenericTypeDefinition) + { + return false; + } + + var types = GetGenericTypeImplementations(type, typeof(IDictionary<,>)); + return types.Any( + t => t.GetGenericArguments()[0] == typeof(string) + && t.GetGenericArguments()[1] == typeof(object)); + } + + public static Type MakeNullable(this Type type, bool nullable = true) + => type.IsNullableType() == nullable + ? type + : nullable + ? typeof(Nullable<>).MakeGenericType(type) + : type.UnwrapNullableType(); + + public static bool IsNumeric(this Type type) + { + type = type.UnwrapNullableType(); + + return type.IsInteger() + || type == typeof(decimal) + || type == typeof(float) + || type == typeof(double); + } + + public static bool IsInteger(this Type type) + { + type = type.UnwrapNullableType(); + + return type == typeof(int) + || type == typeof(long) + || type == typeof(short) + || type == typeof(byte) + || type == typeof(uint) + || type == typeof(ulong) + || type == typeof(ushort) + || type == typeof(sbyte) + || type == typeof(char); + } + + public static bool IsSignedInteger(this Type type) + => type == typeof(int) + || type == typeof(long) + || type == typeof(short) + || type == typeof(sbyte); + + public static bool IsAnonymousType(this Type type) + => type.Name.StartsWith("<>", StringComparison.Ordinal) + && type.GetCustomAttributes(typeof(CompilerGeneratedAttribute), inherit: false).Length > 0 + && type.Name.Contains("AnonymousType"); + + public static PropertyInfo? GetAnyProperty( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.NonPublicProperties)] + this Type type, + string name) + { + var props = type.GetRuntimeProperties().Where(p => p.Name == name).ToList(); + if (props.Count > 1) + { + throw new AmbiguousMatchException(); + } + + return props.SingleOrDefault(); + } + + public static bool IsInstantiable(this Type type) + => type is { IsAbstract: false, IsInterface: false } + && (!type.IsGenericType || !type.IsGenericTypeDefinition); + + public static Type UnwrapEnumType(this Type type) + { + var isNullable = type.IsNullableType(); + var underlyingNonNullableType = isNullable ? type.UnwrapNullableType() : type; + if (!underlyingNonNullableType.IsEnum) + { + return type; + } + + var underlyingEnumType = Enum.GetUnderlyingType(underlyingNonNullableType); + return isNullable ? MakeNullable(underlyingEnumType) : underlyingEnumType; + } + + public static Type GetSequenceType([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] this Type type) + { + var sequenceType = TryGetSequenceType(type); + if (sequenceType == null) + { + throw new ArgumentException($"The type {type.Name} does not represent a sequence"); + } + + return sequenceType; + } + + public static Type? TryGetSequenceType([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] this Type type) + => type.TryGetElementType(typeof(IEnumerable<>)) + ?? type.TryGetElementType(typeof(IAsyncEnumerable<>)); + + public static Type? TryGetElementType( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] this Type type, + Type interfaceOrBaseType) + { + if (type.IsGenericTypeDefinition) + { + return null; + } + + var types = GetGenericTypeImplementations(type, interfaceOrBaseType); + + Type? singleImplementation = null; + foreach (var implementation in types) + { + if (singleImplementation == null) + { + singleImplementation = implementation; + } + else + { + singleImplementation = null; + break; + } + } + + return singleImplementation?.GenericTypeArguments.FirstOrDefault(); + } + + public static bool IsCompatibleWith(this Type propertyType, Type fieldType) + { + if (propertyType.IsAssignableFrom(fieldType) + || fieldType.IsAssignableFrom(propertyType)) + { + return true; + } + + var propertyElementType = propertyType.TryGetSequenceType(); + var fieldElementType = fieldType.TryGetSequenceType(); + + return propertyElementType != null + && fieldElementType != null + && IsCompatibleWith(propertyElementType, fieldElementType); + } + + public static IEnumerable GetGenericTypeImplementations(this Type type, Type interfaceOrBaseType) + { + var typeInfo = type.GetTypeInfo(); + if (!typeInfo.IsGenericTypeDefinition) + { + var baseTypes = interfaceOrBaseType.GetTypeInfo().IsInterface + ? typeInfo.ImplementedInterfaces + : type.GetBaseTypes(); + foreach (var baseType in baseTypes) + { + if (baseType.IsGenericType + && baseType.GetGenericTypeDefinition() == interfaceOrBaseType) + { + yield return baseType; + } + } + + if (type.IsGenericType + && type.GetGenericTypeDefinition() == interfaceOrBaseType) + { + yield return type; + } + } + } + + public static IEnumerable GetBaseTypes(this Type type) + { + var currentType = type.BaseType; + + while (currentType != null) + { + yield return currentType; + + currentType = currentType.BaseType; + } + } + + public static List GetBaseTypesAndInterfacesInclusive(this Type type) + { + var baseTypes = new List(); + var typesToProcess = new Queue(); + typesToProcess.Enqueue(type); + + while (typesToProcess.Count > 0) + { + type = typesToProcess.Dequeue(); + baseTypes.Add(type); + + if (type.IsNullableValueType()) + { + typesToProcess.Enqueue(Nullable.GetUnderlyingType(type)!); + } + + if (type.IsConstructedGenericType) + { + typesToProcess.Enqueue(type.GetGenericTypeDefinition()); + } + + if (type is { IsGenericTypeDefinition: false, IsInterface: false }) + { + if (type.BaseType != null) + { + typesToProcess.Enqueue(type.BaseType); + } + + foreach (var @interface in GetDeclaredInterfaces(type)) + { + typesToProcess.Enqueue(@interface); + } + } + } + + return baseTypes; + } + + public static IEnumerable GetTypesInHierarchy(this Type type) + { + var currentType = type; + + while (currentType != null) + { + yield return currentType; + + currentType = currentType.BaseType; + } + } + + public static IEnumerable GetDeclaredInterfaces( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] this Type type) + { + var interfaces = type.GetInterfaces(); + if (type.BaseType == typeof(object) + || type.BaseType == null) + { + return interfaces; + } + + return interfaces.Except(GetInterfacesSuppressed(type.BaseType)); + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070", Justification = "https://github.com/dotnet/linker/issues/2473")] + static IEnumerable GetInterfacesSuppressed(Type type) + => type.GetInterfaces(); + } + + public static ConstructorInfo? GetDeclaredConstructor( + [DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)] + this Type type, + Type[]? types) + { + types ??= Array.Empty(); + + return type.GetTypeInfo().DeclaredConstructors + .SingleOrDefault( + c => !c.IsStatic + && c.GetParameters().Select(p => p.ParameterType).SequenceEqual(types))!; + } + + public static IEnumerable GetPropertiesInHierarchy(this Type type, string name) + { + var currentType = type; + do + { + var typeInfo = currentType.GetTypeInfo(); + foreach (var propertyInfo in typeInfo.DeclaredProperties) + { + if (propertyInfo.Name.Equals(name, StringComparison.Ordinal) + && !(propertyInfo.GetMethod ?? propertyInfo.SetMethod)!.IsStatic) + { + yield return propertyInfo; + } + } + + currentType = typeInfo.BaseType; + } + while (currentType != null); + } + + // Looking up the members through the whole hierarchy allows to find inherited private members. + public static IEnumerable GetMembersInHierarchy(this Type type) + { + var currentType = type; + + do + { + // Do the whole hierarchy for properties first since looking for fields is slower. + foreach (var propertyInfo in currentType.GetRuntimeProperties().Where(pi => !(pi.GetMethod ?? pi.SetMethod)!.IsStatic)) + { + yield return propertyInfo; + } + + foreach (var fieldInfo in currentType.GetRuntimeFields().Where(f => !f.IsStatic)) + { + yield return fieldInfo; + } + + currentType = currentType.BaseType; + } + while (currentType != null); + } + + public static IEnumerable GetMembersInHierarchy( + [DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicProperties + | DynamicallyAccessedMemberTypes.NonPublicProperties + | DynamicallyAccessedMemberTypes.PublicFields + | DynamicallyAccessedMemberTypes.NonPublicFields)] + this Type type, + string name) + => type.GetMembersInHierarchy().Where(m => m.Name == name); + + private static readonly Dictionary CommonTypeDictionary = new() + { +#pragma warning disable IDE0034 // Simplify 'default' expression - default causes default(object) + { typeof(int), default(int) }, + { typeof(Guid), default(Guid) }, + { typeof(DateOnly), default(DateOnly) }, + { typeof(DateTime), default(DateTime) }, + { typeof(DateTimeOffset), default(DateTimeOffset) }, + { typeof(TimeOnly), default(TimeOnly) }, + { typeof(long), default(long) }, + { typeof(bool), default(bool) }, + { typeof(double), default(double) }, + { typeof(short), default(short) }, + { typeof(float), default(float) }, + { typeof(byte), default(byte) }, + { typeof(char), default(char) }, + { typeof(uint), default(uint) }, + { typeof(ushort), default(ushort) }, + { typeof(ulong), default(ulong) }, + { typeof(sbyte), default(sbyte) } +#pragma warning restore IDE0034 // Simplify 'default' expression + }; + + public static object? GetDefaultValue( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] + this Type type) + { + if (!type.IsValueType) + { + return null; + } + + // A bit of perf code to avoid calling Activator.CreateInstance for common types and + // to avoid boxing on every call. This is about 50% faster than just calling CreateInstance + // for all value types. + return CommonTypeDictionary.TryGetValue(type, out var value) + ? value + : Activator.CreateInstance(type); + } + + [RequiresUnreferencedCode("Gets all types from the given assembly - unsafe for trimming")] + public static IEnumerable GetConstructibleTypes(this Assembly assembly) + => assembly.GetLoadableDefinedTypes().Where( + t => t is { IsAbstract: false, IsGenericTypeDefinition: false }); + + [RequiresUnreferencedCode("Gets all types from the given assembly - unsafe for trimming")] + public static IEnumerable GetLoadableDefinedTypes(this Assembly assembly) + { + try + { + return assembly.DefinedTypes; + } + catch (ReflectionTypeLoadException ex) + { + return ex.Types.Where(t => t != null).Select(IntrospectionExtensions.GetTypeInfo!); + } + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public static string DisplayName(this Type type, bool fullName = true, bool compilable = false) + { + var stringBuilder = new StringBuilder(); + ProcessType(stringBuilder, type, fullName, compilable); + return stringBuilder.ToString(); + } + + private static void ProcessType(StringBuilder builder, Type type, bool fullName, bool compilable) + { + if (type.IsGenericType) + { + var genericArguments = type.GetGenericArguments(); + ProcessGenericType(builder, type, genericArguments, genericArguments.Length, fullName, compilable); + } + else if (type.IsArray) + { + ProcessArrayType(builder, type, fullName, compilable); + } + else if (BuiltInTypeNames.TryGetValue(type, out var builtInName)) + { + builder.Append(builtInName); + } + else if (!type.IsGenericParameter) + { + if (compilable) + { + if (type.IsNested) + { + ProcessType(builder, type.DeclaringType!, fullName, compilable); + builder.Append('.'); + } + else if (fullName) + { + builder.Append(type.Namespace).Append('.'); + } + + builder.Append(type.Name); + } + else + { + builder.Append(fullName ? type.FullName : type.Name); + } + } + } + + private static void ProcessArrayType(StringBuilder builder, Type type, bool fullName, bool compilable) + { + var innerType = type; + while (innerType.IsArray) + { + innerType = innerType.GetElementType()!; + } + + ProcessType(builder, innerType, fullName, compilable); + + while (type.IsArray) + { + builder.Append('['); + builder.Append(',', type.GetArrayRank() - 1); + builder.Append(']'); + type = type.GetElementType()!; + } + } + + private static void ProcessGenericType( + StringBuilder builder, + Type type, + Type[] genericArguments, + int length, + bool fullName, + bool compilable) + { + if (type.IsConstructedGenericType + && type.GetGenericTypeDefinition() == typeof(Nullable<>)) + { + ProcessType(builder, type.UnwrapNullableType(), fullName, compilable); + builder.Append('?'); + return; + } + + var offset = type.IsNested ? type.DeclaringType!.GetGenericArguments().Length : 0; + + if (compilable) + { + if (type.IsNested) + { + ProcessType(builder, type.DeclaringType!, fullName, compilable); + builder.Append('.'); + } + else if (fullName) + { + builder.Append(type.Namespace); + builder.Append('.'); + } + } + else + { + if (fullName) + { + if (type.IsNested) + { + ProcessGenericType(builder, type.DeclaringType!, genericArguments, offset, fullName, compilable); + builder.Append('+'); + } + else + { + builder.Append(type.Namespace); + builder.Append('.'); + } + } + } + + var genericPartIndex = type.Name.IndexOf('`'); + if (genericPartIndex <= 0) + { + builder.Append(type.Name); + return; + } + + builder.Append(type.Name, 0, genericPartIndex); + builder.Append('<'); + + for (var i = offset; i < length; i++) + { + ProcessType(builder, genericArguments[i], fullName, compilable); + if (i + 1 == length) + { + continue; + } + + builder.Append(','); + if (!genericArguments[i + 1].IsGenericParameter) + { + builder.Append(' '); + } + } + + builder.Append('>'); + } + + public static IEnumerable GetNamespaces(this Type type) + { + if (BuiltInTypeNames.ContainsKey(type)) + { + yield break; + } + + if (type.IsArray) + { + foreach (var ns in type.GetElementType()!.GetNamespaces()) + { + yield return ns; + } + + yield break; + } + + yield return type.Namespace!; + + if (type.IsGenericType) + { + foreach (var typeArgument in type.GenericTypeArguments) + { + foreach (var ns in typeArgument.GetNamespaces()) + { + yield return ns; + } + } + } + } + + public static ConstantExpression GetDefaultValueConstant(this Type type) + => (ConstantExpression)GenerateDefaultValueConstantMethod + .MakeGenericMethod(type).Invoke(null, Array.Empty())!; + + private static readonly MethodInfo GenerateDefaultValueConstantMethod = + typeof(SharedTypeExtensions).GetTypeInfo().GetDeclaredMethod(nameof(GenerateDefaultValueConstant))!; + + private static ConstantExpression GenerateDefaultValueConstant() + => Expression.Constant(default(TDefault), typeof(TDefault)); +} diff --git a/src/net/KEFCore/Shared8/StringBuilderExtensions.cs b/src/net/KEFCore/Shared8/StringBuilderExtensions.cs new file mode 100644 index 00000000..0e0b253f --- /dev/null +++ b/src/net/KEFCore/Shared8/StringBuilderExtensions.cs @@ -0,0 +1,111 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Globalization; + +namespace System.Text; + +internal static class StringBuilderExtensions +{ + public static StringBuilder AppendJoin( + this StringBuilder stringBuilder, + IEnumerable values, + string separator = ", ") + => stringBuilder.AppendJoin(values, (sb, value) => sb.Append(value), separator); + + public static StringBuilder AppendJoin( + this StringBuilder stringBuilder, + string separator, + params string[] values) + => stringBuilder.AppendJoin(values, (sb, value) => sb.Append(value), separator); + + public static StringBuilder AppendJoin( + this StringBuilder stringBuilder, + IEnumerable values, + Action joinAction, + string separator = ", ") + { + var appended = false; + + foreach (var value in values) + { + joinAction(stringBuilder, value); + stringBuilder.Append(separator); + appended = true; + } + + if (appended) + { + stringBuilder.Length -= separator.Length; + } + + return stringBuilder; + } + + public static StringBuilder AppendJoin( + this StringBuilder stringBuilder, + IEnumerable values, + Func joinFunc, + string separator = ", ") + { + var appended = false; + + foreach (var value in values) + { + if (joinFunc(stringBuilder, value)) + { + stringBuilder.Append(separator); + appended = true; + } + } + + if (appended) + { + stringBuilder.Length -= separator.Length; + } + + return stringBuilder; + } + + public static StringBuilder AppendJoin( + this StringBuilder stringBuilder, + IEnumerable values, + TParam param, + Action joinAction, + string separator = ", ") + { + var appended = false; + + foreach (var value in values) + { + joinAction(stringBuilder, value, param); + stringBuilder.Append(separator); + appended = true; + } + + if (appended) + { + stringBuilder.Length -= separator.Length; + } + + return stringBuilder; + } + + public static void AppendBytes(this StringBuilder builder, byte[] bytes) + { + builder.Append("'0x"); + + for (var i = 0; i < bytes.Length; i++) + { + if (i > 31) + { + builder.Append("..."); + break; + } + + builder.Append(bytes[i].ToString("X2", CultureInfo.InvariantCulture)); + } + + builder.Append('\''); + } +} diff --git a/src/net/KEFCore/Storage/Internal/EntityTypeProducer.cs b/src/net/KEFCore/Storage/Internal/EntityTypeProducer.cs index 7bef6792..099ad82f 100644 --- a/src/net/KEFCore/Storage/Internal/EntityTypeProducer.cs +++ b/src/net/KEFCore/Storage/Internal/EntityTypeProducer.cs @@ -220,6 +220,7 @@ public EntityTypeProducer(IEntityType entityType, IKafkaCluster cluster) }; if (_onChangeEvent != null) { + _kafkaCompactedReplicator.OnRemoteAdd += KafkaCompactedReplicator_OnRemoteAdd; _kafkaCompactedReplicator.OnRemoteUpdate += KafkaCompactedReplicator_OnRemoteUpdate; _kafkaCompactedReplicator.OnRemoteRemove += KafkaCompactedReplicator_OnRemoteRemove; } @@ -238,6 +239,7 @@ public EntityTypeProducer(IEntityType entityType, IKafkaCluster cluster) _streamData = new KafkaStreamsTableRetriever(cluster, entityType, _keySerdes!, _valueSerdes!); } } + /// public virtual IEntityType EntityType => _entityType; /// @@ -296,11 +298,19 @@ public IEnumerable ValueBuffers } } + private void KafkaCompactedReplicator_OnRemoteAdd(IKNetCompactedReplicator arg1, KeyValuePair arg2) + { + Task.Factory.StartNew(() => + { + _onChangeEvent?.Invoke(new EntityTypeChanged(_entityType, EntityTypeChanged.ChangeKindType.Added, arg2.Key)); + }); + } + private void KafkaCompactedReplicator_OnRemoteUpdate(IKNetCompactedReplicator arg1, KeyValuePair arg2) { Task.Factory.StartNew(() => { - _onChangeEvent?.Invoke(new EntityTypeChanged(_entityType, EntityTypeChanged.ChangeKindType.Upserted, arg2.Key)); + _onChangeEvent?.Invoke(new EntityTypeChanged(_entityType, EntityTypeChanged.ChangeKindType.Updated, arg2.Key)); }); } diff --git a/src/net/KEFCore/Storage/Internal/KafkaCluster.cs b/src/net/KEFCore/Storage/Internal/KafkaCluster.cs index d91b8cbd..886ca894 100644 --- a/src/net/KEFCore/Storage/Internal/KafkaCluster.cs +++ b/src/net/KEFCore/Storage/Internal/KafkaCluster.cs @@ -79,11 +79,19 @@ public virtual KafkaIntegerValueGenerator GetIntegerValueGenerator( property, entityType.GetDerivedTypesInclusive().Select(type => EnsureTable(type)).ToArray()); +#else + var entityType = property.DeclaringType; + + return EnsureTable(entityType.ContainingEntityType).GetIntegerValueGenerator( + property, + entityType.ContainingEntityType.GetDerivedTypesInclusive().Select(type => EnsureTable(type)).ToArray()); +#endif } } /// diff --git a/src/net/KEFCore/Storage/Internal/KafkaTypeMapping.cs b/src/net/KEFCore/Storage/Internal/KafkaTypeMapping.cs index 27f8f3be..beb4e5d1 100644 --- a/src/net/KEFCore/Storage/Internal/KafkaTypeMapping.cs +++ b/src/net/KEFCore/Storage/Internal/KafkaTypeMapping.cs @@ -15,7 +15,9 @@ * * Refer to LICENSE for more information. */ - +#if NET8_0 +using Microsoft.EntityFrameworkCore.Storage.Json; +#endif namespace MASES.EntityFrameworkCore.KNet.Storage.Internal; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -45,7 +47,16 @@ private KafkaTypeMapping(CoreTypeMappingParameters parameters) : base(parameters) { } +#if !NET8_0 /// public override CoreTypeMapping Clone(ValueConverter? converter) => new KafkaTypeMapping(Parameters.WithComposedConverter(converter)); +#else + /// + protected override CoreTypeMapping Clone(CoreTypeMappingParameters parameters) + => new KafkaTypeMapping(parameters); + /// + public override CoreTypeMapping WithComposedConverter(ValueConverter? converter, ValueComparer? comparer = null, ValueComparer? keyComparer = null, CoreTypeMapping? elementMapping = null, JsonValueReaderWriter? jsonValueReaderWriter = null) + => new KafkaTypeMapping(Parameters.WithComposedConverter(converter, comparer, keyComparer, elementMapping, jsonValueReaderWriter)); +#endif } diff --git a/src/net/KEFCore/ValueGeneration/Internal/KafkaValueGeneratorSelector.cs b/src/net/KEFCore/ValueGeneration/Internal/KafkaValueGeneratorSelector.cs index d4546c56..0465b08a 100644 --- a/src/net/KEFCore/ValueGeneration/Internal/KafkaValueGeneratorSelector.cs +++ b/src/net/KEFCore/ValueGeneration/Internal/KafkaValueGeneratorSelector.cs @@ -38,6 +38,7 @@ public KafkaValueGeneratorSelector( { _kafkaCluster = kafkaDatabase.Cluster; } +#if !NET8_0 /// public override ValueGenerator Select(IProperty property, IEntityType entityType) => property.GetValueGeneratorFactory() == null @@ -45,7 +46,15 @@ public override ValueGenerator Select(IProperty property, IEntityType entityType && property.ClrType.UnwrapNullableType() != typeof(char) ? GetOrCreate(property) : base.Select(property, entityType); - +#else + /// + public override ValueGenerator Select(IProperty property, ITypeBase typeBase) + => property.GetValueGeneratorFactory() == null + && property.ClrType.IsInteger() + && property.ClrType.UnwrapNullableType() != typeof(char) + ? GetOrCreate(property) + : base.Select(property, typeBase); +#endif private ValueGenerator GetOrCreate(IProperty property) { var type = property.ClrType.UnwrapNullableType().UnwrapEnumType(); @@ -89,9 +98,14 @@ private ValueGenerator GetOrCreate(IProperty property) { return _kafkaCluster.GetIntegerValueGenerator(property); } - +#if !NET8_0 throw new ArgumentException( CoreStrings.InvalidValueGeneratorFactoryProperty( "KafkaIntegerValueGeneratorFactory", property.Name, property.DeclaringEntityType.DisplayName())); +#else + throw new ArgumentException( + CoreStrings.InvalidValueGeneratorFactoryProperty( + "KafkaIntegerValueGeneratorFactory", property.Name, property.DeclaringType.DisplayName())); +#endif } } diff --git a/src/net/templates/templatepack.csproj b/src/net/templates/templatepack.csproj index a3e2c44d..e07b5572 100644 --- a/src/net/templates/templatepack.csproj +++ b/src/net/templates/templatepack.csproj @@ -23,6 +23,6 @@ - + diff --git a/src/net/templates/templates/kefcoreApp/kefcoreApp.csproj b/src/net/templates/templates/kefcoreApp/kefcoreApp.csproj index 92f6865e..e3aed586 100644 --- a/src/net/templates/templates/kefcoreApp/kefcoreApp.csproj +++ b/src/net/templates/templates/kefcoreApp/kefcoreApp.csproj @@ -1,7 +1,7 @@  latest - net6.0;net7.0 + net6.0;net7.0;net8.0 diff --git a/src/net/templates/templates/kefcoreAppWithEvents/kefcoreAppWithEvents.csproj b/src/net/templates/templates/kefcoreAppWithEvents/kefcoreAppWithEvents.csproj index 92f6865e..e3aed586 100644 --- a/src/net/templates/templates/kefcoreAppWithEvents/kefcoreAppWithEvents.csproj +++ b/src/net/templates/templates/kefcoreAppWithEvents/kefcoreAppWithEvents.csproj @@ -1,7 +1,7 @@  latest - net6.0;net7.0 + net6.0;net7.0;net8.0 diff --git a/test/Common/Common.props b/test/Common/Common.props index c6477ec0..fc01f005 100644 --- a/test/Common/Common.props +++ b/test/Common/Common.props @@ -12,14 +12,13 @@ - - - - - - + + + + + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/test/KEFCore.Benchmark.Test/Benchmark.Avro.KNetReplicator.json b/test/KEFCore.Benchmark.Test/Benchmark.Avro.KNetReplicator.json index 5240969e..a1d17bb0 100644 --- a/test/KEFCore.Benchmark.Test/Benchmark.Avro.KNetReplicator.json +++ b/test/KEFCore.Benchmark.Test/Benchmark.Avro.KNetReplicator.json @@ -1,6 +1,6 @@ { "UseAvro": true, "DatabaseName": "TestDBBenchmarkAvro", - "BootstrapServers": "192.168.1.103:9092", + "BootstrapServers": "192.168.1.105:9092", "NumberOfExecutions": 10 } diff --git a/test/KEFCore.Benchmark.Test/Benchmark.Avro.KafkaStreams.json b/test/KEFCore.Benchmark.Test/Benchmark.Avro.KafkaStreams.json index ad97188e..ed105d41 100644 --- a/test/KEFCore.Benchmark.Test/Benchmark.Avro.KafkaStreams.json +++ b/test/KEFCore.Benchmark.Test/Benchmark.Avro.KafkaStreams.json @@ -2,6 +2,6 @@ "UseAvro": true, "DatabaseName": "TestDBBenchmarkAvro", "UseCompactedReplicator": false, - "BootstrapServers": "192.168.1.103:9092", + "BootstrapServers": "192.168.1.105:9092", "NumberOfExecutions": 10 } diff --git a/test/KEFCore.Benchmark.Test/Benchmark.KNetReplicator.json b/test/KEFCore.Benchmark.Test/Benchmark.KNetReplicator.json index a3d394ce..571c0beb 100644 --- a/test/KEFCore.Benchmark.Test/Benchmark.KNetReplicator.json +++ b/test/KEFCore.Benchmark.Test/Benchmark.KNetReplicator.json @@ -1,5 +1,5 @@ { "DatabaseName": "TestDBBenchmark", - "BootstrapServers": "192.168.1.103:9092", + "BootstrapServers": "192.168.1.105:9092", "NumberOfExecutions": 10 } diff --git a/test/KEFCore.Benchmark.Test/Benchmark.KafkaStreams.json b/test/KEFCore.Benchmark.Test/Benchmark.KafkaStreams.json index 0ce92c41..9eda823e 100644 --- a/test/KEFCore.Benchmark.Test/Benchmark.KafkaStreams.json +++ b/test/KEFCore.Benchmark.Test/Benchmark.KafkaStreams.json @@ -1,6 +1,6 @@ { "DatabaseName": "TestDBBenchmark", "UseCompactedReplicator": false, - "BootstrapServers": "192.168.1.103:9092", + "BootstrapServers": "192.168.1.105:9092", "NumberOfExecutions": 10 } diff --git a/test/KEFCore.Benchmark.Test/Benchmark.Protobuf.KNetReplicator.json b/test/KEFCore.Benchmark.Test/Benchmark.Protobuf.KNetReplicator.json index 902c9431..05902135 100644 --- a/test/KEFCore.Benchmark.Test/Benchmark.Protobuf.KNetReplicator.json +++ b/test/KEFCore.Benchmark.Test/Benchmark.Protobuf.KNetReplicator.json @@ -1,6 +1,6 @@ { "UseProtobuf": true, "DatabaseName": "TestDBBenchmarkProtobuf", - "BootstrapServers": "192.168.1.103:9092", + "BootstrapServers": "192.168.1.105:9092", "NumberOfExecutions": 10 } diff --git a/test/KEFCore.Benchmark.Test/Benchmark.Protobuf.KafkaStreams.json b/test/KEFCore.Benchmark.Test/Benchmark.Protobuf.KafkaStreams.json index cc6af1b8..e0d9de2b 100644 --- a/test/KEFCore.Benchmark.Test/Benchmark.Protobuf.KafkaStreams.json +++ b/test/KEFCore.Benchmark.Test/Benchmark.Protobuf.KafkaStreams.json @@ -2,6 +2,6 @@ "UseProtobuf": true, "DatabaseName": "TestDBBenchmarkProtobuf", "UseCompactedReplicator": false, - "BootstrapServers": "192.168.1.103:9092", + "BootstrapServers": "192.168.1.105:9092", "NumberOfExecutions": 10 } diff --git a/test/KEFCore.Benchmark.Test/KEFCore.Benchmark.Test.csproj b/test/KEFCore.Benchmark.Test/KEFCore.Benchmark.Test.csproj index 808f7550..64754747 100644 --- a/test/KEFCore.Benchmark.Test/KEFCore.Benchmark.Test.csproj +++ b/test/KEFCore.Benchmark.Test/KEFCore.Benchmark.Test.csproj @@ -43,4 +43,11 @@ PreserveNewest + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + diff --git a/test/KEFCore.Benchmark.Test/Program.cs b/test/KEFCore.Benchmark.Test/Program.cs index b5384c58..6bbfa4f1 100644 --- a/test/KEFCore.Benchmark.Test/Program.cs +++ b/test/KEFCore.Benchmark.Test/Program.cs @@ -130,13 +130,13 @@ static void Main(string[] args) { Url = "http://blogs.msdn.com/adonet" + i.ToString(), Posts = new List() - { - new Post() { - Title = "title", - Content = i.ToString() - } - }, + new Post() + { + Title = "title", + Content = i.ToString() + } + }, Rating = i, }); } diff --git a/test/KEFCore.Complex.Test/ComplexTest.KNetReplicator.json b/test/KEFCore.Complex.Test/ComplexTest.KNetReplicator.json index 2eea8bca..67dd32ef 100644 --- a/test/KEFCore.Complex.Test/ComplexTest.KNetReplicator.json +++ b/test/KEFCore.Complex.Test/ComplexTest.KNetReplicator.json @@ -1,6 +1,6 @@ { "DatabaseName": "TestDBComplex", - "BootstrapServers": "192.168.1.103:9092", + "BootstrapServers": "192.168.1.105:9092", "NumberOfElements": 10, "NumberOfExtraElements": 1 } diff --git a/test/KEFCore.Complex.Test/KEFCore.Complex.Test.csproj b/test/KEFCore.Complex.Test/KEFCore.Complex.Test.csproj index 2ecf4495..30a76963 100644 --- a/test/KEFCore.Complex.Test/KEFCore.Complex.Test.csproj +++ b/test/KEFCore.Complex.Test/KEFCore.Complex.Test.csproj @@ -22,4 +22,11 @@ PreserveNewest + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + diff --git a/test/KEFCore.Extractor.Test/Extractor.Test.Blog.Avro.json b/test/KEFCore.Extractor.Test/Extractor.Test.Blog.Avro.json index 94db10ef..5c2e6bad 100644 --- a/test/KEFCore.Extractor.Test/Extractor.Test.Blog.Avro.json +++ b/test/KEFCore.Extractor.Test/Extractor.Test.Blog.Avro.json @@ -1,4 +1,4 @@ { - "BootstrapServers": "192.168.1.103:9092", + "BootstrapServers": "192.168.1.105:9092", "TopicToSubscribe": "TestDBBenchmarkAvro.MASES.EntityFrameworkCore.KNet.Test.Blog" } diff --git a/test/KEFCore.Extractor.Test/Extractor.Test.Blog.json b/test/KEFCore.Extractor.Test/Extractor.Test.Blog.json index 61be11e3..c0df676a 100644 --- a/test/KEFCore.Extractor.Test/Extractor.Test.Blog.json +++ b/test/KEFCore.Extractor.Test/Extractor.Test.Blog.json @@ -1,4 +1,4 @@ { - "BootstrapServers": "192.168.1.103:9092", + "BootstrapServers": "192.168.1.105:9092", "TopicToSubscribe": "TestDBBenchmark.MASES.EntityFrameworkCore.KNet.Test.Blog" } diff --git a/test/KEFCore.Extractor.Test/KEFCore.Extractor.Test.csproj b/test/KEFCore.Extractor.Test/KEFCore.Extractor.Test.csproj index 7df147d5..4933a331 100644 --- a/test/KEFCore.Extractor.Test/KEFCore.Extractor.Test.csproj +++ b/test/KEFCore.Extractor.Test/KEFCore.Extractor.Test.csproj @@ -22,4 +22,11 @@ PreserveNewest + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + diff --git a/test/KEFCore.Test/KEFCore.Test.csproj b/test/KEFCore.Test/KEFCore.Test.csproj index 5a5f27d1..a67450a1 100644 --- a/test/KEFCore.Test/KEFCore.Test.csproj +++ b/test/KEFCore.Test/KEFCore.Test.csproj @@ -70,4 +70,11 @@ PreserveNewest + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + diff --git a/test/KEFCore.Test/Test.KNetReplicator.json b/test/KEFCore.Test/Test.KNetReplicator.json index 243bea6a..a8269495 100644 --- a/test/KEFCore.Test/Test.KNetReplicator.json +++ b/test/KEFCore.Test/Test.KNetReplicator.json @@ -1,3 +1,3 @@ { - "BootstrapServers": "192.168.1.103:9092" + "BootstrapServers": "192.168.1.105:9092" } diff --git a/test/KEFCore.Test/Test.KNetReplicatorModelBuilder.json b/test/KEFCore.Test/Test.KNetReplicatorModelBuilder.json index 562f8dc7..382d90aa 100644 --- a/test/KEFCore.Test/Test.KNetReplicatorModelBuilder.json +++ b/test/KEFCore.Test/Test.KNetReplicatorModelBuilder.json @@ -1,4 +1,4 @@ { "UseModelBuilder": true, - "BootstrapServers": "192.168.1.103:9092" + "BootstrapServers": "192.168.1.105:9092" } diff --git a/test/KEFCore.Test/Test.KNetReplicatorNoLoad.json b/test/KEFCore.Test/Test.KNetReplicatorNoLoad.json index 99c3dced..9a280c08 100644 --- a/test/KEFCore.Test/Test.KNetReplicatorNoLoad.json +++ b/test/KEFCore.Test/Test.KNetReplicatorNoLoad.json @@ -1,5 +1,5 @@ { "DeleteApplicationData": false, "LoadApplicationData": false, - "BootstrapServers": "192.168.1.103:9092" + "BootstrapServers": "192.168.1.105:9092" } diff --git a/test/KEFCore.Test/Test.KNetReplicatorWithEvents.json b/test/KEFCore.Test/Test.KNetReplicatorWithEvents.json index ec89b6ea..7445d8e1 100644 --- a/test/KEFCore.Test/Test.KNetReplicatorWithEvents.json +++ b/test/KEFCore.Test/Test.KNetReplicatorWithEvents.json @@ -1,5 +1,5 @@ { - "BootstrapServers": "192.168.1.103:9092", + "BootstrapServers": "192.168.1.105:9092", "DatabaseName": "TestDBWithEvents", "NumberOfElements": 10, "WithEvents": true diff --git a/test/KEFCore.Test/Test.KafkaStreams.json b/test/KEFCore.Test/Test.KafkaStreams.json index 602575ba..1b15e6c2 100644 --- a/test/KEFCore.Test/Test.KafkaStreams.json +++ b/test/KEFCore.Test/Test.KafkaStreams.json @@ -1,4 +1,4 @@ { "UseCompactedReplicator": false, - "BootstrapServers": "192.168.1.103:9092" + "BootstrapServers": "192.168.1.105:9092" } diff --git a/test/KEFCore.Test/Test.KafkaStreamsModelBuilder.json b/test/KEFCore.Test/Test.KafkaStreamsModelBuilder.json index 3020d098..2a9d05ac 100644 --- a/test/KEFCore.Test/Test.KafkaStreamsModelBuilder.json +++ b/test/KEFCore.Test/Test.KafkaStreamsModelBuilder.json @@ -1,5 +1,5 @@ { "UseCompactedReplicator": false, "UseModelBuilder": true, - "BootstrapServers": "192.168.1.103:9092" + "BootstrapServers": "192.168.1.105:9092" } diff --git a/test/KEFCore.Test/Test.KafkaStreamsNoLoad.json b/test/KEFCore.Test/Test.KafkaStreamsNoLoad.json index 4b3b764d..e8e23839 100644 --- a/test/KEFCore.Test/Test.KafkaStreamsNoLoad.json +++ b/test/KEFCore.Test/Test.KafkaStreamsNoLoad.json @@ -2,5 +2,5 @@ "DeleteApplicationData": false, "LoadApplicationData": false, "UseCompactedReplicator": false, - "BootstrapServers": "192.168.1.103:9092" + "BootstrapServers": "192.168.1.105:9092" } diff --git a/test/KEFCore.Test/Test.KafkaStreamsPersisted.json b/test/KEFCore.Test/Test.KafkaStreamsPersisted.json index ca645ace..4f472961 100644 --- a/test/KEFCore.Test/Test.KafkaStreamsPersisted.json +++ b/test/KEFCore.Test/Test.KafkaStreamsPersisted.json @@ -3,5 +3,5 @@ "DeleteApplicationData": false, "LoadApplicationData": false, "UseCompactedReplicator": false, - "BootstrapServers": "192.168.1.103:9092" + "BootstrapServers": "192.168.1.105:9092" } diff --git a/test/KEFCore.Test/TestAvro.KNetReplicatorModelBuilder.json b/test/KEFCore.Test/TestAvro.KNetReplicatorModelBuilder.json index 0f04b9f1..d5dbcdb0 100644 --- a/test/KEFCore.Test/TestAvro.KNetReplicatorModelBuilder.json +++ b/test/KEFCore.Test/TestAvro.KNetReplicatorModelBuilder.json @@ -1,5 +1,5 @@ { "UseAvro": true, "UseModelBuilder": true, - "BootstrapServers": "192.168.1.103:9092" + "BootstrapServers": "192.168.1.105:9092" }