Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split output batches of joins that do not respect batch size #12969

Merged
merged 11 commits into from
Oct 18, 2024
26 changes: 17 additions & 9 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,12 @@ config_namespace! {
/// if the source of statistics is accurate.
/// We plan to make this the default in the future.
pub use_row_number_estimates_to_optimize_partitioning: bool, default = false

/// Should DataFusion enforce batch size in joins or not. By default,
/// DataFusion will not enforce batch size in joins. Enforcing batch size
/// in joins can help to avoid out-of-memory errors when joining large
/// tables with a highly-selective join filter.
alihan-synnada marked this conversation as resolved.
Show resolved Hide resolved
pub enforce_batch_size_in_joins: bool, default = false
}
}

Expand Down Expand Up @@ -1222,16 +1228,18 @@ impl ConfigField for TableOptions {
fn set(&mut self, key: &str, value: &str) -> Result<()> {
// Extensions are handled in the public `ConfigOptions::set`
let (key, rem) = key.split_once('.').unwrap_or((key, ""));
let Some(format) = &self.current_format else {
return _config_err!("Specify a format for TableOptions");
};
match key {
"format" => match format {
#[cfg(feature = "parquet")]
ConfigFileType::PARQUET => self.parquet.set(rem, value),
ConfigFileType::CSV => self.csv.set(rem, value),
ConfigFileType::JSON => self.json.set(rem, value),
},
"format" => {
let Some(format) = &self.current_format else {
return _config_err!("Specify a format for TableOptions");
};
match format {
#[cfg(feature = "parquet")]
ConfigFileType::PARQUET => self.parquet.set(rem, value),
ConfigFileType::CSV => self.csv.set(rem, value),
ConfigFileType::JSON => self.json.set(rem, value),
}
}
_ => _config_err!("Config value \"{key}\" not found on TableOptions"),
}
}
Expand Down
14 changes: 14 additions & 0 deletions datafusion/execution/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,20 @@ impl SessionConfig {
self
}

/// Enables or disables the enforcement of batch size in joins
pub fn with_enforce_batch_size_in_joins(
mut self,
enforce_batch_size_in_joins: bool,
) -> Self {
self.options.execution.enforce_batch_size_in_joins = enforce_batch_size_in_joins;
self
}

/// Returns true if the joins will be enforced to output batches of the configured size
pub fn enforce_batch_size_in_joins(&self) -> bool {
self.options.execution.enforce_batch_size_in_joins
}

/// Convert configuration options to name-value pairs with values
/// converted to strings.
///
Expand Down
84 changes: 59 additions & 25 deletions datafusion/physical-plan/src/joins/cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
//! and producing batches in parallel for the right partitions

use super::utils::{
adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut,
adjust_right_output_partitioning, BatchSplitter, BatchTransformer,
BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut,
StatefulStreamResult,
};
use crate::coalesce_partitions::CoalescePartitionsExec;
Expand Down Expand Up @@ -86,6 +87,7 @@ impl CrossJoinExec {

let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata));
let cache = Self::compute_properties(&left, &right, Arc::clone(&schema));

CrossJoinExec {
left,
right,
Expand Down Expand Up @@ -246,6 +248,10 @@ impl ExecutionPlan for CrossJoinExec {
let reservation =
MemoryConsumer::new("CrossJoinExec").register(context.memory_pool());

let batch_size = context.session_config().batch_size();
let enforce_batch_size_in_joins =
context.session_config().enforce_batch_size_in_joins();

let left_fut = self.left_fut.once(|| {
load_left_input(
Arc::clone(&self.left),
Expand All @@ -255,15 +261,29 @@ impl ExecutionPlan for CrossJoinExec {
)
});

Ok(Box::pin(CrossJoinStream {
schema: Arc::clone(&self.schema),
left_fut,
right: stream,
left_index: 0,
join_metrics,
state: CrossJoinStreamState::WaitBuildSide,
left_data: RecordBatch::new_empty(self.left().schema()),
}))
if enforce_batch_size_in_joins {
Ok(Box::pin(CrossJoinStream {
schema: Arc::clone(&self.schema),
left_fut,
right: stream,
left_index: 0,
join_metrics,
state: CrossJoinStreamState::WaitBuildSide,
left_data: RecordBatch::new_empty(self.left().schema()),
batch_transformer: BatchSplitter::new(batch_size),
}))
} else {
Ok(Box::pin(CrossJoinStream {
schema: Arc::clone(&self.schema),
left_fut,
right: stream,
left_index: 0,
join_metrics,
state: CrossJoinStreamState::WaitBuildSide,
left_data: RecordBatch::new_empty(self.left().schema()),
batch_transformer: NoopBatchTransformer::new(),
}))
}
}

fn statistics(&self) -> Result<Statistics> {
Expand Down Expand Up @@ -319,7 +339,7 @@ fn stats_cartesian_product(
}

/// A stream that issues [RecordBatch]es as they arrive from the right of the join.
struct CrossJoinStream {
struct CrossJoinStream<T> {
/// Input schema
schema: Arc<Schema>,
/// Future for data from left side
Expand All @@ -334,9 +354,11 @@ struct CrossJoinStream {
state: CrossJoinStreamState,
/// Left data
left_data: RecordBatch,
/// Batch transformer
batch_transformer: T,
}

impl RecordBatchStream for CrossJoinStream {
impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for CrossJoinStream<T> {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
Expand Down Expand Up @@ -390,7 +412,7 @@ fn build_batch(
}

#[async_trait]
impl Stream for CrossJoinStream {
impl<T: BatchTransformer + Unpin + Send> Stream for CrossJoinStream<T> {
type Item = Result<RecordBatch>;

fn poll_next(
Expand All @@ -401,7 +423,7 @@ impl Stream for CrossJoinStream {
}
}

impl CrossJoinStream {
impl<T: BatchTransformer> CrossJoinStream<T> {
/// Separate implementation function that unpins the [`CrossJoinStream`] so
/// that partial borrows work correctly
fn poll_next_impl(
Expand Down Expand Up @@ -470,21 +492,33 @@ impl CrossJoinStream {
fn build_batches(&mut self) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
let right_batch = self.state.try_as_record_batch()?;
if self.left_index < self.left_data.num_rows() {
let join_timer = self.join_metrics.join_time.timer();
let result =
build_batch(self.left_index, right_batch, &self.left_data, &self.schema);
join_timer.done();

if let Ok(ref batch) = result {
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());
match self.batch_transformer.next() {
None => {
let join_timer = self.join_metrics.join_time.timer();
let result = build_batch(
self.left_index,
right_batch,
&self.left_data,
&self.schema,
);
join_timer.done();

self.batch_transformer.set_batch(result?);
}
Some((batch, last)) => {
if last {
self.left_index += 1;
}

self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());
return Ok(StatefulStreamResult::Ready(Some(batch)));
}
}
self.left_index += 1;
result.map(|r| StatefulStreamResult::Ready(Some(r)))
} else {
self.state = CrossJoinStreamState::FetchProbeBatch;
Ok(StatefulStreamResult::Continue)
}
Ok(StatefulStreamResult::Continue)
}
}

Expand Down
Loading
Loading