Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more arguments for LEAD/LAG window functions #13340

Merged
merged 1 commit into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ protected Object[][] provideQueries() {
new Object[]{"SELECT RANK() OVER(PARTITION BY a.col2 ORDER BY a.col1) FROM a"},
new Object[]{"SELECT a.col1, LEAD(a.col3) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT a.col1, LAG(a.col3) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT a.col1, LEAD(a.col3, 5) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT a.col1, LAG(a.col3, 5) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT a.col1, LEAD(a.col3, 5, -1) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT a.col1, LAG(a.col3, 5, -1) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT DENSE_RANK() OVER(ORDER BY a.col1) FROM a"},
new Object[]{"SELECT a.col1, SUM(a.col3) OVER (ORDER BY a.col2), MIN(a.col3) OVER (ORDER BY a.col2) FROM a"},
new Object[]{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
package org.apache.pinot.query.runtime.operator.utils;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -223,8 +222,7 @@ public Accumulator(RexExpression.FunctionCall aggCall, DataSchema inputSchema) {
private RexExpression toAggregationFunctionOperand(RexExpression.FunctionCall aggCall) {
List<RexExpression> functionOperands = aggCall.getFunctionOperands();
int numOperands = functionOperands.size();
Preconditions.checkState(numOperands < 2, "Aggregate functions cannot have more than one operand");
return numOperands == 1 ? functionOperands.get(0) : new RexExpression.Literal(ColumnDataType.INT, 1);
return numOperands == 0 ? new RexExpression.Literal(ColumnDataType.INT, 1) : functionOperands.get(0);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,66 @@
*/
package org.apache.pinot.query.runtime.operator.window.value;

import java.util.ArrayList;
import com.google.common.base.Preconditions;
import java.util.Arrays;
import java.util.List;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.PinotDataType;
import org.apache.pinot.query.planner.logical.RexExpression;


public class LagValueWindowFunction extends ValueWindowFunction {
private final int _offset;
private final Object _defaultValue;

public LagValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema,
List<RelFieldCollation> collations, boolean partitionByOnly) {
super(aggCall, inputSchema, collations, partitionByOnly);
int offset = 1;
Object defaultValue = null;
List<RexExpression> operands = aggCall.getFunctionOperands();
int numOperands = operands.size();
if (numOperands > 1) {
RexExpression secondOperand = operands.get(1);
Preconditions.checkArgument(secondOperand instanceof RexExpression.Literal,
"Second operand (offset) of LAG function must be a literal");
Object offsetValue = ((RexExpression.Literal) secondOperand).getValue();
if (offsetValue instanceof Number) {
offset = ((Number) offsetValue).intValue();
}
}
if (numOperands == 3) {
RexExpression thirdOperand = operands.get(2);
Preconditions.checkArgument(thirdOperand instanceof RexExpression.Literal,
"Third operand (default value) of LAG function must be a literal");
RexExpression.Literal defaultValueLiteral = (RexExpression.Literal) thirdOperand;
defaultValue = defaultValueLiteral.getValue();
if (defaultValue != null) {
DataSchema.ColumnDataType srcDataType = defaultValueLiteral.getDataType();
DataSchema.ColumnDataType destDataType = inputSchema.getColumnDataType(0);
if (srcDataType != destDataType) {
// Convert the default value to the same data type as the input column
// (e.g. convert INT to LONG, FLOAT to DOUBLE, etc.
defaultValue = PinotDataType.getPinotDataTypeForExecution(destDataType)
.convert(defaultValue, PinotDataType.getPinotDataTypeForExecution(srcDataType));
}
}
}
_offset = offset;
_defaultValue = defaultValue;
}

@Override
public List<Object> processRows(List<Object[]> rows) {
List<Object> result = new ArrayList<>(rows.size());
Object[] prevRow = null;
for (Object[] row : rows) {
if (prevRow == null) {
result.add(null);
} else {
result.add(extractValueFromRow(prevRow));
}
prevRow = row;
int numRows = rows.size();
Object[] result = new Object[numRows];
if (_defaultValue != null) {
Arrays.fill(result, 0, _offset, _defaultValue);
}
for (int i = _offset; i < numRows; i++) {
result[i] = extractValueFromRow(rows.get(i - _offset));
}
return result;
return Arrays.asList(result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,66 @@
*/
package org.apache.pinot.query.runtime.operator.window.value;

import com.google.common.base.Preconditions;
import java.util.Arrays;
import java.util.List;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.PinotDataType;
import org.apache.pinot.query.planner.logical.RexExpression;


public class LeadValueWindowFunction extends ValueWindowFunction {

private final int _offset;
private final Object _defaultValue;

public LeadValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema,
List<RelFieldCollation> collations, boolean partitionByOnly) {
super(aggCall, inputSchema, collations, partitionByOnly);
int offset = 1;
Object defaultValue = null;
List<RexExpression> operands = aggCall.getFunctionOperands();
int numOperands = operands.size();
if (numOperands > 1) {
RexExpression secondOperand = operands.get(1);
Preconditions.checkArgument(secondOperand instanceof RexExpression.Literal,
"Second operand (offset) of LAG function must be a literal");
Object offsetValue = ((RexExpression.Literal) secondOperand).getValue();
if (offsetValue instanceof Number) {
offset = ((Number) offsetValue).intValue();
}
}
if (numOperands == 3) {
RexExpression thirdOperand = operands.get(2);
Preconditions.checkArgument(thirdOperand instanceof RexExpression.Literal,
"Third operand (default value) of LAG function must be a literal");
RexExpression.Literal defaultValueLiteral = (RexExpression.Literal) thirdOperand;
defaultValue = defaultValueLiteral.getValue();
if (defaultValue != null) {
DataSchema.ColumnDataType srcDataType = defaultValueLiteral.getDataType();
DataSchema.ColumnDataType destDataType = inputSchema.getColumnDataType(0);
if (srcDataType != destDataType) {
// Convert the default value to the same data type as the input column
// (e.g. convert INT to LONG, FLOAT to DOUBLE, etc.
defaultValue = PinotDataType.getPinotDataTypeForExecution(destDataType)
.convert(defaultValue, PinotDataType.getPinotDataTypeForExecution(srcDataType));
}
}
}
_offset = offset;
_defaultValue = defaultValue;
}

@Override
public List<Object> processRows(List<Object[]> rows) {
int numRows = rows.size();
Object[] result = new Object[numRows];
Object[] nextRow = null;
for (int i = numRows - 1; i >= 0; i--) {
if (nextRow == null) {
result[i] = null;
} else {
result[i] = extractValueFromRow(nextRow);
}
nextRow = rows.get(i);
for (int i = 0; i < numRows - _offset; i++) {
result[i] = extractValueFromRow(rows.get(i + _offset));
}
if (_defaultValue != null) {
Arrays.fill(result, numRows - _offset, numRows, _defaultValue);
}
return Arrays.asList(result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,96 @@ public void testShouldHandleWindowWithPartialResultsWhenHitDataRowsLimit() {
"Max rows in window should be reached");
}

@Test
public void testLeadLagWindowFunction() {
// Given:
DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, STRING});
when(_input.nextBlock()).thenReturn(
OperatorTestUtil.block(inputSchema, new Object[]{3, "and"}, new Object[]{2, "bar"}, new Object[]{2, "foo"},
new Object[]{1, "foo"})).thenReturn(
OperatorTestUtil.block(inputSchema, new Object[]{1, "foo"}, new Object[]{2, "foo"}, new Object[]{1, "numb"},
new Object[]{2, "the"}, new Object[]{3, "true"}))
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
DataSchema resultSchema = new DataSchema(new String[]{"group", "arg", "lead", "lag"},
new ColumnDataType[]{INT, STRING, INT, INT});
List<Integer> keys = List.of(0);
List<RelFieldCollation> collations =
List.of(new RelFieldCollation(1, RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST));
List<RexExpression.FunctionCall> aggCalls =
List.of(new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LEAD.name(),
List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 1))),
new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LAG.name(),
List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 1))));
WindowAggregateOperator operator =
getOperator(inputSchema, resultSchema, keys, collations, aggCalls, WindowNode.WindowFrameType.RANGE,
Integer.MIN_VALUE, 0);

// When:
List<Object[]> resultRows = operator.nextBlock().getContainer();
// Then:
verifyResultRows(resultRows, keys, Map.of(
1, List.of(
new Object[]{1, "foo", 1, null},
new Object[]{1, "foo", 1, 1},
new Object[]{1, "numb", null, 1}),
2, List.of(
new Object[]{2, "bar", 2, null},
new Object[]{2, "foo", 2, 2},
new Object[]{2, "foo", 2, 2},
new Object[]{2, "the", null, 2}),
3, List.of(
new Object[]{3, "and", 3, null},
new Object[]{3, "true", null, 3})
));
assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)");
}

