diff --git a/common/src/main/java/org/opensearch/sql/common/patterns/PatternAggregationHelpers.java b/common/src/main/java/org/opensearch/sql/common/patterns/PatternAggregationHelpers.java new file mode 100644 index 00000000000..8ede948649f --- /dev/null +++ b/common/src/main/java/org/opensearch/sql/common/patterns/PatternAggregationHelpers.java @@ -0,0 +1,504 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.patterns; + +import com.google.common.collect.ImmutableMap; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Static helper methods for pattern aggregation operations. These methods wrap the complex logic in + * BrainLogParser and PatternUtils to be callable from UDFs in scripted metric aggregations. + */ +public final class PatternAggregationHelpers { + + private PatternAggregationHelpers() { + // Utility class + } + + /** + * Initialize pattern accumulator state. + * + * @return Empty accumulator map with logMessages buffer and patternGroupMap + */ + public static Map initPatternAccumulator() { + Map acc = new HashMap<>(); + acc.put("logMessages", new ArrayList()); + acc.put("patternGroupMap", new HashMap>()); + return acc; + } + + /** + * Initialize pattern accumulator state in-place. This method is designed for OpenSearch scripted + * metric aggregation's init_script phase, where the state map is provided by OpenSearch and must + * be modified in-place rather than replaced. + * + * @param state The mutable state map provided by OpenSearch (will be modified in-place) + * @return The same state map (for chaining/return value) + */ + @SuppressWarnings("unchecked") + public static Map initPatternState(Object state) { + Map stateMap = (Map) state; + stateMap.put("logMessages", new ArrayList()); + stateMap.put("patternGroupMap", new HashMap>()); + return stateMap; + } + + /** + * Add a log message to the accumulator (overload for Object acc and int thresholdPercentage). + * This overload handles the case when the accumulator is passed as a generic Object and + * thresholdPercentage is passed as an integer at runtime (from the script engine). + * + * @param acc Current accumulator state (as Object, will be cast to Map) + * @param logMessage The log message to process + * @param maxSampleCount Maximum samples to keep per pattern + * @param bufferLimit Maximum buffer size before triggering partial merge + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as int) + * @return Updated accumulator + */ + @SuppressWarnings("unchecked") + public static Map addLogToPattern( + Object acc, + String logMessage, + int maxSampleCount, + int bufferLimit, + int variableCountThreshold, + int thresholdPercentage) { + return addLogToPattern( + (Map) acc, + logMessage, + maxSampleCount, + bufferLimit, + variableCountThreshold, + (double) thresholdPercentage); + } + + /** + * Add a log message to the accumulator (overload for Object acc and BigDecimal + * thresholdPercentage). This overload handles the case when the accumulator is passed as a + * generic Object and thresholdPercentage is passed as BigDecimal at runtime. + * + * @param acc Current accumulator state (as Object, will be cast to Map) + * @param logMessage The log message to process + * @param maxSampleCount Maximum samples to keep per pattern + * @param bufferLimit Maximum buffer size before triggering partial merge + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as BigDecimal) + * @return Updated accumulator + */ + @SuppressWarnings("unchecked") + public static Map addLogToPattern( + Object acc, + String logMessage, + int maxSampleCount, + int bufferLimit, + int variableCountThreshold, + java.math.BigDecimal thresholdPercentage) { + return addLogToPattern( + (Map) acc, + logMessage, + maxSampleCount, + bufferLimit, + variableCountThreshold, + thresholdPercentage != null + ? thresholdPercentage.doubleValue() + : BrainLogParser.DEFAULT_FREQUENCY_THRESHOLD_PERCENTAGE); + } + + /** + * Add a log message to the accumulator (overload for Object acc and double thresholdPercentage). + * This overload handles the case when the accumulator is passed as a generic Object and + * thresholdPercentage is passed as double at runtime. + * + * @param acc Current accumulator state (as Object, will be cast to Map) + * @param logMessage The log message to process + * @param maxSampleCount Maximum samples to keep per pattern + * @param bufferLimit Maximum buffer size before triggering partial merge + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as double) + * @return Updated accumulator + */ + @SuppressWarnings("unchecked") + public static Map addLogToPattern( + Object acc, + String logMessage, + int maxSampleCount, + int bufferLimit, + int variableCountThreshold, + double thresholdPercentage) { + return addLogToPattern( + (Map) acc, + logMessage, + maxSampleCount, + bufferLimit, + variableCountThreshold, + thresholdPercentage); + } + + /** + * Add a log message to the accumulator (overload for int thresholdPercentage). This overload + * handles the case when thresholdPercentage is passed as an integer at runtime. + * + * @param acc Current accumulator state + * @param logMessage The log message to process + * @param maxSampleCount Maximum samples to keep per pattern + * @param bufferLimit Maximum buffer size before triggering partial merge + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as int) + * @return Updated accumulator + */ + public static Map addLogToPattern( + Map acc, + String logMessage, + int maxSampleCount, + int bufferLimit, + int variableCountThreshold, + int thresholdPercentage) { + return addLogToPattern( + acc, + logMessage, + maxSampleCount, + bufferLimit, + variableCountThreshold, + (double) thresholdPercentage); + } + + /** + * Add a log message to the accumulator (overload for BigDecimal thresholdPercentage). This + * overload handles the case when thresholdPercentage is passed as BigDecimal at runtime. + * + * @param acc Current accumulator state + * @param logMessage The log message to process + * @param maxSampleCount Maximum samples to keep per pattern + * @param bufferLimit Maximum buffer size before triggering partial merge + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as BigDecimal) + * @return Updated accumulator + */ + public static Map addLogToPattern( + Map acc, + String logMessage, + int maxSampleCount, + int bufferLimit, + int variableCountThreshold, + java.math.BigDecimal thresholdPercentage) { + return addLogToPattern( + acc, + logMessage, + maxSampleCount, + bufferLimit, + variableCountThreshold, + thresholdPercentage != null + ? thresholdPercentage.doubleValue() + : BrainLogParser.DEFAULT_FREQUENCY_THRESHOLD_PERCENTAGE); + } + + /** + * Add a log message to the accumulator and trigger partial merge if buffer is full. + * + * @param acc Current accumulator state + * @param logMessage The log message to process + * @param maxSampleCount Maximum samples to keep per pattern + * @param bufferLimit Maximum buffer size before triggering partial merge + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage + * @return Updated accumulator + */ + @SuppressWarnings("unchecked") + public static Map addLogToPattern( + Map acc, + String logMessage, + int maxSampleCount, + int bufferLimit, + int variableCountThreshold, + double thresholdPercentage) { + + if (logMessage == null) { + return acc; + } + + List logMessages = (List) acc.get("logMessages"); + logMessages.add(logMessage); + + // Trigger partial merge when buffer reaches limit + if (bufferLimit > 0 && logMessages.size() >= bufferLimit) { + Map> patternGroupMap = + (Map>) acc.get("patternGroupMap"); + + BrainLogParser parser = + new BrainLogParser(variableCountThreshold, (float) thresholdPercentage); + Map> partialPatterns = + parser.parseAllLogPatterns(logMessages, maxSampleCount); + + patternGroupMap = + PatternUtils.mergePatternGroups(patternGroupMap, partialPatterns, maxSampleCount); + + acc.put("patternGroupMap", patternGroupMap); + logMessages.clear(); + } + + return acc; + } + + /** + * Combine two accumulators (for combine_script phase). + * + * @param acc1 First accumulator + * @param acc2 Second accumulator + * @param maxSampleCount Maximum samples to keep per pattern + * @return Merged accumulator + */ + @SuppressWarnings("unchecked") + public static Map combinePatternAccumulators( + Map acc1, Map acc2, int maxSampleCount) { + + Map> patterns1 = + (Map>) acc1.get("patternGroupMap"); + Map> patterns2 = + (Map>) acc2.get("patternGroupMap"); + + Map> merged = + PatternUtils.mergePatternGroups(patterns1, patterns2, maxSampleCount); + + // Merge logMessages from both accumulators to preserve buffered messages + List logMessages1 = (List) acc1.get("logMessages"); + List logMessages2 = (List) acc2.get("logMessages"); + List mergedLogMessages = new ArrayList<>(); + if (logMessages1 != null) { + mergedLogMessages.addAll(logMessages1); + } + if (logMessages2 != null) { + mergedLogMessages.addAll(logMessages2); + } + + Map result = new HashMap<>(); + result.put("logMessages", mergedLogMessages); + result.put("patternGroupMap", merged); + return result; + } + + /** + * Produce final pattern result (for reduce_script phase). + * + * @param acc Accumulator state + * @param maxSampleCount Maximum samples per pattern + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage + * @param showNumberedToken Whether to show numbered tokens in output + * @return List of pattern result objects sorted by count + */ + @SuppressWarnings("unchecked") + public static List> producePatternResult( + Map acc, + int maxSampleCount, + int variableCountThreshold, + double thresholdPercentage, + boolean showNumberedToken) { + + // Process any remaining logs in buffer + List logMessages = (List) acc.get("logMessages"); + Map> patternGroupMap = + (Map>) acc.get("patternGroupMap"); + + if (logMessages != null && !logMessages.isEmpty()) { + BrainLogParser parser = + new BrainLogParser(variableCountThreshold, (float) thresholdPercentage); + Map> partialPatterns = + parser.parseAllLogPatterns(logMessages, maxSampleCount); + patternGroupMap = + PatternUtils.mergePatternGroups(patternGroupMap, partialPatterns, maxSampleCount); + } + + // Format and sort final output by pattern count + return patternGroupMap.values().stream() + .sorted( + Comparator.comparing( + m -> (Long) m.get(PatternUtils.PATTERN_COUNT), + Comparator.nullsLast(Comparator.reverseOrder()))) + .map(m -> formatPatternOutput(m, showNumberedToken)) + .collect(Collectors.toList()); + } + + /** + * Produce final pattern result from states array (overload for int thresholdPercentage). This + * overload handles the case when thresholdPercentage is passed as an integer at runtime. + * + * @param states List of shard-level accumulator states + * @param maxSampleCount Maximum samples per pattern + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as int) + * @param showNumberedToken Whether to show numbered tokens in output + * @return List of pattern result objects sorted by count + */ + public static List> producePatternResultFromStates( + List states, + int maxSampleCount, + int variableCountThreshold, + int thresholdPercentage, + boolean showNumberedToken) { + return producePatternResultFromStates( + states, + maxSampleCount, + variableCountThreshold, + (double) thresholdPercentage, + showNumberedToken); + } + + /** + * Produce final pattern result from states array (overload for Object states). This overload + * handles the case when states is passed as a generic Object at runtime due to type erasure. + * + * @param states List of shard-level accumulator states (as Object, will be cast to List) + * @param maxSampleCount Maximum samples per pattern + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage + * @param showNumberedToken Whether to show numbered tokens in output + * @return List of pattern result objects sorted by count + */ + @SuppressWarnings("unchecked") + public static List> producePatternResultFromStates( + Object states, + int maxSampleCount, + int variableCountThreshold, + double thresholdPercentage, + boolean showNumberedToken) { + return producePatternResultFromStates( + (List) states, + maxSampleCount, + variableCountThreshold, + thresholdPercentage, + showNumberedToken); + } + + /** + * Produce final pattern result from states array (overload for Object states with BigDecimal). + * This overload handles the case when states is Object and thresholdPercentage is BigDecimal. + * + * @param states List of shard-level accumulator states (as Object, will be cast to List) + * @param maxSampleCount Maximum samples per pattern + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as BigDecimal) + * @param showNumberedToken Whether to show numbered tokens in output + * @return List of pattern result objects sorted by count + */ + @SuppressWarnings("unchecked") + public static List> producePatternResultFromStates( + Object states, + int maxSampleCount, + int variableCountThreshold, + java.math.BigDecimal thresholdPercentage, + boolean showNumberedToken) { + return producePatternResultFromStates( + (List) states, + maxSampleCount, + variableCountThreshold, + thresholdPercentage != null + ? thresholdPercentage.doubleValue() + : BrainLogParser.DEFAULT_FREQUENCY_THRESHOLD_PERCENTAGE, + showNumberedToken); + } + + /** + * Produce final pattern result from states array (overload for BigDecimal thresholdPercentage). + * This overload handles the case when thresholdPercentage is passed as BigDecimal at runtime. + * + * @param states List of shard-level accumulator states + * @param maxSampleCount Maximum samples per pattern + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as BigDecimal) + * @param showNumberedToken Whether to show numbered tokens in output + * @return List of pattern result objects sorted by count + */ + public static List> producePatternResultFromStates( + List states, + int maxSampleCount, + int variableCountThreshold, + java.math.BigDecimal thresholdPercentage, + boolean showNumberedToken) { + return producePatternResultFromStates( + states, + maxSampleCount, + variableCountThreshold, + thresholdPercentage != null + ? thresholdPercentage.doubleValue() + : BrainLogParser.DEFAULT_FREQUENCY_THRESHOLD_PERCENTAGE, + showNumberedToken); + } + + /** + * Produce final pattern result from states array (for reduce_script phase). This method combines + * all shard-level states and produces the final aggregated result. + * + * @param states List of shard-level accumulator states + * @param maxSampleCount Maximum samples per pattern + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage + * @param showNumberedToken Whether to show numbered tokens in output + * @return List of pattern result objects sorted by count + */ + @SuppressWarnings("unchecked") + public static List> producePatternResultFromStates( + List states, + int maxSampleCount, + int variableCountThreshold, + double thresholdPercentage, + boolean showNumberedToken) { + + if (states == null || states.isEmpty()) { + return new ArrayList<>(); + } + + // Combine all states into a single accumulator + Map combined = (Map) states.get(0); + for (int i = 1; i < states.size(); i++) { + Map state = (Map) states.get(i); + combined = combinePatternAccumulators(combined, state, maxSampleCount); + } + + // Produce final result from combined state + return producePatternResult( + combined, maxSampleCount, variableCountThreshold, thresholdPercentage, showNumberedToken); + } + + /** + * Format a single pattern result for output. + * + *

Note: Token extraction is NOT done here. The pattern is returned with wildcards (e.g., + * {@code <*>}) and token extraction is performed later by {@code + * PatternParserFunctionImpl.evalAggSamples()} after the data returns from OpenSearch. This + * approach avoids the XContent serialization issue where nested {@code Map>} + * structures are not properly serialized. + * + * @param patternInfo Pattern information map + * @param showNumberedToken Whether numbered tokens should be shown (determines output format) + * @return Formatted pattern output with pattern, count, and sample_logs + */ + @SuppressWarnings("unchecked") + private static Map formatPatternOutput( + Map patternInfo, boolean showNumberedToken) { + + String pattern = (String) patternInfo.get(PatternUtils.PATTERN); + Long count = (Long) patternInfo.get(PatternUtils.PATTERN_COUNT); + List sampleLogs = (List) patternInfo.get(PatternUtils.SAMPLE_LOGS); + + // For UDAF pushdown, we don't compute tokens here. + // Tokens will be computed by PatternParserFunctionImpl.evalAggSamples() after data returns + // from OpenSearch. This avoids XContent serialization issues with nested Map structures. + // The showNumberedToken flag is passed through to indicate the expected output format. + return ImmutableMap.of( + PatternUtils.PATTERN, + pattern, + PatternUtils.PATTERN_COUNT, + count, + PatternUtils.SAMPLE_LOGS, + sampleLogs); + } +} diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index 96fe2e04eea..9e766c0aa37 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -41,6 +41,7 @@ public enum Key { CALCITE_ENGINE_ENABLED("plugins.calcite.enabled"), CALCITE_FALLBACK_ALLOWED("plugins.calcite.fallback.allowed"), CALCITE_PUSHDOWN_ENABLED("plugins.calcite.pushdown.enabled"), + CALCITE_UDAF_PUSHDOWN_ENABLED("plugins.calcite.udaf_pushdown.enabled"), CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR( "plugins.calcite.pushdown.rowcount.estimation.factor"), CALCITE_SUPPORT_ALL_JOIN_TYPES("plugins.calcite.all_join_types.allowed"), diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index f1bc5fd6a0d..6b9ca4c52f3 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -3259,31 +3259,74 @@ private RexNode explicitMapType( return new RexInputRef(((RexInputRef) origin).getIndex(), newMapType); } + /** + * Flattens the parsed pattern result into individual fields for projection. + * + *

This method handles two scenarios: + * + *

    + *
  • Label mode: extracts pattern (and optionally tokens) from parsedNode + *
  • Aggregation mode: extracts pattern, pattern_count, tokens (optional), and sample_logs + *
+ * + *

When both flattenPatternAggResult and showNumberedToken are true, the pattern and tokens + * need transformation via evalAggSamples (converting wildcards to numbered tokens). + * + * @param originalPatternResultAlias alias for the pattern field + * @param parsedNode the source RexNode containing parsed pattern data + * @param context the Calcite plan context + * @param flattenPatternAggResult true if in aggregation mode (includes pattern_count, + * sample_logs) + * @param showNumberedToken true if tokens should be extracted and pattern transformed + */ private void flattenParsedPattern( String originalPatternResultAlias, RexNode parsedNode, CalcitePlanContext context, boolean flattenPatternAggResult, - Boolean showNumberedToken) { - List fattenedNodes = new ArrayList<>(); + boolean showNumberedToken) { + List flattenedNodes = new ArrayList<>(); List projectNames = new ArrayList<>(); - // Flatten map struct fields + + RelDataType varcharType = + context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR); + + // For aggregation mode with numbered tokens, we need to compute tokens locally + // using evalAggSamples. The UDAF returns pattern with wildcards and sample_logs, + // but NOT tokens (to avoid XContent serialization issues with nested Maps). + // The transformed result contains: pattern (with numbered tokens) and tokens map. + RexNode transformedPatternResult = null; + if (flattenPatternAggResult && showNumberedToken) { + transformedPatternResult = buildEvalAggSamplesCall(parsedNode, context); + } + + // Determine source for pattern and tokens: + // - When transformedPatternResult exists, use it (pattern/tokens need transformation) + // - pattern_count and sample_logs always come from the original parsedNode + RexNode patternAndTokensSource = + transformedPatternResult != null ? transformedPatternResult : parsedNode; + + // 1. Always add pattern field RexNode patternExpr = context.rexBuilder.makeCast( - context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), + varcharType, PPLFuncImpTable.INSTANCE.resolve( context.rexBuilder, BuiltinFunctionName.INTERNAL_ITEM, - parsedNode, + patternAndTokensSource, context.rexBuilder.makeLiteral(PatternUtils.PATTERN)), true, true); - fattenedNodes.add(context.relBuilder.alias(patternExpr, originalPatternResultAlias)); + flattenedNodes.add(context.relBuilder.alias(patternExpr, originalPatternResultAlias)); projectNames.add(originalPatternResultAlias); + + // 2. Add pattern_count when in aggregation mode (from original parsedNode) if (flattenPatternAggResult) { + RelDataType bigintType = + context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT); RexNode patternCountExpr = context.rexBuilder.makeCast( - context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT), + bigintType, PPLFuncImpTable.INSTANCE.resolve( context.rexBuilder, BuiltinFunctionName.INTERNAL_ITEM, @@ -3291,31 +3334,40 @@ private void flattenParsedPattern( context.rexBuilder.makeLiteral(PatternUtils.PATTERN_COUNT)), true, true); - fattenedNodes.add(context.relBuilder.alias(patternCountExpr, PatternUtils.PATTERN_COUNT)); + flattenedNodes.add(context.relBuilder.alias(patternCountExpr, PatternUtils.PATTERN_COUNT)); projectNames.add(PatternUtils.PATTERN_COUNT); } + + // 3. Add tokens when showNumberedToken is enabled if (showNumberedToken) { + RelDataType tokensType = + context + .rexBuilder + .getTypeFactory() + .createMapType( + varcharType, + context.rexBuilder.getTypeFactory().createArrayType(varcharType, -1)); RexNode tokensExpr = context.rexBuilder.makeCast( - UserDefinedFunctionUtils.tokensMap, + tokensType, PPLFuncImpTable.INSTANCE.resolve( context.rexBuilder, BuiltinFunctionName.INTERNAL_ITEM, - parsedNode, + patternAndTokensSource, context.rexBuilder.makeLiteral(PatternUtils.TOKENS)), true, true); - fattenedNodes.add(context.relBuilder.alias(tokensExpr, PatternUtils.TOKENS)); + flattenedNodes.add(context.relBuilder.alias(tokensExpr, PatternUtils.TOKENS)); projectNames.add(PatternUtils.TOKENS); } + + // 4. Add sample_logs when in aggregation mode (from original parsedNode) if (flattenPatternAggResult) { + RelDataType sampleLogsArrayType = + context.rexBuilder.getTypeFactory().createArrayType(varcharType, -1); RexNode sampleLogsExpr = context.rexBuilder.makeCast( - context - .rexBuilder - .getTypeFactory() - .createArrayType( - context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), -1), + sampleLogsArrayType, PPLFuncImpTable.INSTANCE.resolve( context.rexBuilder, BuiltinFunctionName.INTERNAL_ITEM, @@ -3323,10 +3375,45 @@ private void flattenParsedPattern( context.rexBuilder.makeLiteral(PatternUtils.SAMPLE_LOGS)), true, true); - fattenedNodes.add(context.relBuilder.alias(sampleLogsExpr, PatternUtils.SAMPLE_LOGS)); + flattenedNodes.add(context.relBuilder.alias(sampleLogsExpr, PatternUtils.SAMPLE_LOGS)); projectNames.add(PatternUtils.SAMPLE_LOGS); } - projectPlusOverriding(fattenedNodes, projectNames, context); + + projectPlusOverriding(flattenedNodes, projectNames, context); + } + + /** + * Builds the evalAggSamples call to transform pattern with wildcards to numbered tokens and + * compute the tokens map from sample logs. + * + * @param parsedNode The UDAF result containing pattern and sample_logs + * @param context The Calcite plan context + * @return RexNode representing the evalAggSamples call result + */ + private RexNode buildEvalAggSamplesCall(RexNode parsedNode, CalcitePlanContext context) { + // Extract pattern string (with wildcards) from UDAF result + RexNode patternStr = + PPLFuncImpTable.INSTANCE.resolve( + context.rexBuilder, + BuiltinFunctionName.INTERNAL_ITEM, + parsedNode, + context.rexBuilder.makeLiteral(PatternUtils.PATTERN)); + + // Extract sample_logs from UDAF result + RexNode sampleLogs = + PPLFuncImpTable.INSTANCE.resolve( + context.rexBuilder, + BuiltinFunctionName.INTERNAL_ITEM, + explicitMapType(context, parsedNode, SqlTypeName.VARCHAR), + context.rexBuilder.makeLiteral(PatternUtils.SAMPLE_LOGS)); + + // Call evalAggSamples to transform pattern (wildcards -> numbered tokens) and compute tokens + return PPLFuncImpTable.INSTANCE.resolve( + context.rexBuilder, + BuiltinFunctionName.INTERNAL_PATTERN_PARSER, + patternStr, + sampleLogs, + context.rexBuilder.makeLiteral(true)); } private void buildExpandRelNode( diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java index f93a0e7c49d..ff207912e35 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java +++ b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java @@ -5,23 +5,24 @@ package org.opensearch.sql.calcite.udf.udaf; -import com.google.common.collect.ImmutableMap; import java.math.BigDecimal; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; import org.opensearch.sql.calcite.udf.UserDefinedAggFunction; import org.opensearch.sql.calcite.udf.udaf.LogPatternAggFunction.LogParserAccumulator; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.patterns.BrainLogParser; -import org.opensearch.sql.common.patterns.PatternUtils; -import org.opensearch.sql.common.patterns.PatternUtils.ParseResult; - +import org.opensearch.sql.common.patterns.PatternAggregationHelpers; + +/** + * User-defined aggregate function for log pattern extraction using the Brain algorithm. This UDAF + * is used for in-memory pattern aggregation in Calcite. For OpenSearch scripted metric pushdown, + * see {@link PatternAggregationHelpers} which provides the same logic with Map-based state. + * + *

Both implementations share the same underlying logic through {@link PatternAggregationHelpers} + * to ensure consistency. + */ public class LogPatternAggFunction implements UserDefinedAggFunction { private int bufferLimit = 100000; private int maxSampleCount = 10; @@ -36,10 +37,9 @@ public LogParserAccumulator init() { @Override public Object result(LogParserAccumulator acc) { - if (acc.size() == 0 && acc.logSize() == 0) { + if (acc.isEmpty()) { return null; } - return acc.value( maxSampleCount, variableCountThreshold, thresholdPercentage, showNumberedToken); } @@ -83,17 +83,16 @@ public LogParserAccumulator add( if (Objects.isNull(field)) { return acc; } + // Store parameters for result() phase this.bufferLimit = bufferLimit; this.maxSampleCount = maxSampleCount; this.showNumberedToken = showNumberedToken; this.variableCountThreshold = variableCountThreshold; this.thresholdPercentage = thresholdPercentage; - acc.evaluate(field); - if (bufferLimit > 0 && acc.logSize() == bufferLimit) { - acc.partialMerge( - maxSampleCount, variableCountThreshold, thresholdPercentage, showNumberedToken); - acc.clearBuffer(); - } + + // Delegate to shared helper logic + PatternAggregationHelpers.addLogToPattern( + acc.state, field, maxSampleCount, bufferLimit, variableCountThreshold, thresholdPercentage); return acc; } @@ -147,82 +146,50 @@ public LogParserAccumulator add( this.variableCountThreshold); } + /** + * Accumulator for log pattern aggregation. This is a thin wrapper around the Map-based state used + * by {@link PatternAggregationHelpers}, providing type safety for Calcite UDAF while reusing the + * same underlying logic. + */ public static class LogParserAccumulator implements Accumulator { - private final List logMessages; - public Map> patternGroupMap = new HashMap<>(); - - public int size() { - return patternGroupMap.size(); - } - - public int logSize() { - return logMessages.size(); - } + /** The underlying state map, compatible with PatternAggregationHelpers */ + final Map state; public LogParserAccumulator() { - this.logMessages = new ArrayList<>(); - } - - public void evaluate(String value) { - logMessages.add(value); + this.state = PatternAggregationHelpers.initPatternAccumulator(); } - public void clearBuffer() { - logMessages.clear(); - } - - public void partialMerge(Object... argList) { - if (logMessages.isEmpty()) { - return; - } - assert argList.length == 4 : "partialMerge of LogParserAccumulator requires 4 parameters"; - int maxSampleCount = (int) argList[0]; - BrainLogParser logParser = - new BrainLogParser((int) argList[1], ((Double) argList[2]).floatValue()); - Map> partialPatternGroupMap = - logParser.parseAllLogPatterns(logMessages, maxSampleCount); - patternGroupMap = - PatternUtils.mergePatternGroups(patternGroupMap, partialPatternGroupMap, maxSampleCount); + @SuppressWarnings("unchecked") + public boolean isEmpty() { + List logMessages = (List) state.get("logMessages"); + Map patternGroupMap = (Map) state.get("patternGroupMap"); + return (logMessages == null || logMessages.isEmpty()) + && (patternGroupMap == null || patternGroupMap.isEmpty()); } @Override public Object value(Object... argList) { - partialMerge(argList); - clearBuffer(); - - Boolean showToken = (Boolean) argList[3]; - return patternGroupMap.values().stream() - .sorted( - Comparator.comparing( - m -> (Long) m.get(PatternUtils.PATTERN_COUNT), - Comparator.nullsLast(Comparator.reverseOrder()))) - .map( - m -> { - String pattern = (String) m.get(PatternUtils.PATTERN); - Long count = (Long) m.get(PatternUtils.PATTERN_COUNT); - List sampleLogs = (List) m.get(PatternUtils.SAMPLE_LOGS); - Map> tokensMap = new HashMap<>(); - ParseResult parseResult = null; - if (showToken) { - parseResult = PatternUtils.parsePattern(pattern, PatternUtils.WILDCARD_PATTERN); - for (String sampleLog : sampleLogs) { - PatternUtils.extractVariables( - parseResult, sampleLog, tokensMap, PatternUtils.WILDCARD_PREFIX); - } - } - return ImmutableMap.of( - PatternUtils.PATTERN, - showToken - ? parseResult.toTokenOrderString(PatternUtils.WILDCARD_PREFIX) - : pattern, - PatternUtils.PATTERN_COUNT, - count, - PatternUtils.TOKENS, - showToken ? tokensMap : Collections.EMPTY_MAP, - PatternUtils.SAMPLE_LOGS, - sampleLogs); - }) - .collect(Collectors.toList()); + // Return the current state for use by LogPatternAggFunction.result() + // The argList contains [maxSampleCount, variableCountThreshold, thresholdPercentage, + // showNumberedToken] + if (isEmpty()) { + return null; + } + int maxSampleCount = + argList.length > 0 && argList[0] != null ? ((Number) argList[0]).intValue() : 10; + int variableCountThreshold = + argList.length > 1 && argList[1] != null + ? ((Number) argList[1]).intValue() + : BrainLogParser.DEFAULT_VARIABLE_COUNT_THRESHOLD; + double thresholdPercentage = + argList.length > 2 && argList[2] != null + ? ((Number) argList[2]).doubleValue() + : BrainLogParser.DEFAULT_FREQUENCY_THRESHOLD_PERCENTAGE; + boolean showNumberedToken = + argList.length > 3 && argList[3] != null && Boolean.TRUE.equals(argList[3]); + + return PatternAggregationHelpers.producePatternResult( + state, maxSampleCount, variableCountThreshold, thresholdPercentage, showNumberedToken); } } } diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java index f619d966cc8..ff9bc23bc54 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java @@ -11,6 +11,7 @@ import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.*; import com.google.common.collect.ImmutableSet; +import java.lang.reflect.Type; import java.time.Instant; import java.time.ZoneId; import java.time.ZoneOffset; @@ -209,7 +210,7 @@ public static List convertToExprValues( * @return an adapted ImplementorUDF with the expr method, which is a UserDefinedFunctionBuilder */ public static ImplementorUDF adaptExprMethodToUDF( - java.lang.reflect.Type type, + Type type, String methodName, SqlReturnTypeInference returnTypeInference, NullPolicy nullPolicy, @@ -240,7 +241,7 @@ public UDFOperandMetadata getOperandMetadata() { * FunctionProperties} at the beginning to a Calcite-compatible UserDefinedFunctionBuilder. */ public static ImplementorUDF adaptExprMethodWithPropertiesToUDF( - java.lang.reflect.Type type, + Type type, String methodName, SqlReturnTypeInference returnTypeInference, NullPolicy nullPolicy, @@ -317,4 +318,44 @@ public static List prependFunctionProperties( operandsWithProperties.addFirst(properties); return Collections.unmodifiableList(operandsWithProperties); } + + /** + * Adapt a static method from any class to a UserDefinedFunctionBuilder. This is a general-purpose + * adapter that can wrap static helper methods (e.g., PatternAggregationHelpers methods) as UDFs + * for use in scripted metrics. + * + * @param type the class containing the static method + * @param methodName the name of the static method to be invoked + * @param returnTypeInference the return type inference of the UDF + * @param nullPolicy the null policy of the UDF + * @param operandMetadata type checker for operands + * @return an adapted ImplementorUDF wrapping the static method + */ + public static ImplementorUDF adaptStaticMethodToUDF( + Type type, + String methodName, + SqlReturnTypeInference returnTypeInference, + NullPolicy nullPolicy, + @Nullable UDFOperandMetadata operandMetadata) { + + NotNullImplementor implementor = + (translator, call, translatedOperands) -> { + // For static methods that work with generic objects (Map, List, etc.), + // we don't need type conversion like adaptMathFunctionToUDF + // Just pass the operands directly to the static method + return Expressions.call(type, methodName, translatedOperands); + }; + + return new ImplementorUDF(implementor, nullPolicy) { + @Override + public SqlReturnTypeInference getReturnTypeInference() { + return returnTypeInference; + } + + @Override + public UDFOperandMetadata getOperandMetadata() { + return operandMetadata; + } + }; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 50f88d47baf..6ee6a229d74 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -346,6 +346,11 @@ public enum BuiltinFunctionName { INTERNAL_PATTERN_PARSER(FunctionName.of("pattern_parser")), INTERNAL_PATTERN(FunctionName.of("pattern")), INTERNAL_UNCOLLECT_PATTERNS(FunctionName.of("uncollect_patterns")), + // Pattern aggregation UDFs for scripted metric pushdown + PATTERN_INIT_UDF(FunctionName.of("pattern_init_udf"), true), + PATTERN_ADD_UDF(FunctionName.of("pattern_add_udf"), true), + PATTERN_COMBINE_UDF(FunctionName.of("pattern_combine_udf"), true), + PATTERN_RESULT_UDF(FunctionName.of("pattern_result_udf"), true), INTERNAL_GROK(FunctionName.of("grok"), true), INTERNAL_PARSE(FunctionName.of("parse"), true), INTERNAL_REGEXP_REPLACE_PG_4(FunctionName.of("regexp_replace_pg_4"), true), diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java index 3810352cbfd..29c04625bea 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java @@ -21,12 +21,16 @@ import org.apache.calcite.adapter.enumerable.RexToLixTranslator; import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexCall; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeTransforms; +import org.apache.calcite.sql.type.SqlTypeUtil; import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable; import org.apache.calcite.util.BuiltInMethod; import org.opensearch.sql.calcite.udf.udaf.FirstAggFunction; @@ -40,6 +44,7 @@ import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; +import org.opensearch.sql.common.patterns.PatternAggregationHelpers; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.datetime.DateTimeFunctions; import org.opensearch.sql.expression.function.CollectionUDF.AppendFunctionImpl; @@ -482,6 +487,49 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { PPLReturnTypes.STRING_ARRAY, PPLOperandTypes.ANY_SCALAR_OPTIONAL_INTEGER); + // Pattern aggregation helper UDFs for scripted metric pushdown + // This UDF takes state as parameter and modifies it in-place (for OpenSearch scripted metric) + public static final SqlOperator PATTERN_INIT_UDF = + UserDefinedFunctionUtils.adaptStaticMethodToUDF( + PatternAggregationHelpers.class, + "initPatternState", + ReturnTypes.explicit(SqlTypeName.ANY), // Returns Map + NullPolicy.ANY, + null) // Takes state as parameter + .toUDF("PATTERN_INIT_UDF"); + + public static final SqlOperator PATTERN_ADD_UDF = + UserDefinedFunctionUtils.adaptStaticMethodToUDF( + PatternAggregationHelpers.class, + "addLogToPattern", + ReturnTypes.explicit(SqlTypeName.ANY), // Returns Map + NullPolicy.ANY, + null) // TODO: Add proper operand type checking + .toUDF("PATTERN_ADD_UDF"); + + public static final SqlOperator PATTERN_COMBINE_UDF = + UserDefinedFunctionUtils.adaptStaticMethodToUDF( + PatternAggregationHelpers.class, + "combinePatternAccumulators", + ReturnTypes.explicit(SqlTypeName.ANY), // Returns Map + NullPolicy.ANY, + null) // TODO: Add proper operand type checking + .toUDF("PATTERN_COMBINE_UDF"); + + public static final SqlOperator PATTERN_RESULT_UDF = + UserDefinedFunctionUtils.adaptStaticMethodToUDF( + PatternAggregationHelpers.class, + "producePatternResultFromStates", + opBinding -> { + // Returns List> - represented as ARRAY + RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + RelDataType anyType = typeFactory.createSqlType(SqlTypeName.ANY); + return SqlTypeUtil.createArrayType(typeFactory, anyType, true); + }, + NullPolicy.ANY, + null) // TODO: Add proper operand type checking + .toUDF("PATTERN_RESULT_UDF"); + public static final SqlOperator ENHANCED_COALESCE = new EnhancedCoalesceFunction().toUDF("COALESCE"); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 2d594c48f55..4d4dbbda954 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -163,6 +163,10 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.NOW; import static org.opensearch.sql.expression.function.BuiltinFunctionName.NULLIF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.OR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.PATTERN_ADD_UDF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.PATTERN_COMBINE_UDF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.PATTERN_INIT_UDF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.PATTERN_RESULT_UDF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.PERCENTILE_APPROX; import static org.opensearch.sql.expression.function.BuiltinFunctionName.PERIOD_ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.PERIOD_DIFF; @@ -981,6 +985,11 @@ void populate() { registerOperator(WEEKOFYEAR, PPLBuiltinOperators.WEEK); registerOperator(INTERNAL_PATTERN_PARSER, PPLBuiltinOperators.PATTERN_PARSER); + // Register pattern aggregation helper UDFs for scripted metric pushdown + registerOperator(PATTERN_INIT_UDF, PPLBuiltinOperators.PATTERN_INIT_UDF); + registerOperator(PATTERN_ADD_UDF, PPLBuiltinOperators.PATTERN_ADD_UDF); + registerOperator(PATTERN_COMBINE_UDF, PPLBuiltinOperators.PATTERN_COMBINE_UDF); + registerOperator(PATTERN_RESULT_UDF, PPLBuiltinOperators.PATTERN_RESULT_UDF); registerOperator(TONUMBER, PPLBuiltinOperators.TONUMBER); registerOperator(TOSTRING, PPLBuiltinOperators.TOSTRING); register( diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java index e4f7f1f9d1c..1e7aa80617b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java @@ -69,26 +69,35 @@ public Expression implement( : "PATTERN_PARSER should have 2 or 3 arguments"; RelDataType inputType = call.getOperands().get(1).getType(); - Method method = resolveEvaluationMethod(inputType); + Method method = resolveEvaluationMethod(inputType, operandCount); ScalarFunctionImpl function = (ScalarFunctionImpl) ScalarFunctionImpl.create(method); return function.getImplementor().implement(translator, call, RexImpTable.NullAs.NULL); } - private Method resolveEvaluationMethod(RelDataType inputType) { + private Method resolveEvaluationMethod(RelDataType inputType, int operandCount) { if (inputType.getSqlTypeName() == SqlTypeName.VARCHAR) { return getMethod(String.class, "evalField"); } RelDataType componentType = inputType.getComponentType(); - return (componentType.getSqlTypeName() == SqlTypeName.MAP) - ? Types.lookupMethod( - PatternParserFunctionImpl.class, - "evalAgg", - String.class, - Objects.class, - Boolean.class) - : getMethod(List.class, "evalSamples"); + if (componentType.getSqlTypeName() == SqlTypeName.MAP) { + // evalAgg: for label mode with aggregation results (array of maps) + return Types.lookupMethod( + PatternParserFunctionImpl.class, "evalAgg", String.class, Objects.class, Boolean.class); + } else if (operandCount == 3) { + // evalAggSamples: for UDAF pushdown aggregation mode + // Takes pattern (String), sample_logs (List), showNumberedToken (Boolean) + return Types.lookupMethod( + PatternParserFunctionImpl.class, + "evalAggSamples", + String.class, + List.class, + Boolean.class); + } else { + // evalSamples: for simple pattern with sample logs (2 arguments) + return getMethod(List.class, "evalSamples"); + } } private Method getMethod(Class paramType, String methodName) { @@ -126,13 +135,23 @@ public static Object evalAgg( if (bestCandidate != null) { String bestCandidatePattern = String.join(" ", bestCandidate); Map> tokensMap = new HashMap<>(); - if (showNumberedToken) { + String outputPattern = bestCandidatePattern; // Default: return as-is + + if (Boolean.TRUE.equals(showNumberedToken)) { + // Parse pattern with wildcard format (<*>, <*IP*>, etc.) + // LogPatternAggFunction.value() returns patterns in wildcard format ParseResult parseResult = - PatternUtils.parsePattern(bestCandidatePattern, PatternUtils.TOKEN_PATTERN); + PatternUtils.parsePattern(bestCandidatePattern, PatternUtils.WILDCARD_PATTERN); + + // Transform pattern from wildcards to numbered tokens (, , etc.) + outputPattern = parseResult.toTokenOrderString(PatternUtils.TOKEN_PREFIX); + + // Extract token values from the field PatternUtils.extractVariables(parseResult, field, tokensMap, PatternUtils.TOKEN_PREFIX); } + return ImmutableMap.of( - PatternUtils.PATTERN, bestCandidatePattern, + PatternUtils.PATTERN, outputPattern, PatternUtils.TOKENS, tokensMap); } else { return ImmutableMap.of(); @@ -174,6 +193,47 @@ public static Object evalSamples( tokensMap); } + /** + * Extract tokens from aggregated pattern and sample logs for UDAF pushdown. Transforms the + * pattern from wildcard format (e.g., <*>) to numbered token format (e.g., <token1>, + * <token2>) when showNumberedToken is true. + * + *

This method is designed to be called after UDAF pushdown returns from OpenSearch. The UDAF + * returns patterns with wildcards, and this method transforms them to numbered tokens and + * extracts token values from sample logs. + * + * @param pattern The pattern string with wildcards (e.g., <*>, <*IP*>) + * @param sampleLogs List of sample log messages + * @param showNumberedToken Whether to transform to numbered tokens and extract token values + * @return Map containing pattern (possibly transformed) and tokens (if showNumberedToken is true) + */ + public static Object evalAggSamples( + @Parameter(name = "pattern") String pattern, + @Parameter(name = "sample_logs") List sampleLogs, + @Parameter(name = "showNumberedToken") Boolean showNumberedToken) { + if (Strings.isBlank(pattern)) { + return EMPTY_RESULT; + } + + Map> tokensMap = new HashMap<>(); + String outputPattern = pattern; // Default: return pattern as-is (with wildcards) + + if (Boolean.TRUE.equals(showNumberedToken)) { + // Parse pattern with wildcard format (<*>, <*IP*>, etc.) + ParseResult parseResult = PatternUtils.parsePattern(pattern, PatternUtils.WILDCARD_PATTERN); + + // Transform pattern from wildcards to numbered tokens (, , etc.) + outputPattern = parseResult.toTokenOrderString(PatternUtils.TOKEN_PREFIX); + + // Extract token values from sample logs + for (String sampleLog : sampleLogs) { + PatternUtils.extractVariables(parseResult, sampleLog, tokensMap, PatternUtils.TOKEN_PREFIX); + } + } + + return ImmutableMap.of(PatternUtils.PATTERN, outputPattern, PatternUtils.TOKENS, tokensMap); + } + private static List findBestCandidate( List> candidates, List tokens) { return candidates.stream() @@ -188,7 +248,8 @@ private static float calculateScore(List tokens, List candidate) String candidateToken = candidate.get(i); if (Objects.equals(preprocessedToken, candidateToken)) { score += 1; - } else if (preprocessedToken.startsWith("<*") && candidateToken.startsWith(" vs <*IP*>) score += 1; } } diff --git a/docs/user/ppl/cmd/patterns.md b/docs/user/ppl/cmd/patterns.md index 6941efbe4f1..c880b08baf1 100644 --- a/docs/user/ppl/cmd/patterns.md +++ b/docs/user/ppl/cmd/patterns.md @@ -13,7 +13,7 @@ The `patterns` command supports the following modes: The command identifies variable parts of log messages (such as timestamps, numbers, IP addresses, and unique identifiers) and replaces them with `<*>` placeholders to create reusable patterns. For example, email addresses like `amberduke@pyrami.com` and `hattiebond@netagy.com` are replaced with the pattern `<*>@<*>.<*>`. -> **Note**: The `patterns` command is not executed on OpenSearch data nodes. It only groups log patterns from log messages that have been returned to the coordinator node. +> **Note**: By default, the `patterns` command is not executed on OpenSearch data nodes. It only groups log patterns from log messages that have been returned to the coordinator node. However, when using `mode=aggregation` with `method=brain` and the `plugins.calcite.udaf_pushdown.enabled` cluster setting is set to `true`, the aggregation may be pushed down and executed on data nodes as a scripted metric aggregation for improved performance. See [Enabling UDAF pushdown for patterns aggregation](#enabling-udaf-pushdown-for-patterns-aggregation) for more details. ## Syntax @@ -67,7 +67,7 @@ The `brain` method accepts the following parameters. By default, the Apache Calcite engine labels variables using the `<*>` placeholder. If the `show_numbered_token` option is enabled, the Calcite engine's `label` mode not only labels the text pattern but also assigns numbered placeholders to variable tokens. In `aggregation` mode, it outputs both the labeled pattern and the variable tokens for each pattern. In this case, variable placeholders use the format `` instead of `<*>`. -## Changing the default pattern method +## Changing the default pattern method To override default pattern parameters, run the following command: @@ -83,7 +83,26 @@ PUT _cluster/settings } } ``` - + +## Enabling UDAF pushdown for patterns aggregation + +When using the `patterns` command with `mode=aggregation` and `method=brain`, the aggregation can optionally be pushed down to OpenSearch as a scripted metric aggregation for parallel execution across data nodes. This can improve performance for large datasets but uses scripted metric aggregations which lack circuit breaker protection. + +By default, UDAF pushdown is **disabled**. To enable it, run the following command: + +```bash ignore +PUT _cluster/settings +{ + "persistent": { + "plugins.calcite.udaf_pushdown.enabled": true + } +} +``` + +> **Warning**: Enabling UDAF pushdown executes user-defined aggregation functions as scripted metric aggregations on OpenSearch data nodes. This bypasses certain memory circuit breakers and may cause out-of-memory errors on nodes when processing very large datasets. Use with caution and monitor cluster resource usage. + +When UDAF pushdown is disabled (the default), the pattern aggregation runs locally on the coordinator node after fetching the data from OpenSearch. + ## Simple pattern examples The following are examples of using the `simple_pattern` method. diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java index 46df914e611..74e3f8a9958 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java @@ -19,6 +19,7 @@ import java.io.IOException; import org.json.JSONObject; import org.junit.Test; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.ppl.PPLIntegTestCase; public class CalcitePPLPatternsIT extends PPLIntegTestCase { @@ -530,4 +531,168 @@ public void testBrainParseWithUUID_ShowNumberedToken() throws IOException { "[PlaceOrder] user_id= user_currency=USD", ImmutableMap.of("", ImmutableList.of("d664d7be-77d8-11f0-8880-0242f00b101d")))); } + + @Test + public void testBrainAggregationMode_UDAFPushdown_NotShowNumberedToken() throws IOException { + // Test UDAF pushdown for patterns BRAIN aggregation mode + // This verifies that the query is pushed down to OpenSearch as a scripted metric aggregation + // UDAF pushdown is disabled by default, enable it for this test + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "true")); + try { + JSONObject result = + executeQuery( + String.format( + "source=%s | patterns content method=brain mode=aggregation" + + " variable_count_threshold=5", + TEST_INDEX_HDFS_LOGS)); + + // Verify schema matches expected output + verifySchema( + result, + schema("patterns_field", "string"), + schema("pattern_count", "bigint"), + schema("sample_logs", "array")); + + // Verify data rows - should match the non-pushdown results + verifyDataRows( + result, + rows( + "Verification succeeded <*> blk_<*>", + 2, + ImmutableList.of( + "Verification succeeded for blk_-1547954353065580372", + "Verification succeeded for blk_6996194389878584395")), + rows( + "BLOCK* NameSystem.addStoredBlock: blockMap updated: <*IP*> is added to blk_<*>" + + " size <*>", + 2, + ImmutableList.of( + "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.31.85:50010 is added" + + " to blk_-7017553867379051457 size 67108864", + "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.107.19:50010 is added" + + " to blk_-3249711809227781266 size 67108864")), + rows( + "<*> NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_<*>_<*>_r_<*>_<*>/part<*>" + + " blk_<*>", + 2, + ImmutableList.of( + "BLOCK* NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000296_0/part-00296." + + " blk_-6620182933895093708", + "BLOCK* NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000318_0/part-00318." + + " blk_2096692261399680562")), + rows( + "PacketResponder failed <*> blk_<*>", + 2, + ImmutableList.of( + "PacketResponder failed for blk_6996194389878584395", + "PacketResponder failed for blk_-1547954353065580372"))); + } finally { + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "false")); + } + } + + @Test + public void testBrainAggregationMode_UDAFPushdown_ShowNumberedToken() throws IOException { + // Test UDAF pushdown for patterns BRAIN aggregation mode with numbered tokens + // UDAF pushdown is disabled by default, enable it for this test + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "true")); + try { + JSONObject result = + executeQuery( + String.format( + "source=%s | patterns content method=brain mode=aggregation" + + " show_numbered_token=true variable_count_threshold=5", + TEST_INDEX_HDFS_LOGS)); + + // Verify schema includes tokens field + verifySchema( + result, + schema("patterns_field", "string"), + schema("pattern_count", "bigint"), + schema("tokens", "struct"), + schema("sample_logs", "array")); + + // Verify data rows with tokens + verifyDataRows( + result, + rows( + "Verification succeeded blk_", + 2, + ImmutableMap.of( + "", + ImmutableList.of("for", "for"), + "", + ImmutableList.of("-1547954353065580372", "6996194389878584395")), + ImmutableList.of( + "Verification succeeded for blk_-1547954353065580372", + "Verification succeeded for blk_6996194389878584395")), + rows( + "BLOCK* NameSystem.addStoredBlock: blockMap updated: is added to" + + " blk_ size ", + 2, + ImmutableMap.of( + "", + ImmutableList.of("10.251.31.85:50010", "10.251.107.19:50010"), + "", + ImmutableList.of("67108864", "67108864"), + "", + ImmutableList.of("-7017553867379051457", "-3249711809227781266")), + ImmutableList.of( + "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.31.85:50010 is added" + + " to blk_-7017553867379051457 size 67108864", + "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.107.19:50010 is added" + + " to blk_-3249711809227781266 size 67108864")), + rows( + " NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task___r__/part" + + " blk_", + 2, + ImmutableMap.of( + "", + ImmutableList.of("0", "0"), + "", + ImmutableList.of("000296", "000318"), + "", + ImmutableList.of("-6620182933895093708", "2096692261399680562"), + "", + ImmutableList.of("-00296.", "-00318."), + "", + ImmutableList.of("BLOCK*", "BLOCK*"), + "", + ImmutableList.of("0002", "0002"), + "", + ImmutableList.of("200811092030", "200811092030")), + ImmutableList.of( + "BLOCK* NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000296_0/part-00296." + + " blk_-6620182933895093708", + "BLOCK* NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000318_0/part-00318." + + " blk_2096692261399680562")), + rows( + "PacketResponder failed blk_", + 2, + ImmutableMap.of( + "", + ImmutableList.of("for", "for"), + "", + ImmutableList.of("6996194389878584395", "-1547954353065580372")), + ImmutableList.of( + "PacketResponder failed for blk_6996194389878584395", + "PacketResponder failed for blk_-1547954353065580372"))); + } finally { + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "false")); + } + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index 62eadd7ef5e..7d192317564 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -433,13 +433,44 @@ public void testPatternsSimplePatternMethodWithAggPushDownExplain() throws IOExc @Test public void testPatternsBrainMethodWithAggPushDownExplain() throws IOException { - // TODO: Correct calcite expected result once pushdown is supported - String expected = loadExpectedPlan("explain_patterns_brain_agg_push.yaml"); - assertYamlEqualsIgnoreId( - expected, - explainQueryYaml( - "source=opensearch-sql_test_index_account" - + "| patterns email method=brain mode=aggregation show_numbered_token=true")); + // UDAF pushdown is disabled by default, enable it for this test + Assume.assumeTrue(isCalciteEnabled()); + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "true")); + try { + String expected = loadExpectedPlan("explain_patterns_brain_agg_push.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + "source=opensearch-sql_test_index_account" + + "| patterns email method=brain mode=aggregation show_numbered_token=true")); + } finally { + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "false")); + } + } + + @Test + public void testPatternsBrainMethodWithAggGroupByPushDownExplain() throws IOException { + // Patterns with group by is only supported in Calcite mode with UDAF pushdown enabled + Assume.assumeTrue(isCalciteEnabled()); + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "true")); + try { + String expected = loadExpectedPlan("explain_patterns_brain_agg_group_by_push.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + "source=opensearch-sql_test_index_account| patterns email by gender method=brain" + + " mode=aggregation show_numbered_token=true")); + } finally { + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "false")); + } } @Test diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_group_by_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_group_by_push.yaml new file mode 100644 index 00000000000..007f93e365c --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_group_by_push.yaml @@ -0,0 +1,19 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(gender=[$0], patterns_field=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($2, 'pattern')), ITEM($2, 'sample_logs'), true), 'pattern'))], pattern_count=[SAFE_CAST(ITEM($2, 'pattern_count'))], tokens=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($2, 'pattern')), ITEM($2, 'sample_logs'), true), 'tokens'))], sample_logs=[SAFE_CAST(ITEM($2, 'sample_logs'))]) + LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{1}]) + LogicalAggregate(group=[{0}], patterns_field=[pattern($1, $2, $3, $4)]) + LogicalProject(gender=[$4], email=[$9], $f17=[10], $f18=[100000], $f19=[true]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + Uncollect + LogicalProject(patterns_field=[$cor0.patterns_field]) + LogicalValues(tuples=[[{ 0 }]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableCalc(expr#0..2=[{inputs}], expr#3=['pattern'], expr#4=[ITEM($t2, $t3)], expr#5=[SAFE_CAST($t4)], expr#6=['sample_logs'], expr#7=[ITEM($t2, $t6)], expr#8=[true], expr#9=[PATTERN_PARSER($t5, $t7, $t8)], expr#10=[ITEM($t9, $t3)], expr#11=[SAFE_CAST($t10)], expr#12=['pattern_count'], expr#13=[ITEM($t2, $t12)], expr#14=[SAFE_CAST($t13)], expr#15=['tokens'], expr#16=[ITEM($t9, $t15)], expr#17=[SAFE_CAST($t16)], expr#18=[SAFE_CAST($t7)], gender=[$t0], patterns_field=[$t11], pattern_count=[$t14], tokens=[$t17], sample_logs=[$t18]) + EnumerableCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{1}]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},patterns_field=pattern($1, $2, $3, $4))], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"gender":{"terms":{"field":"gender.keyword","missing_bucket":true,"missing_order":"first","order":"asc"}}}]},"aggregations":{"patterns_field":{"scripted_metric":{"init_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQCCnsKICAib3AiOiB7CiAgICAibmFtZSI6ICJQQVRURVJOX0lOSVRfVURGIiwKICAgICJraW5kIjogIk9USEVSX0ZVTkNUSU9OIiwKICAgICJzeW50YXgiOiAiRlVOQ1RJT04iCiAgfSwKICAib3BlcmFuZHMiOiBbCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAwLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJBTlkiLAogICAgICAgICJudWxsYWJsZSI6IGZhbHNlLAogICAgICAgICJwcmVjaXNpb24iOiAtMSwKICAgICAgICAic2NhbGUiOiAtMjE0NzQ4MzY0OAogICAgICB9CiAgICB9CiAgXSwKICAiY2xhc3MiOiAib3JnLm9wZW5zZWFyY2guc3FsLmV4cHJlc3Npb24uZnVuY3Rpb24uVXNlckRlZmluZWRGdW5jdGlvbkJ1aWxkZXIkMSIsCiAgInR5cGUiOiB7CiAgICAidHlwZSI6ICJBTlkiLAogICAgIm51bGxhYmxlIjogZmFsc2UsCiAgICAicHJlY2lzaW9uIjogLTEsCiAgICAic2NhbGUiOiAtMjE0NzQ4MzY0OAogIH0sCiAgImRldGVybWluaXN0aWMiOiB0cnVlLAogICJkeW5hbWljIjogZmFsc2UKfQ==\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3],"DIGESTS":["state"]}},"map_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQEW3sKICAib3AiOiB7CiAgICAibmFtZSI6ICJQQVRURVJOX0FERF9VREYiLAogICAgImtpbmQiOiAiT1RIRVJfRlVOQ1RJT04iLAogICAgInN5bnRheCI6ICJGVU5DVElPTiIKICB9LAogICJvcGVyYW5kcyI6IFsKICAgIHsKICAgICAgImR5bmFtaWNQYXJhbSI6IDAsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIkFOWSIsCiAgICAgICAgIm51bGxhYmxlIjogZmFsc2UsCiAgICAgICAgInByZWNpc2lvbiI6IC0xLAogICAgICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAxLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJWQVJDSEFSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlLAogICAgICAgICJwcmVjaXNpb24iOiAtMQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogMiwKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiSU5URUdFUiIsCiAgICAgICAgIm51bGxhYmxlIjogdHJ1ZQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogMywKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiSU5URUdFUiIsCiAgICAgICAgIm51bGxhYmxlIjogdHJ1ZQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogNCwKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiSU5URUdFUiIsCiAgICAgICAgIm51bGxhYmxlIjogdHJ1ZQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogNSwKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiRE9VQkxFIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0KICAgIH0KICBdLAogICJjbGFzcyI6ICJvcmcub3BlbnNlYXJjaC5zcWwuZXhwcmVzc2lvbi5mdW5jdGlvbi5Vc2VyRGVmaW5lZEZ1bmN0aW9uQnVpbGRlciQxIiwKICAidHlwZSI6IHsKICAgICJ0eXBlIjogIkFOWSIsCiAgICAibnVsbGFibGUiOiBmYWxzZSwKICAgICJwcmVjaXNpb24iOiAtMSwKICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgfSwKICAiZGV0ZXJtaW5pc3RpYyI6IHRydWUsCiAgImR5bmFtaWMiOiBmYWxzZQp9\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3,0,2,2,2,2],"DIGESTS":["state","email.keyword",10,100000,5,0.3]}},"combine_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQAgHsKICAiZHluYW1pY1BhcmFtIjogMCwKICAidHlwZSI6IHsKICAgICJ0eXBlIjogIkFOWSIsCiAgICAibnVsbGFibGUiOiBmYWxzZSwKICAgICJwcmVjaXNpb24iOiAtMSwKICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgfQp9\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3],"DIGESTS":["state"]}},"reduce_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQEH3sKICAib3AiOiB7CiAgICAibmFtZSI6ICJQQVRURVJOX1JFU1VMVF9VREYiLAogICAgImtpbmQiOiAiT1RIRVJfRlVOQ1RJT04iLAogICAgInN5bnRheCI6ICJGVU5DVElPTiIKICB9LAogICJvcGVyYW5kcyI6IFsKICAgIHsKICAgICAgImR5bmFtaWNQYXJhbSI6IDAsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIkFOWSIsCiAgICAgICAgIm51bGxhYmxlIjogZmFsc2UsCiAgICAgICAgInByZWNpc2lvbiI6IC0xLAogICAgICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAxLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJJTlRFR0VSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAyLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJJTlRFR0VSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAzLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJET1VCTEUiLAogICAgICAgICJudWxsYWJsZSI6IHRydWUKICAgICAgfQogICAgfSwKICAgIHsKICAgICAgImR5bmFtaWNQYXJhbSI6IDQsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIkJPT0xFQU4iLAogICAgICAgICJudWxsYWJsZSI6IHRydWUKICAgICAgfQogICAgfQogIF0sCiAgImNsYXNzIjogIm9yZy5vcGVuc2VhcmNoLnNxbC5leHByZXNzaW9uLmZ1bmN0aW9uLlVzZXJEZWZpbmVkRnVuY3Rpb25CdWlsZGVyJDEiLAogICJ0eXBlIjogewogICAgInR5cGUiOiAiQVJSQVkiLAogICAgIm51bGxhYmxlIjogdHJ1ZSwKICAgICJjb21wb25lbnQiOiB7CiAgICAgICJ0eXBlIjogIkFOWSIsCiAgICAgICJudWxsYWJsZSI6IGZhbHNlLAogICAgICAicHJlY2lzaW9uIjogLTEsCiAgICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgICB9CiAgfSwKICAiZGV0ZXJtaW5pc3RpYyI6IHRydWUsCiAgImR5bmFtaWMiOiBmYWxzZQp9\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3,2,2,2,2],"DIGESTS":["states",10,5,0.3,true]}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) + EnumerableUncollect + EnumerableCalc(expr#0=[{inputs}], expr#1=[$cor0], expr#2=[$t1.patterns_field], patterns_field=[$t2]) + EnumerableValues(tuples=[[{ 0 }]]) diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_push.yaml index 0b2d4584804..ed98865ce43 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_push.yaml +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_push.yaml @@ -1,7 +1,7 @@ calcite: logical: | LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) - LogicalProject(patterns_field=[SAFE_CAST(ITEM($1, 'pattern'))], pattern_count=[SAFE_CAST(ITEM($1, 'pattern_count'))], tokens=[SAFE_CAST(ITEM($1, 'tokens'))], sample_logs=[SAFE_CAST(ITEM($1, 'sample_logs'))]) + LogicalProject(patterns_field=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($1, 'pattern')), ITEM($1, 'sample_logs'), true), 'pattern'))], pattern_count=[SAFE_CAST(ITEM($1, 'pattern_count'))], tokens=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($1, 'pattern')), ITEM($1, 'sample_logs'), true), 'tokens'))], sample_logs=[SAFE_CAST(ITEM($1, 'sample_logs'))]) LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) LogicalAggregate(group=[{}], patterns_field=[pattern($0, $1, $2, $3)]) LogicalProject(email=[$9], $f17=[10], $f18=[100000], $f19=[true]) @@ -11,11 +11,9 @@ calcite: LogicalValues(tuples=[[{ 0 }]]) physical: | EnumerableLimit(fetch=[10000]) - EnumerableCalc(expr#0..1=[{inputs}], expr#2=['pattern'], expr#3=[ITEM($t1, $t2)], expr#4=[SAFE_CAST($t3)], expr#5=['pattern_count'], expr#6=[ITEM($t1, $t5)], expr#7=[SAFE_CAST($t6)], expr#8=['tokens'], expr#9=[ITEM($t1, $t8)], expr#10=[SAFE_CAST($t9)], expr#11=['sample_logs'], expr#12=[ITEM($t1, $t11)], expr#13=[SAFE_CAST($t12)], patterns_field=[$t4], pattern_count=[$t7], tokens=[$t10], sample_logs=[$t13]) + EnumerableCalc(expr#0..1=[{inputs}], expr#2=['pattern'], expr#3=[ITEM($t1, $t2)], expr#4=[SAFE_CAST($t3)], expr#5=['sample_logs'], expr#6=[ITEM($t1, $t5)], expr#7=[true], expr#8=[PATTERN_PARSER($t4, $t6, $t7)], expr#9=[ITEM($t8, $t2)], expr#10=[SAFE_CAST($t9)], expr#11=['pattern_count'], expr#12=[ITEM($t1, $t11)], expr#13=[SAFE_CAST($t12)], expr#14=['tokens'], expr#15=[ITEM($t8, $t14)], expr#16=[SAFE_CAST($t15)], expr#17=[SAFE_CAST($t6)], patterns_field=[$t10], pattern_count=[$t13], tokens=[$t16], sample_logs=[$t17]) EnumerableCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) - EnumerableAggregate(group=[{}], patterns_field=[pattern($0, $1, $2, $3)]) - EnumerableCalc(expr#0=[{inputs}], expr#1=[10], expr#2=[100000], expr#3=[true], proj#0..3=[{exprs}]) - CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[email]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","_source":{"includes":["email"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={},patterns_field=pattern($0, $1, $2, $3))], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"patterns_field":{"scripted_metric":{"init_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQCCnsKICAib3AiOiB7CiAgICAibmFtZSI6ICJQQVRURVJOX0lOSVRfVURGIiwKICAgICJraW5kIjogIk9USEVSX0ZVTkNUSU9OIiwKICAgICJzeW50YXgiOiAiRlVOQ1RJT04iCiAgfSwKICAib3BlcmFuZHMiOiBbCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAwLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJBTlkiLAogICAgICAgICJudWxsYWJsZSI6IGZhbHNlLAogICAgICAgICJwcmVjaXNpb24iOiAtMSwKICAgICAgICAic2NhbGUiOiAtMjE0NzQ4MzY0OAogICAgICB9CiAgICB9CiAgXSwKICAiY2xhc3MiOiAib3JnLm9wZW5zZWFyY2guc3FsLmV4cHJlc3Npb24uZnVuY3Rpb24uVXNlckRlZmluZWRGdW5jdGlvbkJ1aWxkZXIkMSIsCiAgInR5cGUiOiB7CiAgICAidHlwZSI6ICJBTlkiLAogICAgIm51bGxhYmxlIjogZmFsc2UsCiAgICAicHJlY2lzaW9uIjogLTEsCiAgICAic2NhbGUiOiAtMjE0NzQ4MzY0OAogIH0sCiAgImRldGVybWluaXN0aWMiOiB0cnVlLAogICJkeW5hbWljIjogZmFsc2UKfQ==\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3],"DIGESTS":["state"]}},"map_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQEW3sKICAib3AiOiB7CiAgICAibmFtZSI6ICJQQVRURVJOX0FERF9VREYiLAogICAgImtpbmQiOiAiT1RIRVJfRlVOQ1RJT04iLAogICAgInN5bnRheCI6ICJGVU5DVElPTiIKICB9LAogICJvcGVyYW5kcyI6IFsKICAgIHsKICAgICAgImR5bmFtaWNQYXJhbSI6IDAsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIkFOWSIsCiAgICAgICAgIm51bGxhYmxlIjogZmFsc2UsCiAgICAgICAgInByZWNpc2lvbiI6IC0xLAogICAgICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAxLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJWQVJDSEFSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlLAogICAgICAgICJwcmVjaXNpb24iOiAtMQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogMiwKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiSU5URUdFUiIsCiAgICAgICAgIm51bGxhYmxlIjogdHJ1ZQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogMywKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiSU5URUdFUiIsCiAgICAgICAgIm51bGxhYmxlIjogdHJ1ZQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogNCwKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiSU5URUdFUiIsCiAgICAgICAgIm51bGxhYmxlIjogdHJ1ZQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogNSwKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiRE9VQkxFIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0KICAgIH0KICBdLAogICJjbGFzcyI6ICJvcmcub3BlbnNlYXJjaC5zcWwuZXhwcmVzc2lvbi5mdW5jdGlvbi5Vc2VyRGVmaW5lZEZ1bmN0aW9uQnVpbGRlciQxIiwKICAidHlwZSI6IHsKICAgICJ0eXBlIjogIkFOWSIsCiAgICAibnVsbGFibGUiOiBmYWxzZSwKICAgICJwcmVjaXNpb24iOiAtMSwKICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgfSwKICAiZGV0ZXJtaW5pc3RpYyI6IHRydWUsCiAgImR5bmFtaWMiOiBmYWxzZQp9\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3,0,2,2,2,2],"DIGESTS":["state","email.keyword",10,100000,5,0.3]}},"combine_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQAgHsKICAiZHluYW1pY1BhcmFtIjogMCwKICAidHlwZSI6IHsKICAgICJ0eXBlIjogIkFOWSIsCiAgICAibnVsbGFibGUiOiBmYWxzZSwKICAgICJwcmVjaXNpb24iOiAtMSwKICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgfQp9\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3],"DIGESTS":["state"]}},"reduce_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQEH3sKICAib3AiOiB7CiAgICAibmFtZSI6ICJQQVRURVJOX1JFU1VMVF9VREYiLAogICAgImtpbmQiOiAiT1RIRVJfRlVOQ1RJT04iLAogICAgInN5bnRheCI6ICJGVU5DVElPTiIKICB9LAogICJvcGVyYW5kcyI6IFsKICAgIHsKICAgICAgImR5bmFtaWNQYXJhbSI6IDAsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIkFOWSIsCiAgICAgICAgIm51bGxhYmxlIjogZmFsc2UsCiAgICAgICAgInByZWNpc2lvbiI6IC0xLAogICAgICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAxLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJJTlRFR0VSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAyLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJJTlRFR0VSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAzLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJET1VCTEUiLAogICAgICAgICJudWxsYWJsZSI6IHRydWUKICAgICAgfQogICAgfSwKICAgIHsKICAgICAgImR5bmFtaWNQYXJhbSI6IDQsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIkJPT0xFQU4iLAogICAgICAgICJudWxsYWJsZSI6IHRydWUKICAgICAgfQogICAgfQogIF0sCiAgImNsYXNzIjogIm9yZy5vcGVuc2VhcmNoLnNxbC5leHByZXNzaW9uLmZ1bmN0aW9uLlVzZXJEZWZpbmVkRnVuY3Rpb25CdWlsZGVyJDEiLAogICJ0eXBlIjogewogICAgInR5cGUiOiAiQVJSQVkiLAogICAgIm51bGxhYmxlIjogdHJ1ZSwKICAgICJjb21wb25lbnQiOiB7CiAgICAgICJ0eXBlIjogIkFOWSIsCiAgICAgICJudWxsYWJsZSI6IGZhbHNlLAogICAgICAicHJlY2lzaW9uIjogLTEsCiAgICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgICB9CiAgfSwKICAiZGV0ZXJtaW5pc3RpYyI6IHRydWUsCiAgImR5bmFtaWMiOiBmYWxzZQp9\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3,2,2,2,2],"DIGESTS":["states",10,5,0.3,true]}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) EnumerableUncollect EnumerableCalc(expr#0=[{inputs}], expr#1=[$cor0], expr#2=[$t1.patterns_field], patterns_field=[$t2]) EnumerableValues(tuples=[[{ 0 }]]) diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_group_by_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_group_by_push.yaml new file mode 100644 index 00000000000..acb0a0d0113 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_group_by_push.yaml @@ -0,0 +1,21 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(gender=[$0], patterns_field=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($2, 'pattern')), ITEM($2, 'sample_logs'), true), 'pattern'))], pattern_count=[SAFE_CAST(ITEM($2, 'pattern_count'))], tokens=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($2, 'pattern')), ITEM($2, 'sample_logs'), true), 'tokens'))], sample_logs=[SAFE_CAST(ITEM($2, 'sample_logs'))]) + LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{1}]) + LogicalAggregate(group=[{0}], patterns_field=[pattern($1, $2, $3, $4)]) + LogicalProject(gender=[$4], email=[$9], $f17=[10], $f18=[100000], $f19=[true]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + Uncollect + LogicalProject(patterns_field=[$cor0.patterns_field]) + LogicalValues(tuples=[[{ 0 }]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableCalc(expr#0..2=[{inputs}], expr#3=['pattern'], expr#4=[ITEM($t2, $t3)], expr#5=[SAFE_CAST($t4)], expr#6=['sample_logs'], expr#7=[ITEM($t2, $t6)], expr#8=[true], expr#9=[PATTERN_PARSER($t5, $t7, $t8)], expr#10=[ITEM($t9, $t3)], expr#11=[SAFE_CAST($t10)], expr#12=['pattern_count'], expr#13=[ITEM($t2, $t12)], expr#14=[SAFE_CAST($t13)], expr#15=['tokens'], expr#16=[ITEM($t9, $t15)], expr#17=[SAFE_CAST($t16)], expr#18=[SAFE_CAST($t7)], gender=[$t0], patterns_field=[$t11], pattern_count=[$t14], tokens=[$t17], sample_logs=[$t18]) + EnumerableCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{1}]) + EnumerableAggregate(group=[{0}], patterns_field=[pattern($1, $2, $3, $4)]) + EnumerableCalc(expr#0..16=[{inputs}], expr#17=[10], expr#18=[100000], expr#19=[true], gender=[$t4], email=[$t9], $f17=[$t17], $f18=[$t18], $f19=[$t19]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + EnumerableUncollect + EnumerableCalc(expr#0=[{inputs}], expr#1=[$cor0], expr#2=[$t1.patterns_field], patterns_field=[$t2]) + EnumerableValues(tuples=[[{ 0 }]]) diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_push.yaml index bc9bc027e34..d58ef2abc1d 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_push.yaml +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_push.yaml @@ -1,7 +1,7 @@ calcite: logical: | LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) - LogicalProject(patterns_field=[SAFE_CAST(ITEM($1, 'pattern'))], pattern_count=[SAFE_CAST(ITEM($1, 'pattern_count'))], tokens=[SAFE_CAST(ITEM($1, 'tokens'))], sample_logs=[SAFE_CAST(ITEM($1, 'sample_logs'))]) + LogicalProject(patterns_field=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($1, 'pattern')), ITEM($1, 'sample_logs'), true), 'pattern'))], pattern_count=[SAFE_CAST(ITEM($1, 'pattern_count'))], tokens=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($1, 'pattern')), ITEM($1, 'sample_logs'), true), 'tokens'))], sample_logs=[SAFE_CAST(ITEM($1, 'sample_logs'))]) LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) LogicalAggregate(group=[{}], patterns_field=[pattern($0, $1, $2, $3)]) LogicalProject(email=[$9], $f17=[10], $f18=[100000], $f19=[true]) @@ -11,7 +11,7 @@ calcite: LogicalValues(tuples=[[{ 0 }]]) physical: | EnumerableLimit(fetch=[10000]) - EnumerableCalc(expr#0..1=[{inputs}], expr#2=['pattern'], expr#3=[ITEM($t1, $t2)], expr#4=[SAFE_CAST($t3)], expr#5=['pattern_count'], expr#6=[ITEM($t1, $t5)], expr#7=[SAFE_CAST($t6)], expr#8=['tokens'], expr#9=[ITEM($t1, $t8)], expr#10=[SAFE_CAST($t9)], expr#11=['sample_logs'], expr#12=[ITEM($t1, $t11)], expr#13=[SAFE_CAST($t12)], patterns_field=[$t4], pattern_count=[$t7], tokens=[$t10], sample_logs=[$t13]) + EnumerableCalc(expr#0..1=[{inputs}], expr#2=['pattern'], expr#3=[ITEM($t1, $t2)], expr#4=[SAFE_CAST($t3)], expr#5=['sample_logs'], expr#6=[ITEM($t1, $t5)], expr#7=[true], expr#8=[PATTERN_PARSER($t4, $t6, $t7)], expr#9=[ITEM($t8, $t2)], expr#10=[SAFE_CAST($t9)], expr#11=['pattern_count'], expr#12=[ITEM($t1, $t11)], expr#13=[SAFE_CAST($t12)], expr#14=['tokens'], expr#15=[ITEM($t8, $t14)], expr#16=[SAFE_CAST($t15)], expr#17=[SAFE_CAST($t6)], patterns_field=[$t10], pattern_count=[$t13], tokens=[$t16], sample_logs=[$t17]) EnumerableCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) EnumerableAggregate(group=[{}], patterns_field=[pattern($0, $1, $2, $3)]) EnumerableCalc(expr#0..16=[{inputs}], expr#17=[10], expr#18=[100000], expr#19=[true], email=[$t9], $f17=[$t17], $f18=[$t18], $f19=[$t19]) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java index d772b3e603b..35ea858cf9c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java @@ -195,6 +195,19 @@ private ExprValue parse( return ExprNullValue.of(); } + // Check for arrays first, even if field type is not defined in mapping. + // This handles nested arrays in aggregation results where inner fields + // (like sample_logs in pattern aggregation) may not have type mappings. + // Exclude GeoPoint types as they have special array handling (e.g., [lon, lat] format). + if (content.isArray() + && (fieldType.isEmpty() || supportArrays) + && !fieldType + .map(t -> t.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.GeoPoint))) + .orElse(false)) { + ExprType type = fieldType.orElse(ARRAY); + return parseArray(content, field, type, supportArrays); + } + // Field type may be not defined in mapping if users have disabled dynamic mapping. // Then try to parse content directly based on the value itself // Besides, sub-fields of generated objects are also of type UNDEFINED. We parse the content @@ -481,24 +494,31 @@ public JsonPath getChildPath() { */ private ExprValue parseArray( Content content, String prefix, ExprType type, boolean supportArrays) { - List result = new ArrayList<>(); - // ARRAY is mapped to nested but can take the json structure of an Object. if (content.objectValue() instanceof ObjectNode) { + List result = new ArrayList<>(); result.add(parseStruct(content, prefix, supportArrays)); - // non-object type arrays are only supported when parsing inner_hits of OS response. - } else if (!(type instanceof OpenSearchDataType + return new ExprCollectionValue(result); + } + + // Get the array iterator once and reuse it + var arrayIterator = content.array(); + + // Handle empty arrays early + if (!arrayIterator.hasNext()) { + return supportArrays ? new ExprCollectionValue(List.of()) : ExprNullValue.of(); + } + + // non-object type arrays are only supported when parsing inner_hits of OS response. + if (!(type instanceof OpenSearchDataType && ((OpenSearchDataType) type).getExprType().equals(ARRAY)) && !supportArrays) { - return parseInnerArrayValue(content.array().next(), prefix, type, supportArrays); - } else { - content - .array() - .forEachRemaining( - v -> { - result.add(parseInnerArrayValue(v, prefix, type, supportArrays)); - }); + return parseInnerArrayValue(arrayIterator.next(), prefix, type, supportArrays); } + + List result = new ArrayList<>(); + arrayIterator.forEachRemaining( + v -> result.add(parseInnerArrayValue(v, prefix, type, supportArrays))); return new ExprCollectionValue(result); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java index 247f40b3733..d5c31418bd5 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java @@ -98,6 +98,7 @@ import org.opensearch.sql.opensearch.response.agg.StatsParser; import org.opensearch.sql.opensearch.response.agg.TopHitsParser; import org.opensearch.sql.opensearch.storage.script.aggregation.dsl.CompositeAggregationBuilder; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.ScriptedMetricUDAFRegistry; import org.opensearch.sql.utils.Utils; /** @@ -141,6 +142,7 @@ public static class AggregateBuilderHelper { final RelOptCluster cluster; final boolean bucketNullable; final int queryBucketSize; + final boolean udafPushdownEnabled; > T build(RexNode node, T aggBuilder) { return build(node, aggBuilder::field, aggBuilder::script); @@ -613,6 +615,24 @@ yield switch (functionName) { !args.isEmpty() ? args.getFirst().getKey() : null, AggregationBuilders.cardinality(aggName)), new SingleValueParser(aggName)); + case INTERNAL_PATTERN -> { + if (!helper.udafPushdownEnabled) { + throw new AggregateAnalyzerException( + "UDAF pushdown is disabled. Enable it via cluster setting" + + " 'plugins.calcite.udaf_pushdown.enabled'"); + } + yield ScriptedMetricUDAFRegistry.INSTANCE + .lookup(functionName) + .map( + udaf -> + udaf.buildAggregation( + args, aggName, helper.cluster, helper.rowType, helper.fieldTypes)) + .orElseThrow( + () -> + new AggregateAnalyzerException( + String.format( + "No scripted metric UDAF registered for %s", functionName))); + } default -> throw new AggregateAnalyzer.AggregateAnalyzerException( String.format("Unsupported push-down aggregator %s", aggCall.getAggregation())); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/ScriptedMetricParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/ScriptedMetricParser.java new file mode 100644 index 00000000000..a06399711e6 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/ScriptedMetricParser.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.List; +import java.util.Map; +import lombok.EqualsAndHashCode; +import lombok.RequiredArgsConstructor; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.metrics.ScriptedMetric; + +/** + * Parser for scripted metric aggregation responses. Extracts the final result from the reduce phase + * of a scripted metric aggregation. + */ +@EqualsAndHashCode +@RequiredArgsConstructor +public class ScriptedMetricParser implements MetricParser { + + private final String name; + + @Override + public String getName() { + return name; + } + + @Override + @SuppressWarnings("unchecked") + public List> parse(Aggregation agg) { + if (agg instanceof ScriptedMetric scriptedMetric) { + // Extract the final result from the reduce script + Object result = scriptedMetric.aggregation(); + // The reduce script for UDAF aggregation returns List> + // which represents the array of results. We wrap this in a single Map with + // the aggregation field name as key, so the response is 1 row containing + // the array that can be expanded by Uncollect in the query plan. + if (result instanceof List) { + return List.of(Map.of(name, result)); + } + throw new IllegalArgumentException( + String.format( + "Expected List> from scripted metric but got %s", + result == null ? "null" : result.getClass().getSimpleName())); + } + throw new IllegalArgumentException( + String.format( + "Expected ScriptedMetric aggregation but got %s", + agg == null ? "null" : agg.getClass().getSimpleName())); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index bd8001f589d..31ffb28924b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -172,6 +172,13 @@ public class OpenSearchSettings extends Settings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting CALCITE_UDAF_PUSHDOWN_ENABLED_SETTING = + Setting.boolSetting( + Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic); + public static final Setting CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR_SETTING = Setting.doubleSetting( Key.CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR.getKeyValue(), @@ -455,6 +462,12 @@ public OpenSearchSettings(ClusterSettings clusterSettings) { Key.CALCITE_PUSHDOWN_ENABLED, CALCITE_PUSHDOWN_ENABLED_SETTING, new Updater(Key.CALCITE_PUSHDOWN_ENABLED)); + register( + settingBuilder, + clusterSettings, + Key.CALCITE_UDAF_PUSHDOWN_ENABLED, + CALCITE_UDAF_PUSHDOWN_ENABLED_SETTING, + new Updater(Key.CALCITE_UDAF_PUSHDOWN_ENABLED)); register( settingBuilder, clusterSettings, @@ -656,6 +669,7 @@ public static List> pluginSettings() { .add(CALCITE_ENGINE_ENABLED_SETTING) .add(CALCITE_FALLBACK_ALLOWED_SETTING) .add(CALCITE_PUSHDOWN_ENABLED_SETTING) + .add(CALCITE_UDAF_PUSHDOWN_ENABLED_SETTING) .add(CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR_SETTING) .add(CALCITE_SUPPORT_ALL_JOIN_TYPES_SETTING) .add(DEFAULT_PATTERN_METHOD_SETTING) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java index dbe8306d4b2..a2b714d6360 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java @@ -386,9 +386,16 @@ public AbstractRelNode pushDownAggregate(Aggregate aggregate, @Nullable Project } int queryBucketSize = osIndex.getQueryBucketSize(); boolean bucketNullable = !PPLHintUtils.ignoreNullBucket(aggregate); + boolean udafPushdownEnabled = + osIndex.getSettings().getSettingValue(Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED); AggregateAnalyzer.AggregateBuilderHelper helper = new AggregateAnalyzer.AggregateBuilderHelper( - getRowType(), fieldTypes, getCluster(), bucketNullable, queryBucketSize); + getRowType(), + fieldTypes, + getCluster(), + bucketNullable, + queryBucketSize, + udafPushdownEnabled); final Pair, OpenSearchAggregationResponseParser> builderAndParser = AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); Map extendedTypeMapping = diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/CalciteScriptEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/CalciteScriptEngine.java index 224d7019ec2..7e2c3231e74 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/CalciteScriptEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/CalciteScriptEngine.java @@ -75,12 +75,17 @@ import org.opensearch.script.NumberSortScript; import org.opensearch.script.ScriptContext; import org.opensearch.script.ScriptEngine; +import org.opensearch.script.ScriptedMetricAggContexts; import org.opensearch.script.StringSortScript; import org.opensearch.search.lookup.SourceLookup; import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.opensearch.storage.script.aggregation.CalciteAggregationScriptFactory; import org.opensearch.sql.opensearch.storage.script.field.CalciteFieldScriptFactory; import org.opensearch.sql.opensearch.storage.script.filter.CalciteFilterScriptFactory; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.CalciteScriptedMetricCombineScriptFactory; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.CalciteScriptedMetricInitScriptFactory; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.CalciteScriptedMetricMapScriptFactory; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.CalciteScriptedMetricReduceScriptFactory; import org.opensearch.sql.opensearch.storage.script.sort.CalciteNumberSortScriptFactory; import org.opensearch.sql.opensearch.storage.script.sort.CalciteStringSortScriptFactory; import org.opensearch.sql.opensearch.storage.serde.RelJsonSerializer; @@ -113,6 +118,18 @@ public CalciteScriptEngine(RelOptCluster relOptCluster) { .put(NumberSortScript.CONTEXT, CalciteNumberSortScriptFactory::new) .put(StringSortScript.CONTEXT, CalciteStringSortScriptFactory::new) .put(FieldScript.CONTEXT, CalciteFieldScriptFactory::new) + .put( + ScriptedMetricAggContexts.InitScript.CONTEXT, + (func, type) -> new CalciteScriptedMetricInitScriptFactory(func)) + .put( + ScriptedMetricAggContexts.MapScript.CONTEXT, + (func, type) -> new CalciteScriptedMetricMapScriptFactory(func)) + .put( + ScriptedMetricAggContexts.CombineScript.CONTEXT, + (func, type) -> new CalciteScriptedMetricCombineScriptFactory(func)) + .put( + ScriptedMetricAggContexts.ReduceScript.CONTEXT, + (func, type) -> new CalciteScriptedMetricReduceScriptFactory(func)) .build(); @Override @@ -214,6 +231,11 @@ public Object get(String name) { case DOC_VALUE -> getFromDocValue((String) digests.get(index)); case SOURCE -> getFromSource((String) digests.get(index)); case LITERAL -> digests.get(index); + case SPECIAL_VARIABLE -> + // Special variables (state, states) are not in this context + // They should be handled by ScriptedMetricDataContext + throw new IllegalStateException( + "SPECIAL_VARIABLE " + digests.get(index) + " not supported in this context"); }; } catch (Exception e) { throw new IllegalStateException("Failed to get value for parameter " + name); @@ -245,7 +267,8 @@ public Object getFromSource(String name) { public enum Source { DOC_VALUE(0), SOURCE(1), - LITERAL(2); + LITERAL(2), + SPECIAL_VARIABLE(3); // For scripted metric state/states variables private final int value; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricCombineScriptFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricCombineScriptFactory.java new file mode 100644 index 00000000000..48a9300bcb9 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricCombineScriptFactory.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; + +import java.util.Map; +import lombok.RequiredArgsConstructor; +import org.apache.calcite.DataContext; +import org.apache.calcite.linq4j.function.Function1; +import org.opensearch.script.ScriptedMetricAggContexts; + +/** + * Factory for Calcite-based CombineScript in scripted metric aggregations. Combines shard-level + * accumulators using RexNode expressions. + */ +@RequiredArgsConstructor +public class CalciteScriptedMetricCombineScriptFactory + implements ScriptedMetricAggContexts.CombineScript.Factory { + + private final Function1 function; + + @Override + public ScriptedMetricAggContexts.CombineScript newInstance( + Map params, Map state) { + return new CalciteScriptedMetricCombineScript(function, params, state); + } + + /** CombineScript that executes compiled RexNode expression. */ + private static class CalciteScriptedMetricCombineScript + extends ScriptedMetricAggContexts.CombineScript { + + private final Function1 function; + + public CalciteScriptedMetricCombineScript( + Function1 function, + Map params, + Map state) { + super(params, state); + this.function = function; + } + + @Override + public Object execute() { + // Create data context for combine script + @SuppressWarnings("unchecked") + Map state = (Map) getState(); + DataContext dataContext = new ScriptedMetricDataContext.CombineContext(getParams(), state); + + // Execute the compiled RexNode expression + Object[] result = function.apply(dataContext); + + // Return the combined result + return (result != null && result.length > 0) ? result[0] : getState(); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricInitScriptFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricInitScriptFactory.java new file mode 100644 index 00000000000..13d9dcbbbbf --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricInitScriptFactory.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; + +import java.util.Map; +import lombok.RequiredArgsConstructor; +import org.apache.calcite.DataContext; +import org.apache.calcite.linq4j.function.Function1; +import org.opensearch.script.ScriptedMetricAggContexts; + +/** + * Factory for Calcite-based InitScript in scripted metric aggregations. Executes RexNode + * expressions compiled to Java code via CalciteScriptEngine. + */ +@RequiredArgsConstructor +public class CalciteScriptedMetricInitScriptFactory + implements ScriptedMetricAggContexts.InitScript.Factory { + + private final Function1 function; + + @Override + public ScriptedMetricAggContexts.InitScript newInstance( + Map params, Map state) { + return new CalciteScriptedMetricInitScript(function, params, state); + } + + /** InitScript that executes compiled RexNode expression. */ + private static class CalciteScriptedMetricInitScript + extends ScriptedMetricAggContexts.InitScript { + + private final Function1 function; + + public CalciteScriptedMetricInitScript( + Function1 function, + Map params, + Map state) { + super(params, state); + this.function = function; + } + + @Override + public void execute() { + // Create data context for init script (no document access, only params) + @SuppressWarnings("unchecked") + Map state = (Map) getState(); + DataContext dataContext = new ScriptedMetricDataContext.InitContext(getParams(), state); + + // Execute the compiled RexNode expression and merge result into state + Object[] result = function.apply(dataContext); + ScriptedMetricDataContext.mergeResultIntoState(result, state); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricMapScriptFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricMapScriptFactory.java new file mode 100644 index 00000000000..70128748a8f --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricMapScriptFactory.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; + +import java.util.Map; +import lombok.RequiredArgsConstructor; +import org.apache.calcite.DataContext; +import org.apache.calcite.linq4j.function.Function1; +import org.apache.lucene.index.LeafReaderContext; +import org.opensearch.script.ScriptedMetricAggContexts; +import org.opensearch.search.lookup.SearchLookup; + +/** + * Factory for Calcite-based MapScript in scripted metric aggregations. Executes RexNode expressions + * compiled to Java code with document field access. + */ +@RequiredArgsConstructor +public class CalciteScriptedMetricMapScriptFactory + implements ScriptedMetricAggContexts.MapScript.Factory { + + private final Function1 function; + + @Override + public ScriptedMetricAggContexts.MapScript.LeafFactory newFactory( + Map params, Map state, SearchLookup lookup) { + return new CalciteMapScriptLeafFactory(function, params, state, lookup); + } + + /** Leaf factory that creates MapScript instances for each segment. */ + @RequiredArgsConstructor + private static class CalciteMapScriptLeafFactory + implements ScriptedMetricAggContexts.MapScript.LeafFactory { + + private final Function1 function; + private final Map params; + private final Map state; + private final SearchLookup lookup; + + @Override + public ScriptedMetricAggContexts.MapScript newInstance(LeafReaderContext ctx) { + return new CalciteScriptedMetricMapScript(function, params, state, lookup, ctx); + } + } + + /** + * MapScript that executes compiled RexNode expression for each document. + * + *

The DataContext is created once in the constructor and reused for all documents to avoid + * object allocation overhead per document. This is safe because: + * + *

    + *
  • params, state references don't change between documents + *
  • doc and sourceLookup are updated internally by OpenSearch before each execute() call + *
  • sources and digests (derived from params) are the same for all documents + *
+ */ + private static class CalciteScriptedMetricMapScript extends ScriptedMetricAggContexts.MapScript { + + private final Function1 function; + private final DataContext dataContext; + + public CalciteScriptedMetricMapScript( + Function1 function, + Map params, + Map state, + SearchLookup lookup, + LeafReaderContext leafContext) { + super(params, state, lookup, leafContext); + this.function = function; + // Create DataContext once and reuse for all documents in this segment. + // OpenSearch updates doc values and source lookup internally before each execute(). + this.dataContext = + new ScriptedMetricDataContext.MapContext( + params, state, getDoc(), lookup.getLeafSearchLookup(leafContext).source()); + } + + @Override + @SuppressWarnings("unchecked") + public void execute() { + // Execute the compiled RexNode expression (reusing the same DataContext) + Object[] result = function.apply(dataContext); + ScriptedMetricDataContext.mergeResultIntoState(result, (Map) getState()); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricReduceScriptFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricReduceScriptFactory.java new file mode 100644 index 00000000000..cb8a93cdecf --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricReduceScriptFactory.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; + +import java.util.List; +import java.util.Map; +import lombok.RequiredArgsConstructor; +import org.apache.calcite.DataContext; +import org.apache.calcite.linq4j.function.Function1; +import org.opensearch.script.ScriptedMetricAggContexts; + +/** + * Factory for Calcite-based ReduceScript in scripted metric aggregations. Produces final result + * from all shard-level combined results using RexNode expressions. + */ +@RequiredArgsConstructor +public class CalciteScriptedMetricReduceScriptFactory + implements ScriptedMetricAggContexts.ReduceScript.Factory { + + private final Function1 function; + + @Override + public ScriptedMetricAggContexts.ReduceScript newInstance( + Map params, List states) { + return new CalciteScriptedMetricReduceScript(function, params, states); + } + + /** ReduceScript that executes compiled RexNode expression. */ + private static class CalciteScriptedMetricReduceScript + extends ScriptedMetricAggContexts.ReduceScript { + + private final Function1 function; + + public CalciteScriptedMetricReduceScript( + Function1 function, + Map params, + List states) { + super(params, states); + this.function = function; + } + + @Override + public Object execute() { + // Create data context for reduce script + DataContext dataContext = + new ScriptedMetricDataContext.ReduceContext(getParams(), getStates()); + + // Execute the compiled RexNode expression + Object[] result = function.apply(dataContext); + + // Return the final result + return (result != null && result.length > 0) ? result[0] : getStates(); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricDataContext.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricDataContext.java new file mode 100644 index 00000000000..f95feff6f65 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricDataContext.java @@ -0,0 +1,209 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; + +import static org.opensearch.sql.opensearch.storage.serde.ScriptParameterHelper.DIGESTS; +import static org.opensearch.sql.opensearch.storage.serde.ScriptParameterHelper.SOURCES; + +import java.util.List; +import java.util.Map; +import org.apache.calcite.DataContext; +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.linq4j.QueryProvider; +import org.apache.calcite.schema.SchemaPlus; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.opensearch.index.fielddata.ScriptDocValues; +import org.opensearch.search.lookup.SourceLookup; +import org.opensearch.sql.opensearch.storage.script.CalciteScriptEngine.Source; + +/** + * DataContext implementations for scripted metric aggregation script phases. Provides access to + * params, state/states variables, and document fields depending on the phase. + * + *

Each script phase has its own context: + * + *

    + *
  • {@link InitContext} - init_script: params and state + *
  • {@link MapContext} - map_script: params, state, doc values, and source lookup + *
  • {@link CombineContext} - combine_script: params and state + *
  • {@link ReduceContext} - reduce_script: params and states (array from all shards) + *
+ */ +public abstract class ScriptedMetricDataContext implements DataContext { + + protected final Map params; + protected final List sources; + protected final List digests; + + protected ScriptedMetricDataContext(Map params) { + this.params = params; + this.sources = ((List) params.get(SOURCES)).stream().map(Source::fromValue).toList(); + this.digests = (List) params.get(DIGESTS); + } + + @Override + public @Nullable SchemaPlus getRootSchema() { + return null; + } + + @Override + public JavaTypeFactory getTypeFactory() { + return null; + } + + @Override + public QueryProvider getQueryProvider() { + return null; + } + + /** + * Merges the execution result into the state map. This is a common operation used in init_script + * and map_script phases to update the accumulator state. + * + *

If the result is a Map, its entries are merged into the state. Otherwise, the result is + * stored under the "accumulator" key. + * + * @param result The result array from function execution (may be null or empty) + * @param state The state map to update + */ + @SuppressWarnings("unchecked") + public static void mergeResultIntoState(Object[] result, Map state) { + if (result != null && result.length > 0) { + if (result[0] instanceof Map) { + state.putAll((Map) result[0]); + } else { + state.put("accumulator", result[0]); + } + } + } + + /** + * Parse dynamic parameter index from name pattern "?N". + * + * @param name The parameter name (expected format: "?0", "?1", etc.) + * @return The parameter index + * @throws IllegalArgumentException if name doesn't match expected pattern or is malformed + */ + protected int parseDynamicParamIndex(String name) { + if (!name.startsWith("?")) { + throw new IllegalArgumentException( + "Unexpected parameter name format: " + name + ". Expected '?N' pattern."); + } + int index; + try { + index = Integer.parseInt(name.substring(1)); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "Malformed parameter name '" + name + "'. Expected '?N' pattern where N is an integer.", + e); + } + if (index >= sources.size()) { + throw new IllegalArgumentException( + "Parameter index " + index + " out of bounds. Sources size: " + sources.size()); + } + return index; + } + + /** + * Base class for init and combine phases that share identical get() logic. Both phases only have + * access to params and state (no doc values). + */ + protected abstract static class StateOnlyContext extends ScriptedMetricDataContext { + protected final Map state; + + protected StateOnlyContext(Map params, Map state) { + super(params); + this.state = state; + } + + @Override + public Object get(String name) { + int index = parseDynamicParamIndex(name); + return switch (sources.get(index)) { + case SPECIAL_VARIABLE -> state; + case LITERAL -> digests.get(index); + default -> + throw new IllegalStateException( + "Unexpected source type " + sources.get(index) + " in StateOnlyContext"); + }; + } + } + + /** DataContext for InitScript phase - provides params and state. */ + public static class InitContext extends StateOnlyContext { + public InitContext(Map params, Map state) { + super(params, state); + } + } + + /** DataContext for CombineScript phase - provides params and state. */ + public static class CombineContext extends StateOnlyContext { + public CombineContext(Map params, Map state) { + super(params, state); + } + } + + /** DataContext for MapScript phase - provides params, state, doc values, and source lookup. */ + public static class MapContext extends ScriptedMetricDataContext { + private final Map state; + private final Map> doc; + private final SourceLookup sourceLookup; + + public MapContext( + Map params, + Map state, + Map> doc, + SourceLookup sourceLookup) { + super(params); + this.state = state; + this.doc = doc; + this.sourceLookup = sourceLookup; + } + + @Override + public Object get(String name) { + int index = parseDynamicParamIndex(name); + return switch (sources.get(index)) { + case SPECIAL_VARIABLE -> state; + case LITERAL -> digests.get(index); + case DOC_VALUE -> getDocValue((String) digests.get(index)); + case SOURCE -> sourceLookup != null ? sourceLookup.get((String) digests.get(index)) : null; + }; + } + + private Object getDocValue(String fieldName) { + if (doc != null && doc.containsKey(fieldName)) { + ScriptDocValues docValue = doc.get(fieldName); + if (docValue != null && !docValue.isEmpty()) { + return docValue.get(0); + } + } + return null; + } + } + + /** DataContext for ReduceScript phase - provides params and states array from all shards. */ + public static class ReduceContext extends ScriptedMetricDataContext { + private final List states; + + public ReduceContext(Map params, List states) { + super(params); + this.states = states; + } + + @Override + public Object get(String name) { + int index = parseDynamicParamIndex(name); + return switch (sources.get(index)) { + case SPECIAL_VARIABLE -> states; + case LITERAL -> digests.get(index); + default -> + throw new IllegalStateException( + "Unexpected source type " + sources.get(index) + " in ReduceContext"); + }; + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricUDAF.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricUDAF.java new file mode 100644 index 00000000000..8bec1db3520 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricUDAF.java @@ -0,0 +1,235 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; + +import java.util.List; +import java.util.Map; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.script.Script; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.opensearch.response.agg.MetricParser; +import org.opensearch.sql.opensearch.response.agg.ScriptedMetricParser; +import org.opensearch.sql.opensearch.storage.script.CompoundedScriptEngine; +import org.opensearch.sql.opensearch.storage.serde.RelJsonSerializer; +import org.opensearch.sql.opensearch.storage.serde.ScriptParameterHelper; +import org.opensearch.sql.opensearch.storage.serde.SerializationWrapper; + +/** + * Interface for User-Defined Aggregate Functions (UDAFs) that can be pushed down to OpenSearch as + * scripted metric aggregations. + * + *

A scripted metric aggregation has four phases: + * + *

    + *
  • init_script: Initializes the accumulator state on each shard + *
  • map_script: Processes each document, updating the accumulator + *
  • combine_script: Combines shard-level states (runs on each shard) + *
  • reduce_script: Produces final result from all shard states (runs on coordinator) + *
+ * + *

Implementations should encapsulate all domain-specific logic for a particular UDAF, keeping + * the AggregateAnalyzer generic and reusable. + */ +public interface ScriptedMetricUDAF { + + /** + * Returns the function name this UDAF handles. + * + * @return The BuiltinFunctionName that this UDAF implements + */ + BuiltinFunctionName getFunctionName(); + + /** + * Build the init_script RexNode for initializing accumulator state. + * + * @param context The script context containing builders and utilities + * @return RexNode representing the init script expression + */ + RexNode buildInitScript(ScriptContext context); + + /** + * Build the map_script RexNode for processing each document. + * + * @param context The script context containing builders and utilities + * @param args The arguments from the aggregate call + * @return RexNode representing the map script expression + */ + RexNode buildMapScript(ScriptContext context, List args); + + /** + * Build the combine_script RexNode for combining shard-level states. + * + * @param context The script context containing builders and utilities + * @return RexNode representing the combine script expression + */ + RexNode buildCombineScript(ScriptContext context); + + /** + * Build the reduce_script RexNode for producing final result. + * + * @param context The script context containing builders and utilities + * @param args The arguments from the aggregate call + * @return RexNode representing the reduce script expression + */ + RexNode buildReduceScript(ScriptContext context, List args); + + /** + * Context object providing utilities for script generation. Each script phase gets its own + * context with isolated parameter helpers. + */ + class ScriptContext { + private final RexBuilder rexBuilder; + private final ScriptParameterHelper paramHelper; + private final RelOptCluster cluster; + private final RelDataType rowType; + private final Map fieldTypes; + + public ScriptContext( + RexBuilder rexBuilder, + ScriptParameterHelper paramHelper, + RelOptCluster cluster, + RelDataType rowType, + Map fieldTypes) { + this.rexBuilder = rexBuilder; + this.paramHelper = paramHelper; + this.cluster = cluster; + this.rowType = rowType; + this.fieldTypes = fieldTypes; + } + + public RexBuilder getRexBuilder() { + return rexBuilder; + } + + public ScriptParameterHelper getParamHelper() { + return paramHelper; + } + + public RelOptCluster getCluster() { + return cluster; + } + + public RelDataType getRowType() { + return rowType; + } + + public Map getFieldTypes() { + return fieldTypes; + } + + /** + * Add a special variable (like 'state' or 'states') and return its dynamic param reference. + * + * @param varName The variable name + * @param type The SQL type for the parameter + * @return RexNode representing the dynamic parameter reference + */ + public RexNode addSpecialVariableRef(String varName, SqlTypeName type) { + int index = paramHelper.addSpecialVariable(varName); + return rexBuilder.makeDynamicParam(rexBuilder.getTypeFactory().createSqlType(type), index); + } + } + + /** + * Build the complete scripted metric aggregation. + * + *

This is the main entry point that creates all four scripts and assembles them into an + * OpenSearch aggregation builder. The default implementation handles the common boilerplate. + * + * @param args The arguments from the aggregate call + * @param aggName The name of the aggregation + * @param cluster The RelOptCluster for creating builders + * @param rowType The row type containing field information + * @param fieldTypes Map of field names to expression types + * @return Pair of aggregation builder and metric parser + */ + default Pair buildAggregation( + List> args, + String aggName, + RelOptCluster cluster, + RelDataType rowType, + Map fieldTypes) { + + RelJsonSerializer serializer = new RelJsonSerializer(cluster); + RexBuilder rexBuilder = cluster.getRexBuilder(); + List fieldList = rowType.getFieldList(); + + // Create parameter helpers for each script phase + ScriptParameterHelper initParamHelper = + new ScriptParameterHelper(fieldList, fieldTypes, rexBuilder); + ScriptParameterHelper mapParamHelper = + new ScriptParameterHelper(fieldList, fieldTypes, rexBuilder); + ScriptParameterHelper combineParamHelper = + new ScriptParameterHelper(fieldList, fieldTypes, rexBuilder); + ScriptParameterHelper reduceParamHelper = + new ScriptParameterHelper(fieldList, fieldTypes, rexBuilder); + + // Create contexts for each phase + ScriptContext initContext = + new ScriptContext(rexBuilder, initParamHelper, cluster, rowType, fieldTypes); + ScriptContext mapContext = + new ScriptContext(rexBuilder, mapParamHelper, cluster, rowType, fieldTypes); + ScriptContext combineContext = + new ScriptContext(rexBuilder, combineParamHelper, cluster, rowType, fieldTypes); + ScriptContext reduceContext = + new ScriptContext(rexBuilder, reduceParamHelper, cluster, rowType, fieldTypes); + + // Extract RexNodes from args + List argRefs = args.stream().map(Pair::getKey).toList(); + + // Build scripts + RexNode initRex = buildInitScript(initContext); + RexNode mapRex = buildMapScript(mapContext, argRefs); + RexNode combineRex = buildCombineScript(combineContext); + RexNode reduceRex = buildReduceScript(reduceContext, argRefs); + + // Create Script objects + Script initScript = createScript(serializer, initRex, initParamHelper); + Script mapScript = createScript(serializer, mapRex, mapParamHelper); + Script combineScript = createScript(serializer, combineRex, combineParamHelper); + Script reduceScript = createScript(serializer, reduceRex, reduceParamHelper); + + // Build scripted metric aggregation + AggregationBuilder aggBuilder = + AggregationBuilders.scriptedMetric(aggName) + .initScript(initScript) + .mapScript(mapScript) + .combineScript(combineScript) + .reduceScript(reduceScript); + + return Pair.of(aggBuilder, new ScriptedMetricParser(aggName)); + } + + /** + * Create a Script object from a RexNode expression. + * + * @param serializer The JSON serializer for RexNode + * @param rexNode The expression to serialize + * @param paramHelper The parameter helper containing script parameters + * @return Script object ready for OpenSearch + */ + private static Script createScript( + RelJsonSerializer serializer, RexNode rexNode, ScriptParameterHelper paramHelper) { + String serializedCode = serializer.serialize(rexNode, paramHelper); + String wrappedCode = + SerializationWrapper.wrapWithLangType( + CompoundedScriptEngine.ScriptEngineType.CALCITE, serializedCode); + return new Script( + Script.DEFAULT_SCRIPT_TYPE, + CompoundedScriptEngine.COMPOUNDED_LANG_NAME, + wrappedCode, + paramHelper.getParameters()); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricUDAFRegistry.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricUDAFRegistry.java new file mode 100644 index 00000000000..db181afd68b --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricUDAFRegistry.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.udaf.PatternScriptedMetricUDAF; + +/** + * Registry for ScriptedMetricUDAF implementations. + * + *

This registry provides a lookup mechanism for finding UDAF implementations that can be pushed + * down to OpenSearch as scripted metric aggregations. Each UDAF implementation is registered by its + * function name. + * + *

To add a new UDAF pushdown: + * + *

    + *
  1. Create a class implementing {@link ScriptedMetricUDAF} + *
  2. Register it in this registry by calling {@link #register(ScriptedMetricUDAF)} + *
+ */ +public final class ScriptedMetricUDAFRegistry { + + /** Singleton instance */ + public static final ScriptedMetricUDAFRegistry INSTANCE = new ScriptedMetricUDAFRegistry(); + + private final Map udafMap; + + private ScriptedMetricUDAFRegistry() { + this.udafMap = new HashMap<>(); + registerBuiltinUDAFs(); + } + + /** Register all built-in scripted metric UDAFs. */ + private void registerBuiltinUDAFs() { + // Register Pattern (BRAIN) UDAF + register(PatternScriptedMetricUDAF.INSTANCE); + } + + /** + * Register a ScriptedMetricUDAF implementation. + * + * @param udaf The UDAF implementation to register + */ + public void register(ScriptedMetricUDAF udaf) { + udafMap.put(udaf.getFunctionName(), udaf); + } + + /** + * Look up a ScriptedMetricUDAF by function name. + * + * @param functionName The function name to look up + * @return Optional containing the UDAF if found, empty otherwise + */ + public Optional lookup(BuiltinFunctionName functionName) { + return Optional.ofNullable(udafMap.get(functionName)); + } + + /** + * Look up a ScriptedMetricUDAF by function name string. + * + * @param functionName The function name string to look up + * @return Optional containing the UDAF if found, empty otherwise + */ + public Optional lookup(String functionName) { + return BuiltinFunctionName.ofAggregation(functionName) + .flatMap(name -> Optional.ofNullable(udafMap.get(name))); + } + + /** + * Check if a function name has a registered ScriptedMetricUDAF. + * + * @param functionName The function name to check + * @return true if a UDAF is registered for this function + */ + public boolean hasUDAF(BuiltinFunctionName functionName) { + return udafMap.containsKey(functionName); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/udaf/PatternScriptedMetricUDAF.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/udaf/PatternScriptedMetricUDAF.java new file mode 100644 index 00000000000..06caca43c7b --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/udaf/PatternScriptedMetricUDAF.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.scriptedmetric.udaf; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.type.SqlTypeName; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.PPLBuiltinOperators; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.ScriptedMetricUDAF; + +/** + * Scripted metric UDAF implementation for the Pattern (BRAIN) aggregation function. + * + *

This implementation handles the pushdown of the pattern detection algorithm to OpenSearch, + * using the BrainLogParser for log pattern analysis. The four script phases are: + * + *

    + *
  • init_script: Initializes state with logMessages buffer and patternGroupMap + *
  • map_script: Adds log messages to accumulator, triggers partial merge when buffer is + * full + *
  • combine_script: Returns shard-level state for the reduce phase + *
  • reduce_script: Combines all shard states and produces final pattern results + *
+ */ +public class PatternScriptedMetricUDAF implements ScriptedMetricUDAF { + + // Default parameter values for pattern UDAF + private static final int DEFAULT_MAX_SAMPLE_COUNT = 10; + private static final int DEFAULT_BUFFER_LIMIT = 100000; + private static final int DEFAULT_VARIABLE_COUNT_THRESHOLD = 5; + private static final BigDecimal DEFAULT_THRESHOLD_PERCENTAGE = BigDecimal.valueOf(0.3); + + /** Singleton instance */ + public static final PatternScriptedMetricUDAF INSTANCE = new PatternScriptedMetricUDAF(); + + private PatternScriptedMetricUDAF() {} + + @Override + public BuiltinFunctionName getFunctionName() { + return BuiltinFunctionName.INTERNAL_PATTERN; + } + + @Override + public RexNode buildInitScript(ScriptContext context) { + RexBuilder rexBuilder = context.getRexBuilder(); + RexNode stateRef = context.addSpecialVariableRef("state", SqlTypeName.ANY); + return rexBuilder.makeCall(PPLBuiltinOperators.PATTERN_INIT_UDF, List.of(stateRef)); + } + + @Override + public RexNode buildMapScript(ScriptContext context, List args) { + RexBuilder rexBuilder = context.getRexBuilder(); + List mapArgs = new ArrayList<>(); + + // Add state variable reference + RexNode stateRef = context.addSpecialVariableRef("state", SqlTypeName.ANY); + mapArgs.add(stateRef); + + // Add field reference (first argument) + if (!args.isEmpty()) { + mapArgs.add(args.get(0)); + } + + // Add parameters with defaults: + // args[1] = maxSampleCount + // args[2] = bufferLimit + // args[3] = showNumberedToken (not used in map script) + // args[4] = thresholdPercentage (optional) + // args[5] = variableCountThreshold (optional) + mapArgs.add(getArgOrDefault(args, 1, makeIntLiteral(rexBuilder, DEFAULT_MAX_SAMPLE_COUNT))); + mapArgs.add(getArgOrDefault(args, 2, makeIntLiteral(rexBuilder, DEFAULT_BUFFER_LIMIT))); + mapArgs.add( + getArgOrDefault(args, 5, makeIntLiteral(rexBuilder, DEFAULT_VARIABLE_COUNT_THRESHOLD))); + mapArgs.add( + getArgOrDefault(args, 4, makeDoubleLiteral(rexBuilder, DEFAULT_THRESHOLD_PERCENTAGE))); + + return rexBuilder.makeCall(PPLBuiltinOperators.PATTERN_ADD_UDF, mapArgs); + } + + @Override + public RexNode buildCombineScript(ScriptContext context) { + // Combine script simply returns the shard-level state + return context.addSpecialVariableRef("state", SqlTypeName.ANY); + } + + @Override + public RexNode buildReduceScript(ScriptContext context, List args) { + RexBuilder rexBuilder = context.getRexBuilder(); + RexNode statesRef = context.addSpecialVariableRef("states", SqlTypeName.ANY); + + List reduceArgs = new ArrayList<>(); + reduceArgs.add(statesRef); + + // maxSampleCount + reduceArgs.add(getArgOrDefault(args, 1, makeIntLiteral(rexBuilder, DEFAULT_MAX_SAMPLE_COUNT))); + + // Determine variableCountThreshold and thresholdPercentage + RexNode variableCountThreshold = makeIntLiteral(rexBuilder, DEFAULT_VARIABLE_COUNT_THRESHOLD); + RexNode thresholdPercentage = makeDoubleLiteral(rexBuilder, DEFAULT_THRESHOLD_PERCENTAGE); + + if (args.size() > 5) { + thresholdPercentage = args.get(4); + variableCountThreshold = args.get(5); + } else if (args.size() > 4) { + RexNode arg4 = args.get(4); + SqlTypeName arg4Type = arg4.getType().getSqlTypeName(); + if (arg4Type == SqlTypeName.DOUBLE + || arg4Type == SqlTypeName.DECIMAL + || arg4Type == SqlTypeName.FLOAT) { + thresholdPercentage = arg4; + } else { + variableCountThreshold = arg4; + } + } + + reduceArgs.add(variableCountThreshold); + reduceArgs.add(thresholdPercentage); + + // showNumberedToken (default false) + reduceArgs.add(getArgOrDefault(args, 3, rexBuilder.makeLiteral(false))); + + return rexBuilder.makeCall(PPLBuiltinOperators.PATTERN_RESULT_UDF, reduceArgs); + } + + /** Get argument from list or return default value. */ + private static RexNode getArgOrDefault(List args, int index, RexNode defaultValue) { + return args.size() > index ? args.get(index) : defaultValue; + } + + /** Create integer literal for pattern UDAF parameters. */ + private static RexNode makeIntLiteral(RexBuilder rexBuilder, int value) { + return rexBuilder.makeLiteral( + value, rexBuilder.getTypeFactory().createSqlType(SqlTypeName.INTEGER), true); + } + + /** Create double literal for pattern UDAF parameters. */ + private static RexNode makeDoubleLiteral(RexBuilder rexBuilder, BigDecimal value) { + return rexBuilder.makeLiteral( + value, rexBuilder.getTypeFactory().createSqlType(SqlTypeName.DOUBLE), true); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/serde/ScriptParameterHelper.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/serde/ScriptParameterHelper.java index 1916ab6c2c3..aa6d2de0f4d 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/serde/ScriptParameterHelper.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/serde/ScriptParameterHelper.java @@ -42,9 +42,12 @@ public class ScriptParameterHelper { * *

0 stands for DOC_VALUE * - *

1 stand for SOURCE + *

1 stands for SOURCE * *

2 stands for LITERAL + * + *

3 stands for SPECIAL_VARIABLE - retrieves value from special context variables (e.g., state, + * states in scripted metric aggregations) */ List sources; @@ -94,4 +97,18 @@ public Map getParameters() { } }; } + + /** + * Adds a special variable reference (like state or states in scripted metric aggregations) and + * returns the index. + * + * @param variableName The name of the special variable (e.g., "state", "states") + * @return The index in the sources/digests lists + */ + public int addSpecialVariable(String variableName) { + int index = sources.size(); + sources.add(3); // SPECIAL_VARIABLE = 3 + digests.add(variableName); + return index; + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java index 2e79da953b1..de66ddd8338 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java @@ -153,7 +153,8 @@ void analyze_aggCall_simple() throws ExpressionNotAnalyzableException { List.of(countCall, avgCall, sumCall, minCall, maxCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(0)); AggregateAnalyzer.AggregateBuilderHelper helper = - new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); + new AggregateAnalyzer.AggregateBuilderHelper( + rowType, fieldTypes, null, true, BUCKET_SIZE, false); Pair, OpenSearchAggregationResponseParser> result = AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); assertEquals( @@ -236,7 +237,8 @@ void analyze_aggCall_extended() throws ExpressionNotAnalyzableException { List.of(varSampCall, varPopCall, stddevSampCall, stddevPopCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(0)); AggregateAnalyzer.AggregateBuilderHelper helper = - new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); + new AggregateAnalyzer.AggregateBuilderHelper( + rowType, fieldTypes, null, true, BUCKET_SIZE, false); Pair, OpenSearchAggregationResponseParser> result = AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); assertEquals( @@ -277,7 +279,8 @@ void analyze_groupBy() throws ExpressionNotAnalyzableException { Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of(0, 1)); Project project = createMockProject(List.of(0, 1)); AggregateAnalyzer.AggregateBuilderHelper helper = - new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); + new AggregateAnalyzer.AggregateBuilderHelper( + rowType, fieldTypes, null, true, BUCKET_SIZE, false); Pair, OpenSearchAggregationResponseParser> result = AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); @@ -318,7 +321,8 @@ void analyze_aggCall_TextWithoutKeyword() { Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(2)); AggregateAnalyzer.AggregateBuilderHelper helper = - new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); + new AggregateAnalyzer.AggregateBuilderHelper( + rowType, fieldTypes, null, true, BUCKET_SIZE, false); ExpressionNotAnalyzableException exception = assertThrows( ExpressionNotAnalyzableException.class, @@ -345,7 +349,8 @@ void analyze_groupBy_TextWithoutKeyword() { Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of(0)); Project project = createMockProject(List.of(2)); AggregateAnalyzer.AggregateBuilderHelper helper = - new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); + new AggregateAnalyzer.AggregateBuilderHelper( + rowType, fieldTypes, null, true, BUCKET_SIZE, false); ExpressionNotAnalyzableException exception = assertThrows( ExpressionNotAnalyzableException.class, @@ -699,7 +704,7 @@ void verify() throws ExpressionNotAnalyzableException { } AggregateAnalyzer.AggregateBuilderHelper helper = new AggregateAnalyzer.AggregateBuilderHelper( - rowType, fieldTypes, agg.getCluster(), true, BUCKET_SIZE); + rowType, fieldTypes, agg.getCluster(), true, BUCKET_SIZE, false); Pair, OpenSearchAggregationResponseParser> result = AggregateAnalyzer.analyze(agg, project, outputFields, helper); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLPatternsTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLPatternsTest.java index c272453b829..11d16dd2914 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLPatternsTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLPatternsTest.java @@ -395,9 +395,12 @@ public void testPatternsAggregationMode_ShowNumberedToken_ForBrainMethod() { RelNode root = getRelNode(ppl); String expectedLogical = - "LogicalProject(patterns_field=[SAFE_CAST(ITEM($1, 'pattern'))]," - + " pattern_count=[SAFE_CAST(ITEM($1, 'pattern_count'))], tokens=[SAFE_CAST(ITEM($1," - + " 'tokens'))], sample_logs=[SAFE_CAST(ITEM($1, 'sample_logs'))])\n" + "LogicalProject(patterns_field=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($1," + + " 'pattern')), ITEM($1, 'sample_logs'), true), 'pattern'))]," + + " pattern_count=[SAFE_CAST(ITEM($1, 'pattern_count'))]," + + " tokens=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($1, 'pattern')), ITEM($1," + + " 'sample_logs'), true), 'tokens'))], sample_logs=[SAFE_CAST(ITEM($1," + + " 'sample_logs'))])\n" + " LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])\n" + " LogicalAggregate(group=[{}], patterns_field=[pattern($0, $1, $2, $3)])\n" + " LogicalProject(ENAME=[$1], $f8=[10], $f9=[100000], $f10=[true])\n" @@ -408,11 +411,13 @@ public void testPatternsAggregationMode_ShowNumberedToken_ForBrainMethod() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT TRY_CAST(`t20`.`patterns_field`['pattern'] AS STRING) `patterns_field`," - + " TRY_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT) `pattern_count`," - + " TRY_CAST(`t20`.`patterns_field`['tokens'] AS MAP< VARCHAR, VARCHAR ARRAY >)" - + " `tokens`, TRY_CAST(`t20`.`patterns_field`['sample_logs'] AS ARRAY< STRING >)" - + " `sample_logs`\n" + "SELECT TRY_CAST(PATTERN_PARSER(TRY_CAST(`t20`.`patterns_field`['pattern'] AS STRING)," + + " `t20`.`patterns_field`['sample_logs'], TRUE)['pattern'] AS STRING)" + + " `patterns_field`, TRY_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT)" + + " `pattern_count`, TRY_CAST(PATTERN_PARSER(TRY_CAST(`t20`.`patterns_field`['pattern']" + + " AS STRING), `t20`.`patterns_field`['sample_logs'], TRUE)['tokens'] AS MAP< VARCHAR," + + " VARCHAR ARRAY >) `tokens`, TRY_CAST(`t20`.`patterns_field`['sample_logs'] AS ARRAY<" + + " STRING >) `sample_logs`\n" + "FROM (SELECT `pattern`(`ENAME`, 10, 100000, TRUE) `patterns_field`\n" + "FROM `scott`.`EMP`) `$cor0`,\n" + "LATERAL UNNEST((SELECT `$cor0`.`patterns_field`\n" @@ -460,9 +465,12 @@ public void testPatternsAggregationModeWithGroupBy_ShowNumberedToken_ForBrainMet RelNode root = getRelNode(ppl); String expectedLogical = - "LogicalProject(DEPTNO=[$0], patterns_field=[SAFE_CAST(ITEM($2, 'pattern'))]," - + " pattern_count=[SAFE_CAST(ITEM($2, 'pattern_count'))], tokens=[SAFE_CAST(ITEM($2," - + " 'tokens'))], sample_logs=[SAFE_CAST(ITEM($2, 'sample_logs'))])\n" + "LogicalProject(DEPTNO=[$0]," + + " patterns_field=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($2, 'pattern'))," + + " ITEM($2, 'sample_logs'), true), 'pattern'))], pattern_count=[SAFE_CAST(ITEM($2," + + " 'pattern_count'))], tokens=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($2," + + " 'pattern')), ITEM($2, 'sample_logs'), true), 'tokens'))]," + + " sample_logs=[SAFE_CAST(ITEM($2, 'sample_logs'))])\n" + " LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{1}])\n" + " LogicalAggregate(group=[{1}], patterns_field=[pattern($0, $2, $3, $4)])\n" + " LogicalProject(ENAME=[$1], DEPTNO=[$7], $f8=[10], $f9=[100000], $f10=[true])\n" @@ -473,11 +481,14 @@ public void testPatternsAggregationModeWithGroupBy_ShowNumberedToken_ForBrainMet verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `$cor0`.`DEPTNO`, TRY_CAST(`t20`.`patterns_field`['pattern'] AS STRING)" + "SELECT `$cor0`.`DEPTNO`," + + " TRY_CAST(PATTERN_PARSER(TRY_CAST(`t20`.`patterns_field`['pattern'] AS STRING)," + + " `t20`.`patterns_field`['sample_logs'], TRUE)['pattern'] AS STRING)" + " `patterns_field`, TRY_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT)" - + " `pattern_count`, TRY_CAST(`t20`.`patterns_field`['tokens'] AS MAP< VARCHAR," - + " VARCHAR ARRAY >) `tokens`, TRY_CAST(`t20`.`patterns_field`['sample_logs'] AS" - + " ARRAY< STRING >) `sample_logs`\n" + + " `pattern_count`, TRY_CAST(PATTERN_PARSER(TRY_CAST(`t20`.`patterns_field`['pattern']" + + " AS STRING), `t20`.`patterns_field`['sample_logs'], TRUE)['tokens'] AS MAP< VARCHAR," + + " VARCHAR ARRAY >) `tokens`, TRY_CAST(`t20`.`patterns_field`['sample_logs'] AS ARRAY<" + + " STRING >) `sample_logs`\n" + "FROM (SELECT `DEPTNO`, `pattern`(`ENAME`, 10, 100000, TRUE) `patterns_field`\n" + "FROM `scott`.`EMP`\n" + "GROUP BY `DEPTNO`) `$cor0`,\n"