diff --git a/src/main/java/org/springframework/data/jpa/repository/query/DefaultQueryEnhancer.java b/src/main/java/org/springframework/data/jpa/repository/query/DefaultQueryEnhancer.java index 53a07bf6f9..92387c973e 100644 --- a/src/main/java/org/springframework/data/jpa/repository/query/DefaultQueryEnhancer.java +++ b/src/main/java/org/springframework/data/jpa/repository/query/DefaultQueryEnhancer.java @@ -46,7 +46,7 @@ public String detectAlias() { @Override public String createCountQueryFor(@Nullable String countProjection) { - return QueryUtils.createCountQueryFor(this.query.getQueryString(), countProjection); + return QueryUtils.createCountQueryFor(this.query.getQueryString(), countProjection, this.query.isNativeQuery()); } @Override diff --git a/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java b/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java index 64b95d569d..cd9fc61d79 100644 --- a/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java +++ b/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java @@ -15,9 +15,8 @@ */ package org.springframework.data.jpa.repository.query; -import static org.springframework.data.jpa.repository.query.JSqlParserUtils.getJSqlCount; -import static org.springframework.data.jpa.repository.query.JSqlParserUtils.getJSqlLower; -import static org.springframework.data.jpa.repository.query.QueryUtils.checkSortExpression; +import static org.springframework.data.jpa.repository.query.JSqlParserUtils.*; +import static org.springframework.data.jpa.repository.query.QueryUtils.*; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Alias; @@ -414,7 +413,7 @@ public String createCountQueryFor(@Nullable String countProjection) { return selectBody.toString(); } - String countProp = tableAlias == null ? "*" : tableAlias; + String countProp = query.isNativeQuery() ? (distinct ? "*" : "1") : tableAlias == null ? "*" : tableAlias; Function jSqlCount = getJSqlCount(Collections.singletonList(countProp), distinct); selectBody.setSelectItems(Collections.singletonList(new SelectExpressionItem(jSqlCount))); diff --git a/src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java b/src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java index 895c7c5ff8..69d4ec7114 100644 --- a/src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java +++ b/src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java @@ -570,6 +570,19 @@ public static String createCountQueryFor(String originalQuery) { */ @Deprecated public static String createCountQueryFor(String originalQuery, @Nullable String countProjection) { + return createCountQueryFor(originalQuery, countProjection, false); + } + + /** + * Creates a count projected query from the given original query. + * + * @param originalQuery must not be {@literal null}. + * @param countProjection may be {@literal null}. + * @param nativeQuery whether the underlying query is a native query. + * @return a query String to be used a count query for pagination. Guaranteed to be not {@literal null}. + * @since 2.7.8 + */ + static String createCountQueryFor(String originalQuery, @Nullable String countProjection, boolean nativeQuery) { Assert.hasText(originalQuery, "OriginalQuery must not be null or empty!"); @@ -591,9 +604,14 @@ public static String createCountQueryFor(String originalQuery, @Nullable String String replacement = useVariable ? SIMPLE_COUNT_VALUE : complexCountValue; - String alias = QueryUtils.detectAlias(originalQuery); - if ("*".equals(variable) && alias != null) { - replacement = alias; + if (nativeQuery && (variable.contains(",") || "*".equals(variable))) { + replacement = "1"; + } else { + + String alias = QueryUtils.detectAlias(originalQuery); + if (("*".equals(variable) && alias != null)) { + replacement = alias; + } } countQuery = matcher.replaceFirst(String.format(COUNT_REPLACEMENT_TEMPLATE, replacement)); diff --git a/src/test/java/org/springframework/data/jpa/repository/query/QueryEnhancerUnitTests.java b/src/test/java/org/springframework/data/jpa/repository/query/QueryEnhancerUnitTests.java index 9128c7d3dc..2a5e73e6fd 100644 --- a/src/test/java/org/springframework/data/jpa/repository/query/QueryEnhancerUnitTests.java +++ b/src/test/java/org/springframework/data/jpa/repository/query/QueryEnhancerUnitTests.java @@ -228,7 +228,12 @@ void createsCountQueryForNestedReferenceCorrectly() { @Test // DATAJPA-420 void createsCountQueryForScalarSelects() { - assertCountQuery("select p.lastname,p.firstname from Person p", "select count(p) from Person p", true); + assertCountQuery("select p.lastname,p.firstname from Person p", "select count(p) from Person p", false); + } + + @Test // DATAJPA-420 + void createsCountQueryForNativeScalarSelects() { + assertCountQuery("select p.lastname,p.firstname from Person p", "select count(1) from Person p", true); } @Test // DATAJPA-456 @@ -487,7 +492,7 @@ void createCountQuerySupportsWhitespaceCharacters() { " order by user.name\n ", true); assertThat(getEnhancer(query).createCountQueryFor()) - .isEqualToIgnoringCase("select count(user) from User user where user.age = 18"); + .isEqualToIgnoringCase("select count(1) from User user where user.age = 18"); } @Test @@ -500,7 +505,7 @@ void createCountQuerySupportsLineBreaksInSelectClause() { " order\nby\nuser.name\n ", true); assertThat(getEnhancer(query).createCountQueryFor()) - .isEqualToIgnoringCase("select count(user) from User user where user.age = 18"); + .isEqualToIgnoringCase("select count(1) from User user where user.age = 18"); } @Test // DATAJPA-1061 @@ -721,17 +726,17 @@ void countQueryUsesCorrectVariable() { QueryEnhancer queryEnhancer = getEnhancer(nativeQuery); String countQueryFor = queryEnhancer.createCountQueryFor(); - assertThat(countQueryFor).isEqualTo("SELECT count(*) FROM User WHERE created_at > $1"); + assertThat(countQueryFor).isEqualTo("SELECT count(1) FROM User WHERE created_at > $1"); nativeQuery = new StringQuery("SELECT * FROM (select * from test) ", true); queryEnhancer = getEnhancer(nativeQuery); countQueryFor = queryEnhancer.createCountQueryFor(); - assertThat(countQueryFor).isEqualTo("SELECT count(*) FROM (SELECT * FROM test)"); + assertThat(countQueryFor).isEqualTo("SELECT count(1) FROM (SELECT * FROM test)"); nativeQuery = new StringQuery("SELECT * FROM (select * from test) as test", true); queryEnhancer = getEnhancer(nativeQuery); countQueryFor = queryEnhancer.createCountQueryFor(); - assertThat(countQueryFor).isEqualTo("SELECT count(test) FROM (SELECT * FROM test) AS test"); + assertThat(countQueryFor).isEqualTo("SELECT count(1) FROM (SELECT * FROM test) AS test"); } @Test // GH-2555 @@ -861,7 +866,7 @@ void withStatementsWorksWithJSQLParser() { assertThat(queryEnhancer.createCountQueryFor()).isEqualToIgnoringCase( "with sample_data (day, value) AS (VALUES ((0, 13), (1, 12), (2, 15), (3, 4), (4, 8), (5, 16)))\n" - + "SELECT count(a) FROM sample_data AS a"); + + "SELECT count(1) FROM sample_data AS a"); assertThat(queryEnhancer.applySorting(Sort.by("day").descending())).endsWith("ORDER BY a.day DESC"); assertThat(queryEnhancer.getJoinAliases()).isEmpty(); assertThat(queryEnhancer.detectAlias()).isEqualToIgnoringCase("a"); @@ -884,7 +889,7 @@ void multipleWithStatementsWorksWithJSQLParser() { assertThat(queryEnhancer.createCountQueryFor()).isEqualToIgnoringCase( "with sample_data (day, value) AS (VALUES ((0, 13), (1, 12), (2, 15), (3, 4), (4, 8), (5, 16))),test2 AS (VALUES (1, 2, 3))\n" - + "SELECT count(a) FROM sample_data AS a"); + + "SELECT count(1) FROM sample_data AS a"); assertThat(queryEnhancer.applySorting(Sort.by("day").descending())).endsWith("ORDER BY a.day DESC"); assertThat(queryEnhancer.getJoinAliases()).isEmpty(); assertThat(queryEnhancer.detectAlias()).isEqualToIgnoringCase("a"); @@ -989,4 +994,5 @@ private static void endsIgnoringCase(String original, String endWithIgnoreCase) private static QueryEnhancer getEnhancer(DeclaredQuery query) { return QueryEnhancerFactory.forQuery(query); } + }