diff --git a/src/main/java/org/jabref/logic/search/query/SearchToSqlConversion.java b/src/main/java/org/jabref/logic/search/query/SearchToSqlConversion.java index 5522dfef5ad..265b3cf55c5 100644 --- a/src/main/java/org/jabref/logic/search/query/SearchToSqlConversion.java +++ b/src/main/java/org/jabref/logic/search/query/SearchToSqlConversion.java @@ -4,8 +4,8 @@ import org.jabref.search.SearchLexer; import org.jabref.search.SearchParser; -import org.antlr.v4.runtime.ANTLRInputStream; import org.antlr.v4.runtime.BailErrorStrategy; +import org.antlr.v4.runtime.CharStreams; import org.antlr.v4.runtime.CommonTokenStream; public class SearchToSqlConversion { @@ -17,7 +17,7 @@ public static String searchToSql(String table, String searchExpression) { } private static SearchParser.StartContext getStartContext(String searchExpression) { - SearchLexer lexer = new SearchLexer(new ANTLRInputStream(searchExpression)); + SearchLexer lexer = new SearchLexer(CharStreams.fromString(searchExpression)); lexer.removeErrorListeners(); // no infos on file system lexer.addErrorListener(ThrowingErrorListener.INSTANCE); SearchParser parser = new SearchParser(new CommonTokenStream(lexer)); diff --git a/src/main/java/org/jabref/logic/search/query/SearchToSqlVisitor.java b/src/main/java/org/jabref/logic/search/query/SearchToSqlVisitor.java index c96e369e70a..522da052d82 100644 --- a/src/main/java/org/jabref/logic/search/query/SearchToSqlVisitor.java +++ b/src/main/java/org/jabref/logic/search/query/SearchToSqlVisitor.java @@ -1,6 +1,8 @@ package org.jabref.logic.search.query; +import java.util.ArrayList; import java.util.EnumSet; +import java.util.List; import java.util.Optional; import org.jabref.model.entry.field.InternalField; @@ -9,7 +11,6 @@ import org.jabref.search.SearchBaseVisitor; import org.jabref.search.SearchParser; -import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -24,8 +25,8 @@ public class SearchToSqlVisitor extends SearchBaseVisitor { private static final String MAIN_TABLE = "main_table"; private final String mainTableName; - - private int cteCount = 0; + private final List ctes = new ArrayList<>(); + private int cteCounter = 0; public SearchToSqlVisitor(String mainTableName) { this.mainTableName = mainTableName; @@ -40,55 +41,82 @@ private enum SearchTermFlag { @Override public String visitStart(SearchParser.StartContext ctx) { - String whereClause = getWhereClause(ctx); - String result; - result = """ - WITH %s - SELECT * FROM cte%s - """.formatted(whereClause, cteCount - 1); - - LOGGER.trace("Converted search query to SQL: {}", result); - return result; - } + String query = visit(ctx.expression()); - @VisibleForTesting - public String getWhereClause(SearchParser.StartContext ctx) { - return visit(ctx.expression()); - } + StringBuilder sql = new StringBuilder("WITH\n"); + for (String cte : ctes) { + sql.append(cte).append(",\n"); + } - @Override - public String visitUnaryExpression(SearchParser.UnaryExpressionContext ctx) { - return "NOT " + visit(ctx.expression()); + // Remove the last comma and newline + if (!ctes.isEmpty()) { + sql.setLength(sql.length() - 2); + } + + sql.append("SELECT * FROM ").append(query); + LOGGER.trace("Converted search query to SQL: {}", sql); + return sql.toString(); } @Override - public String visitParenExpression(SearchParser.ParenExpressionContext ctx) { - return visit(ctx.expression()); + public String visitUnaryExpression(SearchParser.UnaryExpressionContext ctx) { + String subQuery = visit(ctx.expression()); + String cte = """ + cte%d AS ( + SELECT %s + FROM %s + EXCEPT + SELECT %s + FROM "%s" + ) + """.formatted( + cteCounter, + PostgreConstants.ENTRY_ID, + subQuery, + PostgreConstants.ENTRY_ID, + mainTableName); + ctes.add(cte); + return "cte" + cteCounter++; } @Override public String visitBinaryExpression(SearchParser.BinaryExpressionContext ctx) { String left = visit(ctx.left); String right = visit(ctx.right); - return """ - %s, - %s, - cte%s AS ( + String operator = "AND".equalsIgnoreCase(ctx.operator.getText()) ? "INTERSECT" : "UNION"; + + String cte = """ + cte%d AS ( SELECT %s - FROM cte%s + FROM %s %s SELECT %s - FROM cte%s + FROM %s ) """.formatted( - left, - right, - cteCount++, + cteCounter, PostgreConstants.ENTRY_ID, - cteCount - 3, - "AND".equalsIgnoreCase(ctx.operator.getText()) ? "INTERSECT" : "UNION", + left, + operator, PostgreConstants.ENTRY_ID, - cteCount - 2); + right); + ctes.add(cte); + return "cte" + cteCounter++; + } + + @Override + public String visitParenExpression(SearchParser.ParenExpressionContext ctx) { + return visit(ctx.expression()); + } + + @Override + public String visitAtomExpression(SearchParser.AtomExpressionContext ctx) { + return visit(ctx.comparison()); + } + + @Override + public String visitName(SearchParser.NameContext ctx) { + return ctx.getText(); } @Override @@ -102,6 +130,7 @@ public String visitComparison(SearchParser.ComparisonContext context) { } Optional fieldDescriptor = Optional.ofNullable(context.left); + String cte; if (fieldDescriptor.isPresent()) { String field = fieldDescriptor.get().getText(); @@ -134,11 +163,13 @@ public String visitComparison(SearchParser.ComparisonContext context) { setFlags(searchFlags, SearchTermFlag.REGULAR_EXPRESSION, true, true); } - return getFieldQueryNode(field, right, searchFlags); + cte = getFieldQueryNode(field, right, searchFlags); } else { // Query without any field name - return getFieldQueryNode("any", right, EnumSet.of(SearchTermFlag.INEXACT_MATCH, SearchTermFlag.CASE_INSENSITIVE)); + cte = getFieldQueryNode("any", right, EnumSet.of(SearchTermFlag.INEXACT_MATCH, SearchTermFlag.CASE_INSENSITIVE)); } + ctes.add(cte); + return "cte" + cteCounter++; } private void setFlags(EnumSet flags, SearchTermFlag matchType, boolean caseSensitive, boolean negation) { @@ -163,7 +194,7 @@ private String getFieldQueryNode(String field, String term, EnumSet searchFlags) { private String buildFieldQuery(String field, String operator, String prefixSuffix, String term) { return """ - cte%s AS ( + cte%d AS ( SELECT %s.%s FROM "%s" AS %s WHERE (%s.%s = '%s') AND ((%s.%s %s '%s%s%s') OR (%s.%s %s '%s%s%s')) ) """.formatted( - cteCount++, + cteCounter, MAIN_TABLE, PostgreConstants.ENTRY_ID, mainTableName, MAIN_TABLE, MAIN_TABLE, PostgreConstants.FIELD_NAME, @@ -198,15 +229,15 @@ private String buildFieldQuery(String field, String operator, String prefixSuffi prefixSuffix, term, prefixSuffix); } - private String buildAnyFieldQuery(String operator, String prefixSuffix, String term) { + private String buildAllFieldsQuery(String operator, String prefixSuffix, String term) { return """ - cte%s AS ( + cte%d AS ( SELECT %s.%s FROM "%s" AS %s WHERE ((%s.%s %s '%s%s%s') OR (%s.%s %s '%s%s%s')) ) """.formatted( - cteCount++, + cteCounter, MAIN_TABLE, PostgreConstants.ENTRY_ID, mainTableName, MAIN_TABLE, MAIN_TABLE, PostgreConstants.FIELD_VALUE_LITERAL,