Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: add batch_size option to merge_columns #2896

Merged
merged 2 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def merge_columns(
self,
value_func: Callable[[pa.RecordBatch], pa.RecordBatch],
columns: Optional[list[str]] = None,
batch_size: Optional[int] = None,
) -> Tuple[FragmentMetadata, LanceSchema]:
"""Add columns to this Fragment.

Expand All @@ -390,7 +391,7 @@ def merge_columns(
Tuple[FragmentMetadata, LanceSchema]
A new fragment with the added column(s) and the final schema.
"""
updater = self._fragment.updater(columns)
updater = self._fragment.updater(columns, batch_size)

while True:
batch = updater.next()
Expand Down
27 changes: 27 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,33 @@ def test_merge_with_commit(tmp_path: Path):
assert tbl == expected


def test_merge_batch_size(tmp_path: Path):
# Create dataset with 10 fragments with 100 rows each
table = pa.table({"a": range(1000)})
for batch_size in [1, 10, 100, 1000]:
ds_path = str(tmp_path / str(batch_size))
dataset = lance.write_dataset(table, ds_path, max_rows_per_file=100)
fragments = []

def mutate(batch):
assert batch.num_rows <= batch_size
return pa.RecordBatch.from_pydict({"b": batch.column("a")})

for frag in dataset.get_fragments():
merged, schema = frag.merge_columns(mutate, batch_size=batch_size)
fragments.append(merged)

merge = lance.LanceOperation.Merge(fragments, schema)
dataset = lance.LanceDataset.commit(
ds_path, merge, read_version=dataset.version
)

dataset.validate()
tbl = dataset.to_table()
expected = pa.table({"a": range(1000), "b": range(1000)})
assert tbl == expected


def test_merge_with_schema_holes(tmp_path: Path):
# Create table with 3 cols
table = pa.table({"a": range(10)})
Expand Down
6 changes: 4 additions & 2 deletions python/src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,12 @@ impl FileFragment {
Ok(Scanner::new(scn))
}

