Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Expr::evalWithMemo #48

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions velox/exec/tests/ExprTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,15 @@ class ExprTest : public testing::Test {
untyped, rowType ? rowType : testDataType_, execCtx_->pool());
}

std::unique_ptr<exec::ExprSet> compileExpression(
const std::string& expr,
const RowTypePtr& rowType) {
std::vector<std::shared_ptr<const core::ITypedExpr>> expressions = {
parseExpression(expr, rowType)};
return std::make_unique<exec::ExprSet>(
std::move(expressions), execCtx_.get());
}

std::vector<VectorPtr> evaluateMultiple(
const std::vector<std::string>& texts,
const RowVectorPtr& input) {
Expand All @@ -215,6 +224,15 @@ class ExprTest : public testing::Test {
return evaluateMultiple({text}, input)[0];
}

VectorPtr evaluate(exec::ExprSet* exprSet, const RowVectorPtr& input) {
exec::EvalCtx context(execCtx_.get(), exprSet, input.get());

SelectivityVector rows(input->size());
std::vector<VectorPtr> result(1);
exprSet->eval(rows, &context, &result);
return result[0];
}

template <typename T>
void fillVectorAndReference(
const std::vector<EncodingOptions>& options,
Expand Down Expand Up @@ -618,6 +636,16 @@ class ExprTest : public testing::Test {
return vectorMaker_->arrayVector(size, sizeAt, valueAt, isNullAt);
}

template <typename T>
ArrayVectorPtr makeArrayVector(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty sure I have seen an equivalent function in ExprTest.cpp..

vector_size_t size,
std::function<vector_size_t(vector_size_t /* row */)> sizeAt,
std::function<T(vector_size_t /* idx */, vector_size_t /*index */)>
valueAt,
std::function<bool(vector_size_t /*row */)> isNullAt = nullptr) {
return vectorMaker_->arrayVector(size, sizeAt, valueAt, isNullAt);
}

// Create LazyVector that produces a flat vector and asserts that is is being
// loaded for a specific set of rows.
template <typename T>
Expand Down Expand Up @@ -2145,3 +2173,35 @@ TEST_F(ExprTest, rewriteInputs) {
ASSERT_EQ(*expectedExpr, *expr);
}
}

TEST_F(ExprTest, memo) {
auto base = makeArrayVector<int64_t>(
1'000,
[](auto row) { return row % 5 + 1; },
[](auto row, auto index) { return (row % 3) + index; });

auto evenIndices = makeIndices(100, [](auto row) { return 8 + row * 2; });
auto oddIndices = makeIndices(100, [](auto row) { return 9 + row * 2; });

auto rowType = ROW({"c0"}, {base->type()});
auto exprSet = compileExpression("c0[1]", rowType);

auto result = evaluate(
exprSet.get(), makeRowVector({wrapInDictionary(evenIndices, 100, base)}));
auto expectedResult =
makeFlatVector<int64_t>(100, [](auto row) { return (8 + row * 2) % 3; });
assertEqualVectors(expectedResult, result);

result = evaluate(
exprSet.get(), makeRowVector({wrapInDictionary(oddIndices, 100, base)}));
expectedResult =
makeFlatVector<int64_t>(100, [](auto row) { return (9 + row * 2) % 3; });
assertEqualVectors(expectedResult, result);

auto everyFifth = makeIndices(100, [](auto row) { return row * 5; });
result = evaluate(
exprSet.get(), makeRowVector({wrapInDictionary(everyFifth, 100, base)}));
expectedResult =
makeFlatVector<int64_t>(100, [](auto row) { return (row * 5) % 3; });
assertEqualVectors(expectedResult, result);
}
31 changes: 24 additions & 7 deletions velox/expression/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@

#include "velox/expression/Expr.h"
#include "velox/core/Expressions.h"
#include "velox/expression/CastExpr.h"
#include "velox/expression/ControlExpr.h"
#include "velox/expression/ExprCompiler.h"
#include "velox/expression/VarSetter.h"
#include "velox/expression/VectorFunction.h"
#include "velox/expression/VectorFunctionRegistry.h"

namespace facebook::velox::exec {

Expand Down Expand Up @@ -211,12 +209,10 @@ bool Expr::checkGetSharedSubexprValues(
// losing values outside of missingRows.
bool updateFinalSelection = context->isFinalSelection() &&
(missingRows->countSelected() < rows.countSelected());
bool newIsFinalSelection =
updateFinalSelection ? false : context->isFinalSelection();
VarSetter finalSelectionOr(
context->mutableFinalSelection(), &rows, updateFinalSelection);
VarSetter isFinalSelectionOr(
context->mutableIsFinalSelection(), newIsFinalSelection);
context->mutableIsFinalSelection(), false, updateFinalSelection);

evalEncodings(*missingRows, context, &sharedSubexprValues_);
}
Expand Down Expand Up @@ -635,18 +631,39 @@ void Expr::evalWithMemo(
uncached->deselect(*cachedDictionaryIndices_);
}
if (uncached->hasSelections()) {
// Fix finalSelection at "rows" if uncached rows is a strict subset to
// avoid losing values not in uncached rows.
bool updateFinalSelection = context->isFinalSelection() &&
(uncached->countSelected() < rows.countSelected());
VarSetter finalSelectionMemo(
context->mutableFinalSelection(), &rows, updateFinalSelection);
VarSetter isFinalSelectionMemo(
context->mutableIsFinalSelection(), false, updateFinalSelection);

evalAll(*uncached, context, result);
deselectErrors(context, *uncached);
context->exprSet()->addToMemo(this);
auto newCacheSize = uncached->end();

// dictionaryCache_ is valid only for cachedDictionaryIndices_. Hence, a
// safe call to BaseVector::ensureWritable must include all the rows not
// covered by cachedDictionaryIndices_. If BaseVector::ensureWritable is
// called only for a subset of rows not covered by
// cachedDictionaryIndices_, it will attempt to copy rows that are not
// valid leading to a crash.
LocalSelectivityVector allUncached(context, dictionaryCache_->size());
allUncached.get()->setAll();
allUncached.get()->deselect(*cachedDictionaryIndices_);
BaseVector::ensureWritable(
*allUncached.get(), type(), context->pool(), &dictionaryCache_);

if (cachedDictionaryIndices_->size() < newCacheSize) {
int32_t oldSize = cachedDictionaryIndices_->size();
cachedDictionaryIndices_->resize(newCacheSize);
cachedDictionaryIndices_->setValidRange(oldSize, newCacheSize, false);
}

cachedDictionaryIndices_->select(*uncached);
BaseVector::ensureWritable(
*uncached, type(), context->pool(), &dictionaryCache_);
dictionaryCache_->copy(result->get(), *uncached, nullptr);
}
return;
Expand Down