Skip to content

Commit

Permalink
Basic FromSql() support in new pipeline
Browse files Browse the repository at this point in the history
* Parameterization not implemented (see #15750)
* FromSqlRaw() and FromSqlInterpolated() are now defined over DbSet<>,
  not IQueryable<>. FromSql() is still defined over IQueryable<> but
  throws if not directly on a DbSet<>.

Closes #15704
  • Loading branch information
roji committed May 27, 2019
1 parent b25bb1d commit 244613b
Show file tree
Hide file tree
Showing 17 changed files with 284 additions and 102 deletions.
43 changes: 29 additions & 14 deletions src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ namespace Microsoft.EntityFrameworkCore
/// </summary>
public static class RelationalQueryableExtensions
{
internal static readonly MethodInfo FromSqlMethodInfo
private static readonly MethodInfo _fromSqlOnQueryableMethodInfo
= typeof(RelationalQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(FromSqlRaw))
.GetTypeInfo().GetDeclaredMethods(nameof(FromSqlOnQueryable))
.Single();

/// <summary>
Expand Down Expand Up @@ -54,7 +54,10 @@ internal static readonly MethodInfo FromSqlMethodInfo
/// <returns> An <see cref="IQueryable{T}" /> representing the raw SQL query. </returns>
[StringFormatMethod("sql")]
[Obsolete(
"For returning objects from SQL queries using plain strings, use FromSqlRaw instead. For returning objects from SQL queries using interpolated string syntax to create parameters, use FromSqlInterpolated instead.")]
"For returning objects from SQL queries using plain strings, use FromSqlRaw instead. " +
"For returning objects from SQL queries using interpolated string syntax to create parameters, use FromSqlInterpolated instead. " +
"Call either new method directly on the DbSet at the root of the query.",
error: true)]
public static IQueryable<TEntity> FromSql<TEntity>(
[NotNull] this IQueryable<TEntity> source,
[NotParameterized] RawSqlString sql,
Expand All @@ -68,7 +71,7 @@ public static IQueryable<TEntity> FromSql<TEntity>(
return source.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
FromSqlMethodInfo.MakeGenericMethod(typeof(TEntity)),
_fromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
source.Expression,
Expression.Constant(sql.Format),
Expression.Constant(parameters)));
Expand Down Expand Up @@ -96,7 +99,10 @@ public static IQueryable<TEntity> FromSql<TEntity>(
/// <param name="sql"> The interpolated string representing a SQL query. </param>
/// <returns> An <see cref="IQueryable{T}" /> representing the interpolated string SQL query. </returns>
[Obsolete(
"For returning objects from SQL queries using plain strings, use FromSqlRaw instead. For returning objects from SQL queries using interpolated string syntax to create parameters, use FromSqlInterpolated instead.")]
"For returning objects from SQL queries using plain strings, use FromSqlRaw instead. " +
"For returning objects from SQL queries using interpolated string syntax to create parameters, use FromSqlInterpolated instead. " +
"Call either new method directly on the DbSet at the root of the query.",
error: true)]
public static IQueryable<TEntity> FromSql<TEntity>(
[NotNull] this IQueryable<TEntity> source,
[NotNull] [NotParameterized] FormattableString sql)
Expand All @@ -109,7 +115,7 @@ public static IQueryable<TEntity> FromSql<TEntity>(
return source.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
FromSqlMethodInfo.MakeGenericMethod(typeof(TEntity)),
_fromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
source.Expression,
Expression.Constant(sql.Format),
Expression.Constant(sql.GetArguments())));
Expand Down Expand Up @@ -145,7 +151,7 @@ public static IQueryable<TEntity> FromSql<TEntity>(
/// <returns> An <see cref="IQueryable{T}" /> representing the raw SQL query. </returns>
[StringFormatMethod("sql")]
public static IQueryable<TEntity> FromSqlRaw<TEntity>(
[NotNull] this IQueryable<TEntity> source,
[NotNull] this DbSet<TEntity> source,
[NotParameterized] string sql,
[NotNull] params object[] parameters)
where TEntity : class
Expand All @@ -154,11 +160,12 @@ public static IQueryable<TEntity> FromSqlRaw<TEntity>(
Check.NotEmpty(sql, nameof(sql));
Check.NotNull(parameters, nameof(parameters));

return source.Provider.CreateQuery<TEntity>(
var queryableSource = (IQueryable)source;
return queryableSource.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
FromSqlMethodInfo.MakeGenericMethod(typeof(TEntity)),
source.Expression,
_fromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
queryableSource.Expression,
Expression.Constant(sql),
Expression.Constant(parameters)));
}
Expand All @@ -185,21 +192,29 @@ public static IQueryable<TEntity> FromSqlRaw<TEntity>(
/// <param name="sql"> The interpolated string representing a SQL query with parameters. </param>
/// <returns> An <see cref="IQueryable{T}" /> representing the interpolated string SQL query. </returns>
public static IQueryable<TEntity> FromSqlInterpolated<TEntity>(
[NotNull] this IQueryable<TEntity> source,
[NotNull] this DbSet<TEntity> source,
[NotNull] [NotParameterized] FormattableString sql)
where TEntity : class
{
Check.NotNull(source, nameof(source));
Check.NotNull(sql, nameof(sql));
Check.NotEmpty(sql.Format, nameof(source));

return source.Provider.CreateQuery<TEntity>(
var queryableSource = (IQueryable)source;
return queryableSource.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
FromSqlMethodInfo.MakeGenericMethod(typeof(TEntity)),
source.Expression,
_fromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
queryableSource.Expression,
Expression.Constant(sql.Format),
Expression.Constant(sql.GetArguments())));
}

