From c4a6d382edc98fd82cd0de33b7fefea5a8f79d7e Mon Sep 17 00:00:00 2001 From: Diego Krupitza Date: Tue, 28 Jun 2022 15:21:33 +0200 Subject: [PATCH] Adds support for more SelectBody types in JSqlParserQueryEhancer. We now support `ValuesStatement` and `SetOperationList`. This allows native queries to use `union`, `except`, and `with` statements in native SQL queries. Closes #2578. --- .../query/JSqlParserQueryEnhancer.java | 97 +++++++++++- .../jpa/repository/UserRepositoryTests.java | 66 ++++++++ .../query/QueryEnhancerUnitTests.java | 144 +++++++++++++++++- .../jpa/repository/sample/UserRepository.java | 34 +++++ 4 files changed, 333 insertions(+), 8 deletions(-) diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java index 9ac2bc7466..1c6c42ae80 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java @@ -29,9 +29,13 @@ import net.sf.jsqlparser.statement.select.OrderByElement; import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectItem; +import net.sf.jsqlparser.statement.select.SetOperationList; +import net.sf.jsqlparser.statement.select.WithItem; import net.sf.jsqlparser.statement.update.Update; +import net.sf.jsqlparser.statement.values.ValuesStatement; import java.util.ArrayList; import java.util.Collections; @@ -107,6 +111,13 @@ public String applySorting(Sort sort, @Nullable String alias) { } Select selectStatement = parseSelectStatement(queryString); + + if (selectStatement.getSelectBody()instanceof SetOperationList setOperationList) { + return applySortingToSetOperationList(setOperationList, sort); + } else if (!(selectStatement.getSelectBody() instanceof PlainSelect)) { + return queryString; + } + PlainSelect selectBody = (PlainSelect) selectStatement.getSelectBody(); final Set joinAliases = getJoinAliases(selectBody); @@ -115,7 +126,7 @@ public String applySorting(Sort sort, @Nullable String alias) { List orderByElements = sort.stream() // .map(order -> getOrderClause(joinAliases, selectionAliases, alias, order)) // - .collect(Collectors.toList()); + .toList(); if (CollectionUtils.isEmpty(selectBody.getOrderByElements())) { selectBody.setOrderByElements(new ArrayList<>()); @@ -127,6 +138,33 @@ public String applySorting(Sort sort, @Nullable String alias) { } + /** + * Returns the {@link SetOperationList} as a string query with {@link Sort}s applied in the right order. + * + * @param setOperationListStatement + * @param sort + * @return + */ + private String applySortingToSetOperationList(SetOperationList setOperationListStatement, Sort sort) { + + // special case: ValuesStatements are detected as nested OperationListStatements + if (setOperationListStatement.getSelects().stream().anyMatch(ValuesStatement.class::isInstance)) { + return setOperationListStatement.toString(); + } + + // if (CollectionUtils.isEmpty(setOperationListStatement.getOrderByElements())) { + if (setOperationListStatement.getOrderByElements() == null) { + setOperationListStatement.setOrderByElements(new ArrayList<>()); + } + + List orderByElements = sort.stream() // + .map(order -> getOrderClause(Collections.emptySet(), Collections.emptySet(), null, order)) // + .toList(); + setOperationListStatement.getOrderByElements().addAll(orderByElements); + + return setOperationListStatement.toString(); + } + /** * Returns the aliases used inside the selection part in the query. * @@ -175,7 +213,12 @@ private Set getJoinAliases(String query) { return new HashSet<>(); } - return getJoinAliases((PlainSelect) parseSelectStatement(query).getSelectBody()); + Select selectStatement = parseSelectStatement(query); + if (selectStatement.getSelectBody()instanceof PlainSelect selectBody) { + return getJoinAliases(selectBody); + } + + return new HashSet<>(); } /** @@ -259,6 +302,17 @@ private String detectAlias(String query) { } Select selectStatement = parseSelectStatement(query); + + /* + For all the other types ({@link ValuesStatement} and {@link SetOperationList}) it does not make sense to provide + alias since: + * ValuesStatement has no alias + * SetOperation can have multiple alias for each operation item + */ + if (!(selectStatement.getSelectBody() instanceof PlainSelect)) { + return null; + } + PlainSelect selectBody = (PlainSelect) selectStatement.getSelectBody(); return detectAlias(selectBody); } @@ -273,6 +327,10 @@ private String detectAlias(String query) { @Nullable private static String detectAlias(PlainSelect selectBody) { + if (selectBody.getFromItem() == null) { + return null; + } + Alias alias = selectBody.getFromItem().getAlias(); return alias == null ? null : alias.getName(); } @@ -287,6 +345,14 @@ public String createCountQueryFor(@Nullable String countProjection) { Assert.hasText(this.query.getQueryString(), "OriginalQuery must not be null or empty"); Select selectStatement = parseSelectStatement(this.query.getQueryString()); + + /* + We only support count queries for {@link PlainSelect}. + */ + if (!(selectStatement.getSelectBody() instanceof PlainSelect)) { + return this.query.getQueryString(); + } + PlainSelect selectBody = (PlainSelect) selectStatement.getSelectBody(); // remove order by @@ -322,8 +388,15 @@ public String createCountQueryFor(@Nullable String countProjection) { Function jSqlCount = getJSqlCount(Collections.singletonList(countProp), distinct); selectBody.setSelectItems(Collections.singletonList(new SelectExpressionItem(jSqlCount))); - return selectBody.toString(); + if (CollectionUtils.isEmpty(selectStatement.getWithItemsList())) { + return selectBody.toString(); + } + String withStatements = selectStatement.getWithItemsList().stream() // + .map(WithItem::toString) // + .collect(Collectors.joining(",")); + + return "with " + withStatements + "\n" + selectBody; } @Override @@ -336,9 +409,23 @@ public String getProjection() { Assert.hasText(query.getQueryString(), "Query must not be null or empty"); Select selectStatement = parseSelectStatement(query.getQueryString()); - PlainSelect selectBody = (PlainSelect) selectStatement.getSelectBody(); - return selectBody.getSelectItems() // + if (selectStatement.getSelectBody() instanceof ValuesStatement) { + return ""; + } + + SelectBody selectBody = selectStatement.getSelectBody(); + + if (selectStatement.getSelectBody()instanceof SetOperationList setOperationList) { + // using the first one since for setoperations the projection has to be the same + selectBody = setOperationList.getSelects().get(0); + + if (!(selectBody instanceof PlainSelect)) { + return ""; + } + } + + return ((PlainSelect) selectBody).getSelectItems() // .stream() // .map(Object::toString) // .collect(Collectors.joining(", ")).trim(); diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/UserRepositoryTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/UserRepositoryTests.java index 7bc48e30a6..25dcfcdf6e 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/UserRepositoryTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/UserRepositoryTests.java @@ -2902,6 +2902,72 @@ public void correctlyBuildSortClauseWhenSortingByFunctionAliasAndFunctionContain repository.findAllAndSortByFunctionResultNamedParameter("prefix", "suffix", Sort.by("idWithPrefixAndSuffix")); } + @Test // GH-2578 + void simpleNativeExceptTest() { + + flushTestUsers(); + + List foundIds = repository.findWithSimpleExceptNative(); + + assertThat(foundIds) // + .isNotEmpty() // + .contains("Oliver", "kevin"); + } + + @Test // GH-2578 + void simpleNativeUnionTest() { + + flushTestUsers(); + + List foundIds = repository.findWithSimpleUnionNative(); + + assertThat(foundIds) // + .isNotEmpty() // + .containsExactlyInAnyOrder("Dave", "Joachim", "Oliver", "kevin"); + } + + @Test // GH-2578 + void complexNativeExceptTest() { + + flushTestUsers(); + + List foundIds = repository.findWithComplexExceptNative(); + + assertThat(foundIds).containsExactly("Oliver", "kevin"); + } + + @Test // GH-2578 + void simpleValuesStatementNative() { + + flushTestUsers(); + + List foundIds = repository.valuesStatementNative(); + + assertThat(foundIds).containsExactly(1); + } + + @Test // GH-2578 + void withStatementNative() { + + flushTestUsers(); + + List foundData = repository.withNativeStatement(); + + assertThat(foundData) // + .map(User::getFirstname) // + .containsExactly("Joachim", "Dave", "kevin"); + } + + @Test // GH-2578 + void complexWithNativeStatement() { + + flushTestUsers(); + + List foundData = repository.complexWithNativeStatement(); + + assertThat(foundData).containsExactly("joachim", "dave", "kevin"); + } + private Page executeSpecWithSort(Sort sort) { flushTestUsers(); diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/QueryEnhancerUnitTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/QueryEnhancerUnitTests.java index 3572fa403e..c879468633 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/QueryEnhancerUnitTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/QueryEnhancerUnitTests.java @@ -687,8 +687,8 @@ void detectsJoinAliasesCorrectly(String queryString, List aliases) { void correctFunctionAliasWithComplexNestedFunctions() { String queryString = "\nSELECT \nCAST(('{' || string_agg(distinct array_to_string(c.institutes_ids, ','), ',') || '}') AS bigint[]) as institutesIds\nFROM\ncity c"; - StringQuery nativeQuery = new StringQuery(queryString, true); + StringQuery nativeQuery = new StringQuery(queryString, true); JSqlParserQueryEnhancer queryEnhancer = (JSqlParserQueryEnhancer) getEnhancer(nativeQuery); assertThat(queryEnhancer.getSelectionAliases()).contains("institutesIds"); @@ -696,6 +696,7 @@ void correctFunctionAliasWithComplexNestedFunctions() { @Test // GH-2441 void correctApplySortOnComplexNestedFunctionQuery() { + String queryString = "SELECT dd.institutesIds FROM (\n" // + " SELECT\n" // + " CAST(('{' || string_agg(distinct array_to_string(c.institutes_ids, ','), ',') || '}') AS bigint[]) as institutesIds\n" @@ -704,9 +705,7 @@ void correctApplySortOnComplexNestedFunctionQuery() { + " ) dd"; StringQuery nativeQuery = new StringQuery(queryString, true); - QueryEnhancer queryEnhancer = getEnhancer(nativeQuery); - String result = queryEnhancer.applySorting(Sort.by(new Sort.Order(Sort.Direction.ASC, "institutesIds"))); assertThat(result).containsIgnoringCase("order by dd.institutesIds"); @@ -716,6 +715,7 @@ void correctApplySortOnComplexNestedFunctionQuery() { void countQueryUsesCorrectVariable() { StringQuery nativeQuery = new StringQuery("SELECT * FROM User WHERE created_at > $1", true); + QueryEnhancer queryEnhancer = getEnhancer(nativeQuery); String countQueryFor = queryEnhancer.createCountQueryFor(); assertThat(countQueryFor).isEqualTo("SELECT count(*) FROM User WHERE created_at > $1"); @@ -751,6 +751,144 @@ void modifyingQueriesAreDetectedCorrectly() { assertThat(QueryEnhancerFactory.forQuery(modiQuery).createCountQueryFor()).isEqualToIgnoringCase(modifyingQuery); } + @Test // GH-2578 + void setOperationListWorksWithJSQLParser() { + + String setQuery = "select SOME_COLUMN from SOME_TABLE where REPORTING_DATE = :REPORTING_DATE \n" // + + "except \n" // + + "select SOME_COLUMN from SOME_OTHER_TABLE where REPORTING_DATE = :REPORTING_DATE"; + + StringQuery stringQuery = new StringQuery(setQuery, true); + QueryEnhancer queryEnhancer = QueryEnhancerFactory.forQuery(stringQuery); + + assertThat(stringQuery.getAlias()).isNullOrEmpty(); + assertThat(stringQuery.getProjection()).isEqualToIgnoringCase("SOME_COLUMN"); + assertThat(stringQuery.hasConstructorExpression()).isFalse(); + + assertThat(queryEnhancer.createCountQueryFor()).isEqualToIgnoringCase(setQuery); + assertThat(queryEnhancer.applySorting(Sort.by("SOME_COLUMN"))).endsWith("ORDER BY SOME_COLUMN ASC"); + assertThat(queryEnhancer.getJoinAliases()).isEmpty(); + assertThat(queryEnhancer.detectAlias()).isNullOrEmpty(); + assertThat(queryEnhancer.getProjection()).isEqualToIgnoringCase("SOME_COLUMN"); + assertThat(queryEnhancer.hasConstructorExpression()).isFalse(); + } + + @Test // GH-2578 + void complexSetOperationListWorksWithJSQLParser() { + + String setQuery = "select SOME_COLUMN from SOME_TABLE where REPORTING_DATE = :REPORTING_DATE \n" // + + "except \n" // + + "select SOME_COLUMN from SOME_OTHER_TABLE where REPORTING_DATE = :REPORTING_DATE \n" // + + "union select SOME_COLUMN from SOME_OTHER_OTHER_TABLE"; + + StringQuery stringQuery = new StringQuery(setQuery, true); + QueryEnhancer queryEnhancer = QueryEnhancerFactory.forQuery(stringQuery); + + assertThat(stringQuery.getAlias()).isNullOrEmpty(); + assertThat(stringQuery.getProjection()).isEqualToIgnoringCase("SOME_COLUMN"); + assertThat(stringQuery.hasConstructorExpression()).isFalse(); + + assertThat(queryEnhancer.createCountQueryFor()).isEqualToIgnoringCase(setQuery); + assertThat(queryEnhancer.applySorting(Sort.by("SOME_COLUMN").ascending())).endsWith("ORDER BY SOME_COLUMN ASC"); + assertThat(queryEnhancer.getJoinAliases()).isEmpty(); + assertThat(queryEnhancer.detectAlias()).isNullOrEmpty(); + assertThat(queryEnhancer.getProjection()).isEqualToIgnoringCase("SOME_COLUMN"); + assertThat(queryEnhancer.hasConstructorExpression()).isFalse(); + } + + @Test // GH-2578 + void deeplyNestedcomplexSetOperationListWorksWithJSQLParser() { + + String setQuery = "SELECT CustomerID FROM (\n" // + + "\t\t\tselect * from Customers\n" // + + "\t\t\texcept\n"// + + "\t\t\tselect * from Customers where country = 'Austria'\n"// + + "\t)\n" // + + "\texcept\n"// + + "\tselect CustomerID from customers where country = 'Germany'\n"// + + "\t;"; + + StringQuery stringQuery = new StringQuery(setQuery, true); + QueryEnhancer queryEnhancer = QueryEnhancerFactory.forQuery(stringQuery); + + assertThat(stringQuery.getAlias()).isNullOrEmpty(); + assertThat(stringQuery.getProjection()).isEqualToIgnoringCase("CustomerID"); + assertThat(stringQuery.hasConstructorExpression()).isFalse(); + + assertThat(queryEnhancer.createCountQueryFor()).isEqualToIgnoringCase(setQuery); + assertThat(queryEnhancer.applySorting(Sort.by("CustomerID").descending())).endsWith("ORDER BY CustomerID DESC"); + assertThat(queryEnhancer.getJoinAliases()).isEmpty(); + assertThat(queryEnhancer.detectAlias()).isNullOrEmpty(); + assertThat(queryEnhancer.getProjection()).isEqualToIgnoringCase("CustomerID"); + assertThat(queryEnhancer.hasConstructorExpression()).isFalse(); + } + + @Test // GH-2578 + void valuesStatementsWorksWithJSQLParser() { + + String setQuery = "VALUES (1, 2, 'test')"; + + StringQuery stringQuery = new StringQuery(setQuery, true); + QueryEnhancer queryEnhancer = QueryEnhancerFactory.forQuery(stringQuery); + + assertThat(stringQuery.getAlias()).isNullOrEmpty(); + assertThat(stringQuery.getProjection()).isNullOrEmpty(); + assertThat(stringQuery.hasConstructorExpression()).isFalse(); + + assertThat(queryEnhancer.createCountQueryFor()).isEqualToIgnoringCase(setQuery); + assertThat(queryEnhancer.applySorting(Sort.by("CustomerID").descending())).isEqualTo(setQuery); + assertThat(queryEnhancer.getJoinAliases()).isEmpty(); + assertThat(queryEnhancer.detectAlias()).isNullOrEmpty(); + assertThat(queryEnhancer.getProjection()).isNullOrEmpty(); + assertThat(queryEnhancer.hasConstructorExpression()).isFalse(); + } + + @Test // GH-2578 + void withStatementsWorksWithJSQLParser() { + + String setQuery = "with sample_data(day, value) as (values ((0, 13), (1, 12), (2, 15), (3, 4), (4, 8), (5, 16))) \n" + + "select day, value from sample_data as a"; + + StringQuery stringQuery = new StringQuery(setQuery, true); + QueryEnhancer queryEnhancer = QueryEnhancerFactory.forQuery(stringQuery); + + assertThat(stringQuery.getAlias()).isEqualToIgnoringCase("a"); + assertThat(stringQuery.getProjection()).isEqualToIgnoringCase("day, value"); + assertThat(stringQuery.hasConstructorExpression()).isFalse(); + + assertThat(queryEnhancer.createCountQueryFor()).isEqualToIgnoringCase( + "with sample_data (day, value) AS (VALUES ((0, 13), (1, 12), (2, 15), (3, 4), (4, 8), (5, 16)))\n" + + "SELECT count(a) FROM sample_data AS a"); + assertThat(queryEnhancer.applySorting(Sort.by("day").descending())).endsWith("ORDER BY a.day DESC"); + assertThat(queryEnhancer.getJoinAliases()).isEmpty(); + assertThat(queryEnhancer.detectAlias()).isEqualToIgnoringCase("a"); + assertThat(queryEnhancer.getProjection()).isEqualToIgnoringCase("day, value"); + assertThat(queryEnhancer.hasConstructorExpression()).isFalse(); + } + + @Test // GH-2578 + void multipleWithStatementsWorksWithJSQLParser() { + + String setQuery = "with sample_data(day, value) as (values ((0, 13), (1, 12), (2, 15), (3, 4), (4, 8), (5, 16))), test2 as (values (1,2,3)) \n" + + "select day, value from sample_data as a"; + + StringQuery stringQuery = new StringQuery(setQuery, true); + QueryEnhancer queryEnhancer = QueryEnhancerFactory.forQuery(stringQuery); + + assertThat(stringQuery.getAlias()).isEqualToIgnoringCase("a"); + assertThat(stringQuery.getProjection()).isEqualToIgnoringCase("day, value"); + assertThat(stringQuery.hasConstructorExpression()).isFalse(); + + assertThat(queryEnhancer.createCountQueryFor()).isEqualToIgnoringCase( + "with sample_data (day, value) AS (VALUES ((0, 13), (1, 12), (2, 15), (3, 4), (4, 8), (5, 16))),test2 AS (VALUES (1, 2, 3))\n" + + "SELECT count(a) FROM sample_data AS a"); + assertThat(queryEnhancer.applySorting(Sort.by("day").descending())).endsWith("ORDER BY a.day DESC"); + assertThat(queryEnhancer.getJoinAliases()).isEmpty(); + assertThat(queryEnhancer.detectAlias()).isEqualToIgnoringCase("a"); + assertThat(queryEnhancer.getProjection()).isEqualToIgnoringCase("day, value"); + assertThat(queryEnhancer.hasConstructorExpression()).isFalse(); + } + public static Stream detectsJoinAliasesCorrectlySource() { return Stream.of( // diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/sample/UserRepository.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/sample/UserRepository.java index c913ddd541..f9fd6b9865 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/sample/UserRepository.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/sample/UserRepository.java @@ -642,6 +642,40 @@ List findAllAndSortByFunctionResultNamedParameter(@Param("namedParameter @Query(value = "update SD_User u set u.active = false where u.id = :userId", nativeQuery = true) void setActiveToFalseWithModifyingNative(@Param("userId") int userId); + // GH-2578 + @Query(value = "SELECT u.firstname from SD_User u where u.age < 32 " // + + "except " // + + "SELECT u.firstname from SD_User u where u.age >= 32 ", nativeQuery = true) + List findWithSimpleExceptNative(); + + // GH-2578 + @Query(value = "SELECT u.firstname from SD_User u where u.age < 32 " // + + "union " // + + "SELECT u.firstname from SD_User u where u.age >= 32 ", nativeQuery = true) + List findWithSimpleUnionNative(); + + // GH-2578 + @Query(value = "SELECT u.firstname from (select * from SD_User u where u.age < 32) u " // + + "except " // + + "SELECT u.firstname from SD_User u where u.age >= 32 ", nativeQuery = true) + List findWithComplexExceptNative(); + + // GH-2578 + @Query(value = "VALUES (1)", nativeQuery = true) + List valuesStatementNative(); + + // GH-2578 + @Query(value = "with sample_data as ( Select * from SD_User u where u.age > 30 ) \n select * from sample_data", + nativeQuery = true) + List withNativeStatement(); + + // GH-2578 + @Query(value = "with sample_data as ( Select * from SD_User u where u.age > 30 ), \n " // + + "another as ( Select * from SD_User u) \n " // + + "select lower(s.firstname) as lowFirst from sample_data as s,another as a where s.firstname = a.firstname ", + nativeQuery = true) + List complexWithNativeStatement(); + interface RolesAndFirstname { String getFirstname();