Skip to content

Commit

Permalink
Attempt to use sub-queries with CTEs
Browse files Browse the repository at this point in the history
  • Loading branch information
LoayGhreeb committed Sep 28, 2024
1 parent d9f93c4 commit 38c0265
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import org.jabref.search.SearchLexer;
import org.jabref.search.SearchParser;

import com.google.common.annotations.VisibleForTesting;
import org.antlr.v4.runtime.ANTLRInputStream;
import org.antlr.v4.runtime.BailErrorStrategy;
import org.antlr.v4.runtime.CommonTokenStream;
Expand All @@ -27,11 +26,4 @@ private static SearchParser.StartContext getStartContext(String searchExpression
parser.setErrorHandler(new BailErrorStrategy()); // ParseCancellationException on parse errors
return parser.start();
}

@VisibleForTesting
public static String getWhereClause(String table, String searchExpression) {
SearchParser.StartContext context = getStartContext(searchExpression);
SearchToSqlVisitor searchToSqlVisitor = new SearchToSqlVisitor(table);
return searchToSqlVisitor.getWhereClause(context);
}
}
117 changes: 61 additions & 56 deletions src/main/java/org/jabref/logic/search/query/SearchToSqlVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ public class SearchToSqlVisitor extends SearchBaseVisitor<String> {

private static final Logger LOGGER = LoggerFactory.getLogger(SearchToSqlVisitor.class);
private static final String MAIN_TABLE = "main_table";
private static final String SPLIT_TABLE = "split_table";

private final String mainTableName;
private final String splitTableName;
private boolean isExactMatch = false;

private int cteCount = 0;

public SearchToSqlVisitor(String mainTableName) {
this.mainTableName = mainTableName;
this.splitTableName = mainTableName + PostgreConstants.TABLE_NAME_SUFFIX;
}

private enum SearchTermFlag {
Expand All @@ -44,35 +42,11 @@ private enum SearchTermFlag {
public String visitStart(SearchParser.StartContext ctx) {
String whereClause = getWhereClause(ctx);
String result;
result = """
WITH %s
SELECT * FROM cte%s
""".formatted(whereClause, cteCount - 1);

if (isExactMatch) {
result = """
SELECT %s.%s FROM "%s" AS %s
LEFT JOIN "%s" AS %s
ON (%s.%s = %s.%s AND %s.%s = %s.%s)
WHERE (%s)
GROUP BY %s.%s
""".formatted(
MAIN_TABLE, PostgreConstants.ENTRY_ID,
mainTableName, MAIN_TABLE,
splitTableName, SPLIT_TABLE,
MAIN_TABLE, PostgreConstants.ENTRY_ID,
SPLIT_TABLE, PostgreConstants.ENTRY_ID,
MAIN_TABLE, PostgreConstants.FIELD_NAME,
SPLIT_TABLE, PostgreConstants.FIELD_NAME,
whereClause,
MAIN_TABLE, PostgreConstants.ENTRY_ID);
} else {
result = """
SELECT %s.%s FROM "%s" AS %s
WHERE (%s)
GROUP BY %s.%s
""".formatted(
MAIN_TABLE, PostgreConstants.ENTRY_ID,
mainTableName, MAIN_TABLE,
whereClause,
MAIN_TABLE, PostgreConstants.ENTRY_ID);
}
LOGGER.trace("Converted search query to SQL: {}", result);
return result;
}
Expand All @@ -89,16 +63,32 @@ public String visitUnaryExpression(SearchParser.UnaryExpressionContext ctx) {

@Override
public String visitParenExpression(SearchParser.ParenExpressionContext ctx) {
return "(" + visit(ctx.expression()) + ")";
return visit(ctx.expression());
}

@Override
public String visitBinaryExpression(SearchParser.BinaryExpressionContext ctx) {
if ("AND".equalsIgnoreCase(ctx.operator.getText())) {
return visit(ctx.left) + " AND " + visit(ctx.right);
} else {
return visit(ctx.left) + " OR " + visit(ctx.right);
}
String left = visit(ctx.left);
String right = visit(ctx.right);
return """
%s,
%s,
cte%s AS (
SELECT %s
FROM cte%s
%s
SELECT %s
FROM cte%s
)
""".formatted(
left,
right,
cteCount++,
PostgreConstants.ENTRY_ID,
cteCount - 3,
"AND".equalsIgnoreCase(ctx.operator.getText()) ? "INTERSECT" : "UNION",
PostgreConstants.ENTRY_ID,
cteCount - 2);
}

@Override
Expand Down Expand Up @@ -152,7 +142,6 @@ public String visitComparison(SearchParser.ComparisonContext context) {
}

private void setFlags(EnumSet<SearchTermFlag> flags, SearchTermFlag matchType, boolean caseSensitive, boolean negation) {
isExactMatch |= matchType.equals(SearchTermFlag.EXACT_MATCH);
flags.add(matchType);

flags.add(caseSensitive ? SearchTermFlag.CASE_SENSITIVE : SearchTermFlag.CASE_INSENSITIVE);
Expand All @@ -161,7 +150,7 @@ private void setFlags(EnumSet<SearchTermFlag> flags, SearchTermFlag matchType, b
}
}

private static String getFieldQueryNode(String field, String term, EnumSet<SearchTermFlag> searchFlags) {
private String getFieldQueryNode(String field, String term, EnumSet<SearchTermFlag> searchFlags) {
StringBuilder whereClause = new StringBuilder();
String operator = getOperator(searchFlags);
String prefixSuffix = searchFlags.contains(SearchTermFlag.INEXACT_MATCH) ? "%" : "";
Expand All @@ -174,15 +163,9 @@ private static String getFieldQueryNode(String field, String term, EnumSet<Searc
};

if ("anyfield".equals(field) || "any".equals(field)) {
whereClause.append(buildTableQuery(MAIN_TABLE, operator, prefixSuffix, term));
if (searchFlags.contains(SearchTermFlag.EXACT_MATCH)) {
whereClause.append(" OR ").append(buildTableQuery(SPLIT_TABLE, operator, prefixSuffix, term));
}
whereClause.append(buildAnyFieldQuery(operator, prefixSuffix, term));
} else {
whereClause.append(buildFieldQuery(MAIN_TABLE, field, operator, prefixSuffix, term));
if (searchFlags.contains(SearchTermFlag.EXACT_MATCH)) {
whereClause.append(" OR ").append(buildFieldQuery(SPLIT_TABLE, field, operator, prefixSuffix, term));
}
whereClause.append(buildFieldQuery(field, operator, prefixSuffix, term));
}

return whereClause.toString();
Expand All @@ -194,21 +177,43 @@ private static String getOperator(EnumSet<SearchTermFlag> searchFlags) {
: (searchFlags.contains(SearchTermFlag.NEGATION) ? "NOT " : "") + (searchFlags.contains(SearchTermFlag.CASE_SENSITIVE) ? "LIKE" : "ILIKE");
}

private static String buildTableQuery(String tableName, String operator, String prefixSuffix, String term) {
private String buildFieldQuery(String field, String operator, String prefixSuffix, String term) {
return """
(%s.%s %s '%s%s%s') OR (%s.%s %s '%s%s%s')""".formatted(
tableName, PostgreConstants.FIELD_VALUE_LITERAL,
cte%s AS (
SELECT %s.%s
FROM "%s" AS %s
WHERE (%s.%s = '%s') AND ((%s.%s %s '%s%s%s') OR (%s.%s %s '%s%s%s'))
)
""".formatted(
cteCount++,
MAIN_TABLE, PostgreConstants.ENTRY_ID,
mainTableName, MAIN_TABLE,
MAIN_TABLE, PostgreConstants.FIELD_NAME,
field,
MAIN_TABLE, PostgreConstants.FIELD_VALUE_LITERAL,
operator,
prefixSuffix, term, prefixSuffix,
tableName, PostgreConstants.FIELD_VALUE_TRANSFORMED,
MAIN_TABLE, PostgreConstants.FIELD_VALUE_TRANSFORMED,
operator,
prefixSuffix, term, prefixSuffix);
}

private static String buildFieldQuery(String tableName, String field, String operator, String prefixSuffix, String term) {
private String buildAnyFieldQuery(String operator, String prefixSuffix, String term) {
return """
((%s.%s = '%s') AND (%s))""".formatted(
tableName, PostgreConstants.FIELD_NAME, field,
buildTableQuery(tableName, operator, prefixSuffix, term));
cte%s AS (
SELECT %s.%s
FROM "%s" AS %s
WHERE ((%s.%s %s '%s%s%s') OR (%s.%s %s '%s%s%s'))
)
""".formatted(
cteCount++,
MAIN_TABLE, PostgreConstants.ENTRY_ID,
mainTableName, MAIN_TABLE,
MAIN_TABLE, PostgreConstants.FIELD_VALUE_LITERAL,
operator,
prefixSuffix, term, prefixSuffix,
MAIN_TABLE, PostgreConstants.FIELD_VALUE_TRANSFORMED,
operator,
prefixSuffix, term, prefixSuffix);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
class SearchToSqlConversionTest {
@ParameterizedTest
@CsvSource({
"(), alex",
"(), author=alex",
"(), title=dino AND author=alex",
"(), author=computer AND editor=science OR title=math",
"(), (author=computer AND editor=science) OR title=math",
"(), author=computer AND (editor=science OR title=math)",
"(), (author=computer AND editor=science) OR (title=math AND year=2021)",
// case insensitive contains
"((main_table.field_name = 'title') AND ((main_table.field_value_literal ILIKE '%compute%') OR (main_table.field_value_transformed ILIKE '%compute%'))), title=compute",

Expand Down Expand Up @@ -79,6 +86,6 @@ class SearchToSqlConversionTest {
})

void conversion(String expectedWhereClause, String input) {
assertEquals(expectedWhereClause, SearchToSqlConversion.getWhereClause("tableName", input));
assertEquals(expectedWhereClause, SearchToSqlConversion.searchToSql("tableName", input));
}
}

0 comments on commit 38c0265

Please sign in to comment.