Skip to content

Commit

Permalink
Fix CTEs sub-queries and grouping
Browse files Browse the repository at this point in the history
TODO: EXCAT_MATCH to search in split table
  • Loading branch information
LoayGhreeb committed Sep 28, 2024
1 parent 38c0265 commit e8a2836
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import org.jabref.search.SearchLexer;
import org.jabref.search.SearchParser;

import org.antlr.v4.runtime.ANTLRInputStream;
import org.antlr.v4.runtime.BailErrorStrategy;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;

public class SearchToSqlConversion {
Expand All @@ -17,7 +17,7 @@ public static String searchToSql(String table, String searchExpression) {
}

private static SearchParser.StartContext getStartContext(String searchExpression) {
SearchLexer lexer = new SearchLexer(new ANTLRInputStream(searchExpression));
SearchLexer lexer = new SearchLexer(CharStreams.fromString(searchExpression));
lexer.removeErrorListeners(); // no infos on file system
lexer.addErrorListener(ThrowingErrorListener.INSTANCE);
SearchParser parser = new SearchParser(new CommonTokenStream(lexer));
Expand Down
115 changes: 73 additions & 42 deletions src/main/java/org/jabref/logic/search/query/SearchToSqlVisitor.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package org.jabref.logic.search.query;

import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;

import org.jabref.model.entry.field.InternalField;
Expand All @@ -9,7 +11,6 @@
import org.jabref.search.SearchBaseVisitor;
import org.jabref.search.SearchParser;

import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -24,8 +25,8 @@ public class SearchToSqlVisitor extends SearchBaseVisitor<String> {
private static final String MAIN_TABLE = "main_table";

private final String mainTableName;

private int cteCount = 0;
private final List<String> ctes = new ArrayList<>();
private int cteCounter = 0;

public SearchToSqlVisitor(String mainTableName) {
this.mainTableName = mainTableName;
Expand All @@ -40,55 +41,82 @@ private enum SearchTermFlag {

@Override
public String visitStart(SearchParser.StartContext ctx) {
String whereClause = getWhereClause(ctx);
String result;
result = """
WITH %s
SELECT * FROM cte%s
""".formatted(whereClause, cteCount - 1);

LOGGER.trace("Converted search query to SQL: {}", result);
return result;
}
String query = visit(ctx.expression());

@VisibleForTesting
public String getWhereClause(SearchParser.StartContext ctx) {
return visit(ctx.expression());
}
StringBuilder sql = new StringBuilder("WITH\n");
for (String cte : ctes) {
sql.append(cte).append(",\n");
}

@Override
public String visitUnaryExpression(SearchParser.UnaryExpressionContext ctx) {
return "NOT " + visit(ctx.expression());
// Remove the last comma and newline
if (!ctes.isEmpty()) {
sql.setLength(sql.length() - 2);
}

sql.append("SELECT * FROM ").append(query);
LOGGER.trace("Converted search query to SQL: {}", sql);
return sql.toString();
}

@Override
public String visitParenExpression(SearchParser.ParenExpressionContext ctx) {
return visit(ctx.expression());
public String visitUnaryExpression(SearchParser.UnaryExpressionContext ctx) {
String subQuery = visit(ctx.expression());
String cte = """
cte%d AS (
SELECT %s
FROM %s
EXCEPT
SELECT %s
FROM "%s"
)
""".formatted(
cteCounter,
PostgreConstants.ENTRY_ID,
subQuery,
PostgreConstants.ENTRY_ID,
mainTableName);
ctes.add(cte);
return "cte" + cteCounter++;
}

@Override
public String visitBinaryExpression(SearchParser.BinaryExpressionContext ctx) {
String left = visit(ctx.left);
String right = visit(ctx.right);
return """
%s,
%s,
cte%s AS (
String operator = "AND".equalsIgnoreCase(ctx.operator.getText()) ? "INTERSECT" : "UNION";

String cte = """
cte%d AS (
SELECT %s
FROM cte%s
FROM %s
%s
SELECT %s
FROM cte%s
FROM %s
)
""".formatted(
left,
right,
cteCount++,
cteCounter,
PostgreConstants.ENTRY_ID,
cteCount - 3,
"AND".equalsIgnoreCase(ctx.operator.getText()) ? "INTERSECT" : "UNION",
left,
operator,
PostgreConstants.ENTRY_ID,
cteCount - 2);
right);
ctes.add(cte);
return "cte" + cteCounter++;
}

@Override
public String visitParenExpression(SearchParser.ParenExpressionContext ctx) {
return visit(ctx.expression());
}

@Override
public String visitAtomExpression(SearchParser.AtomExpressionContext ctx) {
return visit(ctx.comparison());
}

@Override
public String visitName(SearchParser.NameContext ctx) {
return ctx.getText();
}

@Override
Expand All @@ -102,6 +130,7 @@ public String visitComparison(SearchParser.ComparisonContext context) {
}

Optional<SearchParser.NameContext> fieldDescriptor = Optional.ofNullable(context.left);
String cte;
if (fieldDescriptor.isPresent()) {
String field = fieldDescriptor.get().getText();

Expand Down Expand Up @@ -134,11 +163,13 @@ public String visitComparison(SearchParser.ComparisonContext context) {
setFlags(searchFlags, SearchTermFlag.REGULAR_EXPRESSION, true, true);
}

return getFieldQueryNode(field, right, searchFlags);
cte = getFieldQueryNode(field, right, searchFlags);
} else {
// Query without any field name
return getFieldQueryNode("any", right, EnumSet.of(SearchTermFlag.INEXACT_MATCH, SearchTermFlag.CASE_INSENSITIVE));
cte = getFieldQueryNode("any", right, EnumSet.of(SearchTermFlag.INEXACT_MATCH, SearchTermFlag.CASE_INSENSITIVE));
}
ctes.add(cte);
return "cte" + cteCounter++;
}

private void setFlags(EnumSet<SearchTermFlag> flags, SearchTermFlag matchType, boolean caseSensitive, boolean negation) {
Expand All @@ -163,7 +194,7 @@ private String getFieldQueryNode(String field, String term, EnumSet<SearchTermFl
};

if ("anyfield".equals(field) || "any".equals(field)) {
whereClause.append(buildAnyFieldQuery(operator, prefixSuffix, term));
whereClause.append(buildAllFieldsQuery(operator, prefixSuffix, term));
} else {
whereClause.append(buildFieldQuery(field, operator, prefixSuffix, term));
}
Expand All @@ -179,13 +210,13 @@ private static String getOperator(EnumSet<SearchTermFlag> searchFlags) {

private String buildFieldQuery(String field, String operator, String prefixSuffix, String term) {
return """
cte%s AS (
cte%d AS (
SELECT %s.%s
FROM "%s" AS %s
WHERE (%s.%s = '%s') AND ((%s.%s %s '%s%s%s') OR (%s.%s %s '%s%s%s'))
)
""".formatted(
cteCount++,
cteCounter,
MAIN_TABLE, PostgreConstants.ENTRY_ID,
mainTableName, MAIN_TABLE,
MAIN_TABLE, PostgreConstants.FIELD_NAME,
Expand All @@ -198,15 +229,15 @@ private String buildFieldQuery(String field, String operator, String prefixSuffi
prefixSuffix, term, prefixSuffix);
}

private String buildAnyFieldQuery(String operator, String prefixSuffix, String term) {
private String buildAllFieldsQuery(String operator, String prefixSuffix, String term) {
return """
cte%s AS (
cte%d AS (
SELECT %s.%s
FROM "%s" AS %s
WHERE ((%s.%s %s '%s%s%s') OR (%s.%s %s '%s%s%s'))
)
""".formatted(
cteCount++,
cteCounter,
MAIN_TABLE, PostgreConstants.ENTRY_ID,
mainTableName, MAIN_TABLE,
MAIN_TABLE, PostgreConstants.FIELD_VALUE_LITERAL,
Expand Down

0 comments on commit e8a2836

Please sign in to comment.