fn updater(&self, columns: Option<Vec<String>>) -> PyResult<Updater> {
fn updater(&self, columns: Option<Vec<String>>, batch_size: Option<u32>) -> PyResult<Updater> {
let cols = columns.as_deref();
let inner = RT
.block_on(None, async { self.fragment.updater(cols, None).await })?
.block_on(None, async {
self.fragment.updater(cols, None, batch_size).await
})?
.map_err(|err| PyIOError::new_err(err.to_string()))?;
Ok(Updater::new(inner))
}
Expand Down
9 changes: 5 additions & 4 deletions rust/lance/src/dataset/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,7 @@ impl FileFragment {
&self,
columns: Option<&[T]>,
schemas: Option<(Schema, Schema)>,
batch_size: Option<u32>,
) -> Result<Updater> {
let mut schema = self.dataset.schema().clone();

Expand Down Expand Up @@ -1160,11 +1161,11 @@ impl FileFragment {
let reader = reader?;
let deletion_vector = deletion_vector?.unwrap_or_default();

Updater::try_new(self.clone(), reader, deletion_vector, schemas)
Updater::try_new(self.clone(), reader, deletion_vector, schemas, batch_size)
}

pub(crate) async fn merge(mut self, join_column: &str, joiner: &HashJoiner) -> Result<Self> {
let mut updater = self.updater(Some(&[join_column]), None).await?;
let mut updater = self.updater(Some(&[join_column]), None, None).await?;

while let Some(batch) = updater.next().await? {
let batch = joiner.collect(batch[join_column].clone()).await?;
Expand Down Expand Up @@ -2433,7 +2434,7 @@ mod tests {
}

let fragment = &mut dataset.get_fragment(0).unwrap();
let mut updater = fragment.updater(Some(&["i"]), None).await.unwrap();
let mut updater = fragment.updater(Some(&["i"]), None, None).await.unwrap();
let new_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"double_i",
DataType::Int32,
Expand Down Expand Up @@ -2677,7 +2678,7 @@ mod tests {
let fragment = dataset.get_fragments().pop().unwrap();

// Write batch_s using add_columns
let mut updater = fragment.updater(Some(&["i"]), None).await?;
let mut updater = fragment.updater(Some(&["i"]), None, None).await?;
updater.next().await?;
updater.update(batch_s.clone()).await?;
let frag = updater.finish().await?;
Expand Down
32 changes: 15 additions & 17 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ use snafu::{location, Location};
#[cfg(feature = "substrait")]
use lance_datafusion::substrait::parse_substrait;

pub const DEFAULT_BATCH_SIZE: usize = 8192;
const BATCH_SIZE_FALLBACK: usize = 8192;
// For backwards compatibility / historical reasons we re-calculate the default batch size
// on each call
pub fn get_default_batch_size() -> Option<usize> {
std::env::var("LANCE_DEFAULT_BATCH_SIZE")
.map(|val| Some(val.parse().unwrap()))
.unwrap_or(None)
}

pub const LEGACY_DEFAULT_FRAGMENT_READAHEAD: usize = 4;
lazy_static::lazy_static! {
Expand Down Expand Up @@ -262,23 +269,14 @@ impl Scanner {
// 64KB, this is 16K rows. For local file systems, the default block size
// is just 4K, which would mean only 1K rows, which might be a little small.
// So we use a default minimum of 8K rows.
std::env::var("LANCE_DEFAULT_BATCH_SIZE")
.map(|bs| {
bs.parse().unwrap_or_else(|_| {
panic!(
"The value of LANCE_DEFAULT_BATCH_SIZE ({}) is not a valid batch size",
bs
)
})
})
.unwrap_or_else(|_| {
self.batch_size.unwrap_or_else(|| {
std::cmp::max(
self.dataset.object_store().block_size() / 4,
DEFAULT_BATCH_SIZE,
)
})
get_default_batch_size().unwrap_or_else(|| {
self.batch_size.unwrap_or_else(|| {
std::cmp::max(
self.dataset.object_store().block_size() / 4,
BATCH_SIZE_FALLBACK,
)
})
})
}

fn ensure_not_fragment_scan(&self) -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/dataset/schema_evolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ async fn add_columns_impl(
}

let mut updater = fragment
.updater(read_columns_ref, schemas_ref.clone())
.updater(read_columns_ref, schemas_ref.clone(), None)
.await?;

let mut batch_index = 0;
Expand Down
34 changes: 23 additions & 11 deletions rust/lance/src/dataset/updater.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use lance_table::utils::stream::ReadBatchFutStream;
use snafu::{location, Location};

use super::fragment::FragmentReader;
use super::scanner::get_default_batch_size;
use super::write::{open_writer, GenericWriter};
use super::Dataset;
use crate::dataset::FileFragment;
Expand Down Expand Up @@ -59,16 +60,26 @@ impl Updater {
reader: FragmentReader,
deletion_vector: DeletionVector,
schemas: Option<(Schema, Schema)>,
batch_size: Option<u32>,
) -> Result<Self> {
let (write_schema, final_schema) = if let Some((write_schema, final_schema)) = schemas {
(Some(write_schema), Some(final_schema))
} else {
(None, None)
};

let batch_size = reader.legacy_num_rows_in_batch(0);
let legacy_batch_size = reader.legacy_num_rows_in_batch(0);

let input_stream = reader.read_all(1024)?;
let batch_size = match (&legacy_batch_size, batch_size) {
// If this is a v1 dataset we must use the row group size of the file
(Some(num_rows), _) => *num_rows,
// If this is a v2 dataset, let the user pick the batch size
(None, Some(legacy_batch_size)) => legacy_batch_size,
// Otherwise, default to 1024 if the user didn't specify anything
(None, None) => get_default_batch_size().unwrap_or(1024) as u32,
};

let input_stream = reader.read_all(batch_size)?;

Ok(Self {
fragment,
Expand All @@ -78,7 +89,7 @@ impl Updater {
write_schema,
final_schema,
finished: false,
deletion_restorer: DeletionRestorer::new(deletion_vector, batch_size),
deletion_restorer: DeletionRestorer::new(deletion_vector, legacy_batch_size),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we passing down the legacy batch size here, instead of batch_size?

Copy link
Contributor Author

@westonpace westonpace Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeletionRestorer uses the presence of this field to trigger a "v1 compatibility mode" if this is set (to match row group sizes). Otherwise the deletion restorer doesn't need to know the batch size.

I renamed the internal variable to legacy_batch_size as well to emphasize this. However, maybe I should rename to row_group_size or something better?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. That's fine as is. Thanks for the explanation.

})
}

Expand Down Expand Up @@ -226,18 +237,18 @@ struct DeletionRestorer {
current_row_id: u32,

/// Number of rows in each batch, only used in legacy files for validation
batch_size: Option<u32>,
legacy_batch_size: Option<u32>,

deletion_vector_iter: Option<Box<dyn Iterator<Item = u32> + Send>>,

last_deleted_row_id: Option<u32>,
}

impl DeletionRestorer {
fn new(deletion_vector: DeletionVector, batch_size: Option<u32>) -> Self {
fn new(deletion_vector: DeletionVector, legacy_batch_size: Option<u32>) -> Self {
Self {
current_row_id: 0,
batch_size,
legacy_batch_size,
deletion_vector_iter: Some(deletion_vector.into_sorted_iter()),
last_deleted_row_id: None,
}
Expand All @@ -248,12 +259,12 @@ impl DeletionRestorer {
}

fn is_full(batch_size: Option<u32>, num_rows: u32) -> bool {
if let Some(batch_size) = batch_size {
if let Some(legacy_batch_size) = batch_size {
// We should never encounter the case that `batch_size < num_rows` because
// that would mean we have a v1 writer and it generated a batch with more rows
// than expected
debug_assert!(batch_size >= num_rows);
batch_size == num_rows
debug_assert!(legacy_batch_size >= num_rows);
legacy_batch_size == num_rows
} else {
false
}
Expand Down Expand Up @@ -295,7 +306,8 @@ impl DeletionRestorer {
loop {
if let Some(next_deleted_id) = next_deleted_id {
if next_deleted_id > last_row_id
|| (next_deleted_id == last_row_id && Self::is_full(self.batch_size, num_rows))
|| (next_deleted_id == last_row_id
&& Self::is_full(self.legacy_batch_size, num_rows))
{
// Either the next deleted id is out of range or it is the next row but
// we are full. Either way, stash it and return
Expand All @@ -322,7 +334,7 @@ impl DeletionRestorer {
let deleted_batch_offsets = self.deleted_batch_offsets_in_range(batch.num_rows() as u32);
let batch = add_blanks(batch, &deleted_batch_offsets)?;

if let Some(batch_size) = self.batch_size {
if let Some(batch_size) = self.legacy_batch_size {
// validation just in case, when the input has a fixed batch size then the
// output should have the same fixed batch size (except the last batch)
let is_last = self.is_exhausted();
Expand Down
1 change: 1 addition & 0 deletions rust/lance/src/dataset/write/merge_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ impl MergeInsertJob {
.updater(
Some(&read_columns),
Some((write_schema, dataset.schema().clone())),
None,
)
.await?;

Expand Down
Loading