Skip to content

Commit 396830c

Browse files
committed
CSHARP-4468: LINQ V3 SelectMany + GroupBy results with redundant $push within $group.
1 parent 70ed174 commit 396830c

File tree

4 files changed

+200
-68
lines changed

4 files changed

+200
-68
lines changed

Diff for: src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupPipelineOptimizer.cs

+27-26
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@ public static AstPipeline Optimize(AstPipeline pipeline)
4444
#endregion
4545

4646
private readonly AccumulatorSet _accumulators = new AccumulatorSet();
47+
private AstExpression _element; // normally either "$$ROOT" or "$_v"
4748

4849
private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStage groupStage)
4950
{
5051
try
5152
{
52-
if (IsOptimizableGroupStage(groupStage))
53+
if (IsOptimizableGroupStage(groupStage, out _element))
5354
{
5455
var followingStages = GetFollowingStagesToOptimize(pipeline, i + 1);
5556
if (followingStages == null)
@@ -71,22 +72,22 @@ private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStag
7172

7273
return pipeline;
7374

74-
static bool IsOptimizableGroupStage(AstGroupStage groupStage)
75+
static bool IsOptimizableGroupStage(AstGroupStage groupStage, out AstExpression element)
7576
{
76-
// { $group : { _id : ?, _elements : { $push : "$$ROOT" } } }
77+
// { $group : { _id : ?, _elements : { $push : element } } }
7778
if (groupStage.Fields.Count == 1)
7879
{
7980
var field = groupStage.Fields[0];
8081
if (field.Path == "_elements" &&
8182
field.Value is AstUnaryAccumulatorExpression unaryAccumulatorExpression &&
82-
unaryAccumulatorExpression.Operator == AstUnaryAccumulatorOperator.Push &&
83-
unaryAccumulatorExpression.Arg is AstVarExpression varExpression &&
84-
varExpression.Name == "ROOT")
83+
unaryAccumulatorExpression.Operator == AstUnaryAccumulatorOperator.Push)
8584
{
85+
element = unaryAccumulatorExpression.Arg;
8686
return true;
8787
}
8888
}
8989

90+
element = null;
9091
return false;
9192
}
9293

@@ -173,7 +174,7 @@ private AstStage OptimizeLimitStage(AstLimitStage stage)
173174

174175
private AstStage OptimizeMatchStage(AstMatchStage stage)
175176
{
176-
var optimizedFilter = AccumulatorMover.MoveAccumulators(_accumulators, stage.Filter);
177+
var optimizedFilter = AccumulatorMover.MoveAccumulators(_accumulators, _element, stage.Filter);
177178
return stage.Update(optimizedFilter);
178179
}
179180

@@ -201,7 +202,7 @@ private AstProjectStageSpecification OptimizeProjectStageSpecification(AstProjec
201202

202203
private AstProjectStageSpecification OptimizeProjectStageSetFieldSpecification(AstProjectStageSetFieldSpecification specification)
203204
{
204-
var optimizedValue = AccumulatorMover.MoveAccumulators(_accumulators, specification.Value);
205+
var optimizedValue = AccumulatorMover.MoveAccumulators(_accumulators, _element, specification.Value);
205206
return specification.Update(optimizedValue);
206207
}
207208

@@ -249,27 +250,29 @@ public string AddAccumulatorExpression(AstAccumulatorExpression value)
249250
private class AccumulatorMover : AstNodeVisitor
250251
{
251252
#region static
252-
public static TNode MoveAccumulators<TNode>(AccumulatorSet accumulators, TNode node)
253+
public static TNode MoveAccumulators<TNode>(AccumulatorSet accumulators, AstExpression element, TNode node)
253254
where TNode : AstNode
254255
{
255-
var mover = new AccumulatorMover(accumulators);
256+
var mover = new AccumulatorMover(accumulators, element);
256257
return mover.VisitAndConvert(node);
257258
}
258259
#endregion
259260

260261
private readonly AccumulatorSet _accumulators;
262+
private readonly AstExpression _element;
261263

262-
private AccumulatorMover(AccumulatorSet accumulator)
264+
private AccumulatorMover(AccumulatorSet accumulator, AstExpression element)
263265
{
264266
_accumulators = accumulator;
267+
_element = element;
265268
}
266269

267270
public override AstNode VisitFilterField(AstFilterField node)
268271
{
269-
// "_elements.0.X" => { __agg0 : { $first : "$$ROOT" } } + "__agg0.X"
272+
// "_elements.0.X" => { __agg0 : { $first : element } } + "__agg0.X"
270273
if (node.Path.StartsWith("_elements.0."))
271274
{
272-
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.First, AstExpression.Var("ROOT"));
275+
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.First, _element);
273276
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
274277
var restOfPath = node.Path.Substring("_elements.0.".Length);
275278
var rewrittenPath = $"{accumulatorFieldName}.{restOfPath}";
@@ -288,9 +291,7 @@ public override AstNode VisitGetFieldExpression(AstGetFieldExpression node)
288291
{
289292
if (node.FieldName is AstConstantExpression constantFieldName &&
290293
constantFieldName.Value.IsString &&
291-
constantFieldName.Value.AsString == "_elements" &&
292-
node.Input is AstVarExpression varExpression &&
293-
varExpression.Name == "ROOT")
294+
constantFieldName.Value.AsString == "_elements")
294295
{
295296
throw new UnableToRemoveReferenceToElementsException();
296297
}
@@ -300,18 +301,18 @@ node.Input is AstVarExpression varExpression &&
300301

301302
public override AstNode VisitMapExpression(AstMapExpression node)
302303
{
303-
// { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => root) } } + "$__agg0"
304+
// { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => element) } } + "$__agg0"
304305
if (node.Input is AstGetFieldExpression mapInputGetFieldExpression &&
305306
mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFieldExpression &&
306307
mapInputconstantFieldExpression.Value.IsString &&
307308
mapInputconstantFieldExpression.Value.AsString == "_elements" &&
308309
mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpression &&
309310
mapInputGetFieldVarExpression.Name == "ROOT")
310311
{
311-
var root = AstExpression.Var("ROOT", isCurrent: true);
312-
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(node.In, (node.As, root));
312+
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(node.In, (node.As, _element));
313313
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Push, rewrittenArg);
314314
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
315+
var root = AstExpression.Var("ROOT", isCurrent: true);
315316
return AstExpression.GetField(root, accumulatorFieldName);
316317
}
317318

@@ -321,7 +322,7 @@ mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpressi
321322
public override AstNode VisitPickExpression(AstPickExpression node)
322323
{
323324
// { $pickOperator : { source : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", sortBy : s, selector : f(x) } }
324-
// => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => root) } } } + "$__agg0"
325+
// => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => element) } } } + "$__agg0"
325326
if (node.Source is AstGetFieldExpression getFieldExpression &&
326327
getFieldExpression.Input is AstVarExpression varExpression &&
327328
varExpression.Name == "ROOT" &&
@@ -330,10 +331,10 @@ getFieldExpression.FieldName is AstConstantExpression constantFieldNameExpressio
330331
constantFieldNameExpression.Value.AsString == "_elements")
331332
{
332333
var @operator = node.Operator.ToAccumulatorOperator();
333-
var root = AstExpression.Var("ROOT", isCurrent: true);
334-
var rewrittenSelector = (AstExpression)AstNodeReplacer.Replace(node.Selector, (node.As, root));
334+
var rewrittenSelector = (AstExpression)AstNodeReplacer.Replace(node.Selector, (node.As, _element));
335335
var accumulatorExpression = new AstPickAccumulatorExpression(@operator, node.SortBy, rewrittenSelector, node.N);
336336
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
337+
var root = AstExpression.Var("ROOT", isCurrent: true);
337338
return AstExpression.GetField(root, accumulatorFieldName);
338339
}
339340

@@ -384,7 +385,7 @@ argGetFieldExpression.FieldName is AstConstantExpression constantFieldNameExpres
384385

385386
bool TryOptimizeAccumulatorOfElements(out AstExpression optimizedExpression)
386387
{
387-
// { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : "$$ROOT" } } + "$__agg0"
388+
// { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : element } } + "$__agg0"
388389
if (node.Operator.IsAccumulator(out var accumulatorOperator) &&
389390
node.Arg is AstGetFieldExpression getFieldExpression &&
390391
getFieldExpression.FieldName is AstConstantExpression getFieldConstantFieldNameExpression &&
@@ -393,7 +394,7 @@ getFieldExpression.FieldName is AstConstantExpression getFieldConstantFieldNameE
393394
getFieldExpression.Input is AstVarExpression getFieldInputVarExpression &&
394395
getFieldInputVarExpression.Name == "ROOT")
395396
{
396-
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, root);
397+
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, _element);
397398
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
398399
optimizedExpression = AstExpression.GetField(root, accumulatorFieldName);
399400
return true;
@@ -406,7 +407,7 @@ getFieldExpression.Input is AstVarExpression getFieldInputVarExpression &&
406407

407408
bool TryOptimizeAccumulatorOfMappedElements(out AstExpression optimizedExpression)
408409
{
409-
// { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => root) } } + "$__agg0"
410+
// { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => element) } } + "$__agg0"
410411
if (node.Operator.IsAccumulator(out var accumulatorOperator) &&
411412
node.Arg is AstMapExpression mapExpression &&
412413
mapExpression.Input is AstGetFieldExpression mapInputGetFieldExpression &&
@@ -416,7 +417,7 @@ mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFi
416417
mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpression &&
417418
mapInputGetFieldVarExpression.Name == "ROOT")
418419
{
419-
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(mapExpression.In, (mapExpression.As, root));
420+
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(mapExpression.In, (mapExpression.As, _element));
420421
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, rewrittenArg);
421422
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
422423
optimizedExpression = AstExpression.GetField(root, accumulatorFieldName);

Diff for: tests/MongoDB.Driver.Tests/Linq/Linq2ImplementationTestsOnLinq3/MongoQueryableTests.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1537,8 +1537,8 @@ group f by f.D into g
15371537
4,
15381538
"{ $project : { _v : '$G', _id : 0 } }",
15391539
"{ $unwind : '$_v' }",
1540-
"{ $group : { _id : '$_v.D', _elements : { $push : '$_v' } } }",
1541-
"{ $project : { Key : '$_id', SumF : { $sum : '$_elements.E.F' }, _id : 0 } }");
1540+
"{ $group : { _id : '$_v.D', __agg0 : { $sum : '$_v.E.F' } } }",
1541+
"{ $project : { Key : '$_id', SumF : '$__agg0', _id : 0 } }");
15421542
}
15431543

15441544
[Fact]
@@ -1567,8 +1567,8 @@ group s by s.D into g
15671567
"{ $unwind : '$_v' }",
15681568
"{ $project : { '_v' : '$_v.S', '_id' : 0 } }",
15691569
"{ $unwind : '$_v' }",
1570-
"{ $group : { _id : '$_v.D', _elements : { $push : '$_v' } } }",
1571-
"{ $project : { Key : '$_id', SumF : { $sum : '$_elements.E.F' }, _id : 0 } }");
1570+
"{ $group : { _id : '$_v.D', __agg0 : { $sum : '$_v.E.F' } } }",
1571+
"{ $project : { Key : '$_id', SumF : '$__agg0', _id : 0 } }");
15721572
}
15731573

15741574
[Fact]

0 commit comments

Comments
 (0)