Skip to content

Commit

Permalink
Properly process Final Selection in Reduce(). (facebookincubator#2408)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#2408

In Reduce() we ignored the fact that Final Selection might have been false before we set it to false ourselves.
In that case we need to use different SelectivityVector.

Reviewed By: mbasmanova

Differential Revision: D39121620

fbshipit-source-id: 64fa7710d9c35b31cc6e325a54bddebd66d0881f
  • Loading branch information
Sergey Pershin authored and facebook-github-bot committed Sep 7, 2022
1 parent 125e086 commit d2217d1
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 21 deletions.
44 changes: 31 additions & 13 deletions velox/functions/prestosql/Reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,6 @@ class ReduceFunction : public exec::VectorFunction {

auto flatArray = flattenArray(rows, args[0], decodedArray);

// Loop over lambda functions and apply these to elements of the base array.
// In most cases there will be only one function and the loop will run once.
auto inputFuncIt = args[2]->asUnchecked<FunctionVector>()->iterator(&rows);

SelectivityVector arrayRows(flatArray->size(), false);
BufferPtr elementIndices =
allocateIndices(flatArray->size(), context->pool());

const auto& initialState = args[1];
auto partialResult =
BaseVector::create(initialState->type(), rows.end(), context->pool());
Expand All @@ -107,6 +99,15 @@ class ReduceFunction : public exec::VectorFunction {
VarSetter finalSelection(
context->mutableFinalSelection(), &rows, context->isFinalSelection());
VarSetter isFinalSelection(context->mutableIsFinalSelection(), false);
const SelectivityVector& finalSelectionRows = *context->finalSelection();

// Loop over lambda functions and apply these to elements of the base array.
// In most cases there will be only one function and the loop will run once.
auto inputFuncIt = args[2]->asUnchecked<FunctionVector>()->iterator(&rows);

BufferPtr elementIndices =
allocateIndices(flatArray->size(), context->pool());
SelectivityVector arrayRows(flatArray->size(), false);

// Iteratively apply input function to array elements.
// First, apply input function to first elements of all arrays.
Expand All @@ -117,22 +118,34 @@ class ReduceFunction : public exec::VectorFunction {
while (auto entry = inputFuncIt.next()) {
VectorPtr state = initialState;

int n = 0;
vector_size_t n = 0;
while (true) {
// Sets arrayRows[row] to true if array at that row has n-th element, to
// false otherwise.
// Set elementIndices[row] to the index of the n-th element in the
// array's elements vector.
if (!toNthElementRows(
flatArray, *entry.rows, n, arrayRows, elementIndices)) {
break; // Ran out of elements in all arrays.
}

auto nthElement = BaseVector::wrapInDictionary(
// Create dictionary row -> element in array's elements vector.
auto dictNthElements = BaseVector::wrapInDictionary(
BufferPtr(nullptr),
elementIndices,
flatArray->size(),
flatArray->elements());

std::vector<VectorPtr> lambdaArgs = {state, nthElement};
// Run input lambda on our dictionary - adding n-th element to the
// initial state for every row.
std::vector<VectorPtr> lambdaArgs = {state, dictNthElements};
entry.callable->apply(
arrayRows, rows, nullptr, context, lambdaArgs, &partialResult);
arrayRows,
finalSelectionRows,
nullptr,
context,
lambdaArgs,
&partialResult);
state = partialResult;
n++;
}
Expand All @@ -143,7 +156,12 @@ class ReduceFunction : public exec::VectorFunction {
while (auto entry = outputFuncIt.next()) {
std::vector<VectorPtr> lambdaArgs = {partialResult};
entry.callable->apply(
*entry.rows, rows, nullptr, context, lambdaArgs, result);
*entry.rows,
finalSelectionRows,
nullptr,
context,
lambdaArgs,
result);
}
}

Expand Down
8 changes: 8 additions & 0 deletions velox/functions/prestosql/tests/FunctionBaseTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ class FunctionBaseTest : public testing::Test,
return core::Expressions::inferTypes(untyped, rowType, pool());
}

std::unique_ptr<exec::ExprSet> compileExpression(
const std::string& expr,
const RowTypePtr& rowType) {
std::vector<core::TypedExprPtr> expressions = {
parseExpression(expr, rowType)};
return std::make_unique<exec::ExprSet>(std::move(expressions), &execCtx_);
}

std::unique_ptr<exec::ExprSet> compileExpressions(
const std::vector<std::string>& exprs,
const RowTypePtr& rowType) {
Expand Down
8 changes: 0 additions & 8 deletions velox/functions/prestosql/tests/MapFilterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@ using namespace facebook::velox::test;

class MapFilterTest : public functions::test::FunctionBaseTest {
protected:
std::unique_ptr<exec::ExprSet> compileExpression(
const std::string& expr,
const RowTypePtr& rowType) {
std::vector<std::shared_ptr<const core::ITypedExpr>> expressions = {
parseExpression(expr, rowType)};
return std::make_unique<exec::ExprSet>(std::move(expressions), &execCtx_);
}

template <typename K, typename V>
void checkMapFilter(
BaseVector* inputMap,
Expand Down
45 changes: 45 additions & 0 deletions velox/functions/prestosql/tests/ReduceTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,48 @@ TEST_F(ReduceTest, conditional) {
nullEvery(11));
assertEqualVectors(expectedResult, result);
}

TEST_F(ReduceTest, finalSelection) {
vector_size_t size = 1'000;
auto inputArray = makeArrayVector<int64_t>(
size,
modN(5),
[](auto row, auto index) { return row + index; },
nullEvery(11));
auto input = makeRowVector({
inputArray,
makeFlatVector<int64_t>(
size, [](auto row) { return row; }, nullEvery(11)),
});
registerLambda(
"sum_input",
rowType("s", BIGINT(), "x", BIGINT()),
input->type(),
"s + x");
registerLambda(
"row_output",
rowType("s", BIGINT()),
input->type(),
"row_constructor(s)");

auto result = evaluate<RowVector>(
"if (c1 < 100, row_constructor(c1), "
"reduce(c0, 10, function('sum_input'), function('row_output')))",
input);

auto expectedResult = makeRowVector({makeFlatVector<int64_t>(
size,
[](auto row) -> int64_t {
if (row < 100) {
return row;
} else {
int64_t sum = 10;
for (auto i = 0; i < row % 5; i++) {
sum += row + i;
}
return sum;
}
},
nullEvery(11))});
assertEqualVectors(expectedResult, result);
}

0 comments on commit d2217d1

Please sign in to comment.