internal static IQueryable<TEntity> FromSqlOnQueryable<TEntity>(
[NotNull] this IQueryable<TEntity> source,
[NotParameterized] string sql,
[NotNull] params object[] parameters)
where TEntity : class
=> throw new NotSupportedException();
}
}
16 changes: 16 additions & 0 deletions src/EFCore.Relational/Query/Pipeline/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,22 @@ protected override Expression VisitTable(TableExpression tableExpression)
return tableExpression;
}

protected override Expression VisitFromSql(FromSqlExpression fromSqlExpression)
{
_relationalCommandBuilder.AppendLine("(");

using (_relationalCommandBuilder.Indent())
{
_relationalCommandBuilder.AppendLines(fromSqlExpression.Sql);
// TODO: Generate parameters
}

_relationalCommandBuilder.Append(") AS ")
.Append(_sqlGenerationHelper.DelimitIdentifier(fromSqlExpression.Alias));

return fromSqlExpression;
}

protected override Expression VisitSqlBinary(SqlBinaryExpression sqlBinaryExpression)
{
if (sqlBinaryExpression.OperatorType == ExpressionType.Coalesce)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,51 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Pipeline;
using Microsoft.EntityFrameworkCore.Internal;

namespace Microsoft.EntityFrameworkCore.Relational.Query.Pipeline
{
public class RelationalEntityQueryableExpressionVisitor2 : EntityQueryableExpressionVisitor2
{
private IModel _model;
private readonly IModel _model;

public RelationalEntityQueryableExpressionVisitor2(IModel model)
{
_model = model;
}

protected override ShapedQueryExpression CreateShapedQueryExpression(Type elementType)
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
return new RelationalShapedQueryExpression(_model.FindEntityType(elementType));
if (methodCallExpression.Method.DeclaringType == typeof(RelationalQueryableExtensions)
&& methodCallExpression.Method.Name == nameof(RelationalQueryableExtensions.FromSqlOnQueryable))
{
// The obsolete FromSql extension method is still defined as IQueryable (so that users get obsolete warnings),
// so we do a runtime check that it's invoked on an EntityQueryable<>. The new FromSqlRaw/FromSqlInterpolated methods
// are now defined on DbSet<> instead.
if (!(methodCallExpression.Arguments[0] is ConstantExpression constantExpression
&& constantExpression.IsEntityQueryable()))
{
throw new NotSupportedException(RelationalStrings.FromSqlNotOnDbSet);
}

// TODO: Implement parameters
var sql = (string)((ConstantExpression)methodCallExpression.Arguments[1]).Value;
var queryable = (IQueryable)((ConstantExpression)methodCallExpression.Arguments[0]).Value;
return CreateShapedQueryExpression(queryable.ElementType, sql);
}

return base.VisitMethodCall(methodCallExpression);
}

protected override ShapedQueryExpression CreateShapedQueryExpression(Type elementType)
=> new RelationalShapedQueryExpression(_model.FindEntityType(elementType));

protected virtual ShapedQueryExpression CreateShapedQueryExpression(Type elementType, string sql)
=> new RelationalShapedQueryExpression(_model.FindEntityType(elementType), sql);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,17 @@ public RelationalShapedQueryExpression(IEntityType entityType)
typeof(ValueBuffer)),
false);
}

