diff --git a/src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java b/src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java index b8f7e147bb..150218b07c 100644 --- a/src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java +++ b/src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java @@ -76,6 +76,7 @@ * @author Greg Turnquist * @author Diego Krupitza * @author Jędrzej Biedrzycki + * @author Darin Manica */ public abstract class QueryUtils { @@ -100,6 +101,8 @@ public abstract class QueryUtils { private static final Pattern ALIAS_MATCH; private static final Pattern COUNT_MATCH; + private static final Pattern STARTS_WITH_PAREN = Pattern.compile("^\\s*\\("); + private static final Pattern PARENS_TO_REMOVE = Pattern.compile("(\\(.*\\bfrom\\b[^)]+\\))", CASE_INSENSITIVE); private static final Pattern PROJECTION_CLAUSE = Pattern.compile("select\\s+(?:distinct\\s+)?(.+)\\s+from", Pattern.CASE_INSENSITIVE); @@ -431,13 +434,70 @@ private static String toJpaDirection(Order order) { @Deprecated public static String detectAlias(String query) { String alias = null; - Matcher matcher = ALIAS_MATCH.matcher(query); + Matcher matcher = ALIAS_MATCH.matcher(removeSubqueries(query)); while (matcher.find()) { alias = matcher.group(2); } return alias; } + /** + * Remove subqueries from the query, in order to identify the correct alias + * in order by clauses. If the entire query is surrounded by parenthesis, the + * outermost parenthesis are not removed. + * + * @param query + * @return query with all subqueries removed. + */ + static String removeSubqueries(String query) { + if (!StringUtils.hasText(query)) { + return query; + } + + final List opens = new ArrayList<>(); + final List closes = new ArrayList<>(); + final List closeMatches = new ArrayList<>(); + for (int i=0; i=(startsWithParen?1:0); i--) { + final Integer open = opens.get(i); + final Integer close = findClose(open, closes, closeMatches) + 1; + + + if (close > open) { + final String subquery = sb.substring(open, close); + final Matcher matcher = PARENS_TO_REMOVE.matcher(subquery); + if (matcher.find()) { + sb.replace(open, close, new String(new char[close-open]).replace('\0', ' ')); + } + } + } + + return sb.toString(); + } + + private static Integer findClose(final Integer open, final List closes, final List closeMatches) { + for (int i=0; i open && !closeMatches.get(i)) { + closeMatches.set(i, Boolean.TRUE); + return close; + } + } + + return -1; + } + /** * Creates a where-clause referencing the given entities and appends it to the given query string. Binds the given * entities to the query. diff --git a/src/test/java/org/springframework/data/jpa/repository/query/QueryUtilsUnitTests.java b/src/test/java/org/springframework/data/jpa/repository/query/QueryUtilsUnitTests.java index 1c29b39352..87e807dda3 100644 --- a/src/test/java/org/springframework/data/jpa/repository/query/QueryUtilsUnitTests.java +++ b/src/test/java/org/springframework/data/jpa/repository/query/QueryUtilsUnitTests.java @@ -20,6 +20,8 @@ import java.util.Collections; import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.assertj.core.api.SoftAssertions; import org.junit.jupiter.api.Test; @@ -27,6 +29,7 @@ import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort.Order; import org.springframework.data.jpa.domain.JpaSort; +import org.springframework.util.StringUtils; /** * Unit test for {@link QueryUtils}. @@ -41,6 +44,7 @@ * @author Mohammad Hewedy * @author Greg Turnquist * @author Jędrzej Biedrzycki + * @author Darin Manica */ class QueryUtilsUnitTests { @@ -50,6 +54,7 @@ class QueryUtilsUnitTests { private static final String COUNT_QUERY = "select count(u) from User u"; private static final String QUERY_WITH_AS = "select u from User as u where u.username = ?"; + private static final Pattern MULTI_WHITESPACE = Pattern.compile("\\s+"); @Test void createsCountQueryCorrectly() throws Exception { @@ -104,7 +109,7 @@ void allowsShortJpaSyntax() throws Exception { assertCountQuery(SIMPLE_QUERY, COUNT_QUERY); } - @Test + @Test // GH-2260 void detectsAliasCorrectly() throws Exception { assertThat(detectAlias(QUERY)).isEqualTo("u"); @@ -115,6 +120,47 @@ void detectsAliasCorrectly() throws Exception { assertThat(detectAlias("select u from User u")).isEqualTo("u"); assertThat(detectAlias("select u from com.acme.User u")).isEqualTo("u"); assertThat(detectAlias("select u from T05User u")).isEqualTo("u"); + assertThat(detectAlias("select u from User u where not exists (from User u2)")).isEqualTo("u"); + assertThat(detectAlias("(select u from User u where not exists (from User u2))")).isEqualTo("u"); + assertThat(detectAlias("(select u from User u where not exists ((from User u2 where not exists (from User u3))))")).isEqualTo("u"); + assertThat(detectAlias("from Foo f left join f.bar b with type(b) = BarChild where (f.id = (select max(f.id) from Foo f2 where type(f2) = FooChild) or 1 <> 1) and 1=1")).isEqualTo("f"); + assertThat(detectAlias("(from Foo f max(f) ((((select * from Foo f2 (from Foo f3) max(*)) (from Foo f4)) max(f5)) (f6)) (from Foo f7))")).isEqualTo("f"); + } + + @Test // GH-2260 + void testRemoveSubqueries() throws Exception { + // boundary conditions + assertThat(removeSubqueries(null)).isNull(); + assertThat(removeSubqueries("")).isEmpty(); + assertThat(removeSubqueries(" ")).isEqualTo(" "); + assertThat(removeSubqueries("(")).isEqualTo("("); + assertThat(removeSubqueries(")")).isEqualTo(")"); + assertThat(removeSubqueries("(()")).isEqualTo("(()"); + assertThat(removeSubqueries("())")).isEqualTo("())"); + + // realistic conditions + assertThat(removeSubqueries(QUERY)).isEqualTo(QUERY); + assertThat(removeSubqueries(SIMPLE_QUERY)).isEqualTo(SIMPLE_QUERY); + assertThat(removeSubqueries(COUNT_QUERY)).isEqualTo(COUNT_QUERY); + assertThat(removeSubqueries(QUERY_WITH_AS)).isEqualTo(QUERY_WITH_AS); + assertThat(removeSubqueries("SELECT FROM USER U")).isEqualTo("SELECT FROM USER U"); + assertThat(removeSubqueries("select u from User u")).isEqualTo("select u from User u"); + assertThat(removeSubqueries("select u from com.acme.User u")).isEqualTo("select u from com.acme.User u"); + assertThat(removeSubqueries("select u from T05User u")).isEqualTo("select u from T05User u"); + assertThat(normalizeWhitespace(removeSubqueries("select u from User u where not exists (from User u2)"))).isEqualTo("select u from User u where not exists"); + assertThat(normalizeWhitespace(removeSubqueries("(select u from User u where not exists (from User u2))"))).isEqualTo("(select u from User u where not exists )"); + assertThat(normalizeWhitespace(removeSubqueries("select u from User u where not exists (from User u2 where not exists (from User u3))"))).isEqualTo("select u from User u where not exists"); + assertThat(normalizeWhitespace(removeSubqueries("select u from User u where not exists ((from User u2 where not exists (from User u3)))"))).isEqualTo("select u from User u where not exists ( )"); + assertThat(normalizeWhitespace(removeSubqueries("(select u from User u where not exists ((from User u2 where not exists (from User u3))))"))).isEqualTo("(select u from User u where not exists ( ))"); + } + + private String normalizeWhitespace(String s) { + Matcher matcher = MULTI_WHITESPACE.matcher(s); + if (matcher.find()) { + return matcher.replaceAll(" ").trim(); + } + + return StringUtils.trimWhitespace(s); } @Test