Skip to content

Commit

Permalink
Fix LEAD/LAG window function implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangfu0 committed Jun 7, 2024
1 parent e5decf3 commit 42f1e43
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ protected Object[][] provideQueries() {
new Object[]{"SELECT RANK() OVER(PARTITION BY a.col2 ORDER BY a.col1) FROM a"},
new Object[]{"SELECT a.col1, LEAD(a.col3) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT a.col1, LAG(a.col3) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT a.col1, LEAD(a.col3, 5) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT a.col1, LAG(a.col3, 5) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT a.col1, LEAD(a.col3, 5, -1) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT a.col1, LAG(a.col3, 5, -1) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"},
new Object[]{"SELECT DENSE_RANK() OVER(ORDER BY a.col1) FROM a"},
new Object[]{"SELECT a.col1, SUM(a.col3) OVER (ORDER BY a.col2), MIN(a.col3) OVER (ORDER BY a.col2) FROM a"},
new Object[]{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/
package org.apache.pinot.query.runtime.operator.window.value;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.pinot.common.utils.DataSchema;
Expand All @@ -27,23 +27,35 @@

public class LagValueWindowFunction extends ValueWindowFunction {

private final int _offset;
private final Object _defaultLeadValue;

public LagValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema,
List<RelFieldCollation> collations, boolean partitionByOnly) {
super(aggCall, inputSchema, collations, partitionByOnly);
if (aggCall.getFunctionOperands().size() > 1) {
_offset = Integer.parseInt(aggCall.getFunctionOperands().get(1).toString());
} else {
_offset = 1;
}
if (aggCall.getFunctionOperands().size() == 3) {
_defaultLeadValue = inputSchema.getColumnDataType(0)
.convert(((RexExpression.Literal) aggCall.getFunctionOperands().get(2)).getValue());
} else {
_defaultLeadValue = null;
}
}

@Override
public List<Object> processRows(List<Object[]> rows) {
List<Object> result = new ArrayList<>(rows.size());
Object[] prevRow = null;
for (Object[] row : rows) {
if (prevRow == null) {
result.add(null);
} else {
result.add(extractValueFromRow(prevRow));
}
prevRow = row;
int numRows = rows.size();
Object[] result = new Object[numRows];
if (_defaultLeadValue != null) {
Arrays.fill(result, 0, _offset, _defaultLeadValue);
}
for (int i = _offset; i < numRows; i++) {
result[i] = extractValueFromRow(rows.get(i - _offset));
}
return result;
return Arrays.asList(result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,34 @@

public class LeadValueWindowFunction extends ValueWindowFunction {

private final int _offset;
private final Object _defaultLeadValue;

public LeadValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema,
List<RelFieldCollation> collations, boolean partitionByOnly) {
super(aggCall, inputSchema, collations, partitionByOnly);
if (aggCall.getFunctionOperands().size() > 1) {
_offset = Integer.parseInt(aggCall.getFunctionOperands().get(1).toString());
} else {
_offset = 1;
}
if (aggCall.getFunctionOperands().size() == 3) {
_defaultLeadValue = inputSchema.getColumnDataType(0)
.convert(((RexExpression.Literal) aggCall.getFunctionOperands().get(2)).getValue());
} else {
_defaultLeadValue = null;
}
}

@Override
public List<Object> processRows(List<Object[]> rows) {
int numRows = rows.size();
Object[] result = new Object[numRows];
Object[] nextRow = null;
for (int i = numRows - 1; i >= 0; i--) {
if (nextRow == null) {
result[i] = null;
} else {
result[i] = extractValueFromRow(nextRow);
}
nextRow = rows.get(i);
for (int i = 0; i < numRows - _offset; i++) {
result[i] = extractValueFromRow(rows.get(i + _offset));
}
if (_defaultLeadValue != null) {
Arrays.fill(result, numRows - _offset, numRows, _defaultLeadValue);
}
return Arrays.asList(result);
}
Expand Down

0 comments on commit 42f1e43

Please sign in to comment.