Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiny cleanup of CloningExpressionVisitor #30700

Merged
merged 1 commit into from
Apr 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 109 additions & 118 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -924,145 +924,136 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor
[return: NotNullIfNotNull("expression")]
public override Expression? Visit(Expression? expression)
{
if (expression is SelectExpression selectExpression)
switch (expression)
{
var newProjectionMappings = new Dictionary<ProjectionMember, Expression>(selectExpression._projectionMapping.Count);
foreach (var (projectionMember, value) in selectExpression._projectionMapping)
case SelectExpression selectExpression:
{
newProjectionMappings[projectionMember] = Visit(value);
}

var newProjections = selectExpression._projection.Select(Visit).ToList<ProjectionExpression>();
var newProjectionMappings = new Dictionary<ProjectionMember, Expression>(selectExpression._projectionMapping.Count);
foreach (var (projectionMember, value) in selectExpression._projectionMapping)
{
newProjectionMappings[projectionMember] = Visit(value);
}

var newTables = selectExpression._tables.Select(Visit).ToList<TableExpressionBase>();
var tpcTablesMap = selectExpression._tables.Select(UnwrapJoinExpression).Zip(newTables.Select(UnwrapJoinExpression))
.Where(e => e.First is TpcTablesExpression)
.ToDictionary(e => (TpcTablesExpression)e.First, e => (TpcTablesExpression)e.Second);
var newProjections = selectExpression._projection.Select(Visit).ToList<ProjectionExpression>();

var newTables = selectExpression._tables.Select(Visit).ToList<TableExpressionBase>();
var tpcTablesMap = selectExpression._tables.Select(UnwrapJoinExpression).Zip(newTables.Select(UnwrapJoinExpression))
.Where(e => e.First is TpcTablesExpression)
.ToDictionary(e => (TpcTablesExpression)e.First, e => (TpcTablesExpression)e.Second);

// Since we are cloning we need to generate new table references
// In other cases (like VisitChildren), we just reuse the same table references and update the SelectExpression inside it.
// We initially assign old SelectExpression in table references and later update it once we construct clone
var newTableReferences = selectExpression._tableReferences
.Select(e => new TableReferenceExpression(selectExpression, e.Alias)).ToList();
Check.DebugAssert(
newTables.Select(e => GetAliasFromTableExpressionBase(e)).SequenceEqual(newTableReferences.Select(e => e.Alias)),
"Alias of updated tables must match the old tables.");

var predicate = (SqlExpression?)Visit(selectExpression.Predicate);
var newGroupBy = selectExpression._groupBy.Select(Visit)
.Where(e => !(e is SqlConstantExpression || e is SqlParameterExpression))
.ToList<SqlExpression>();
var havingExpression = (SqlExpression?)Visit(selectExpression.Having);
var newOrderings = selectExpression._orderings.Select(Visit).ToList<OrderingExpression>();
var offset = (SqlExpression?)Visit(selectExpression.Offset);
var limit = (SqlExpression?)Visit(selectExpression.Limit);

var newSelectExpression = new SelectExpression(
selectExpression.Alias, newProjections, newTables, newTableReferences, newGroupBy, newOrderings,
selectExpression.GetAnnotations())
{
Predicate = predicate,
Having = havingExpression,
Offset = offset,
Limit = limit,
IsDistinct = selectExpression.IsDistinct,
Tags = selectExpression.Tags,
_usedAliases = selectExpression._usedAliases.ToHashSet(),
_projectionMapping = newProjectionMappings,
};
newSelectExpression._mutable = selectExpression._mutable;

newSelectExpression._removableJoinTables.AddRange(selectExpression._removableJoinTables);

foreach (var kvp in selectExpression._tpcDiscriminatorValues)
{
newSelectExpression._tpcDiscriminatorValues[tpcTablesMap[kvp.Key]] = kvp.Value;
}

// Since we are cloning we need to generate new table references
// In other cases (like VisitChildren), we just reuse the same table references and update the SelectExpression inside it.
// We initially assign old SelectExpression in table references and later update it once we construct clone
var newTableReferences = selectExpression._tableReferences
.Select(e => new TableReferenceExpression(selectExpression, e.Alias)).ToList();
Check.DebugAssert(
newTables.Select(e => GetAliasFromTableExpressionBase(e)).SequenceEqual(newTableReferences.Select(e => e.Alias)),
"Alias of updated tables must match the old tables.");

var predicate = (SqlExpression?)Visit(selectExpression.Predicate);
var newGroupBy = selectExpression._groupBy.Select(Visit)
.Where(e => !(e is SqlConstantExpression || e is SqlParameterExpression))
.ToList<SqlExpression>();
var havingExpression = (SqlExpression?)Visit(selectExpression.Having);
var newOrderings = selectExpression._orderings.Select(Visit).ToList<OrderingExpression>();
var offset = (SqlExpression?)Visit(selectExpression.Offset);
var limit = (SqlExpression?)Visit(selectExpression.Limit);

var newSelectExpression = new SelectExpression(
selectExpression.Alias, newProjections, newTables, newTableReferences, newGroupBy, newOrderings,
selectExpression.GetAnnotations())
{
Predicate = predicate,
Having = havingExpression,
Offset = offset,
Limit = limit,
IsDistinct = selectExpression.IsDistinct,
Tags = selectExpression.Tags,
_usedAliases = selectExpression._usedAliases.ToHashSet(),
_projectionMapping = newProjectionMappings,
};
newSelectExpression._mutable = selectExpression._mutable;

newSelectExpression._removableJoinTables.AddRange(selectExpression._removableJoinTables);
// Since identifiers are ColumnExpression, they are not visited since they don't contain SelectExpression inside it.
newSelectExpression._identifier.AddRange(selectExpression._identifier);
newSelectExpression._childIdentifiers.AddRange(selectExpression._childIdentifiers);

foreach (var kvp in selectExpression._tpcDiscriminatorValues)
{
newSelectExpression._tpcDiscriminatorValues[tpcTablesMap[kvp.Key]] = kvp.Value;
}
// Remap tableReferences in new select expression
foreach (var tableReference in newTableReferences)
{
tableReference.UpdateTableReference(selectExpression, newSelectExpression);
}

// Since identifiers are ColumnExpression, they are not visited since they don't contain SelectExpression inside it.
newSelectExpression._identifier.AddRange(selectExpression._identifier);
newSelectExpression._childIdentifiers.AddRange(selectExpression._childIdentifiers);
// Now that we have SelectExpression, we visit all components and update table references inside columns
newSelectExpression = (SelectExpression)new ColumnExpressionReplacingExpressionVisitor(
selectExpression, newSelectExpression._tableReferences).Visit(newSelectExpression);

// Remap tableReferences in new select expression
foreach (var tableReference in newTableReferences)
{
tableReference.UpdateTableReference(selectExpression, newSelectExpression);
return newSelectExpression;
}

// Now that we have SelectExpression, we visit all components and update table references inside columns
newSelectExpression = (SelectExpression)new ColumnExpressionReplacingExpressionVisitor(
selectExpression, newSelectExpression._tableReferences).Visit(newSelectExpression);

return newSelectExpression;
}

if (expression is TpcTablesExpression tpcTablesExpression)
{
// Deep clone
var subSelectExpressions = tpcTablesExpression.SelectExpressions.Select(Visit).ToList<SelectExpression>();
var newTpcTable = new TpcTablesExpression(tpcTablesExpression.Alias, tpcTablesExpression.EntityType, subSelectExpressions);
foreach (var annotation in tpcTablesExpression.GetAnnotations())
case TpcTablesExpression tpcTablesExpression:
{
newTpcTable.AddAnnotation(annotation.Name, annotation.Value);
}

return newTpcTable;
}
// Deep clone
var subSelectExpressions = tpcTablesExpression.SelectExpressions.Select(Visit).ToList<SelectExpression>();
var newTpcTable = new TpcTablesExpression(tpcTablesExpression.Alias, tpcTablesExpression.EntityType, subSelectExpressions);
foreach (var annotation in tpcTablesExpression.GetAnnotations())
{
newTpcTable.AddAnnotation(annotation.Name, annotation.Value);
}

if (expression is TableValuedFunctionExpression tableValuedFunctionExpression)
{
var newArguments = new SqlExpression[tableValuedFunctionExpression.Arguments.Count];
for (var i = 0; i < newArguments.Length; i++)
{
newArguments[i] = (SqlExpression)Visit(tableValuedFunctionExpression.Arguments[i]);
return newTpcTable;
}

var newTableValuedFunctionExpression = new TableValuedFunctionExpression(
tableValuedFunctionExpression.StoreFunction,
newArguments)
case TableValuedFunctionExpression tableValuedFunctionExpression:
{
Alias = tableValuedFunctionExpression.Alias
};
var newArguments = new SqlExpression[tableValuedFunctionExpression.Arguments.Count];
for (var i = 0; i < newArguments.Length; i++)
{
newArguments[i] = (SqlExpression)Visit(tableValuedFunctionExpression.Arguments[i]);
}

foreach (var annotation in tableValuedFunctionExpression.GetAnnotations())
{
newTableValuedFunctionExpression.AddAnnotation(annotation.Name, annotation.Value);
var newTableValuedFunctionExpression = new TableValuedFunctionExpression(
tableValuedFunctionExpression.StoreFunction,
newArguments)
{
Alias = tableValuedFunctionExpression.Alias
};

foreach (var annotation in tableValuedFunctionExpression.GetAnnotations())
{
newTableValuedFunctionExpression.AddAnnotation(annotation.Name, annotation.Value);
}

return newTableValuedFunctionExpression;
}

return newTableValuedFunctionExpression;
}
case IClonableTableExpressionBase cloneable:
return cloneable.Clone();

if (expression is IClonableTableExpressionBase cloneable)
{
return cloneable.Clone();
}
// join and set operations are fine, because they contain other TableExpressionBases inside, that will get cloned
// and therefore set expression's Update function will generate a new instance.
case JoinExpressionBase or SetOperationBase:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maumar note using the base classes for joins/set operations rather than listing them out

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my thinking was that if we introduce new join/set expression that happens to need cloning logic, we would need to explicitly make the decision, rather than being covered by base. But now that I think about it, it was overkill.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I think your logic (as in the comment) should hold for any subclass of JoinExpressionBase/SetOperationBase... We can see where it goes...

return base.Visit(expression);

// join and set operations are fine, because they contain other TableExpressionBases inside, that will get cloned
// and therefore set expression's Update function will generate a new instance.
if (expression is CrossJoinExpression
or InnerJoinExpression
or LeftJoinExpression
or CrossApplyExpression
or OuterApplyExpression
or ExceptExpression
or IntersectExpression
or UnionExpression)
{
return base.Visit(expression);
}
case TableExpressionBase:
throw new InvalidOperationException(
RelationalStrings.TableExpressionBaseWithoutCloningLogic(
expression.GetType().Name,
nameof(TableExpressionBase),
nameof(IClonableTableExpressionBase),
nameof(CloningExpressionVisitor),
nameof(SelectExpression)));

if (expression is TableExpressionBase)
{
throw new InvalidOperationException(
RelationalStrings.TableExpressionBaseWithoutCloningLogic(
expression.GetType().Name,
nameof(TableExpressionBase),
nameof(IClonableTableExpressionBase),
nameof(CloningExpressionVisitor),
nameof(SelectExpression))); ;
default:
return base.Visit(expression);
}

return base.Visit(expression);
}
}

Expand Down