public RelationalShapedQueryExpression(IEntityType entityType, string sql)
{
QueryExpression = new SelectExpression(entityType, sql);
ShaperExpression = new EntityShaperExpression(
entityType,
new ProjectionBindingExpression(
QueryExpression,
new ProjectionMember(),
typeof(ValueBuffer)),
false);
}
}
}
4 changes: 4 additions & 0 deletions src/EFCore.Relational/Query/Pipeline/SqlExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ protected override Expression VisitExtension(Expression extensionExpression)
case ExistsExpression existsExpression:
return VisitExists(existsExpression);

case FromSqlExpression fromSqlExpression:
return VisitFromSql(fromSqlExpression);

case InExpression inExpression:
return VisitIn(inExpression);

Expand Down Expand Up @@ -76,6 +79,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
protected abstract Expression VisitExists(ExistsExpression existsExpression);
protected abstract Expression VisitIn(InExpression inExpression);
protected abstract Expression VisitCrossJoin(CrossJoinExpression crossJoinExpression);
protected abstract Expression VisitFromSql(FromSqlExpression fromSqlExpression);
protected abstract Expression VisitInnerJoin(InnerJoinExpression innerJoinExpression);
protected abstract Expression VisitLeftJoin(LeftJoinExpression leftJoinExpression);
protected abstract Expression VisitProjection(ProjectionExpression projectionExpression);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Linq.Expressions;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Query.Internal;

namespace Microsoft.EntityFrameworkCore.Relational.Query.Pipeline.SqlExpressions
{
public class FromSqlExpression : TableExpressionBase
{
#region Fields & Constructors
public FromSqlExpression(
[NotNull] string sql,
[NotNull] string alias)
: base(alias)
{
Sql = sql;
}
#endregion

#region Public Properties

/// <summary>
/// Gets the SQL.
/// </summary>
/// <value>
/// The SQL.
/// </value>
public string Sql { get; }

#endregion

#region Expression-based methods

protected override Expression VisitChildren(ExpressionVisitor visitor)
=> this;

public override void Print(ExpressionPrinter expressionPrinter)
=> expressionPrinter.StringBuilder.Append(Sql);

#endregion

#region Equality & HashCode

public override bool Equals(object obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is FromSqlExpression fromSqlExpression
&& Equals(fromSqlExpression));

private bool Equals(FromSqlExpression fromSqlExpression)
=> base.Equals(fromSqlExpression)
&& string.Equals(Sql, fromSqlExpression.Sql);

public override int GetHashCode()
{
unchecked
{
var hashCode = base.GetHashCode();
hashCode = (hashCode * 397) ^ Sql.GetHashCode();

return hashCode;
}
}

#endregion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ public SelectExpression(IEntityType entityType)
_projectionMapping[new ProjectionMember()] = new EntityProjectionExpression(entityType, tableExpression, false);
}

public SelectExpression(IEntityType entityType, string sql)
: base("")
{
var fromSqlExpression = new FromSqlExpression(
sql,
entityType.GetTableName().ToLower().Substring(0, 1));

_tables.Add(fromSqlExpression);

_projectionMapping[new ProjectionMember()] = new EntityProjectionExpression(entityType, fromSqlExpression, false);
}

public SqlExpression BindProperty(Expression projectionExpression, IProperty property)
{
var member = (projectionExpression as ProjectionBindingExpression).ProjectionMember;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ public class FromSqlExpressionNode : ResultOperatorExpressionNodeBase
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static readonly IReadOnlyCollection<MethodInfo> SupportedMethods = new[] { RelationalQueryableExtensions.FromSqlMethodInfo };
public static readonly IReadOnlyCollection<MethodInfo> SupportedMethods = new List<MethodInfo>
{
// RelationalQueryableExtensions.FromSqlMethodInfo
};

private readonly string _sql;
private readonly Expression _arguments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ protected override Expression VisitExists(ExistsExpression existsExpression)
return ApplyConversion(existsExpression.Update(subquery), condition: true);
}

protected override Expression VisitFromSql(FromSqlExpression fromSqlExpression)
=> fromSqlExpression;

protected override Expression VisitIn(InExpression inExpression)
{
var parentSearchCondition = _isSearchCondition;
Expand Down
Loading

0 comments on commit 244613b

Please sign in to comment.