Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use usize rather than Option<usize> to represent Limit::skipand Limit::offset #3374

Merged
merged 5 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ use std::sync::Arc;
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?;
/// let df = df.filter(col("a").lt_eq(col("b")))?
/// .aggregate(vec![col("a")], vec![min(col("b"))])?
/// .limit(None, Some(100))?;
/// .limit(0, Some(100))?;
/// let results = df.collect();
/// # Ok(())
/// # }
Expand Down Expand Up @@ -217,15 +217,11 @@ impl DataFrame {
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?;
/// let df = df.limit(None, Some(100))?;
/// let df = df.limit(0, Some(100))?;
/// # Ok(())
/// # }
/// ```
pub fn limit(
&self,
skip: Option<usize>,
fetch: Option<usize>,
) -> Result<Arc<DataFrame>> {
pub fn limit(&self, skip: usize, fetch: Option<usize>) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.limit(skip, fetch)?
.build()?;
Expand Down Expand Up @@ -438,7 +434,7 @@ impl DataFrame {
/// # }
/// ```
pub async fn show_limit(&self, num: usize) -> Result<()> {
let results = self.limit(None, Some(num))?.collect().await?;
let results = self.limit(0, Some(num))?.collect().await?;
Ok(pretty::print_batches(&results)?)
}

Expand Down Expand Up @@ -543,7 +539,7 @@ impl DataFrame {
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?;
/// let batches = df.limit(None, Some(100))?.explain(false, false)?.collect().await?;
/// let batches = df.limit(0, Some(100))?.explain(false, false)?.collect().await?;
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -789,7 +785,7 @@ impl TableProvider for DataFrame {
Self::new(
self.session_state.clone(),
&limit
.map_or_else(|| Ok(expr.clone()), |n| expr.limit(None, Some(n)))?
.map_or_else(|| Ok(expr.clone()), |n| expr.limit(0, Some(n)))?
.plan
.clone(),
)
Expand Down Expand Up @@ -923,9 +919,7 @@ mod tests {
async fn limit() -> Result<()> {
// build query using Table API
let t = test_table().await?;
let t2 = t
.select_columns(&["c1", "c2", "c11"])?
.limit(None, Some(10))?;
let t2 = t.select_columns(&["c1", "c2", "c11"])?.limit(0, Some(10))?;
let plan = t2.plan.clone();

// build query using SQL
Expand All @@ -944,7 +938,7 @@ mod tests {
let df = test_table().await?;
let df = df
.select_columns(&["c1", "c2", "c11"])?
.limit(None, Some(10))?
.limit(0, Some(10))?
.explain(false, false)?;
let plan = df.plan.clone();

Expand Down Expand Up @@ -1205,7 +1199,7 @@ mod tests {
.await?
.select_columns(&["c1", "c2", "c3"])?
.filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
.limit(None, Some(1))?
.limit(0, Some(1))?
.sort(vec![
// make the test deterministic
col("c1").sort(true, true),
Expand Down Expand Up @@ -1248,7 +1242,7 @@ mod tests {
col("t2.c2").sort(true, true),
col("t2.c3").sort(true, true),
])?
.limit(None, Some(1))?;
.limit(0, Some(1))?;

let df_results = df.collect().await?;
assert_batches_sorted_eq!(
Expand All @@ -1266,7 +1260,7 @@ mod tests {

assert_eq!("\
Projection: #t1.c1 AS AAA, #t1.c2, #t1.c3, #t2.c1, #t2.c2, #t2.c3\
\n Limit: skip=None, fetch=1\
\n Limit: skip=0, fetch=1\
\n Sort: #t1.c1 ASC NULLS FIRST, #t1.c2 ASC NULLS FIRST, #t1.c3 ASC NULLS FIRST, #t2.c1 ASC NULLS FIRST, #t2.c2 ASC NULLS FIRST, #t2.c3 ASC NULLS FIRST\
\n Inner Join: #t1.c1 = #t2.c1\
\n TableScan: t1\
Expand All @@ -1276,7 +1270,7 @@ mod tests {

assert_eq!("\
Projection: #t1.c1 AS AAA, #t1.c2, #t1.c3, #t2.c1, #t2.c2, #t2.c3\
\n Limit: skip=None, fetch=1\
\n Limit: skip=0, fetch=1\
\n Sort: #t1.c1 ASC NULLS FIRST, #t1.c2 ASC NULLS FIRST, #t1.c3 ASC NULLS FIRST, #t2.c1 ASC NULLS FIRST, #t2.c2 ASC NULLS FIRST, #t2.c3 ASC NULLS FIRST\
\n Inner Join: #t1.c1 = #t2.c1\
\n TableScan: t1 projection=[c1, c2, c3]\
Expand Down Expand Up @@ -1305,7 +1299,7 @@ mod tests {
let df = test_table()
.await?
.select_columns(&["c2", "c3"])?
.limit(None, Some(1))?
.limit(0, Some(1))?
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;

let df_results = df.collect().await?;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ const DEFAULT_SCHEMA: &str = "public";
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?;
/// let df = df.filter(col("a").lt_eq(col("b")))?
/// .aggregate(vec![col("a")], vec![min(col("b"))])?
/// .limit(None, Some(100))?;
/// .limit(0, Some(100))?;
/// let results = df.collect();
/// # Ok(())
/// # }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
//! // create a plan
//! let df = df.filter(col("a").lt_eq(col("b")))?
//! .aggregate(vec![col("a")], vec![min(col("b"))])?
//! .limit(None, Some(100))?;
//! .limit(0, Some(100))?;
//!
//! // execute the plan
//! let results: Vec<RecordBatch> = df.collect().await?;
Expand Down
16 changes: 8 additions & 8 deletions datafusion/core/src/physical_optimizer/repartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,15 +329,15 @@ mod tests {
fn limit_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
Arc::new(GlobalLimitExec::new(
Arc::new(LocalLimitExec::new(input, 100)),
None,
0,
Some(100),
))
}

fn limit_exec_with_skip(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
Arc::new(GlobalLimitExec::new(
Arc::new(LocalLimitExec::new(input, 100)),
Some(5),
5,
Some(100),
))
}
Expand Down Expand Up @@ -407,7 +407,7 @@ mod tests {
let plan = limit_exec(filter_exec(parquet_exec()));

let expected = &[
"GlobalLimitExec: skip=None, fetch=100",
"GlobalLimitExec: skip=0, fetch=100",
"LocalLimitExec: fetch=100",
"FilterExec: c1@0",
// nothing sorts the data, so the local limit doesn't require sorted data either
Expand Down Expand Up @@ -441,7 +441,7 @@ mod tests {
let plan = limit_exec(sort_exec(parquet_exec()));

let expected = &[
"GlobalLimitExec: skip=None, fetch=100",
"GlobalLimitExec: skip=0, fetch=100",
"LocalLimitExec: fetch=100",
// data is sorted so can't repartition here
"SortExec: [c1@0 ASC]",
Expand All @@ -457,7 +457,7 @@ mod tests {
let plan = limit_exec(filter_exec(sort_exec(parquet_exec())));

let expected = &[
"GlobalLimitExec: skip=None, fetch=100",
"GlobalLimitExec: skip=0, fetch=100",
"LocalLimitExec: fetch=100",
"FilterExec: c1@0",
// data is sorted so can't repartition here even though
Expand All @@ -478,12 +478,12 @@ mod tests {
"AggregateExec: mode=Final, gby=[], aggr=[]",
"AggregateExec: mode=Partial, gby=[], aggr=[]",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"GlobalLimitExec: skip=None, fetch=100",
"GlobalLimitExec: skip=0, fetch=100",
"LocalLimitExec: fetch=100",
"FilterExec: c1@0",
// repartition should happen prior to the filter to maximize parallelism
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"GlobalLimitExec: skip=None, fetch=100",
"GlobalLimitExec: skip=0, fetch=100",
"LocalLimitExec: fetch=100",
// Expect no repartition to happen for local limit
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
Expand All @@ -508,7 +508,7 @@ mod tests {
"FilterExec: c1@0",
// repartition should happen prior to the filter to maximize parallelism
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"GlobalLimitExec: skip=None, fetch=100",
"GlobalLimitExec: skip=0, fetch=100",
"LocalLimitExec: fetch=100",
// Expect no repartition to happen for local limit
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
Expand Down
50 changes: 22 additions & 28 deletions datafusion/core/src/physical_plan/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,17 @@ pub struct GlobalLimitExec {
/// Input execution plan
input: Arc<dyn ExecutionPlan>,
/// Number of rows to skip before fetch
skip: Option<usize>,
/// Maximum number of rows to fetch
skip: usize,
/// Maximum number of rows to fetch,
/// `None` means fetching all rows
fetch: Option<usize>,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
}

impl GlobalLimitExec {
/// Create a new GlobalLimitExec
pub fn new(
input: Arc<dyn ExecutionPlan>,
skip: Option<usize>,
fetch: Option<usize>,
) -> Self {
pub fn new(input: Arc<dyn ExecutionPlan>, skip: usize, fetch: Option<usize>) -> Self {
GlobalLimitExec {
input,
skip,
Expand All @@ -77,8 +74,8 @@ impl GlobalLimitExec {
}

/// Number of rows to skip before fetch
pub fn skip(&self) -> Option<&usize> {
self.skip.as_ref()
pub fn skip(&self) -> usize {
self.skip
}

/// Maximum number of rows to fetch
Expand Down Expand Up @@ -181,7 +178,7 @@ impl ExecutionPlan for GlobalLimitExec {
write!(
f,
"GlobalLimitExec: skip={}, fetch={}",
self.skip.map_or("None".to_string(), |x| x.to_string()),
self.skip,
self.fetch.map_or("None".to_string(), |x| x.to_string())
)
}
Expand All @@ -194,7 +191,7 @@ impl ExecutionPlan for GlobalLimitExec {

fn statistics(&self) -> Statistics {
let input_stats = self.input.statistics();
let skip = self.skip.unwrap_or(0);
let skip = self.skip;
match input_stats {
Statistics {
num_rows: Some(nr), ..
Expand Down Expand Up @@ -319,7 +316,7 @@ impl ExecutionPlan for LocalLimitExec {
let stream = self.input.execute(partition, context)?;
Ok(Box::pin(LimitStream::new(
stream,
None,
0,
Some(self.fetch),
baseline_metrics,
)))
Expand Down Expand Up @@ -397,13 +394,13 @@ struct LimitStream {
impl LimitStream {
fn new(
input: SendableRecordBatchStream,
skip: Option<usize>,
skip: usize,
fetch: Option<usize>,
baseline_metrics: BaselineMetrics,
) -> Self {
let schema = input.schema();
Self {
skip: skip.unwrap_or(0),
skip,
fetch: fetch.unwrap_or(usize::MAX),
input: Some(input),
schema,
Expand Down Expand Up @@ -524,11 +521,8 @@ mod tests {
// input should have 4 partitions
assert_eq!(csv.output_partitioning().partition_count(), num_partitions);

let limit = GlobalLimitExec::new(
Arc::new(CoalescePartitionsExec::new(csv)),
None,
Some(7),
);
let limit =
GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), 0, Some(7));

// the result should contain 4 batches (one per input partition)
let iter = limit.execute(0, task_ctx)?;
Expand Down Expand Up @@ -559,7 +553,7 @@ mod tests {
// (5 rows) and 1 row from the second (1 row)
let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let limit_stream =
LimitStream::new(Box::pin(input), None, Some(6), baseline_metrics);
LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
assert_eq!(index.value(), 0);

let results = collect(Box::pin(limit_stream)).await.unwrap();
Expand All @@ -574,7 +568,7 @@ mod tests {
}

// test cases for "skip"
async fn skip_and_fetch(skip: Option<usize>, fetch: Option<usize>) -> Result<usize> {
async fn skip_and_fetch(skip: usize, fetch: Option<usize>) -> Result<usize> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();

Expand All @@ -594,52 +588,52 @@ mod tests {

#[tokio::test]
async fn skip_none_fetch_none() -> Result<()> {
let row_count = skip_and_fetch(None, None).await?;
let row_count = skip_and_fetch(0, None).await?;
assert_eq!(row_count, 100);
Ok(())
}

#[tokio::test]
async fn skip_none_fetch_50() -> Result<()> {
let row_count = skip_and_fetch(None, Some(50)).await?;
let row_count = skip_and_fetch(0, Some(50)).await?;
assert_eq!(row_count, 50);
Ok(())
}

#[tokio::test]
async fn skip_3_fetch_none() -> Result<()> {
// there are total of 100 rows, we skipped 3 rows (offset = 3)
let row_count = skip_and_fetch(Some(3), None).await?;
let row_count = skip_and_fetch(3, None).await?;
assert_eq!(row_count, 97);
Ok(())
}

#[tokio::test]
async fn skip_3_fetch_10() -> Result<()> {
// there are total of 100 rows, we skipped 3 rows (offset = 3)
let row_count = skip_and_fetch(Some(3), Some(10)).await?;
let row_count = skip_and_fetch(3, Some(10)).await?;
assert_eq!(row_count, 10);
Ok(())
}

#[tokio::test]
async fn skip_100_fetch_none() -> Result<()> {
let row_count = skip_and_fetch(Some(100), None).await?;
let row_count = skip_and_fetch(100, None).await?;
assert_eq!(row_count, 0);
Ok(())
}

#[tokio::test]
async fn skip_100_fetch_1() -> Result<()> {
let row_count = skip_and_fetch(Some(100), Some(1)).await?;
let row_count = skip_and_fetch(100, Some(1)).await?;
assert_eq!(row_count, 0);
Ok(())
}

#[tokio::test]
async fn skip_101_fetch_none() -> Result<()> {
// there are total of 100 rows, we skipped 101 rows (offset = 3)
let row_count = skip_and_fetch(Some(101), None).await?;
let row_count = skip_and_fetch(101, None).await?;
assert_eq!(row_count, 0);
Ok(())
}
Expand Down
Loading