@@ -44,12 +44,13 @@ public static AstPipeline Optimize(AstPipeline pipeline)
44
44
#endregion
45
45
46
46
private readonly AccumulatorSet _accumulators = new AccumulatorSet ( ) ;
47
+ private AstExpression _element ; // normally either "$$ROOT" or "$_v"
47
48
48
49
private AstPipeline OptimizeGroupStage ( AstPipeline pipeline , int i , AstGroupStage groupStage )
49
50
{
50
51
try
51
52
{
52
- if ( IsOptimizableGroupStage ( groupStage ) )
53
+ if ( IsOptimizableGroupStage ( groupStage , out _element ) )
53
54
{
54
55
var followingStages = GetFollowingStagesToOptimize ( pipeline , i + 1 ) ;
55
56
if ( followingStages == null )
@@ -71,22 +72,22 @@ private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStag
71
72
72
73
return pipeline ;
73
74
74
- static bool IsOptimizableGroupStage ( AstGroupStage groupStage )
75
+ static bool IsOptimizableGroupStage ( AstGroupStage groupStage , out AstExpression element )
75
76
{
76
- // { $group : { _id : ?, _elements : { $push : "$$ROOT" } } }
77
+ // { $group : { _id : ?, _elements : { $push : element } } }
77
78
if ( groupStage . Fields . Count == 1 )
78
79
{
79
80
var field = groupStage . Fields [ 0 ] ;
80
81
if ( field . Path == "_elements" &&
81
82
field . Value is AstUnaryAccumulatorExpression unaryAccumulatorExpression &&
82
- unaryAccumulatorExpression . Operator == AstUnaryAccumulatorOperator . Push &&
83
- unaryAccumulatorExpression . Arg is AstVarExpression varExpression &&
84
- varExpression . Name == "ROOT" )
83
+ unaryAccumulatorExpression . Operator == AstUnaryAccumulatorOperator . Push )
85
84
{
85
+ element = unaryAccumulatorExpression . Arg ;
86
86
return true ;
87
87
}
88
88
}
89
89
90
+ element = null ;
90
91
return false ;
91
92
}
92
93
@@ -173,7 +174,7 @@ private AstStage OptimizeLimitStage(AstLimitStage stage)
173
174
174
175
private AstStage OptimizeMatchStage ( AstMatchStage stage )
175
176
{
176
- var optimizedFilter = AccumulatorMover . MoveAccumulators ( _accumulators , stage . Filter ) ;
177
+ var optimizedFilter = AccumulatorMover . MoveAccumulators ( _accumulators , _element , stage . Filter ) ;
177
178
return stage . Update ( optimizedFilter ) ;
178
179
}
179
180
@@ -201,7 +202,7 @@ private AstProjectStageSpecification OptimizeProjectStageSpecification(AstProjec
201
202
202
203
private AstProjectStageSpecification OptimizeProjectStageSetFieldSpecification ( AstProjectStageSetFieldSpecification specification )
203
204
{
204
- var optimizedValue = AccumulatorMover . MoveAccumulators ( _accumulators , specification . Value ) ;
205
+ var optimizedValue = AccumulatorMover . MoveAccumulators ( _accumulators , _element , specification . Value ) ;
205
206
return specification . Update ( optimizedValue ) ;
206
207
}
207
208
@@ -249,27 +250,29 @@ public string AddAccumulatorExpression(AstAccumulatorExpression value)
249
250
private class AccumulatorMover : AstNodeVisitor
250
251
{
251
252
#region static
252
- public static TNode MoveAccumulators < TNode > ( AccumulatorSet accumulators , TNode node )
253
+ public static TNode MoveAccumulators < TNode > ( AccumulatorSet accumulators , AstExpression element , TNode node )
253
254
where TNode : AstNode
254
255
{
255
- var mover = new AccumulatorMover ( accumulators ) ;
256
+ var mover = new AccumulatorMover ( accumulators , element ) ;
256
257
return mover . VisitAndConvert ( node ) ;
257
258
}
258
259
#endregion
259
260
260
261
private readonly AccumulatorSet _accumulators ;
262
+ private readonly AstExpression _element ;
261
263
262
- private AccumulatorMover ( AccumulatorSet accumulator )
264
+ private AccumulatorMover ( AccumulatorSet accumulator , AstExpression element )
263
265
{
264
266
_accumulators = accumulator ;
267
+ _element = element ;
265
268
}
266
269
267
270
public override AstNode VisitFilterField ( AstFilterField node )
268
271
{
269
- // "_elements.0.X" => { __agg0 : { $first : "$$ROOT" } } + "__agg0.X"
272
+ // "_elements.0.X" => { __agg0 : { $first : element } } + "__agg0.X"
270
273
if ( node . Path . StartsWith ( "_elements.0." ) )
271
274
{
272
- var accumulatorExpression = AstExpression . UnaryAccumulator ( AstUnaryAccumulatorOperator . First , AstExpression . Var ( "ROOT" ) ) ;
275
+ var accumulatorExpression = AstExpression . UnaryAccumulator ( AstUnaryAccumulatorOperator . First , _element ) ;
273
276
var accumulatorFieldName = _accumulators . AddAccumulatorExpression ( accumulatorExpression ) ;
274
277
var restOfPath = node . Path . Substring ( "_elements.0." . Length ) ;
275
278
var rewrittenPath = $ "{ accumulatorFieldName } .{ restOfPath } ";
@@ -288,9 +291,7 @@ public override AstNode VisitGetFieldExpression(AstGetFieldExpression node)
288
291
{
289
292
if ( node . FieldName is AstConstantExpression constantFieldName &&
290
293
constantFieldName . Value . IsString &&
291
- constantFieldName . Value . AsString == "_elements" &&
292
- node . Input is AstVarExpression varExpression &&
293
- varExpression . Name == "ROOT" )
294
+ constantFieldName . Value . AsString == "_elements" )
294
295
{
295
296
throw new UnableToRemoveReferenceToElementsException ( ) ;
296
297
}
@@ -300,18 +301,18 @@ node.Input is AstVarExpression varExpression &&
300
301
301
302
public override AstNode VisitMapExpression ( AstMapExpression node )
302
303
{
303
- // { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => root ) } } + "$__agg0"
304
+ // { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => element ) } } + "$__agg0"
304
305
if ( node . Input is AstGetFieldExpression mapInputGetFieldExpression &&
305
306
mapInputGetFieldExpression . FieldName is AstConstantExpression mapInputconstantFieldExpression &&
306
307
mapInputconstantFieldExpression . Value . IsString &&
307
308
mapInputconstantFieldExpression . Value . AsString == "_elements" &&
308
309
mapInputGetFieldExpression . Input is AstVarExpression mapInputGetFieldVarExpression &&
309
310
mapInputGetFieldVarExpression . Name == "ROOT" )
310
311
{
311
- var root = AstExpression . Var ( "ROOT" , isCurrent : true ) ;
312
- var rewrittenArg = ( AstExpression ) AstNodeReplacer . Replace ( node . In , ( node . As , root ) ) ;
312
+ var rewrittenArg = ( AstExpression ) AstNodeReplacer . Replace ( node . In , ( node . As , _element ) ) ;
313
313
var accumulatorExpression = AstExpression . UnaryAccumulator ( AstUnaryAccumulatorOperator . Push , rewrittenArg ) ;
314
314
var accumulatorFieldName = _accumulators . AddAccumulatorExpression ( accumulatorExpression ) ;
315
+ var root = AstExpression . Var ( "ROOT" , isCurrent : true ) ;
315
316
return AstExpression . GetField ( root , accumulatorFieldName ) ;
316
317
}
317
318
@@ -321,7 +322,7 @@ mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpressi
321
322
public override AstNode VisitPickExpression ( AstPickExpression node )
322
323
{
323
324
// { $pickOperator : { source : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", sortBy : s, selector : f(x) } }
324
- // => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => root ) } } } + "$__agg0"
325
+ // => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => element ) } } } + "$__agg0"
325
326
if ( node . Source is AstGetFieldExpression getFieldExpression &&
326
327
getFieldExpression . Input is AstVarExpression varExpression &&
327
328
varExpression . Name == "ROOT" &&
@@ -330,10 +331,10 @@ getFieldExpression.FieldName is AstConstantExpression constantFieldNameExpressio
330
331
constantFieldNameExpression . Value . AsString == "_elements" )
331
332
{
332
333
var @operator = node . Operator . ToAccumulatorOperator ( ) ;
333
- var root = AstExpression . Var ( "ROOT" , isCurrent : true ) ;
334
- var rewrittenSelector = ( AstExpression ) AstNodeReplacer . Replace ( node . Selector , ( node . As , root ) ) ;
334
+ var rewrittenSelector = ( AstExpression ) AstNodeReplacer . Replace ( node . Selector , ( node . As , _element ) ) ;
335
335
var accumulatorExpression = new AstPickAccumulatorExpression ( @operator , node . SortBy , rewrittenSelector , node . N ) ;
336
336
var accumulatorFieldName = _accumulators . AddAccumulatorExpression ( accumulatorExpression ) ;
337
+ var root = AstExpression . Var ( "ROOT" , isCurrent : true ) ;
337
338
return AstExpression . GetField ( root , accumulatorFieldName ) ;
338
339
}
339
340
@@ -384,7 +385,7 @@ argGetFieldExpression.FieldName is AstConstantExpression constantFieldNameExpres
384
385
385
386
bool TryOptimizeAccumulatorOfElements ( out AstExpression optimizedExpression )
386
387
{
387
- // { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : "$$ROOT" } } + "$__agg0"
388
+ // { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : element } } + "$__agg0"
388
389
if ( node . Operator . IsAccumulator ( out var accumulatorOperator ) &&
389
390
node . Arg is AstGetFieldExpression getFieldExpression &&
390
391
getFieldExpression . FieldName is AstConstantExpression getFieldConstantFieldNameExpression &&
@@ -393,7 +394,7 @@ getFieldExpression.FieldName is AstConstantExpression getFieldConstantFieldNameE
393
394
getFieldExpression . Input is AstVarExpression getFieldInputVarExpression &&
394
395
getFieldInputVarExpression . Name == "ROOT" )
395
396
{
396
- var accumulatorExpression = AstExpression . UnaryAccumulator ( accumulatorOperator , root ) ;
397
+ var accumulatorExpression = AstExpression . UnaryAccumulator ( accumulatorOperator , _element ) ;
397
398
var accumulatorFieldName = _accumulators . AddAccumulatorExpression ( accumulatorExpression ) ;
398
399
optimizedExpression = AstExpression . GetField ( root , accumulatorFieldName ) ;
399
400
return true ;
@@ -406,7 +407,7 @@ getFieldExpression.Input is AstVarExpression getFieldInputVarExpression &&
406
407
407
408
bool TryOptimizeAccumulatorOfMappedElements ( out AstExpression optimizedExpression )
408
409
{
409
- // { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => root ) } } + "$__agg0"
410
+ // { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => element ) } } + "$__agg0"
410
411
if ( node . Operator . IsAccumulator ( out var accumulatorOperator ) &&
411
412
node . Arg is AstMapExpression mapExpression &&
412
413
mapExpression . Input is AstGetFieldExpression mapInputGetFieldExpression &&
@@ -416,7 +417,7 @@ mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFi
416
417
mapInputGetFieldExpression . Input is AstVarExpression mapInputGetFieldVarExpression &&
417
418
mapInputGetFieldVarExpression . Name == "ROOT" )
418
419
{
419
- var rewrittenArg = ( AstExpression ) AstNodeReplacer . Replace ( mapExpression . In , ( mapExpression . As , root ) ) ;
420
+ var rewrittenArg = ( AstExpression ) AstNodeReplacer . Replace ( mapExpression . In , ( mapExpression . As , _element ) ) ;
420
421
var accumulatorExpression = AstExpression . UnaryAccumulator ( accumulatorOperator , rewrittenArg ) ;
421
422
var accumulatorFieldName = _accumulators . AddAccumulatorExpression ( accumulatorExpression ) ;
422
423
optimizedExpression = AstExpression . GetField ( root , accumulatorFieldName ) ;
0 commit comments