Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Fix potential re-use of QuerySqlGen from multiple threads #15880

Merged
merged 1 commit into from
May 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ public bool MoveNext()
{
try
{

if (_enumerator == null)
{
_enumerator = _innerEnumerable.GetEnumerator();
Expand Down Expand Up @@ -261,7 +260,6 @@ public async Task<bool> MoveNext(CancellationToken cancellationToken)
{
try
{

if (_enumerator == null)
{
_enumerator = _innerEnumerable.GetEnumerator();
Expand All @@ -274,7 +272,8 @@ public async Task<bool> MoveNext(CancellationToken cancellationToken)
: default;

return hasNext;
} catch (Exception exception)
}
catch (Exception exception)
{
_logger.QueryIterationFailed(_contextType, exception);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ protected override Expression VisitShapedQueryExpression(ShapedQueryExpression s
return Expression.New(
typeof(AsyncQueryingEnumerable<>).MakeGenericType(shaperLambda.ReturnType.GetGenericArguments().Single()).GetConstructors()[0],
Expression.Convert(QueryCompilationContext2.QueryContextParameter, typeof(RelationalQueryContext)),
Expression.Constant(_querySqlGeneratorFactory.Create()),
Expression.Constant(_querySqlGeneratorFactory),
Expression.Constant(selectExpression),
Expression.Constant(shaperLambda.Compile()),
Expression.Constant(_contextType),
Expand All @@ -74,7 +74,7 @@ protected override Expression VisitShapedQueryExpression(ShapedQueryExpression s
return Expression.New(
typeof(QueryingEnumerable<>).MakeGenericType(shaperLambda.ReturnType).GetConstructors()[0],
Expression.Convert(QueryCompilationContext2.QueryContextParameter, typeof(RelationalQueryContext)),
Expression.Constant(_querySqlGeneratorFactory.Create()),
Expression.Constant(_querySqlGeneratorFactory),
Expression.Constant(selectExpression),
Expression.Constant(shaperLambda.Compile()),
Expression.Constant(_contextType),
Expand Down Expand Up @@ -216,20 +216,20 @@ private class AsyncQueryingEnumerable<T> : IAsyncEnumerable<T>
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Func<QueryContext, DbDataReader, Task<T>> _shaper;
private readonly QuerySqlGenerator _querySqlGenerator;
private readonly IQuerySqlGeneratorFactory2 _querySqlGeneratorFactory;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;

public AsyncQueryingEnumerable(
RelationalQueryContext relationalQueryContext,
QuerySqlGenerator querySqlGenerator,
IQuerySqlGeneratorFactory2 querySqlGeneratorFactory,
SelectExpression selectExpression,
Func<QueryContext, DbDataReader, Task<T>> shaper,
Type contextType,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
_relationalQueryContext = relationalQueryContext;
_querySqlGenerator = querySqlGenerator;
_querySqlGeneratorFactory = querySqlGeneratorFactory;
_selectExpression = selectExpression;
_shaper = shaper;
_contextType = contextType;
Expand All @@ -247,7 +247,7 @@ private sealed class AsyncEnumerator : IAsyncEnumerator<T>
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Func<QueryContext, DbDataReader, Task<T>> _shaper;
private readonly QuerySqlGenerator _querySqlGenerator;
private readonly IQuerySqlGeneratorFactory2 _querySqlGeneratorFactory;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;

Expand All @@ -256,7 +256,7 @@ public AsyncEnumerator(AsyncQueryingEnumerable<T> queryingEnumerable)
_relationalQueryContext = queryingEnumerable._relationalQueryContext;
_shaper = queryingEnumerable._shaper;
_selectExpression = queryingEnumerable._selectExpression;
_querySqlGenerator = queryingEnumerable._querySqlGenerator;
_querySqlGeneratorFactory = queryingEnumerable._querySqlGeneratorFactory;
_contextType = queryingEnumerable._contextType;
_logger = queryingEnumerable._logger;
}
Expand All @@ -272,38 +272,37 @@ public void Dispose()

public async Task<bool> MoveNext(CancellationToken cancellationToken)
{
if (_dataReader == null)
try
{
await _relationalQueryContext.Connection.OpenAsync(cancellationToken);

try
if (_dataReader == null)
{
var relationalCommand = _querySqlGenerator
.GetCommand(
_selectExpression,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);

_dataReader
= await relationalCommand.ExecuteReaderAsync(
_relationalQueryContext.Connection,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger,
cancellationToken);
await _relationalQueryContext.Connection.OpenAsync(cancellationToken);

try
{
var relationalCommand = _querySqlGeneratorFactory.Create()
.GetCommand(
_selectExpression,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);

_dataReader
= await relationalCommand.ExecuteReaderAsync(
_relationalQueryContext.Connection,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger,
cancellationToken);
}
catch (Exception)
{
// If failure happens creating the data reader, then it won't be available to
// handle closing the connection, so do it explicitly here to preserve ref counting.
_relationalQueryContext.Connection.Close();

throw;
}
}
catch (Exception exception)
{
_logger.QueryIterationFailed(_contextType, exception);
// If failure happens creating the data reader, then it won't be available to
// handle closing the connection, so do it explicitly here to preserve ref counting.
_relationalQueryContext.Connection.Close();

throw;
}
}

try
{
var hasNext = await _dataReader.ReadAsync(cancellationToken);

Current
Expand All @@ -315,7 +314,6 @@ public async Task<bool> MoveNext(CancellationToken cancellationToken)
}
catch (Exception exception)
{

_logger.QueryIterationFailed(_contextType, exception);

throw;
Expand All @@ -329,19 +327,19 @@ private class QueryingEnumerable<T> : IEnumerable<T>
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Func<QueryContext, DbDataReader, T> _shaper;
private readonly QuerySqlGenerator _querySqlGenerator;
private readonly IQuerySqlGeneratorFactory2 _querySqlGeneratorFactory;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;

public QueryingEnumerable(RelationalQueryContext relationalQueryContext,
QuerySqlGenerator querySqlGenerator,
IQuerySqlGeneratorFactory2 querySqlGeneratorFactory,
SelectExpression selectExpression,
Func<QueryContext, DbDataReader, T> shaper,
Type contextType,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
_relationalQueryContext = relationalQueryContext;
_querySqlGenerator = querySqlGenerator;
_querySqlGeneratorFactory = querySqlGeneratorFactory;
_selectExpression = selectExpression;
_shaper = shaper;
_contextType = contextType;
Expand All @@ -357,7 +355,7 @@ private sealed class Enumerator : IEnumerator<T>
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Func<QueryContext, DbDataReader, T> _shaper;
private readonly QuerySqlGenerator _querySqlGenerator;
private readonly IQuerySqlGeneratorFactory2 _querySqlGeneratorFactory;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;

Expand All @@ -366,7 +364,7 @@ public Enumerator(QueryingEnumerable<T> queryingEnumerable)
_relationalQueryContext = queryingEnumerable._relationalQueryContext;
_shaper = queryingEnumerable._shaper;
_selectExpression = queryingEnumerable._selectExpression;
_querySqlGenerator = queryingEnumerable._querySqlGenerator;
_querySqlGeneratorFactory = queryingEnumerable._querySqlGeneratorFactory;
_contextType = queryingEnumerable._contextType;
_logger = queryingEnumerable._logger;
}
Expand All @@ -384,37 +382,36 @@ public void Dispose()

public bool MoveNext()
{
if (_dataReader == null)
try
{
_relationalQueryContext.Connection.Open();

try
if (_dataReader == null)
{
var relationalCommand = _querySqlGenerator
.GetCommand(
_selectExpression,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);

_dataReader
= relationalCommand.ExecuteReader(
_relationalQueryContext.Connection,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);
_relationalQueryContext.Connection.Open();

try
{
var relationalCommand = _querySqlGeneratorFactory.Create()
.GetCommand(
_selectExpression,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);

_dataReader
= relationalCommand.ExecuteReader(
_relationalQueryContext.Connection,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);
}
catch (Exception)
{
// If failure happens creating the data reader, then it won't be available to
// handle closing the connection, so do it explicitly here to preserve ref counting.
_relationalQueryContext.Connection.Close();

throw;
}
}
catch (Exception exception)
{
_logger.QueryIterationFailed(_contextType, exception);
// If failure happens creating the data reader, then it won't be available to
// handle closing the connection, so do it explicitly here to preserve ref counting.
_relationalQueryContext.Connection.Close();

throw;
}
}

try
{
var hasNext = _dataReader.Read();

Current
Expand All @@ -426,7 +423,6 @@ public bool MoveNext()
}
catch (Exception exception)
{

_logger.QueryIterationFailed(_contextType, exception);

throw;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ public override bool Equals(object obj)

private bool Equals(SqlConstantExpression sqlConstantExpression)
=> base.Equals(sqlConstantExpression)
&& Value?.Equals(sqlConstantExpression.Value) == true;
&& (Value == null
? sqlConstantExpression.Value == null
: Value.Equals(sqlConstantExpression.Value));

public override int GetHashCode()
{
Expand Down