Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.RelationSubquery;
import org.opensearch.sql.ast.tree.Rename;
import org.opensearch.sql.ast.tree.Replace;
import org.opensearch.sql.ast.tree.Reverse;
import org.opensearch.sql.ast.tree.Rex;
import org.opensearch.sql.ast.tree.Sort;
Expand Down Expand Up @@ -775,6 +776,11 @@ public LogicalPlan visitCloseCursor(CloseCursor closeCursor, AnalysisContext con
return new LogicalCloseCursor(closeCursor.getChild().get(0).accept(this, context));
}

@Override
public LogicalPlan visitReplace(Replace node, AnalysisContext context) {
throw getOnlyForCalciteException("Replace");
}

@Override
public LogicalPlan visitJoin(Join node, AnalysisContext context) {
throw getOnlyForCalciteException("Join");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.RelationSubquery;
import org.opensearch.sql.ast.tree.Rename;
import org.opensearch.sql.ast.tree.Replace;
import org.opensearch.sql.ast.tree.Reverse;
import org.opensearch.sql.ast.tree.Rex;
import org.opensearch.sql.ast.tree.SPath;
Expand Down Expand Up @@ -239,6 +240,10 @@ public T visitRename(Rename node, C context) {
return visitChildren(node, context);
}

public T visitReplace(Replace node, C context) {
return visitChildren(node, context);
}

public T visitEval(Eval node, C context) {
return visitChildren(node, context);
}
Expand Down
91 changes: 91 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Replace.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = false)
public class Replace extends UnresolvedPlan {
private final UnresolvedExpression pattern;
private final UnresolvedExpression replacement;
private final List<Field> fieldList;
private UnresolvedPlan child;

public Replace(
UnresolvedExpression pattern, UnresolvedExpression replacement, List<Field> fieldList) {
this.pattern = pattern;
this.replacement = replacement;
this.fieldList = fieldList;
validate();
}

private void validate() {
if (pattern == null) {
throw new IllegalArgumentException("Pattern expression cannot be null in Replace command");
}
if (replacement == null) {
throw new IllegalArgumentException(
"Replacement expression cannot be null in Replace command");
}

// Validate pattern is a string literal
if (!(pattern instanceof Literal && ((Literal) pattern).getType() == DataType.STRING)) {
throw new IllegalArgumentException("Pattern must be a string literal in Replace command");
}

// Validate replacement is a string literal
if (!(replacement instanceof Literal && ((Literal) replacement).getType() == DataType.STRING)) {
throw new IllegalArgumentException("Replacement must be a string literal in Replace command");
}

if (fieldList == null || fieldList.isEmpty()) {
throw new IllegalArgumentException(
"Field list cannot be empty in Replace command. Use IN clause to specify the field.");
}
Set<String> uniqueFields = new HashSet<>();
List<String> duplicates =
fieldList.stream()
.map(field -> field.getField().toString())
.filter(fieldName -> !uniqueFields.add(fieldName))
.collect(Collectors.toList());

if (!duplicates.isEmpty()) {
throw new IllegalArgumentException(
String.format("Duplicate fields [%s] in Replace command", String.join(", ", duplicates)));
}
}

@Override
public Replace attach(UnresolvedPlan child) {
if (null == this.child) {
this.child = child;
} else {
this.child.attach(child);
}
return this;
}

@Override
public List<UnresolvedPlan> getChild() {
return this.child == null ? ImmutableList.of() : ImmutableList.of(this.child);
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitReplace(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.sql.fun.SqlLibraryOperators;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
Expand Down Expand Up @@ -109,6 +110,7 @@
import org.opensearch.sql.ast.tree.Regex;
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.Rename;
import org.opensearch.sql.ast.tree.Replace;
import org.opensearch.sql.ast.tree.Rex;
import org.opensearch.sql.ast.tree.SPath;
import org.opensearch.sql.ast.tree.Sort;
Expand All @@ -135,11 +137,13 @@
import org.opensearch.sql.expression.parse.RegexCommonUtils;
import org.opensearch.sql.utils.ParseUtils;
import org.opensearch.sql.utils.WildcardRenameUtils;
import org.opensearch.sql.utils.WildcardReplaceUtils;

public class CalciteRelNodeVisitor extends AbstractNodeVisitor<RelNode, CalcitePlanContext> {

private final CalciteRexNodeVisitor rexVisitor;
private final CalciteAggCallVisitor aggVisitor;
private static final String NEW_FIELD_PREFIX = "new_";

public CalciteRelNodeVisitor() {
this.rexVisitor = new CalciteRexNodeVisitor(this);
Expand Down Expand Up @@ -2136,6 +2140,62 @@ public RelNode visitValues(Values values, CalcitePlanContext context) {
}
}

@Override
public RelNode visitReplace(Replace node, CalcitePlanContext context) {
visitChildren(node, context);
String pattern = ((Literal) node.getPattern()).getValue().toString();
String replacement = ((Literal) node.getReplacement()).getValue().toString();

// Remove quotes if present
pattern = pattern.replaceAll("^[\"']|[\"']$", "");
replacement = replacement.replaceAll("^[\"']|[\"']$", "");

// Validate patterns only if wildcards are present
if (WildcardRenameUtils.isWildcardPattern(pattern)
|| WildcardRenameUtils.isWildcardPattern(replacement)) {
WildcardReplaceUtils.validatePatterns(pattern, replacement);
}

List<RexNode> projectList = new ArrayList<>();
List<String> newFieldNames = new ArrayList<>();
// Add original fields
for (String fieldName : context.relBuilder.peek().getRowType().getFieldNames()) {
projectList.add(context.relBuilder.field(fieldName));
newFieldNames.add(fieldName);
}
// Process fields for replacement
for (Field field : node.getFieldList()) {
String fieldName = field.getField().toString();
RexNode fieldRef = context.relBuilder.field(fieldName);
if (WildcardRenameUtils.isWildcardPattern(pattern)
|| WildcardRenameUtils.isWildcardPattern(replacement)) {
String regexPattern = WildcardReplaceUtils.convertToRegexPattern(pattern);
String regexReplacement = WildcardReplaceUtils.convertToRegexReplacement(replacement);
// Use REGEXP_REPLACE for wildcard patterns
RexNode replaceCall =
context.relBuilder.call(
SqlLibraryOperators.REGEXP_REPLACE_3,
fieldRef,
context.relBuilder.literal(regexPattern),
context.relBuilder.literal(regexReplacement));
projectList.add(replaceCall);
} else {
System.out.println("Using REPLACE");
// Use standard REPLACE for non-wildcard patterns
RexNode replaceCall =
context.relBuilder.call(
SqlStdOperatorTable.REPLACE,
fieldRef,
context.relBuilder.literal(pattern),
context.relBuilder.literal(replacement));
projectList.add(replaceCall);
}
newFieldNames.add(NEW_FIELD_PREFIX + fieldName);
}
context.relBuilder.project(projectList, newFieldNames);
return context.relBuilder.peek();
}

private void buildParseRelNode(Parse node, CalcitePlanContext context) {
RexNode sourceField = rexVisitor.analyze(node.getSourceField(), context);
ParseMethod parseMethod = node.getParseMethod();
Expand Down
111 changes: 111 additions & 0 deletions core/src/main/java/org/opensearch/sql/utils/WildcardReplaceUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package org.opensearch.sql.utils;

import java.util.regex.Pattern;

/** Utility class for handling wildcard patterns in replace operations. */
public class WildcardReplaceUtils {

/**
* Convert wildcard pattern to regex pattern for REGEXP_REPLACE.
*
* @param pattern Pattern that may contain wildcards
* @return Regex pattern
*/
public static String convertToRegexPattern(String pattern) {
if (pattern == null || pattern.isEmpty()) {
return pattern;
}

// If not a wildcard pattern, return as is
if (!WildcardRenameUtils.isWildcardPattern(pattern)) {
return pattern;
}

// Check for consecutive wildcards before any substring operations
if (pattern.matches(".*\\*{2,}.*")) {
throw new IllegalArgumentException("Consecutive wildcards are not supported");
}

// Handle single wildcard pattern
if (pattern.equals("*")) {
return "(.*)";
}

// Handle different wildcard positions
if (pattern.startsWith("*") && pattern.endsWith("*")) {
// *abc* -> Pattern matches 'abc' anywhere
String middle = pattern.substring(1, pattern.length() - 1);
return "(.*)" + Pattern.quote(middle) + "(.*)";
} else if (pattern.startsWith("*")) {
// *abc -> Pattern matches 'abc' at end
String end = pattern.substring(1);
return "(.*)" + Pattern.quote(end) + "$";
} else if (pattern.endsWith("*")) {
// abc* -> Pattern matches 'abc' at start with explicit capture group
String start = pattern.substring(0, pattern.length() - 1);
return "^" + Pattern.quote(start) + "(.*)"; // Explicitly create capture group
}
return pattern;
}

/**
* Convert wildcard replacement to regex replacement. Converts * to corresponding regex group
* references ($1, $2, etc.)
*
* @param replacement Replacement pattern with wildcards
* @return Regex replacement string
*/
public static String convertToRegexReplacement(String replacement) {
if (!WildcardRenameUtils.isWildcardPattern(replacement)) {
return replacement;
}
if (replacement.startsWith("*") && replacement.endsWith("*")) {
// *XYZ* -> Replacement with both prefix and suffix captured content
String middle = replacement.substring(1, replacement.length() - 1);
return "$1" + middle + "$2";
} else if (replacement.startsWith("*")) {
// *XYZ -> Replacement with prefix captured content
String end = replacement.substring(1);
return "$1" + end;
} else if (replacement.endsWith("*")) {
// XYZ* -> Replacement with suffix captured content
String start = replacement.substring(0, replacement.length() - 1);
return start + "$1";
}
return replacement;
}

/**
* Validate wildcard patterns compatibility.
*
* @param pattern Source pattern
* @param replacement Replacement pattern
* @throws IllegalArgumentException if patterns are invalid
*/
public static void validatePatterns(String pattern, String replacement) {
if (WildcardRenameUtils.isWildcardPattern(pattern)
|| WildcardRenameUtils.isWildcardPattern(replacement)) {
if (pattern.matches(".*\\*{2,}.*") || replacement.matches(".*\\*{2,}.*")) {
throw new IllegalArgumentException("Consecutive wildcards are not supported");
}
}

// If replacement has wildcard, pattern must have wildcard
if (WildcardRenameUtils.isWildcardPattern(replacement)
&& !WildcardRenameUtils.isWildcardPattern(pattern)) {
throw new IllegalArgumentException(
"If replacement contains wildcard, pattern must contain wildcard");
}

// Check if wildcard count matches
if (WildcardRenameUtils.isWildcardPattern(replacement)) {
long patternWildcards = pattern.chars().filter(ch -> ch == '*').count();
long replacementWildcards = replacement.chars().filter(ch -> ch == '*').count();

if (replacementWildcards > patternWildcards) {
throw new IllegalArgumentException(
"Number of wildcards in replacement cannot exceed number of wildcards in pattern");
}
}
}
}
Loading
Loading