9595import org .hibernate .persister .entity .EntityPersister ;
9696import org .hibernate .persister .entity .Joinable ;
9797import org .hibernate .persister .entity .SingleTableEntityPersister ;
98- import org .hibernate .query .sqm .BinaryArithmeticOperator ;
9998import org .hibernate .query .BindableType ;
100- import org .hibernate .query .sqm .CastType ;
101- import org .hibernate .query .sqm .ComparisonOperator ;
102- import org .hibernate .query .sqm .DynamicInstantiationNature ;
103- import org .hibernate .query .sqm .FetchClauseType ;
104- import org .hibernate .query .spi .NavigablePath ;
10599import org .hibernate .query .QueryLogging ;
106100import org .hibernate .query .ReturnableType ;
107101import org .hibernate .query .SemanticException ;
108- import org .hibernate .query .sqm .SortOrder ;
109- import org .hibernate .query .sqm .TemporalUnit ;
110- import org .hibernate .query .sqm .UnaryArithmeticOperator ;
111102import org .hibernate .query .criteria .JpaPath ;
103+ import org .hibernate .query .spi .NavigablePath ;
112104import org .hibernate .query .spi .QueryOptions ;
113105import org .hibernate .query .spi .QueryParameterBinding ;
114106import org .hibernate .query .spi .QueryParameterBindings ;
115107import org .hibernate .query .spi .QueryParameterImplementor ;
108+ import org .hibernate .query .sqm .BinaryArithmeticOperator ;
109+ import org .hibernate .query .sqm .CastType ;
110+ import org .hibernate .query .sqm .ComparisonOperator ;
111+ import org .hibernate .query .sqm .DynamicInstantiationNature ;
112+ import org .hibernate .query .sqm .FetchClauseType ;
116113import org .hibernate .query .sqm .InterpretationException ;
114+ import org .hibernate .query .sqm .SortOrder ;
117115import org .hibernate .query .sqm .SqmExpressible ;
118116import org .hibernate .query .sqm .SqmPathSource ;
119117import org .hibernate .query .sqm .SqmQuerySource ;
118+ import org .hibernate .query .sqm .TemporalUnit ;
119+ import org .hibernate .query .sqm .UnaryArithmeticOperator ;
120120import org .hibernate .query .sqm .function .AbstractSqmSelfRenderingFunctionDescriptor ;
121121import org .hibernate .query .sqm .function .SelfRenderingAggregateFunctionSqlAstExpression ;
122122import org .hibernate .query .sqm .function .SelfRenderingFunctionSqlAstExpression ;
262262import org .hibernate .sql .ast .tree .SqlAstNode ;
263263import org .hibernate .sql .ast .tree .Statement ;
264264import org .hibernate .sql .ast .tree .cte .CteColumn ;
265+ import org .hibernate .sql .ast .tree .cte .CteContainer ;
265266import org .hibernate .sql .ast .tree .cte .CteStatement ;
266267import org .hibernate .sql .ast .tree .cte .CteTable ;
267268import org .hibernate .sql .ast .tree .cte .SearchClauseSpecification ;
288289import org .hibernate .sql .ast .tree .expression .Over ;
289290import org .hibernate .sql .ast .tree .expression .Overflow ;
290291import org .hibernate .sql .ast .tree .expression .QueryLiteral ;
292+ import org .hibernate .sql .ast .tree .expression .QueryTransformer ;
291293import org .hibernate .sql .ast .tree .expression .SelfRenderingExpression ;
292294import org .hibernate .sql .ast .tree .expression .SelfRenderingSqlFragmentExpression ;
293295import org .hibernate .sql .ast .tree .expression .SqlSelectionExpression ;
@@ -385,6 +387,7 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
385387 private final SqlAstCreationContext creationContext ;
386388 private final boolean jpaQueryComplianceEnabled ;
387389 private final SqmStatement <?> statement ;
390+ private final CteContainer cteContainer = new GlobalCteContainer ();
388391
389392 private final QueryOptions queryOptions ;
390393 private final LoadQueryInfluencers loadQueryInfluencers ;
@@ -430,6 +433,7 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
430433 private final Stack <Supplier <MappingModelExpressible <?>>> inferrableTypeAccessStack = new StandardStack <>(
431434 () -> null
432435 );
436+ private final Stack <List <QueryTransformer >> queryTransformers = new StandardStack <>();
433437 private boolean inTypeInference ;
434438
435439 private SqmByUnit appliedByUnit ;
@@ -659,7 +663,7 @@ public Statement visitStatement(SqmStatement<?> sqmStatement) {
659663
660664 @ Override
661665 public UpdateStatement visitUpdateStatement (SqmUpdateStatement <?> sqmStatement ) {
662- Map < String , CteStatement > cteStatements = this .visitCteContainer ( sqmStatement );
666+ final CteContainer cteContainer = this .visitCteContainer ( sqmStatement );
663667
664668 final SqmRoot <?> sqmTarget = sqmStatement .getTarget ();
665669 final String entityName = sqmTarget .getEntityName ();
@@ -719,7 +723,7 @@ public UpdateStatement visitUpdateStatement(SqmUpdateStatement<?> sqmStatement)
719723 }
720724
721725 return new UpdateStatement (
722- sqmStatement . isWithRecursive (), cteStatements ,
726+ cteContainer ,
723727 (NamedTableReference ) rootTableGroup .getPrimaryTableReference (),
724728 assignments ,
725729 SqlAstTreeHelper .combinePredicates ( suppliedPredicate , additionalRestrictions ),
@@ -897,7 +901,7 @@ public Expression resolveSqlExpression(
897901
898902 @ Override
899903 public DeleteStatement visitDeleteStatement (SqmDeleteStatement <?> statement ) {
900- Map < String , CteStatement > cteStatements = this .visitCteContainer ( statement );
904+ final CteContainer cteContainer = this .visitCteContainer ( statement );
901905
902906 final String entityName = statement .getTarget ().getEntityName ();
903907 final EntityPersister entityDescriptor = creationContext .getSessionFactory ()
@@ -947,8 +951,7 @@ public DeleteStatement visitDeleteStatement(SqmDeleteStatement<?> statement) {
947951 }
948952
949953 return new DeleteStatement (
950- statement .isWithRecursive (),
951- cteStatements ,
954+ cteContainer ,
952955 (NamedTableReference ) rootTableGroup .getPrimaryTableReference (),
953956 SqlAstTreeHelper .combinePredicates ( suppliedPredicate , additionalRestrictions ),
954957 Collections .emptyList ()
@@ -964,7 +967,7 @@ public DeleteStatement visitDeleteStatement(SqmDeleteStatement<?> statement) {
964967
965968 @ Override
966969 public InsertStatement visitInsertSelectStatement (SqmInsertSelectStatement <?> sqmStatement ) {
967- Map < String , CteStatement > cteStatements = this .visitCteContainer ( sqmStatement );
970+ final CteContainer cteContainer = this .visitCteContainer ( sqmStatement );
968971
969972 final String entityName = sqmStatement .getTarget ().getEntityName ();
970973 final EntityPersister entityDescriptor = creationContext .getSessionFactory ()
@@ -1005,8 +1008,7 @@ public InsertStatement visitInsertSelectStatement(SqmInsertSelectStatement<?> sq
10051008 getFromClauseAccess ().registerTableGroup ( rootPath , rootTableGroup );
10061009
10071010 insertStatement = new InsertStatement (
1008- sqmStatement .isWithRecursive (),
1009- cteStatements ,
1011+ cteContainer ,
10101012 (NamedTableReference ) rootTableGroup .getPrimaryTableReference (),
10111013 Collections .emptyList ()
10121014 );
@@ -1051,7 +1053,7 @@ public InsertStatement visitInsertSelectStatement(SqmInsertSelectStatement<?> sq
10511053
10521054 @ Override
10531055 public InsertStatement visitInsertValuesStatement (SqmInsertValuesStatement <?> sqmStatement ) {
1054- Map < String , CteStatement > cteStatements = this .visitCteContainer ( sqmStatement );
1056+ final CteContainer cteContainer = this .visitCteContainer ( sqmStatement );
10551057 final String entityName = sqmStatement .getTarget ().getEntityName ();
10561058 final EntityPersister entityDescriptor = creationContext .getSessionFactory ()
10571059 .getRuntimeMetamodels ()
@@ -1087,8 +1089,7 @@ public InsertStatement visitInsertValuesStatement(SqmInsertValuesStatement<?> sq
10871089 getFromClauseAccess ().registerTableGroup ( rootPath , rootTableGroup );
10881090
10891091 final InsertStatement insertStatement = new InsertStatement (
1090- sqmStatement .isWithRecursive (),
1091- cteStatements ,
1092+ cteContainer ,
10921093 (NamedTableReference ) rootTableGroup .getPrimaryTableReference (),
10931094 Collections .emptyList ()
10941095 );
@@ -1371,10 +1372,10 @@ public Values visitValues(SqmValues sqmValues) {
13711372
13721373 @ Override
13731374 public SelectStatement visitSelectStatement (SqmSelectStatement <?> statement ) {
1374- Map < String , CteStatement > cteStatements = this .visitCteContainer ( statement );
1375+ final CteContainer cteContainer = this .visitCteContainer ( statement );
13751376 final QueryPart queryPart = visitQueryPart ( statement .getQueryPart () );
13761377 final List <DomainResult <?>> domainResults = queryPart .isRoot () ? this .domainResults : Collections .emptyList ();
1377- return new SelectStatement ( statement . isWithRecursive (), cteStatements , queryPart , domainResults );
1378+ return new SelectStatement ( cteContainer , queryPart , domainResults );
13781379 }
13791380
13801381 @ Override
@@ -1560,14 +1561,15 @@ public static CteTable createCteTable(
15601561 }
15611562
15621563 @ Override
1563- public Map < String , CteStatement > visitCteContainer (SqmCteContainer consumer ) {
1564+ public CteContainer visitCteContainer (SqmCteContainer consumer ) {
15641565 final Collection <SqmCteStatement <?>> sqmCteStatements = consumer .getCteStatements ();
1565- final Map <String , CteStatement > cteStatements = new LinkedHashMap <>( sqmCteStatements .size () );
1566+ if ( consumer .isWithRecursive () ) {
1567+ cteContainer .setWithRecursive ( true );
1568+ }
15661569 for ( SqmCteStatement <?> sqmCteStatement : sqmCteStatements ) {
1567- final CteStatement cteStatement = visitCteStatement ( sqmCteStatement );
1568- cteStatements .put ( cteStatement .getCteTable ().getTableExpression (), cteStatement );
1570+ cteContainer .addCteStatement ( visitCteStatement ( sqmCteStatement ) );
15691571 }
1570- return cteStatements ;
1572+ return cteContainer ;
15711573 }
15721574
15731575 private boolean trackSelectionsForGroup ;
@@ -1688,6 +1690,7 @@ else if ( sqmQuerySpec.hasPositionalGroupItem() ) {
16881690 // In sub-queries, we can never deduplicate the selection items as that might change semantics
16891691 deduplicateSelectionItems = false ;
16901692 pushProcessingState ( processingState );
1693+ queryTransformers .push ( new ArrayList <>() );
16911694
16921695 try {
16931696 // we want to visit the from-clause first
@@ -1721,14 +1724,23 @@ else if ( sqmQuerySpec.hasPositionalGroupItem() ) {
17211724 applyCollectionFilterPredicates ( sqlQuerySpec );
17221725 }
17231726
1724- return sqlQuerySpec ;
1727+ QuerySpec finalQuerySpec = sqlQuerySpec ;
1728+ for ( QueryTransformer transformer : queryTransformers .getCurrent () ) {
1729+ finalQuerySpec = transformer .transform (
1730+ cteContainer ,
1731+ finalQuerySpec ,
1732+ this
1733+ );
1734+ }
1735+ return finalQuerySpec ;
17251736 }
17261737 finally {
17271738 if ( additionalRestrictions != null ) {
17281739 sqlQuerySpec .applyPredicate ( additionalRestrictions );
17291740 }
17301741 additionalRestrictions = originalAdditionalRestrictions ;
17311742 popProcessingStateStack ();
1743+ queryTransformers .pop ();
17321744 currentSqmQueryPart = sqmQueryPart ;
17331745 deduplicateSelectionItems = originalDeduplicateSelectionItems ;
17341746 }
@@ -4738,6 +4750,11 @@ public Expression visitFunction(SqmFunction<?> sqmFunction) {
47384750 }
47394751 }
47404752
4753+ @ Override
4754+ public void registerQueryTransformer (QueryTransformer transformer ) {
4755+ queryTransformers .getCurrent ().add ( transformer );
4756+ }
4757+
47414758 @ Override
47424759 public Star visitStar (SqmStar sqmStar ) {
47434760 return new Star ();
@@ -6564,4 +6581,38 @@ private static JdbcMappingContainer highestPrecedence(JdbcMappingContainer type1
65646581
65656582 return type1 ;
65666583 }
6584+
6585+ private class GlobalCteContainer implements CteContainer {
6586+ private final Map <String , CteStatement > cteStatements ;
6587+ private boolean recursive ;
6588+
6589+ public GlobalCteContainer () {
6590+ this .cteStatements = new LinkedHashMap <>();
6591+ }
6592+
6593+ @ Override
6594+ public boolean isWithRecursive () {
6595+ return recursive ;
6596+ }
6597+
6598+ @ Override
6599+ public void setWithRecursive (boolean recursive ) {
6600+ this .recursive = recursive ;
6601+ }
6602+
6603+ @ Override
6604+ public Map <String , CteStatement > getCteStatements () {
6605+ return cteStatements ;
6606+ }
6607+
6608+ @ Override
6609+ public CteStatement getCteStatement (String cteLabel ) {
6610+ return cteStatements .get ( cteLabel );
6611+ }
6612+
6613+ @ Override
6614+ public void addCteStatement (CteStatement cteStatement ) {
6615+ cteStatements .put ( cteStatement .getCteTable ().getTableExpression (), cteStatement );
6616+ }
6617+ }
65676618}
0 commit comments