Skip to content

Commit

Permalink
Use usize rather than Option<usize> to represent Limit::skipand…
Browse files Browse the repository at this point in the history
… `Limit::offset` (#3374)
  • Loading branch information
HaoYang670 authored Sep 7, 2022
1 parent c359018 commit 43e2d91
Show file tree
Hide file tree
Showing 18 changed files with 173 additions and 197 deletions.
32 changes: 13 additions & 19 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ use std::sync::Arc;
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?;
/// let df = df.filter(col("a").lt_eq(col("b")))?
/// .aggregate(vec![col("a")], vec![min(col("b"))])?
/// .limit(None, Some(100))?;
/// .limit(0, Some(100))?;
/// let results = df.collect();
/// # Ok(())
/// # }
Expand Down Expand Up @@ -217,15 +217,11 @@ impl DataFrame {
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?;
/// let df = df.limit(None, Some(100))?;
/// let df = df.limit(0, Some(100))?;
/// # Ok(())
/// # }
/// ```
pub fn limit(
&self,
skip: Option<usize>,
fetch: Option<usize>,
) -> Result<Arc<DataFrame>> {
pub fn limit(&self, skip: usize, fetch: Option<usize>) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.limit(skip, fetch)?
.build()?;
Expand Down Expand Up @@ -438,7 +434,7 @@ impl DataFrame {
/// # }
/// ```
pub async fn show_limit(&self, num: usize) -> Result<()> {
let results = self.limit(None, Some(num))?.collect().await?;
let results = self.limit(0, Some(num))?.collect().await?;
Ok(pretty::print_batches(&results)?)
}

Expand Down Expand Up @@ -543,7 +539,7 @@ impl DataFrame {
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?;
/// let batches = df.limit(None, Some(100))?.explain(false, false)?.collect().await?;
/// let batches = df.limit(0, Some(100))?.explain(false, false)?.collect().await?;
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -789,7 +785,7 @@ impl TableProvider for DataFrame {
Self::new(
self.session_state.clone(),
&limit
.map_or_else(|| Ok(expr.clone()), |n| expr.limit(None, Some(n)))?
.map_or_else(|| Ok(expr.clone()), |n| expr.limit(0, Some(n)))?
.plan
.clone(),
)
Expand Down Expand Up @@ -923,9 +919,7 @@ mod tests {
async fn limit() -> Result<()> {
// build query using Table API
let t = test_table().await?;
let t2 = t
.select_columns(&["c1", "c2", "c11"])?
.limit(None, Some(10))?;
let t2 = t.select_columns(&["c1", "c2", "c11"])?.limit(0, Some(10))?;
let plan = t2.plan.clone();

// build query using SQL
Expand All @@ -944,7 +938,7 @@ mod tests {
let df = test_table().await?;
let df = df
.select_columns(&["c1", "c2", "c11"])?
.limit(None, Some(10))?
.limit(0, Some(10))?
.explain(false, false)?;
let plan = df.plan.clone();

Expand Down Expand Up @@ -1205,7 +1199,7 @@ mod tests {
.await?
.select_columns(&["c1", "c2", "c3"])?
.filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
.limit(None, Some(1))?
.limit(0, Some(1))?
.sort(vec![
// make the test deterministic
col("c1").sort(true, true),
Expand Down Expand Up @@ -1248,7 +1242,7 @@ mod tests {
col("t2.c2").sort(true, true),
col("t2.c3").sort(true, true),
])?
.limit(None, Some(1))?;
.limit(0, Some(1))?;

let df_results = df.collect().await?;
assert_batches_sorted_eq!(
Expand All @@ -1266,7 +1260,7 @@ mod tests {

assert_eq!("\
Projection: #t1.c1 AS AAA, #t1.c2, #t1.c3, #t2.c1, #t2.c2, #t2.c3\
\n Limit: skip=None, fetch=1\
\n Limit: skip=0, fetch=1\
\n Sort: #t1.c1 ASC NULLS FIRST, #t1.c2 ASC NULLS FIRST, #t1.c3 ASC NULLS FIRST, #t2.c1 ASC NULLS FIRST, #t2.c2 ASC NULLS FIRST, #t2.c3 ASC NULLS FIRST\
\n Inner Join: #t1.c1 = #t2.c1\
\n TableScan: t1\
Expand All @@ -1276,7 +1270,7 @@ mod tests {

assert_eq!("\
Projection: #t1.c1 AS AAA, #t1.c2, #t1.c3, #t2.c1, #t2.c2, #t2.c3\
\n Limit: skip=None, fetch=1\
\n Limit: skip=0, fetch=1\
\n Sort: #t1.c1 ASC NULLS FIRST, #t1.c2 ASC NULLS FIRST, #t1.c3 ASC NULLS FIRST, #t2.c1 ASC NULLS FIRST, #t2.c2 ASC NULLS FIRST, #t2.c3 ASC NULLS FIRST\
\n Inner Join: #t1.c1 = #t2.c1\
\n TableScan: t1 projection=[c1, c2, c3]\
Expand Down Expand Up @@ -1305,7 +1299,7 @@ mod tests {
let df = test_table()
.await?
.select_columns(&["c2", "c3"])?
.limit(None, Some(1))?
.limit(0, Some(1))?
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;

let df_results = df.collect().await?;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ const DEFAULT_SCHEMA: &str = "public";
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?;
/// let df = df.filter(col("a").lt_eq(col("b")))?
/// .aggregate(vec![col("a")], vec![min(col("b"))])?
/// .limit(None, Some(100))?;
/// .limit(0, Some(100))?;
/// let results = df.collect();
/// # Ok(())
/// # }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
//! // create a plan
//! let df = df.filter(col("a").lt_eq(col("b")))?
//! .aggregate(vec![col("a")], vec![min(col("b"))])?
//! .limit(None, Some(100))?;
//! .limit(0, Some(100))?;
//!
//! // execute the plan
//! let results: Vec<RecordBatch> = df.collect().await?;
Expand Down
16 changes: 8 additions & 8 deletions datafusion/core/src/physical_optimizer/repartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,15 +329,15 @@ mod tests {
fn limit_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
Arc::new(GlobalLimitExec::new(
Arc::new(LocalLimitExec::new(input, 100)),
None,
0,
Some(100),
))
}

fn limit_exec_with_skip(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
Arc::new(GlobalLimitExec::new(
Arc::new(LocalLimitExec::new(input, 100)),
Some(5),
5,
Some(100),
))
}
Expand Down Expand Up @@ -407,7 +407,7 @@ mod tests {
let plan = limit_exec(filter_exec(parquet_exec()));

let expected = &[
"GlobalLimitExec: skip=None, fetch=100",
"GlobalLimitExec: skip=0, fetch=100",
"LocalLimitExec: fetch=100",
"FilterExec: c1@0",
// nothing sorts the data, so the local limit doesn't require sorted data either
Expand Down Expand Up @@ -441,7 +441,7 @@ mod tests {
let plan = limit_exec(sort_exec(parquet_exec()));

let expected = &[
"GlobalLimitExec: skip=None, fetch=100",
"GlobalLimitExec: skip=0, fetch=100",
"LocalLimitExec: fetch=100",
// data is sorted so can't repartition here
"SortExec: [c1@0 ASC]",
Expand All @@ -457,7 +457,7 @@ mod tests {
let plan = limit_exec(filter_exec(sort_exec(parquet_exec())));

let expected = &[
"GlobalLimitExec: skip=None, fetch=100",
"GlobalLimitExec: skip=0, fetch=100",
"LocalLimitExec: fetch=100",
"FilterExec: c1@0",
// data is sorted so can't repartition here even though
Expand All @@ -478,12 +478,12 @@ mod tests {
"AggregateExec: mode=Final, gby=[], aggr=[]",
"AggregateExec: mode=Partial, gby=[], aggr=[]",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"GlobalLimitExec: skip=None, fetch=100",
"GlobalLimitExec: skip=0, fetch=100",
"LocalLimitExec: fetch=100",
"FilterExec: c1@0",
// repartition should happen prior to the filter to maximize parallelism
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"GlobalLimitExec: skip=None, fetch=100",
"GlobalLimitExec: skip=0, fetch=100",
"LocalLimitExec: fetch=100",
// Expect no repartition to happen for local limit
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
Expand All @@ -508,7 +508,7 @@ mod tests {
"FilterExec: c1@0",
// repartition should happen prior to the filter to maximize parallelism
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"GlobalLimitExec: skip=None, fetch=100",
"GlobalLimitExec: skip=0, fetch=100",
"LocalLimitExec: fetch=100",
// Expect no repartition to happen for local limit
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
Expand Down
50 changes: 22 additions & 28 deletions datafusion/core/src/physical_plan/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,17 @@ pub struct GlobalLimitExec {
/// Input execution plan
input: Arc<dyn ExecutionPlan>,
/// Number of rows to skip before fetch
skip: Option<usize>,
/// Maximum number of rows to fetch
skip: usize,
/// Maximum number of rows to fetch,
/// `None` means fetching all rows
fetch: Option<usize>,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
}

impl GlobalLimitExec {
/// Create a new GlobalLimitExec
pub fn new(
input: Arc<dyn ExecutionPlan>,
skip: Option<usize>,
fetch: Option<usize>,
) -> Self {
pub fn new(input: Arc<dyn ExecutionPlan>, skip: usize, fetch: Option<usize>) -> Self {
GlobalLimitExec {
input,
skip,
Expand All @@ -77,8 +74,8 @@ impl GlobalLimitExec {
}

/// Number of rows to skip before fetch
pub fn skip(&self) -> Option<&usize> {
self.skip.as_ref()
pub fn skip(&self) -> usize {
self.skip
}

/// Maximum number of rows to fetch
Expand Down Expand Up @@ -181,7 +178,7 @@ impl ExecutionPlan for GlobalLimitExec {
write!(
f,
"GlobalLimitExec: skip={}, fetch={}",
self.skip.map_or("None".to_string(), |x| x.to_string()),
self.skip,
self.fetch.map_or("None".to_string(), |x| x.to_string())
)
}
Expand All @@ -194,7 +191,7 @@ impl ExecutionPlan for GlobalLimitExec {

fn statistics(&self) -> Statistics {
let input_stats = self.input.statistics();
let skip = self.skip.unwrap_or(0);
let skip = self.skip;
match input_stats {
Statistics {
num_rows: Some(nr), ..
Expand Down Expand Up @@ -319,7 +316,7 @@ impl ExecutionPlan for LocalLimitExec {
let stream = self.input.execute(partition, context)?;
Ok(Box::pin(LimitStream::new(
stream,
None,
0,
Some(self.fetch),
baseline_metrics,
)))
Expand Down Expand Up @@ -397,13 +394,13 @@ struct LimitStream {
impl LimitStream {
fn new(
input: SendableRecordBatchStream,
skip: Option<usize>,
skip: usize,
fetch: Option<usize>,
baseline_metrics: BaselineMetrics,
) -> Self {
let schema = input.schema();
Self {
skip: skip.unwrap_or(0),
skip,
fetch: fetch.unwrap_or(usize::MAX),
input: Some(input),
schema,
Expand Down Expand Up @@ -524,11 +521,8 @@ mod tests {
// input should have 4 partitions
assert_eq!(csv.output_partitioning().partition_count(), num_partitions);

let limit = GlobalLimitExec::new(
Arc::new(CoalescePartitionsExec::new(csv)),
None,
Some(7),
);
let limit =
GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), 0, Some(7));

// the result should contain 4 batches (one per input partition)
let iter = limit.execute(0, task_ctx)?;
Expand Down Expand Up @@ -559,7 +553,7 @@ mod tests {
// (5 rows) and 1 row from the second (1 row)
let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
let limit_stream =
LimitStream::new(Box::pin(input), None, Some(6), baseline_metrics);
LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
assert_eq!(index.value(), 0);

let results = collect(Box::pin(limit_stream)).await.unwrap();
Expand All @@ -574,7 +568,7 @@ mod tests {
}

// test cases for "skip"
async fn skip_and_fetch(skip: Option<usize>, fetch: Option<usize>) -> Result<usize> {
async fn skip_and_fetch(skip: usize, fetch: Option<usize>) -> Result<usize> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();

Expand All @@ -594,52 +588,52 @@ mod tests {

#[tokio::test]
async fn skip_none_fetch_none() -> Result<()> {
let row_count = skip_and_fetch(None, None).await?;
let row_count = skip_and_fetch(0, None).await?;
assert_eq!(row_count, 100);
Ok(())
}

#[tokio::test]
async fn skip_none_fetch_50() -> Result<()> {
let row_count = skip_and_fetch(None, Some(50)).await?;
let row_count = skip_and_fetch(0, Some(50)).await?;
assert_eq!(row_count, 50);
Ok(())
}

#[tokio::test]
async fn skip_3_fetch_none() -> Result<()> {
// there are total of 100 rows, we skipped 3 rows (offset = 3)
let row_count = skip_and_fetch(Some(3), None).await?;
let row_count = skip_and_fetch(3, None).await?;
assert_eq!(row_count, 97);
Ok(())
}

#[tokio::test]
async fn skip_3_fetch_10() -> Result<()> {
// there are total of 100 rows, we skipped 3 rows (offset = 3)
let row_count = skip_and_fetch(Some(3), Some(10)).await?;
let row_count = skip_and_fetch(3, Some(10)).await?;
assert_eq!(row_count, 10);
Ok(())
}

#[tokio::test]
async fn skip_100_fetch_none() -> Result<()> {
let row_count = skip_and_fetch(Some(100), None).await?;
let row_count = skip_and_fetch(100, None).await?;
assert_eq!(row_count, 0);
Ok(())
}

#[tokio::test]
async fn skip_100_fetch_1() -> Result<()> {
let row_count = skip_and_fetch(Some(100), Some(1)).await?;
let row_count = skip_and_fetch(100, Some(1)).await?;
assert_eq!(row_count, 0);
Ok(())
}

#[tokio::test]
async fn skip_101_fetch_none() -> Result<()> {
// there are total of 100 rows, we skipped 101 rows (offset = 3)
let row_count = skip_and_fetch(Some(101), None).await?;
let row_count = skip_and_fetch(101, None).await?;
assert_eq!(row_count, 0);
Ok(())
}
Expand Down
Loading

0 comments on commit 43e2d91

Please sign in to comment.