diff --git a/velox/exec/tests/ExprTest.cpp b/velox/exec/tests/ExprTest.cpp index 8b0ed27ac718..83bd5acd71da 100644 --- a/velox/exec/tests/ExprTest.cpp +++ b/velox/exec/tests/ExprTest.cpp @@ -191,6 +191,15 @@ class ExprTest : public testing::Test { untyped, rowType ? rowType : testDataType_, execCtx_->pool()); } + std::unique_ptr compileExpression( + const std::string& expr, + const RowTypePtr& rowType) { + std::vector> expressions = { + parseExpression(expr, rowType)}; + return std::make_unique( + std::move(expressions), execCtx_.get()); + } + std::vector evaluateMultiple( const std::vector& texts, const RowVectorPtr& input) { @@ -215,6 +224,15 @@ class ExprTest : public testing::Test { return evaluateMultiple({text}, input)[0]; } + VectorPtr evaluate(exec::ExprSet* exprSet, const RowVectorPtr& input) { + exec::EvalCtx context(execCtx_.get(), exprSet, input.get()); + + SelectivityVector rows(input->size()); + std::vector result(1); + exprSet->eval(rows, &context, &result); + return result[0]; + } + template void fillVectorAndReference( const std::vector& options, @@ -618,6 +636,16 @@ class ExprTest : public testing::Test { return vectorMaker_->arrayVector(size, sizeAt, valueAt, isNullAt); } + template + ArrayVectorPtr makeArrayVector( + vector_size_t size, + std::function sizeAt, + std::function + valueAt, + std::function isNullAt = nullptr) { + return vectorMaker_->arrayVector(size, sizeAt, valueAt, isNullAt); + } + // Create LazyVector that produces a flat vector and asserts that is is being // loaded for a specific set of rows. template @@ -2145,3 +2173,35 @@ TEST_F(ExprTest, rewriteInputs) { ASSERT_EQ(*expectedExpr, *expr); } } + +TEST_F(ExprTest, memo) { + auto base = makeArrayVector( + 1'000, + [](auto row) { return row % 5 + 1; }, + [](auto row, auto index) { return (row % 3) + index; }); + + auto evenIndices = makeIndices(100, [](auto row) { return 8 + row * 2; }); + auto oddIndices = makeIndices(100, [](auto row) { return 9 + row * 2; }); + + auto rowType = ROW({"c0"}, {base->type()}); + auto exprSet = compileExpression("c0[1]", rowType); + + auto result = evaluate( + exprSet.get(), makeRowVector({wrapInDictionary(evenIndices, 100, base)})); + auto expectedResult = + makeFlatVector(100, [](auto row) { return (8 + row * 2) % 3; }); + assertEqualVectors(expectedResult, result); + + result = evaluate( + exprSet.get(), makeRowVector({wrapInDictionary(oddIndices, 100, base)})); + expectedResult = + makeFlatVector(100, [](auto row) { return (9 + row * 2) % 3; }); + assertEqualVectors(expectedResult, result); + + auto everyFifth = makeIndices(100, [](auto row) { return row * 5; }); + result = evaluate( + exprSet.get(), makeRowVector({wrapInDictionary(everyFifth, 100, base)})); + expectedResult = + makeFlatVector(100, [](auto row) { return (row * 5) % 3; }); + assertEqualVectors(expectedResult, result); +} diff --git a/velox/expression/Expr.cpp b/velox/expression/Expr.cpp index 0add3b80aed0..ebba3cc10859 100644 --- a/velox/expression/Expr.cpp +++ b/velox/expression/Expr.cpp @@ -14,12 +14,10 @@ #include "velox/expression/Expr.h" #include "velox/core/Expressions.h" -#include "velox/expression/CastExpr.h" #include "velox/expression/ControlExpr.h" #include "velox/expression/ExprCompiler.h" #include "velox/expression/VarSetter.h" #include "velox/expression/VectorFunction.h" -#include "velox/expression/VectorFunctionRegistry.h" namespace facebook::velox::exec { @@ -211,12 +209,10 @@ bool Expr::checkGetSharedSubexprValues( // losing values outside of missingRows. bool updateFinalSelection = context->isFinalSelection() && (missingRows->countSelected() < rows.countSelected()); - bool newIsFinalSelection = - updateFinalSelection ? false : context->isFinalSelection(); VarSetter finalSelectionOr( context->mutableFinalSelection(), &rows, updateFinalSelection); VarSetter isFinalSelectionOr( - context->mutableIsFinalSelection(), newIsFinalSelection); + context->mutableIsFinalSelection(), false, updateFinalSelection); evalEncodings(*missingRows, context, &sharedSubexprValues_); } @@ -635,18 +631,39 @@ void Expr::evalWithMemo( uncached->deselect(*cachedDictionaryIndices_); } if (uncached->hasSelections()) { + // Fix finalSelection at "rows" if uncached rows is a strict subset to + // avoid losing values not in uncached rows. + bool updateFinalSelection = context->isFinalSelection() && + (uncached->countSelected() < rows.countSelected()); + VarSetter finalSelectionMemo( + context->mutableFinalSelection(), &rows, updateFinalSelection); + VarSetter isFinalSelectionMemo( + context->mutableIsFinalSelection(), false, updateFinalSelection); + evalAll(*uncached, context, result); deselectErrors(context, *uncached); context->exprSet()->addToMemo(this); auto newCacheSize = uncached->end(); + + // dictionaryCache_ is valid only for cachedDictionaryIndices_. Hence, a + // safe call to BaseVector::ensureWritable must include all the rows not + // covered by cachedDictionaryIndices_. If BaseVector::ensureWritable is + // called only for a subset of rows not covered by + // cachedDictionaryIndices_, it will attempt to copy rows that are not + // valid leading to a crash. + LocalSelectivityVector allUncached(context, dictionaryCache_->size()); + allUncached.get()->setAll(); + allUncached.get()->deselect(*cachedDictionaryIndices_); + BaseVector::ensureWritable( + *allUncached.get(), type(), context->pool(), &dictionaryCache_); + if (cachedDictionaryIndices_->size() < newCacheSize) { int32_t oldSize = cachedDictionaryIndices_->size(); cachedDictionaryIndices_->resize(newCacheSize); cachedDictionaryIndices_->setValidRange(oldSize, newCacheSize, false); } + cachedDictionaryIndices_->select(*uncached); - BaseVector::ensureWritable( - *uncached, type(), context->pool(), &dictionaryCache_); dictionaryCache_->copy(result->get(), *uncached, nullptr); } return;