From ec4e71465b8cc33dde0b3d33870b80249cab7c66 Mon Sep 17 00:00:00 2001 From: Xiang Fu Date: Fri, 7 Jun 2024 13:32:29 -0700 Subject: [PATCH] Fix LEAD/LAG window function implementation --- .../pinot/query/QueryEnvironmentTestBase.java | 4 + .../operator/utils/AggregationUtils.java | 4 +- .../window/value/LagValueWindowFunction.java | 56 +++++++++--- .../window/value/LeadValueWindowFunction.java | 50 +++++++++-- .../operator/WindowAggregateOperatorTest.java | 90 +++++++++++++++++++ 5 files changed, 182 insertions(+), 22 deletions(-) diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java index 6b3b8a363170..8e33ec1ef3b9 100644 --- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java @@ -191,6 +191,10 @@ protected Object[][] provideQueries() { new Object[]{"SELECT RANK() OVER(PARTITION BY a.col2 ORDER BY a.col1) FROM a"}, new Object[]{"SELECT a.col1, LEAD(a.col3) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"}, new Object[]{"SELECT a.col1, LAG(a.col3) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"}, + new Object[]{"SELECT a.col1, LEAD(a.col3, 5) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"}, + new Object[]{"SELECT a.col1, LAG(a.col3, 5) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"}, + new Object[]{"SELECT a.col1, LEAD(a.col3, 5, -1) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"}, + new Object[]{"SELECT a.col1, LAG(a.col3, 5, -1) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"}, new Object[]{"SELECT DENSE_RANK() OVER(ORDER BY a.col1) FROM a"}, new Object[]{"SELECT a.col1, SUM(a.col3) OVER (ORDER BY a.col2), MIN(a.col3) OVER (ORDER BY a.col2) FROM a"}, new Object[]{ diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java index 0133843be0a5..ed24af5a3c1f 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java @@ -18,7 +18,6 @@ */ package org.apache.pinot.query.runtime.operator.utils; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import java.util.HashMap; import java.util.List; @@ -223,8 +222,7 @@ public Accumulator(RexExpression.FunctionCall aggCall, DataSchema inputSchema) { private RexExpression toAggregationFunctionOperand(RexExpression.FunctionCall aggCall) { List functionOperands = aggCall.getFunctionOperands(); int numOperands = functionOperands.size(); - Preconditions.checkState(numOperands < 2, "Aggregate functions cannot have more than one operand"); - return numOperands == 1 ? functionOperands.get(0) : new RexExpression.Literal(ColumnDataType.INT, 1); + return numOperands == 0 ? new RexExpression.Literal(ColumnDataType.INT, 1) : functionOperands.get(0); } } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java index 63c49bac4403..797fca9313bb 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java @@ -18,32 +18,66 @@ */ package org.apache.pinot.query.runtime.operator.window.value; -import java.util.ArrayList; +import com.google.common.base.Preconditions; +import java.util.Arrays; import java.util.List; import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.common.utils.PinotDataType; import org.apache.pinot.query.planner.logical.RexExpression; public class LagValueWindowFunction extends ValueWindowFunction { + private final int _offset; + private final Object _defaultValue; public LagValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, List collations, boolean partitionByOnly) { super(aggCall, inputSchema, collations, partitionByOnly); + int offset = 1; + Object defaultValue = null; + List operands = aggCall.getFunctionOperands(); + int numOperands = operands.size(); + if (numOperands > 1) { + RexExpression secondOperand = operands.get(1); + Preconditions.checkArgument(secondOperand instanceof RexExpression.Literal, + "Second operand (offset) of LAG function must be a literal"); + Object offsetValue = ((RexExpression.Literal) secondOperand).getValue(); + if (offsetValue instanceof Number) { + offset = ((Number) offsetValue).intValue(); + } + } + if (numOperands == 3) { + RexExpression thirdOperand = operands.get(2); + Preconditions.checkArgument(thirdOperand instanceof RexExpression.Literal, + "Third operand (default value) of LAG function must be a literal"); + RexExpression.Literal defaultValueLiteral = (RexExpression.Literal) thirdOperand; + defaultValue = defaultValueLiteral.getValue(); + if (defaultValue != null) { + DataSchema.ColumnDataType srcDataType = defaultValueLiteral.getDataType(); + DataSchema.ColumnDataType destDataType = inputSchema.getColumnDataType(0); + if (srcDataType != destDataType) { + // Convert the default value to the same data type as the input column + // (e.g. convert INT to LONG, FLOAT to DOUBLE, etc. + defaultValue = PinotDataType.getPinotDataTypeForExecution(destDataType) + .convert(defaultValue, PinotDataType.getPinotDataTypeForExecution(srcDataType)); + } + } + } + _offset = offset; + _defaultValue = defaultValue; } @Override public List processRows(List rows) { - List result = new ArrayList<>(rows.size()); - Object[] prevRow = null; - for (Object[] row : rows) { - if (prevRow == null) { - result.add(null); - } else { - result.add(extractValueFromRow(prevRow)); - } - prevRow = row; + int numRows = rows.size(); + Object[] result = new Object[numRows]; + if (_defaultValue != null) { + Arrays.fill(result, 0, _offset, _defaultValue); + } + for (int i = _offset; i < numRows; i++) { + result[i] = extractValueFromRow(rows.get(i - _offset)); } - return result; + return Arrays.asList(result); } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java index 530675844cdd..099c3fba5f9e 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java @@ -18,32 +18,66 @@ */ package org.apache.pinot.query.runtime.operator.window.value; +import com.google.common.base.Preconditions; import java.util.Arrays; import java.util.List; import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.common.utils.PinotDataType; import org.apache.pinot.query.planner.logical.RexExpression; public class LeadValueWindowFunction extends ValueWindowFunction { + private final int _offset; + private final Object _defaultValue; + public LeadValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, List collations, boolean partitionByOnly) { super(aggCall, inputSchema, collations, partitionByOnly); + int offset = 1; + Object defaultValue = null; + List operands = aggCall.getFunctionOperands(); + int numOperands = operands.size(); + if (numOperands > 1) { + RexExpression secondOperand = operands.get(1); + Preconditions.checkArgument(secondOperand instanceof RexExpression.Literal, + "Second operand (offset) of LAG function must be a literal"); + Object offsetValue = ((RexExpression.Literal) secondOperand).getValue(); + if (offsetValue instanceof Number) { + offset = ((Number) offsetValue).intValue(); + } + } + if (numOperands == 3) { + RexExpression thirdOperand = operands.get(2); + Preconditions.checkArgument(thirdOperand instanceof RexExpression.Literal, + "Third operand (default value) of LAG function must be a literal"); + RexExpression.Literal defaultValueLiteral = (RexExpression.Literal) thirdOperand; + defaultValue = defaultValueLiteral.getValue(); + if (defaultValue != null) { + DataSchema.ColumnDataType srcDataType = defaultValueLiteral.getDataType(); + DataSchema.ColumnDataType destDataType = inputSchema.getColumnDataType(0); + if (srcDataType != destDataType) { + // Convert the default value to the same data type as the input column + // (e.g. convert INT to LONG, FLOAT to DOUBLE, etc. + defaultValue = PinotDataType.getPinotDataTypeForExecution(destDataType) + .convert(defaultValue, PinotDataType.getPinotDataTypeForExecution(srcDataType)); + } + } + } + _offset = offset; + _defaultValue = defaultValue; } @Override public List processRows(List rows) { int numRows = rows.size(); Object[] result = new Object[numRows]; - Object[] nextRow = null; - for (int i = numRows - 1; i >= 0; i--) { - if (nextRow == null) { - result[i] = null; - } else { - result[i] = extractValueFromRow(nextRow); - } - nextRow = rows.get(i); + for (int i = 0; i < numRows - _offset; i++) { + result[i] = extractValueFromRow(rows.get(i + _offset)); + } + if (_defaultValue != null) { + Arrays.fill(result, numRows - _offset, numRows, _defaultValue); } return Arrays.asList(result); } diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java index f351c94f2505..cd3ca26b6cfd 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java @@ -561,6 +561,96 @@ public void testShouldHandleWindowWithPartialResultsWhenHitDataRowsLimit() { "Max rows in window should be reached"); } + @Test + public void testLeadLagWindowFunction() { + // Given: + DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, STRING}); + when(_input.nextBlock()).thenReturn( + OperatorTestUtil.block(inputSchema, new Object[]{3, "and"}, new Object[]{2, "bar"}, new Object[]{2, "foo"}, + new Object[]{1, "foo"})).thenReturn( + OperatorTestUtil.block(inputSchema, new Object[]{1, "foo"}, new Object[]{2, "foo"}, new Object[]{1, "numb"}, + new Object[]{2, "the"}, new Object[]{3, "true"})) + .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); + DataSchema resultSchema = new DataSchema(new String[]{"group", "arg", "lead", "lag"}, + new ColumnDataType[]{INT, STRING, INT, INT}); + List keys = List.of(0); + List collations = + List.of(new RelFieldCollation(1, RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST)); + List aggCalls = + List.of(new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LEAD.name(), + List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 1))), + new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LAG.name(), + List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 1)))); + WindowAggregateOperator operator = + getOperator(inputSchema, resultSchema, keys, collations, aggCalls, WindowNode.WindowFrameType.RANGE, + Integer.MIN_VALUE, 0); + + // When: + List resultRows = operator.nextBlock().getContainer(); + // Then: + verifyResultRows(resultRows, keys, Map.of( + 1, List.of( + new Object[]{1, "foo", 1, null}, + new Object[]{1, "foo", 1, 1}, + new Object[]{1, "numb", null, 1}), + 2, List.of( + new Object[]{2, "bar", 2, null}, + new Object[]{2, "foo", 2, 2}, + new Object[]{2, "foo", 2, 2}, + new Object[]{2, "the", null, 2}), + 3, List.of( + new Object[]{3, "and", 3, null}, + new Object[]{3, "true", null, 3}) + )); + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testLeadLagWindowFunction2() { + // Given: + DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, STRING}); + when(_input.nextBlock()).thenReturn( + OperatorTestUtil.block(inputSchema, new Object[]{3, "and"}, new Object[]{2, "bar"}, new Object[]{2, "foo"}, + new Object[]{1, "foo"})).thenReturn( + OperatorTestUtil.block(inputSchema, new Object[]{1, "foo"}, new Object[]{2, "foo"}, new Object[]{1, "numb"}, + new Object[]{2, "the"}, new Object[]{3, "true"})) + .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); + DataSchema resultSchema = new DataSchema(new String[]{"group", "arg", "lead", "lag"}, + new ColumnDataType[]{INT, STRING, INT, INT}); + List keys = List.of(0); + List collations = + List.of(new RelFieldCollation(1, RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST)); + List aggCalls = + List.of(new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LEAD.name(), + List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 2), + new RexExpression.Literal(ColumnDataType.INT, 100))), + new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LAG.name(), + List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 1), + new RexExpression.Literal(ColumnDataType.INT, 200)))); + WindowAggregateOperator operator = + getOperator(inputSchema, resultSchema, keys, collations, aggCalls, WindowNode.WindowFrameType.RANGE, + Integer.MIN_VALUE, 0); + + // When: + List resultRows = operator.nextBlock().getContainer(); + // Then: + verifyResultRows(resultRows, keys, Map.of( + 1, List.of( + new Object[]{1, "foo", 1, 200}, + new Object[]{1, "foo", 100, 1}, + new Object[]{1, "numb", 100, 1}), + 2, List.of( + new Object[]{2, "bar", 2, 200}, + new Object[]{2, "foo", 2, 2}, + new Object[]{2, "foo", 100, 2}, + new Object[]{2, "the", 100, 2}), + 3, List.of( + new Object[]{3, "and", 100, 200}, + new Object[]{3, "true", 100, 3}) + )); + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + private WindowAggregateOperator getOperator(DataSchema inputSchema, DataSchema resultSchema, List keys, List collations, List aggCalls, WindowNode.WindowFrameType windowFrameType, int lowerBound, int upperBound, PlanNode.NodeHint nodeHint) {