diff --git a/velox/core/QueryCtx.h b/velox/core/QueryCtx.h index f1bc10d5578f..5b8199b4766b 100644 --- a/velox/core/QueryCtx.h +++ b/velox/core/QueryCtx.h @@ -98,7 +98,8 @@ class QueryCtx : public Context { } uint64_t maxPartialAggregationMemoryUsage() const { - return 1L << 24; // 16MB + return get( + kMaxPartialAggregationMemory, kMaxPartialAggregationMemoryDefault); } uint64_t maxPartitionedOutputBufferSize() const { @@ -195,6 +196,9 @@ class QueryCtx : public Context { static constexpr const char* kMaxLocalExchangeBufferSize = "max_local_exchange_buffer_size"; + static constexpr const char* kMaxPartialAggregationMemory = + "max_partial_aggregation_memory"; + // Overrides the previous configuration. Note that this function is NOT // thread-safe and should probably only be used in tests. void setConfigOverridesUnsafe( @@ -222,6 +226,9 @@ class QueryCtx : public Context { static constexpr uint64_t kMaxLocalExchangeBufferSizeDefault = 32UL << 20; + // 16MB + static constexpr uint64_t kMaxPartialAggregationMemoryDefault = 1L << 24; + CancelPoolPtr cancelPool_; std::unique_ptr pool_; memory::MappedMemory* mappedMemory_; diff --git a/velox/exec/HashAggregation.cpp b/velox/exec/HashAggregation.cpp index 3d08e544da14..7a76a3a23bcf 100644 --- a/velox/exec/HashAggregation.cpp +++ b/velox/exec/HashAggregation.cpp @@ -408,6 +408,11 @@ RowVectorPtr HashAggregation::getOutput() { // Drop reference to input_ to make it singly-referenced at the producer and // allow for memory reuse. input_ = nullptr; + + if (partialFull_) { + groupingSet_->resetPartial(); + partialFull_ = false; + } return output; } diff --git a/velox/exec/tests/AggregationTest.cpp b/velox/exec/tests/AggregationTest.cpp index 01c9cfe756bc..de8c4dfd1a9f 100644 --- a/velox/exec/tests/AggregationTest.cpp +++ b/velox/exec/tests/AggregationTest.cpp @@ -391,5 +391,45 @@ TEST_F(AggregationTest, allKeyTypes) { " GROUP BY c0, C1, C2, c3, C4, C5"); } +TEST_F(AggregationTest, partialAggregationMemoryLimit) { + auto vectors = { + makeRowVector({makeFlatVector( + 100, [](auto row) { return row; }, nullEvery(5))}), + makeRowVector({makeFlatVector( + 110, [](auto row) { return row + 29; }, nullEvery(7))}), + makeRowVector({makeFlatVector( + 90, [](auto row) { return row - 71; }, nullEvery(7))}), + }; + + createDuckDbTable(vectors); + + // Set an artificially low limit on the amount of data to accumulate in + // the partial aggregation. + CursorParameters params; + params.queryCtx = core::QueryCtx::create(); + + params.queryCtx->setConfigOverridesUnsafe({ + {core::QueryCtx::kMaxPartialAggregationMemory, "100"}, + }); + + // Distinct aggregation. + params.planNode = PlanBuilder() + .values(vectors) + .partialAggregation({0}, {}) + .finalAggregation({0}, {}) + .planNode(); + + assertQuery(params, "SELECT distinct c0 FROM tmp"); + + // Count aggregation. + params.planNode = PlanBuilder() + .values(vectors) + .partialAggregation({0}, {"count(1)"}) + .finalAggregation({0}, {"sum(a0)"}) + .planNode(); + + assertQuery(params, "SELECT c0, count(1) FROM tmp GROUP BY 1"); +} + } // namespace } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/OperatorTestBase.h b/velox/exec/tests/OperatorTestBase.h index 059b8c2fb9ac..d1fdd45db722 100644 --- a/velox/exec/tests/OperatorTestBase.h +++ b/velox/exec/tests/OperatorTestBase.h @@ -48,6 +48,13 @@ class OperatorTestBase : public testing::Test { return assertQuery(plan, splits, duckDbSql, sortingKeys); } + std::shared_ptr assertQuery( + const CursorParameters& params, + const std::string& duckDbSql) { + return test::assertQuery( + params, [&](exec::Task* /*task*/) {}, duckDbSql, duckDbQueryRunner_); + } + std::shared_ptr assertQuery( const std::shared_ptr& plan, const std::string& duckDbSql) { diff --git a/velox/exec/tests/QueryAssertions.cpp b/velox/exec/tests/QueryAssertions.cpp index bb92156e6139..6cc2705c76f5 100644 --- a/velox/exec/tests/QueryAssertions.cpp +++ b/velox/exec/tests/QueryAssertions.cpp @@ -524,7 +524,7 @@ std::shared_ptr assertQuery( } std::shared_ptr assertQuery( - CursorParameters& params, + const CursorParameters& params, std::function addSplits, const std::string& duckDbSql, DuckDbQueryRunner& duckDbQueryRunner, diff --git a/velox/exec/tests/QueryAssertions.h b/velox/exec/tests/QueryAssertions.h index 331eb91f5032..42e8d17ef8f9 100644 --- a/velox/exec/tests/QueryAssertions.h +++ b/velox/exec/tests/QueryAssertions.h @@ -84,7 +84,7 @@ std::shared_ptr assertQuery( std::optional> sortingKeys = std::nullopt); std::shared_ptr assertQuery( - CursorParameters& params, + const CursorParameters& params, std::function addSplits, const std::string& duckDbSql, DuckDbQueryRunner& duckDbQueryRunner,