-
Notifications
You must be signed in to change notification settings - Fork 875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separate ArrayReader::next_batch with read_records and consume_batch #2237
Changes from all commits
16ad70f
76b7d40
bc07bdb
631f55a
1e800bb
429e241
fa42ba0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,9 +39,13 @@ where | |
pages: Box<dyn PageIterator>, | ||
def_levels_buffer: Option<Vec<i16>>, | ||
rep_levels_buffer: Option<Vec<i16>>, | ||
data_buffer: Vec<T::T>, | ||
column_desc: ColumnDescPtr, | ||
column_reader: Option<ColumnReaderImpl<T>>, | ||
converter: C, | ||
in_progress_def_levels_buffer: Option<Vec<i16>>, | ||
in_progress_rep_levels_buffer: Option<Vec<i16>>, | ||
before_consume: bool, | ||
_parquet_type_marker: PhantomData<T>, | ||
_converter_marker: PhantomData<C>, | ||
} | ||
|
@@ -59,7 +63,10 @@ where | |
&self.data_type | ||
} | ||
|
||
fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> { | ||
fn read_records(&mut self, batch_size: usize) -> Result<usize> { | ||
if !self.before_consume { | ||
self.before_consume = true; | ||
} | ||
// Try to initialize column reader | ||
if self.column_reader.is_none() { | ||
self.next_column_reader()?; | ||
|
@@ -126,7 +133,6 @@ where | |
break; | ||
} | ||
} | ||
|
||
data_buffer.truncate(num_read); | ||
def_levels_buffer | ||
.iter_mut() | ||
|
@@ -135,13 +141,35 @@ where | |
.iter_mut() | ||
.for_each(|buf| buf.truncate(num_read)); | ||
|
||
self.def_levels_buffer = def_levels_buffer; | ||
self.rep_levels_buffer = rep_levels_buffer; | ||
if let Some(mut def_levels_buffer) = def_levels_buffer { | ||
match &mut self.in_progress_def_levels_buffer { | ||
None => { | ||
self.in_progress_def_levels_buffer = Some(def_levels_buffer); | ||
} | ||
Some(buf) => buf.append(&mut def_levels_buffer), | ||
} | ||
} | ||
|
||
if let Some(mut rep_levels_buffer) = rep_levels_buffer { | ||
match &mut self.in_progress_rep_levels_buffer { | ||
None => { | ||
self.in_progress_rep_levels_buffer = Some(rep_levels_buffer); | ||
} | ||
Some(buf) => buf.append(&mut rep_levels_buffer), | ||
} | ||
} | ||
|
||
self.data_buffer.append(&mut data_buffer); | ||
|
||
Ok(num_read) | ||
} | ||
|
||
let data: Vec<Option<T::T>> = if self.def_levels_buffer.is_some() { | ||
fn consume_batch(&mut self) -> Result<ArrayRef> { | ||
let data: Vec<Option<T::T>> = if self.in_progress_def_levels_buffer.is_some() { | ||
let data_buffer = std::mem::take(&mut self.data_buffer); | ||
data_buffer | ||
.into_iter() | ||
.zip(self.def_levels_buffer.as_ref().unwrap().iter()) | ||
.zip(self.in_progress_def_levels_buffer.as_ref().unwrap().iter()) | ||
.map(|(t, def_level)| { | ||
if *def_level == self.column_desc.max_def_level() { | ||
Some(t) | ||
|
@@ -151,7 +179,7 @@ where | |
}) | ||
.collect() | ||
} else { | ||
data_buffer.into_iter().map(Some).collect() | ||
self.data_buffer.iter().map(|x| Some(x.clone())).collect() | ||
}; | ||
|
||
let mut array = self.converter.convert(data)?; | ||
|
@@ -160,6 +188,11 @@ where | |
array = arrow::compute::cast(&array, &self.data_type)?; | ||
} | ||
|
||
self.data_buffer = vec![]; | ||
self.def_levels_buffer = std::mem::take(&mut self.in_progress_def_levels_buffer); | ||
self.rep_levels_buffer = std::mem::take(&mut self.in_progress_rep_levels_buffer); | ||
self.before_consume = false; | ||
|
||
Ok(array) | ||
} | ||
|
||
|
@@ -168,20 +201,31 @@ where | |
Some(reader) => reader.skip_records(num_records), | ||
None => { | ||
if self.next_column_reader()? { | ||
self.column_reader.as_mut().unwrap().skip_records(num_records) | ||
}else { | ||
self.column_reader | ||
.as_mut() | ||
.unwrap() | ||
.skip_records(num_records) | ||
} else { | ||
Ok(0) | ||
} | ||
} | ||
} | ||
} | ||
|
||
fn get_def_levels(&self) -> Option<&[i16]> { | ||
self.def_levels_buffer.as_deref() | ||
if self.before_consume { | ||
self.in_progress_def_levels_buffer.as_deref() | ||
} else { | ||
self.def_levels_buffer.as_deref() | ||
} | ||
} | ||
|
||
fn get_rep_levels(&self) -> Option<&[i16]> { | ||
self.rep_levels_buffer.as_deref() | ||
if self.before_consume { | ||
self.in_progress_rep_levels_buffer.as_deref() | ||
} else { | ||
self.rep_levels_buffer.as_deref() | ||
} | ||
} | ||
} | ||
|
||
|
@@ -208,9 +252,13 @@ where | |
pages, | ||
def_levels_buffer: None, | ||
rep_levels_buffer: None, | ||
data_buffer: vec![], | ||
column_desc, | ||
column_reader: None, | ||
converter, | ||
in_progress_def_levels_buffer: None, | ||
in_progress_rep_levels_buffer: None, | ||
before_consume: true, | ||
_parquet_type_marker: PhantomData, | ||
_converter_marker: PhantomData, | ||
}) | ||
|
@@ -349,30 +397,32 @@ mod tests { | |
|
||
let mut accu_len: usize = 0; | ||
|
||
let array = array_reader.next_batch(values_per_page / 2).unwrap(); | ||
assert_eq!(array.len(), values_per_page / 2); | ||
let len = array_reader.read_records(values_per_page / 2).unwrap(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now, after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again I think removing batch_size from consume_batch allows preserving the existing behaviour |
||
assert_eq!(len, values_per_page / 2); | ||
assert_eq!( | ||
Some(&def_levels[accu_len..(accu_len + array.len())]), | ||
Some(&def_levels[accu_len..(accu_len + len)]), | ||
array_reader.get_def_levels() | ||
); | ||
assert_eq!( | ||
Some(&rep_levels[accu_len..(accu_len + array.len())]), | ||
Some(&rep_levels[accu_len..(accu_len + len)]), | ||
array_reader.get_rep_levels() | ||
); | ||
accu_len += array.len(); | ||
accu_len += len; | ||
array_reader.consume_batch().unwrap(); | ||
|
||
// Read next values_per_page values, the first values_per_page/2 ones are from the first column chunk, | ||
// and the last values_per_page/2 ones are from the second column chunk | ||
let array = array_reader.next_batch(values_per_page).unwrap(); | ||
assert_eq!(array.len(), values_per_page); | ||
let len = array_reader.read_records(values_per_page).unwrap(); | ||
assert_eq!(len, values_per_page); | ||
assert_eq!( | ||
Some(&def_levels[accu_len..(accu_len + array.len())]), | ||
Some(&def_levels[accu_len..(accu_len + len)]), | ||
array_reader.get_def_levels() | ||
); | ||
assert_eq!( | ||
Some(&rep_levels[accu_len..(accu_len + array.len())]), | ||
Some(&rep_levels[accu_len..(accu_len + len)]), | ||
array_reader.get_rep_levels() | ||
); | ||
let array = array_reader.consume_batch().unwrap(); | ||
let strings = array.as_any().downcast_ref::<StringArray>().unwrap(); | ||
for i in 0..array.len() { | ||
if array.is_valid(i) { | ||
|
@@ -384,19 +434,20 @@ mod tests { | |
assert_eq!(all_values[i + accu_len], None) | ||
} | ||
} | ||
accu_len += array.len(); | ||
accu_len += len; | ||
|
||
// Try to read values_per_page values, however there are only values_per_page/2 values | ||
let array = array_reader.next_batch(values_per_page).unwrap(); | ||
assert_eq!(array.len(), values_per_page / 2); | ||
let len = array_reader.read_records(values_per_page).unwrap(); | ||
assert_eq!(len, values_per_page / 2); | ||
assert_eq!( | ||
Some(&def_levels[accu_len..(accu_len + array.len())]), | ||
Some(&def_levels[accu_len..(accu_len + len)]), | ||
array_reader.get_def_levels() | ||
); | ||
assert_eq!( | ||
Some(&rep_levels[accu_len..(accu_len + array.len())]), | ||
Some(&rep_levels[accu_len..(accu_len + len)]), | ||
array_reader.get_rep_levels() | ||
); | ||
array_reader.consume_batch().unwrap(); | ||
} | ||
|
||
#[test] | ||
|
@@ -491,31 +542,34 @@ mod tests { | |
let mut accu_len: usize = 0; | ||
|
||
// println!("---------- reading a batch of {} values ----------", values_per_page / 2); | ||
let array = array_reader.next_batch(values_per_page / 2).unwrap(); | ||
assert_eq!(array.len(), values_per_page / 2); | ||
let len = array_reader.read_records(values_per_page / 2).unwrap(); | ||
assert_eq!(len, values_per_page / 2); | ||
assert_eq!( | ||
Some(&def_levels[accu_len..(accu_len + array.len())]), | ||
Some(&def_levels[accu_len..(accu_len + len)]), | ||
array_reader.get_def_levels() | ||
); | ||
assert_eq!( | ||
Some(&rep_levels[accu_len..(accu_len + array.len())]), | ||
Some(&rep_levels[accu_len..(accu_len + len)]), | ||
array_reader.get_rep_levels() | ||
); | ||
accu_len += array.len(); | ||
accu_len += len; | ||
array_reader.consume_batch().unwrap(); | ||
|
||
// Read next values_per_page values, the first values_per_page/2 ones are from the first column chunk, | ||
// and the last values_per_page/2 ones are from the second column chunk | ||
// println!("---------- reading a batch of {} values ----------", values_per_page); | ||
let array = array_reader.next_batch(values_per_page).unwrap(); | ||
assert_eq!(array.len(), values_per_page); | ||
//let array = array_reader.next_batch(values_per_page).unwrap(); | ||
let len = array_reader.read_records(values_per_page).unwrap(); | ||
assert_eq!(len, values_per_page); | ||
assert_eq!( | ||
Some(&def_levels[accu_len..(accu_len + array.len())]), | ||
Some(&def_levels[accu_len..(accu_len + len)]), | ||
array_reader.get_def_levels() | ||
); | ||
assert_eq!( | ||
Some(&rep_levels[accu_len..(accu_len + array.len())]), | ||
Some(&rep_levels[accu_len..(accu_len + len)]), | ||
array_reader.get_rep_levels() | ||
); | ||
let array = array_reader.consume_batch().unwrap(); | ||
let strings = array.as_any().downcast_ref::<StringArray>().unwrap(); | ||
for i in 0..array.len() { | ||
if array.is_valid(i) { | ||
|
@@ -527,19 +581,20 @@ mod tests { | |
assert_eq!(all_values[i + accu_len], None) | ||
} | ||
} | ||
accu_len += array.len(); | ||
accu_len += len; | ||
|
||
// Try to read values_per_page values, however there are only values_per_page/2 values | ||
// println!("---------- reading a batch of {} values ----------", values_per_page); | ||
let array = array_reader.next_batch(values_per_page).unwrap(); | ||
assert_eq!(array.len(), values_per_page / 2); | ||
let len = array_reader.read_records(values_per_page).unwrap(); | ||
assert_eq!(len, values_per_page / 2); | ||
assert_eq!( | ||
Some(&def_levels[accu_len..(accu_len + array.len())]), | ||
Some(&def_levels[accu_len..(accu_len + len)]), | ||
array_reader.get_def_levels() | ||
); | ||
assert_eq!( | ||
Some(&rep_levels[accu_len..(accu_len + array.len())]), | ||
Some(&rep_levels[accu_len..(accu_len + len)]), | ||
array_reader.get_rep_levels() | ||
); | ||
array_reader.consume_batch().unwrap(); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can just return
self.def_levels_buffer.as_deref()
, if you look at PrimitiveArrayReader it will only make the data available after the call toconsume_batch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
without this, will fail in
this cause by
get_level
before consumecomplex_object_array
(this is the common situation like other readers), butcomplex_object_array
sometimes(self is nullable) needget_level
after consume, so i think we should keep this check