-
Notifications
You must be signed in to change notification settings - Fork 808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve parquet reading performance for columns with nulls by preserving bitmask when possible (#1037) #1054
Conversation
parquet/src/column/reader.rs
Outdated
>; | ||
|
||
#[doc(hidden)] | ||
pub struct GenericColumnReader<R, D, V> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would help me to document somewhere what R
, D
and V
are intended for (to make reading this code easier)
Pushed a PR that fixes a bug in the handling of |
} else if self.packed_count != self.packed_offset { | ||
let to_read = (self.packed_count - self.packed_offset).min(len - read); | ||
let offset = self.data_offset * 8 + self.packed_offset; | ||
buffer.append_packed_range(offset..offset + to_read, self.data.as_ref()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like this is the main change in this PR? how often does this case happen for def levels in practice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how often does this case happen for def levels in practice
This depends on what you mean by "this" 😅
The major change in this PR is not decoding definition levels for columns without nested nullability - i.e. max_def_level == 1, and just decoding directly to the null bitmask. This is very common, with almost all parquet data I've come across being flat.
My personal experience with projects trying to use nested data in parquet is eventually it becomes too much of a pain due to the patchy ecosystem support, and the schema ends up just getting flattened
Previously the code would allocate i16 buffers, populate them with the decoded data, and then deduce a null bitmask from these i16 buffers. This code will now decode directly to the null bitmask in the event of max_def_level == 1, avoiding allocations along with the costs associated with decode and bitmask reconstruction.
As an added bonus, it happens that by decoding directly we can exploit the inherent properties of the hybrid encoding to improve performance - with the packed representation already being a bitmask, and the RLE representation allowing operations on runs of bits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, I should have been more explicit; what I meant is how common is it in practice to have max_def_level == 1 plus bit-packing of the def levels, because this is where the biggest optimization is, isn't it. RLE-encoded def level reading would still be better than before (as no intermediate translation into integers) and that's great, but probably not as fast as directly copying the bit-packed values. I do agree on flat parquet files being common though, most parquet files I have seen have been flat as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on second thought, reading of run-length-encoded def levels could be just as fast if append_packed
could be used for it as well (except that the buffer to copy from would be a static buffer of all 1s of some fixed length)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plus bit-packing of the def levels
The logic within RleEncoder
uses run-length encoding if the repetition count is greater than 8, otherwise it uses the bit-packed version. Therefore how common bit-packed sequences are depends on the distribution of nulls within the data.
TBC what is called RLE encoding by parquet is actually hybrid encoding, a page isn't entirely bit-packed or run-length encoded, but contains blocks of either
but probably not as fast as directly copying the bit-packed values
I'm not sure I agree with this, copying the bit-packed values is actually potentially more expensive, as it requires shifting and masking the source data. By contrast, inserting a run of nulls is simply a case of incrementing the length of the buffer (as everything is 0-initialized), whereas setting sequences of valid bits can be done at the byte level (or possibly larger).
while read != len { | ||
if self.rle_left != 0 { | ||
let to_read = self.rle_left.min(len - read); | ||
buffer.append_n(to_read, self.rle_value); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if append_n
could be made faster using append_packed
2dc5b80
to
c92e79b
Compare
Codecov Report
@@ Coverage Diff @@
## master #1054 +/- ##
==========================================
+ Coverage 82.53% 82.58% +0.04%
==========================================
Files 173 173
Lines 50615 50876 +261
==========================================
+ Hits 41774 42014 +240
- Misses 8841 8862 +21
Continue to review full report at Codecov.
|
c92e79b
to
94d66ad
Compare
Looking into test failures |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is great work. Thank you @tustvold
@yordan-pavlov would you like to review again prior to merge?
/// [`Self::consume_def_levels`] and [`Self::consume_rep_levels`] will always return `None` | ||
/// | ||
pub(crate) fn new_with_options(desc: ColumnDescPtr, null_mask_only: bool) -> Self { | ||
let def_levels = (desc.max_def_level() > 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the use of null_mask_only
here -- I thought null_mask_only
would be set only if max_def_level() ==
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment clarifying, its an edge case of nested nullability. Perhaps I should add an explicit test 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test added in 59846eb
type Slice = [i16]; | ||
impl DefinitionLevelBuffer { | ||
pub fn new(desc: &ColumnDescPtr, null_mask_only: bool) -> Self { | ||
let inner = match null_mask_only { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder why null_mask_only
is passed down all the way here only to be rechecked / assert!
ed.
Would it be possible / feasible to decide here in DefinitionLevelBuilder::new
to use BufferInner::Mask
if max_def_level()
is 1 and max_rep_levels()
is 0 and thus avoid passing plumbing the argument around?
let decoder = match self.data.take() { | ||
Some(data) => self | ||
.packed_decoder | ||
.insert(PackedDecoder::new(self.encoding, data)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TIL: Option::insert
👍
} | ||
} | ||
|
||
struct PackedDecoder { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code looks quite similar to BitReader
https://github.com/tustvold/arrow-rs/blob/bitmask-preservation/parquet/src/util/bit_util.rs#L501
I wonder if you looked at possibly reusing that implmentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The short answer is not using that implementation is the major reason this PR represents a non-trivial speed bump, it can decode more optimally as it can decode directly using append_packed_range / append_n. Will add some comments clarifying
|
||
use rand::{thread_rng, Rng, RngCore}; | ||
|
||
#[test] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know there is now significant coverage of this code using the fuzz tests -- #1156 and friends.
Do you think that is sufficient coverage for PackedDecoder
? Or would some more targeted unit tests be valueble too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be possible to write a simple test that compares the output with that of BitReader
👍 Will do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test added in b001f11
packed_offset: usize, | ||
} | ||
|
||
impl PackedDecoder { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the details of the parquet format sufficiently to truly evaluate the correctness of this code; Perhaps some additional test coverage would help, but the fuzz testing may be good enough.
} | ||
} | ||
|
||
struct PackedDecoder { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct PackedDecoder { | |
/// Specialized decoder for bitpacked hybrid format (TODO link) that contains | |
/// only 0 and 1 (for example, definition levels in a non-nested column) | |
/// that directly decodes into a bitmask in the fastest possible way | |
struct PackedDecoder { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am trying to leave breadcrumbs for the next person to look at this code. Is this a correct description of what this structure implements?
Unless anyone wants additional time to review, I'll plan to merge this tomorrow |
@@ -228,6 +232,20 @@ impl ColumnLevelDecoder for DefinitionLevelDecoder { | |||
} | |||
} | |||
|
|||
/// An optimized decoder for decoding [RLE] and [BIT_PACKED] data with a bit width of 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
Thanks @tustvold -- this is pretty epic |
assert_eq!(range.start + writer.len, nulls.len()); | ||
|
||
let decoder = match self.data.take() { | ||
Some(data) => self.column_decoder.insert( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like the intention is that self.data will only be used once (to create a ColumnLevelDecoderImpl) and if that's the case, why not move the entire match statement in the constructor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the type of writer
determines the type of decoder. If BufferInner::Full
it constructs ColumnLevelDecoderImpl
, otherwise it constructs PackedDecoder
. I guess we could just construct both, but this way you'd get a panic if you change writer type...
Which issue does this PR close?
Highly experimental, builds on #1021 #1039 #1052 #1041Closes #1037
Rationale for this change
See ticket.
This leads to anything from a 2-6x performance improvement when decoding columns containing nulls. As is to be expected the biggest savings are where the other decode overheads are less - with the 6x return on "Int32Array, plain encoded, optional, half NULLs - old "
There is some funkiness with the benchmarks and the memory allocator on my local machine, with it "faster" to preallocate a single 64 byte array first before trying to read data.
What changes are included in this PR?
This changes RecordReader to use a new
DefinitionLevelBuffer
that has a correspondingDefinitionLevelDecoder
that can read directly from parquet. Skipping intermediate buffering, and avoiding decoding parquet bitmasks where not necessaryAre there any user-facing changes?
No