diff --git a/ydb/core/formats/arrow/arrow_filter.cpp b/ydb/core/formats/arrow/arrow_filter.cpp index b667d4006fd4..162bbb37b540 100644 --- a/ydb/core/formats/arrow/arrow_filter.cpp +++ b/ydb/core/formats/arrow/arrow_filter.cpp @@ -611,4 +611,12 @@ std::optional TColumnFilter::GetFilteredCount() const { return *FilteredCount; } +void TColumnFilter::Append(const TColumnFilter& filter) { + bool currentVal = filter.GetStartValue(); + for (auto&& i : filter.Filter) { + Add(currentVal, i); + currentVal = !currentVal; + } +} + } diff --git a/ydb/core/formats/arrow/arrow_filter.h b/ydb/core/formats/arrow/arrow_filter.h index b33b9a13707e..80e449ef05c1 100644 --- a/ydb/core/formats/arrow/arrow_filter.h +++ b/ydb/core/formats/arrow/arrow_filter.h @@ -52,6 +52,7 @@ class TColumnFilter { FilteredCount.reset(); } public: + void Append(const TColumnFilter& filter); void Add(const bool value, const ui32 count = 1); std::optional GetFilteredCount() const; const std::vector& BuildSimpleFilter() const; diff --git a/ydb/core/formats/arrow/arrow_helpers.cpp b/ydb/core/formats/arrow/arrow_helpers.cpp index 1a2b90167313..a49cf23e686e 100644 --- a/ydb/core/formats/arrow/arrow_helpers.cpp +++ b/ydb/core/formats/arrow/arrow_helpers.cpp @@ -977,4 +977,35 @@ std::shared_ptr MergeColumns(const std::vector> SliceToRecordBatches(const std::shared_ptr& t) { + std::set splitPositions; + const ui32 numRows = t->num_rows(); + for (auto&& i : t->columns()) { + ui32 pos = 0; + for (auto&& arr : i->chunks()) { + splitPositions.emplace(pos); + pos += arr->length(); + } + AFL_VERIFY(pos == t->num_rows()); + } + std::vector>> slicedData; + slicedData.resize(splitPositions.size()); + std::vector positions(splitPositions.begin(), splitPositions.end()); + for (auto&& i : t->columns()) { + for (ui32 idx = 0; idx < positions.size(); ++idx) { + auto slice = i->Slice(positions[idx], ((idx + 1 == positions.size()) ? numRows : positions[idx + 1]) - positions[idx]); + AFL_VERIFY(slice->num_chunks() == 1); + slicedData[idx].emplace_back(slice->chunks().front()); + } + } + std::vector> result; + ui32 count = 0; + for (auto&& i : slicedData) { + result.emplace_back(arrow::RecordBatch::Make(t->schema(), i.front()->length(), i)); + count += result.back()->num_rows(); + } + AFL_VERIFY(count == t->num_rows())("count", count)("t", t->num_rows()); + return result; +} + } diff --git a/ydb/core/formats/arrow/arrow_helpers.h b/ydb/core/formats/arrow/arrow_helpers.h index 84aba485fa7e..efa064ee71a3 100644 --- a/ydb/core/formats/arrow/arrow_helpers.h +++ b/ydb/core/formats/arrow/arrow_helpers.h @@ -128,6 +128,8 @@ inline bool HasNulls(const std::shared_ptr& column) { return column->null_bitmap_data(); } +std::vector> SliceToRecordBatches(const std::shared_ptr& t); + bool ArrayScalarsEqual(const std::shared_ptr& lhs, const std::shared_ptr& rhs); std::shared_ptr BoolVecToArray(const std::vector& vec); diff --git a/ydb/core/formats/arrow/program.cpp b/ydb/core/formats/arrow/program.cpp index a05ea71da082..621b48fff1a3 100644 --- a/ydb/core/formats/arrow/program.cpp +++ b/ydb/core/formats/arrow/program.cpp @@ -871,12 +871,18 @@ std::shared_ptr TProgramStep::BuildFilter(const std::shar if (Filters.empty()) { return nullptr; } - auto datumBatch = TDatumBatch::FromTable(t); - - NArrow::TStatusValidator::Validate(ApplyAssignes(*datumBatch, NArrow::GetCustomExecContext())); - NArrow::TColumnFilter local = NArrow::TColumnFilter::BuildAllowFilter(); - NArrow::TStatusValidator::Validate(MakeCombinedFilter(*datumBatch, local)); - return std::make_shared(std::move(local)); + std::vector> batches = NArrow::SliceToRecordBatches(t); + NArrow::TColumnFilter fullLocal = NArrow::TColumnFilter::BuildAllowFilter(); + for (auto&& rb : batches) { + auto datumBatch = TDatumBatch::FromRecordBatch(rb); + NArrow::TStatusValidator::Validate(ApplyAssignes(*datumBatch, NArrow::GetCustomExecContext())); + NArrow::TColumnFilter local = NArrow::TColumnFilter::BuildAllowFilter(); + NArrow::TStatusValidator::Validate(MakeCombinedFilter(*datumBatch, local)); + AFL_VERIFY(local.Size() == datumBatch->Rows)("local", local.Size())("datum", datumBatch->Rows); + fullLocal.Append(local); + } + AFL_VERIFY(fullLocal.Size() == t->num_rows())("filter", fullLocal.Size())("t", t->num_rows()); + return std::make_shared(std::move(fullLocal)); } const std::set& TProgramStep::GetFilterOriginalColumnIds() const {