Skip to content

Commit

Permalink
Add RelationalCommandCaching based on parameter value nullability
Browse files Browse the repository at this point in the history
Resolves #15892
  • Loading branch information
smitpatel committed Oct 21, 2019
1 parent bf62b29 commit 4562f7b
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Storage;

namespace Microsoft.EntityFrameworkCore.Query
Expand All @@ -19,29 +18,23 @@ public partial class RelationalShapedQueryCompilingExpressionVisitor
private class QueryingEnumerable<T> : IEnumerable<T>, IAsyncEnumerable<T>
{
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly RelationalCommandCache _relationalCommandCache;
private readonly IReadOnlyList<string> _columnNames;
private readonly Func<QueryContext, DbDataReader, ResultContext, int[], ResultCoordinator, T> _shaper;
private readonly IQuerySqlGeneratorFactory _querySqlGeneratorFactory;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly IParameterNameGeneratorFactory _parameterNameGeneratorFactory;

public QueryingEnumerable(
RelationalQueryContext relationalQueryContext,
IQuerySqlGeneratorFactory querySqlGeneratorFactory,
ISqlExpressionFactory sqlExpressionFactory,
IParameterNameGeneratorFactory parameterNameGeneratorFactory,
SelectExpression selectExpression,
RelationalCommandCache relationalCommandCache,
IReadOnlyList<string> columnNames,
Func<QueryContext, DbDataReader, ResultContext, int[], ResultCoordinator, T> shaper,
Type contextType,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
_relationalQueryContext = relationalQueryContext;
_querySqlGeneratorFactory = querySqlGeneratorFactory;
_sqlExpressionFactory = sqlExpressionFactory;
_parameterNameGeneratorFactory = parameterNameGeneratorFactory;
_selectExpression = selectExpression;
_relationalCommandCache = relationalCommandCache;
_columnNames = columnNames;
_shaper = shaper;
_contextType = contextType;
_logger = logger;
Expand All @@ -53,28 +46,25 @@ public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToke

private sealed class Enumerator : IEnumerator<T>
{
private RelationalDataReader _dataReader;
private int[] _indexMap;
private ResultCoordinator _resultCoordinator;
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly RelationalCommandCache _relationalCommandCache;
private readonly IReadOnlyList<string> _columnNames;
private readonly Func<QueryContext, DbDataReader, ResultContext, int[], ResultCoordinator, T> _shaper;
private readonly IQuerySqlGeneratorFactory _querySqlGeneratorFactory;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly IParameterNameGeneratorFactory _parameterNameGeneratorFactory;

private RelationalDataReader _dataReader;
private int[] _indexMap;
private ResultCoordinator _resultCoordinator;

public Enumerator(QueryingEnumerable<T> queryingEnumerable)
{
_relationalQueryContext = queryingEnumerable._relationalQueryContext;
_relationalCommandCache = queryingEnumerable._relationalCommandCache;
_columnNames = queryingEnumerable._columnNames;
_shaper = queryingEnumerable._shaper;
_selectExpression = queryingEnumerable._selectExpression;
_querySqlGeneratorFactory = queryingEnumerable._querySqlGeneratorFactory;
_contextType = queryingEnumerable._contextType;
_logger = queryingEnumerable._logger;
_sqlExpressionFactory = queryingEnumerable._sqlExpressionFactory;
_parameterNameGeneratorFactory = queryingEnumerable._parameterNameGeneratorFactory;
}

public T Current { get; private set; }
Expand All @@ -89,12 +79,8 @@ public bool MoveNext()
{
if (_dataReader == null)
{
var selectExpression = new ParameterValueBasedSelectExpressionOptimizer(
_sqlExpressionFactory,
_parameterNameGeneratorFactory)
.Optimize(_selectExpression, _relationalQueryContext.ParameterValues);

var relationalCommand = _querySqlGeneratorFactory.Create().GetCommand(selectExpression);
var relationalCommand = _relationalCommandCache.GetRelationalCommand(
_relationalQueryContext.ParameterValues);

_dataReader
= relationalCommand.ExecuteReader(
Expand All @@ -104,28 +90,22 @@ public bool MoveNext()
_relationalQueryContext.Context,
_relationalQueryContext.CommandLogger));

if (selectExpression.IsNonComposedFromSql())
// Non-Composed FromSql
if (_columnNames != null)
{
var projection = _selectExpression.Projection.ToList();
var readerColumns = Enumerable.Range(0, _dataReader.DbDataReader.FieldCount)
.ToDictionary(i => _dataReader.DbDataReader.GetName(i), i => i, StringComparer.OrdinalIgnoreCase);

_indexMap = new int[projection.Count];
for (var i = 0; i < projection.Count; i++)
_indexMap = new int[_columnNames.Count];
for (var i = 0; i < _columnNames.Count; i++)
{
if (projection[i].Expression is ColumnExpression columnExpression)
var columnName = _columnNames[i];
if (!readerColumns.TryGetValue(columnName, out var ordinal))
{
var columnName = columnExpression.Name;
if (columnName != null)
{
if (!readerColumns.TryGetValue(columnName, out var ordinal))
{
throw new InvalidOperationException(RelationalStrings.FromSqlMissingColumn(columnName));
}

_indexMap[i] = ordinal;
}
throw new InvalidOperationException(RelationalStrings.FromSqlMissingColumn(columnName));
}

_indexMap[i] = ordinal;
}
}
else
Expand Down Expand Up @@ -191,31 +171,28 @@ public void Dispose()

private sealed class AsyncEnumerator : IAsyncEnumerator<T>
{
private RelationalDataReader _dataReader;
private int[] _indexMap;
private ResultCoordinator _resultCoordinator;
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly RelationalCommandCache _relationalCommandCache;
private readonly IReadOnlyList<string> _columnNames;
private readonly Func<QueryContext, DbDataReader, ResultContext, int[], ResultCoordinator, T> _shaper;
private readonly IQuerySqlGeneratorFactory _querySqlGeneratorFactory;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly IParameterNameGeneratorFactory _parameterNameGeneratorFactory;
private readonly CancellationToken _cancellationToken;

private RelationalDataReader _dataReader;
private int[] _indexMap;
private ResultCoordinator _resultCoordinator;

public AsyncEnumerator(
QueryingEnumerable<T> queryingEnumerable,
CancellationToken cancellationToken)
{
_relationalQueryContext = queryingEnumerable._relationalQueryContext;
_relationalCommandCache = queryingEnumerable._relationalCommandCache;
_columnNames = queryingEnumerable._columnNames;
_shaper = queryingEnumerable._shaper;
_selectExpression = queryingEnumerable._selectExpression;
_querySqlGeneratorFactory = queryingEnumerable._querySqlGeneratorFactory;
_contextType = queryingEnumerable._contextType;
_logger = queryingEnumerable._logger;
_sqlExpressionFactory = queryingEnumerable._sqlExpressionFactory;
_parameterNameGeneratorFactory = queryingEnumerable._parameterNameGeneratorFactory;
_cancellationToken = cancellationToken;
}

Expand All @@ -229,12 +206,8 @@ public async ValueTask<bool> MoveNextAsync()
{
if (_dataReader == null)
{
var selectExpression = new ParameterValueBasedSelectExpressionOptimizer(
_sqlExpressionFactory,
_parameterNameGeneratorFactory)
.Optimize(_selectExpression, _relationalQueryContext.ParameterValues);

var relationalCommand = _querySqlGeneratorFactory.Create().GetCommand(selectExpression);
var relationalCommand = _relationalCommandCache.GetRelationalCommand(
_relationalQueryContext.ParameterValues);

_dataReader
= await relationalCommand.ExecuteReaderAsync(
Expand All @@ -245,28 +218,22 @@ public async ValueTask<bool> MoveNextAsync()
_relationalQueryContext.CommandLogger),
_cancellationToken);

if (selectExpression.IsNonComposedFromSql())
// Non-Composed FromSql
if (_columnNames != null)
{
var projection = _selectExpression.Projection.ToList();
var readerColumns = Enumerable.Range(0, _dataReader.DbDataReader.FieldCount)
.ToDictionary(i => _dataReader.DbDataReader.GetName(i), i => i, StringComparer.OrdinalIgnoreCase);

_indexMap = new int[projection.Count];
for (var i = 0; i < projection.Count; i++)
_indexMap = new int[_columnNames.Count];
for (var i = 0; i < _columnNames.Count; i++)
{
if (projection[i].Expression is ColumnExpression columnExpression)
var columnName = _columnNames[i];
if (!readerColumns.TryGetValue(columnName, out var ordinal))
{
var columnName = columnExpression.Name;
if (columnName != null)
{
if (!readerColumns.TryGetValue(columnName, out var ordinal))
{
throw new InvalidOperationException(RelationalStrings.FromSqlMissingColumn(columnName));
}

_indexMap[i] = ordinal;
}
throw new InvalidOperationException(RelationalStrings.FromSqlMissingColumn(columnName));
}

_indexMap[i] = ordinal;
}
}
else
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Storage;

namespace Microsoft.EntityFrameworkCore.Query
{
public partial class RelationalShapedQueryCompilingExpressionVisitor
{
private class RelationalCommandCache
{
private readonly ConcurrentDictionary<CommandCacheKey, IRelationalCommand> _commandCache
= new ConcurrentDictionary<CommandCacheKey, IRelationalCommand>(CommandCacheKeyComparer.Instance);
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly IParameterNameGeneratorFactory _parameterNameGeneratorFactory;
private readonly IQuerySqlGeneratorFactory _querySqlGeneratorFactory;
private readonly SelectExpression _selectExpression;
private readonly ParameterValueBasedSelectExpressionOptimizer _parameterValueBasedSelectExpressionOptimizer;

public RelationalCommandCache(
ISqlExpressionFactory sqlExpressionFactory,
IParameterNameGeneratorFactory parameterNameGeneratorFactory,
IQuerySqlGeneratorFactory querySqlGeneratorFactory,
SelectExpression selectExpression)
{
_sqlExpressionFactory = sqlExpressionFactory;
_parameterNameGeneratorFactory = parameterNameGeneratorFactory;
_querySqlGeneratorFactory = querySqlGeneratorFactory;
_selectExpression = selectExpression;
_parameterValueBasedSelectExpressionOptimizer = new ParameterValueBasedSelectExpressionOptimizer(
_sqlExpressionFactory,
_parameterNameGeneratorFactory);
}

public virtual IRelationalCommand GetRelationalCommand(IReadOnlyDictionary<string, object> parameters)
{
var key = new CommandCacheKey(parameters);

if (_commandCache.TryGetValue(key, out var relationalCommand))
{
return relationalCommand;
}

var selectExpression = _parameterValueBasedSelectExpressionOptimizer.Optimize(_selectExpression, parameters);

relationalCommand = _querySqlGeneratorFactory.Create().GetCommand(selectExpression);

if (ReferenceEquals(selectExpression, _selectExpression))
{
_commandCache.TryAdd(key, relationalCommand);
}

return relationalCommand;
}

private sealed class CommandCacheKeyComparer : IEqualityComparer<CommandCacheKey>
{
public static readonly CommandCacheKeyComparer Instance = new CommandCacheKeyComparer();

private CommandCacheKeyComparer()
{
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool Equals(CommandCacheKey x, CommandCacheKey y)
{
if (x.ParameterValues.Count > 0)
{
foreach (var parameterValue in x.ParameterValues)
{
var value = parameterValue.Value;

if (!y.ParameterValues.TryGetValue(parameterValue.Key, out var otherValue))
{
return false;
}

if (value == null
!= (otherValue == null))
{
return false;
}

if (value is IEnumerable
&& value.GetType() == typeof(object[]))
{
// FromSql parameters must have the same number of elements
return ((object[])value).Length == (otherValue as object[])?.Length;
}
}
}

return true;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int GetHashCode(CommandCacheKey obj) => 0;
}

private readonly struct CommandCacheKey
{
public readonly IReadOnlyDictionary<string, object> ParameterValues;

public CommandCacheKey(IReadOnlyDictionary<string, object> parameterValues)
=> ParameterValues = parameterValues;
}
}
}
}
Loading

0 comments on commit 4562f7b

Please sign in to comment.