From b7d76c498d9312fe15bbd259963e99d1e4ccb6e7 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Thu, 12 May 2022 17:30:42 -0700 Subject: [PATCH] Query: Add interface to get ITableBase for root level table sources --- ...yableMethodTranslatingExpressionVisitor.cs | 14 +++++--- .../Query/SqlExpressionFactory.cs | 8 +---- .../Query/SqlExpressions/FromSqlExpression.cs | 32 ++++++++++++++----- .../SqlExpressions/ITableBasedExpression.cs | 21 ++++++++++++ .../Query/SqlExpressions/SelectExpression.cs | 8 +---- .../Query/SqlExpressions/TableExpression.cs | 5 ++- .../TableValuedFunctionExpression.cs | 5 ++- 7 files changed, 64 insertions(+), 29 deletions(-) create mode 100644 src/EFCore.Relational/Query/SqlExpressions/ITableBasedExpression.cs diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 36eb7b2ce46..31a745a013b 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -77,8 +77,7 @@ protected override Expression VisitExtension(Expression extensionExpression) _sqlExpressionFactory.Select( fromSqlQueryRootExpression.EntityType, new FromSqlExpression( - fromSqlQueryRootExpression.EntityType.GetDefaultMappings().Single().Table.Name[..1] - .ToLowerInvariant(), + fromSqlQueryRootExpression.EntityType.GetDefaultMappings().Single().Table, fromSqlQueryRootExpression.Sql, fromSqlQueryRootExpression.Argument))); @@ -127,9 +126,14 @@ protected override Expression VisitExtension(Expression extensionExpression) when queryRootExpression.GetType() == typeof(QueryRootExpression) && queryRootExpression.EntityType.GetSqlQueryMappings().FirstOrDefault(m => m.IsDefaultSqlQueryMapping)?.SqlQuery is ISqlQuery sqlQuery: - return Visit( - new FromSqlQueryRootExpression( - queryRootExpression.EntityType, sqlQuery.Sql, Expression.Constant(Array.Empty(), typeof(object[])))); + return CreateShapedQueryExpression( + queryRootExpression.EntityType, + _sqlExpressionFactory.Select( + queryRootExpression.EntityType, + new FromSqlExpression( + queryRootExpression.EntityType.GetDefaultMappings().Single().Table, + sqlQuery.Sql, + Expression.Constant(Array.Empty(), typeof(object[]))))); case GroupByShaperExpression groupByShaperExpression: var groupShapedQueryExpression = groupByShaperExpression.GroupingEnumerable; diff --git a/src/EFCore.Relational/Query/SqlExpressionFactory.cs b/src/EFCore.Relational/Query/SqlExpressionFactory.cs index d0f8491b8f5..1f0ff84c295 100644 --- a/src/EFCore.Relational/Query/SqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/SqlExpressionFactory.cs @@ -637,13 +637,7 @@ private void AddConditions(SelectExpression selectExpression, IEntityType entity return; } - var table = selectExpression.Tables[0] switch - { - TableExpression te => te.Table, - TableValuedFunctionExpression tvfe => tvfe.StoreFunction, - _ => entityType.GetDefaultMappings().Single().Table, - }; - + var table = ((ITableBasedExpression)selectExpression.Tables[0]).Table; if (table.IsOptional(entityType)) { var entityProjectionExpression = GetMappedEntityProjectionExpression(selectExpression); diff --git a/src/EFCore.Relational/Query/SqlExpressions/FromSqlExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/FromSqlExpression.cs index 41cd4879957..344ff9bb7f8 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/FromSqlExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/FromSqlExpression.cs @@ -14,28 +14,40 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; /// not used in application code. /// /// -public class FromSqlExpression : TableExpressionBase, IClonableTableExpressionBase +public class FromSqlExpression : TableExpressionBase, IClonableTableExpressionBase, ITableBasedExpression { + private readonly ITableBase _table; + /// /// Creates a new instance of the class. /// - /// A string alias for the table source. + /// A default table base associated with this table source. /// A user-provided custom SQL for the table source. /// A user-provided parameters to pass to the custom SQL. - public FromSqlExpression(string alias, string sql, Expression arguments) - : this(alias, sql, arguments, annotations: null) + public FromSqlExpression(ITableBase defaultTableBase, string sql, Expression arguments) + : this(defaultTableBase.Name[..1].ToLowerInvariant(), defaultTableBase, sql, arguments, annotations: null) { - Sql = sql; - Arguments = arguments; } + // See issue#21660/21627 + ///// + ///// Creates a new instance of the class. + ///// + ///// A sql query associated with this table source. + //public FromSqlExpression(ISqlQuery sqlQuery) + // : this(sqlQuery, sqlQuery.Sql, Constant(Array.Empty(), typeof(object[]))) + //{ + //} + private FromSqlExpression( string alias, + ITableBase tableBase, string sql, Expression arguments, IEnumerable? annotations) : base(alias, annotations) { + _table = tableBase; Sql = sql; Arguments = arguments; } @@ -68,16 +80,19 @@ public override string? Alias /// This expression if no children changed, or an expression with the updated children. public virtual FromSqlExpression Update(Expression arguments) => arguments != Arguments - ? new FromSqlExpression(Alias, Sql, arguments) + ? new FromSqlExpression(Alias, _table, Sql, arguments, GetAnnotations()) : this; + /// + ITableBase ITableBasedExpression.Table => _table; + /// protected override Expression VisitChildren(ExpressionVisitor visitor) => this; /// public virtual TableExpressionBase Clone() - => new FromSqlExpression(Alias, Sql, Arguments, GetAnnotations()); + => new FromSqlExpression(Alias, _table, Sql, Arguments, GetAnnotations()); /// protected override void Print(ExpressionPrinter expressionPrinter) @@ -95,6 +110,7 @@ public override bool Equals(object? obj) private bool Equals(FromSqlExpression fromSqlExpression) => base.Equals(fromSqlExpression) + && _table == fromSqlExpression._table && Sql == fromSqlExpression.Sql && ExpressionEqualityComparer.Instance.Equals(Arguments, fromSqlExpression.Arguments); diff --git a/src/EFCore.Relational/Query/SqlExpressions/ITableBasedExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/ITableBasedExpression.cs new file mode 100644 index 00000000000..5940d74c7cd --- /dev/null +++ b/src/EFCore.Relational/Query/SqlExpressions/ITableBasedExpression.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +/// +/// +/// An interface that gives access to associated with given table source. +/// +/// +/// This type is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// +/// +public interface ITableBasedExpression +{ + /// + /// The associated with given table source. + /// + ITableBase Table { get; } +} diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index dfd08191308..6daa1bf6241 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -374,13 +374,7 @@ internal SelectExpression(IEntityType entityType, TableExpressionBase tableExpre throw new InvalidOperationException(RelationalStrings.SelectExpressionNonTphWithCustomTable(entityType.DisplayName())); } - var table = tableExpressionBase switch - { - TableExpression tableExpression => tableExpression.Table, - TableValuedFunctionExpression tableValuedFunctionExpression => tableValuedFunctionExpression.StoreFunction, - _ => entityType.GetDefaultMappings().Single().Table - }; - + var table = ((ITableBasedExpression)tableExpressionBase).Table; var tableReferenceExpression = new TableReferenceExpression(this, tableExpressionBase.Alias!); AddTable(tableExpressionBase, tableReferenceExpression); diff --git a/src/EFCore.Relational/Query/SqlExpressions/TableExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/TableExpression.cs index dabf91e9df9..959cfdd2f34 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/TableExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/TableExpression.cs @@ -13,7 +13,7 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; /// application or database provider code. If this is a problem for your application or provider, then please file /// an issue at github.com/dotnet/efcore. /// -public sealed class TableExpression : TableExpressionBase, IClonableTableExpressionBase +public sealed class TableExpression : TableExpressionBase, IClonableTableExpressionBase, ITableBasedExpression { internal TableExpression(ITableBase table) : this(table, annotations: null) @@ -28,6 +28,9 @@ private TableExpression(ITableBase table, IEnumerable? annotations) Table = table; } + /// + ITableBase ITableBasedExpression.Table => Table; + /// protected override void Print(ExpressionPrinter expressionPrinter) { diff --git a/src/EFCore.Relational/Query/SqlExpressions/TableValuedFunctionExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/TableValuedFunctionExpression.cs index 97df39b551f..db40d23cd9d 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/TableValuedFunctionExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/TableValuedFunctionExpression.cs @@ -14,7 +14,7 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; /// not used in application code. /// /// -public class TableValuedFunctionExpression : TableExpressionBase +public class TableValuedFunctionExpression : TableExpressionBase, ITableBasedExpression { /// /// Creates a new instance of the class. @@ -61,6 +61,9 @@ public override string? Alias /// public virtual IReadOnlyList Arguments { get; } + /// + ITableBase ITableBasedExpression.Table => StoreFunction; + /// protected override Expression VisitChildren(ExpressionVisitor visitor) {