From 42f1e43a122d46ab40129cdac9f09233448bce11 Mon Sep 17 00:00:00 2001 From: Xiang Fu Date: Fri, 7 Jun 2024 13:32:29 -0700 Subject: [PATCH] Fix LEAD/LAG window function implementation --- .../pinot/query/QueryEnvironmentTestBase.java | 4 +++ .../window/value/LagValueWindowFunction.java | 34 +++++++++++++------ .../window/value/LeadValueWindowFunction.java | 27 ++++++++++----- 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java index 6b3b8a363170..8e33ec1ef3b9 100644 --- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java @@ -191,6 +191,10 @@ protected Object[][] provideQueries() { new Object[]{"SELECT RANK() OVER(PARTITION BY a.col2 ORDER BY a.col1) FROM a"}, new Object[]{"SELECT a.col1, LEAD(a.col3) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"}, new Object[]{"SELECT a.col1, LAG(a.col3) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"}, + new Object[]{"SELECT a.col1, LEAD(a.col3, 5) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"}, + new Object[]{"SELECT a.col1, LAG(a.col3, 5) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"}, + new Object[]{"SELECT a.col1, LEAD(a.col3, 5, -1) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"}, + new Object[]{"SELECT a.col1, LAG(a.col3, 5, -1) OVER (PARTITION BY a.col2 ORDER BY a.col3) FROM a"}, new Object[]{"SELECT DENSE_RANK() OVER(ORDER BY a.col1) FROM a"}, new Object[]{"SELECT a.col1, SUM(a.col3) OVER (ORDER BY a.col2), MIN(a.col3) OVER (ORDER BY a.col2) FROM a"}, new Object[]{ diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java index 63c49bac4403..b776aec07b05 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java @@ -18,7 +18,7 @@ */ package org.apache.pinot.query.runtime.operator.window.value; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.common.utils.DataSchema; @@ -27,23 +27,35 @@ public class LagValueWindowFunction extends ValueWindowFunction { + private final int _offset; + private final Object _defaultLeadValue; + public LagValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, List collations, boolean partitionByOnly) { super(aggCall, inputSchema, collations, partitionByOnly); + if (aggCall.getFunctionOperands().size() > 1) { + _offset = Integer.parseInt(aggCall.getFunctionOperands().get(1).toString()); + } else { + _offset = 1; + } + if (aggCall.getFunctionOperands().size() == 3) { + _defaultLeadValue = inputSchema.getColumnDataType(0) + .convert(((RexExpression.Literal) aggCall.getFunctionOperands().get(2)).getValue()); + } else { + _defaultLeadValue = null; + } } @Override public List processRows(List rows) { - List result = new ArrayList<>(rows.size()); - Object[] prevRow = null; - for (Object[] row : rows) { - if (prevRow == null) { - result.add(null); - } else { - result.add(extractValueFromRow(prevRow)); - } - prevRow = row; + int numRows = rows.size(); + Object[] result = new Object[numRows]; + if (_defaultLeadValue != null) { + Arrays.fill(result, 0, _offset, _defaultLeadValue); + } + for (int i = _offset; i < numRows; i++) { + result[i] = extractValueFromRow(rows.get(i - _offset)); } - return result; + return Arrays.asList(result); } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java index 530675844cdd..dfa2bce076e4 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java @@ -27,23 +27,34 @@ public class LeadValueWindowFunction extends ValueWindowFunction { + private final int _offset; + private final Object _defaultLeadValue; + public LeadValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, List collations, boolean partitionByOnly) { super(aggCall, inputSchema, collations, partitionByOnly); + if (aggCall.getFunctionOperands().size() > 1) { + _offset = Integer.parseInt(aggCall.getFunctionOperands().get(1).toString()); + } else { + _offset = 1; + } + if (aggCall.getFunctionOperands().size() == 3) { + _defaultLeadValue = inputSchema.getColumnDataType(0) + .convert(((RexExpression.Literal) aggCall.getFunctionOperands().get(2)).getValue()); + } else { + _defaultLeadValue = null; + } } @Override public List processRows(List rows) { int numRows = rows.size(); Object[] result = new Object[numRows]; - Object[] nextRow = null; - for (int i = numRows - 1; i >= 0; i--) { - if (nextRow == null) { - result[i] = null; - } else { - result[i] = extractValueFromRow(nextRow); - } - nextRow = rows.get(i); + for (int i = 0; i < numRows - _offset; i++) { + result[i] = extractValueFromRow(rows.get(i + _offset)); + } + if (_defaultLeadValue != null) { + Arrays.fill(result, numRows - _offset, numRows, _defaultLeadValue); } return Arrays.asList(result); }