@Test
public void testLeadLagWindowFunction2() {
// Given:
DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, STRING});
when(_input.nextBlock()).thenReturn(
OperatorTestUtil.block(inputSchema, new Object[]{3, "and"}, new Object[]{2, "bar"}, new Object[]{2, "foo"},
new Object[]{1, "foo"})).thenReturn(
OperatorTestUtil.block(inputSchema, new Object[]{1, "foo"}, new Object[]{2, "foo"}, new Object[]{1, "numb"},
new Object[]{2, "the"}, new Object[]{3, "true"}))
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
DataSchema resultSchema = new DataSchema(new String[]{"group", "arg", "lead", "lag"},
new ColumnDataType[]{INT, STRING, INT, INT});
List<Integer> keys = List.of(0);
List<RelFieldCollation> collations =
List.of(new RelFieldCollation(1, RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST));
List<RexExpression.FunctionCall> aggCalls =
List.of(new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LEAD.name(),
List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 2),
new RexExpression.Literal(ColumnDataType.INT, 100))),
new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LAG.name(),
List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 1),
new RexExpression.Literal(ColumnDataType.INT, 200))));
WindowAggregateOperator operator =
getOperator(inputSchema, resultSchema, keys, collations, aggCalls, WindowNode.WindowFrameType.RANGE,
Integer.MIN_VALUE, 0);

// When:
List<Object[]> resultRows = operator.nextBlock().getContainer();
// Then:
verifyResultRows(resultRows, keys, Map.of(
1, List.of(
new Object[]{1, "foo", 1, 200},
new Object[]{1, "foo", 100, 1},
new Object[]{1, "numb", 100, 1}),
2, List.of(
new Object[]{2, "bar", 2, 200},
new Object[]{2, "foo", 2, 2},
new Object[]{2, "foo", 100, 2},
new Object[]{2, "the", 100, 2}),
3, List.of(
new Object[]{3, "and", 100, 200},
new Object[]{3, "true", 100, 3})
));
assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)");
}

private WindowAggregateOperator getOperator(DataSchema inputSchema, DataSchema resultSchema, List<Integer> keys,
List<RelFieldCollation> collations, List<RexExpression.FunctionCall> aggCalls,
WindowNode.WindowFrameType windowFrameType, int lowerBound, int upperBound, PlanNode.NodeHint nodeHint) {
Expand Down
Loading