Skip to content

Commit

Permalink
optimize code and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Feb 12, 2024
1 parent 7f0f734 commit 769dbfd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 46 deletions.
7 changes: 6 additions & 1 deletion datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use crate::{
optimizer::optimizer::Optimizer,
physical_optimizer::optimizer::{PhysicalOptimizer, PhysicalOptimizerRule},
};
use arrow_schema::Schema;
use datafusion_common::{
alias::AliasGenerator,
exec_err, not_impl_err, plan_datafusion_err, plan_err,
Expand Down Expand Up @@ -941,7 +942,11 @@ impl SessionContext {
) -> Result<DataFrame> {
// check schema uniqueness
let mut batches = batches.into_iter().peekable();
let schema: SchemaRef = batches.peek().unwrap().schema().clone();
let schema = if let Some(batch) = batches.peek() {
batch.schema().clone()
} else {
Arc::new(Schema::empty())
};
let provider =
MemTable::try_new(schema, batches.map(|batch| vec![batch]).collect())?;
Ok(DataFrame::new(
Expand Down
77 changes: 32 additions & 45 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1461,60 +1461,47 @@ async fn test_read_batches() -> Result<()> {
],
)
.unwrap(),
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![0, 1, 2])),
Arc::new(Float32Array::from(vec![4.44, 5.02, 6.03])),
],
)
.unwrap(),
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![0, 1, 3])),
Arc::new(Float32Array::from(vec![6.01, 2.02, 3.03])),
],
)
.unwrap(),
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![3, 1, 2])),
Arc::new(Float32Array::from(vec![1000.01, 2.02, 3.03])),
],
)
.unwrap(),
];
let df = ctx.read_batches(batches).unwrap();
df.clone().show().await.unwrap();
let result = df.collect().await?;
let expected = vec![
"+----+---------+",
"| id | number |",
"+----+---------+",
"| 1 | 1.12 |",
"| 2 | 3.4 |",
"| 3 | 2.33 |",
"| 4 | 9.1 |",
"| 5 | 6.66 |",
"| 3 | 1.11 |",
"| 4 | 2.22 |",
"| 5 | 3.33 |",
"| 0 | 4.44 |",
"| 1 | 5.02 |",
"| 2 | 6.03 |",
"| 0 | 6.01 |",
"| 1 | 2.02 |",
"| 3 | 3.03 |",
"| 3 | 1000.01 |",
"| 1 | 2.02 |",
"| 2 | 3.03 |",
"+----+---------+",
"+----+--------+",
"| id | number |",
"+----+--------+",
"| 1 | 1.12 |",
"| 2 | 3.4 |",
"| 3 | 2.33 |",
"| 4 | 9.1 |",
"| 5 | 6.66 |",
"| 3 | 1.11 |",
"| 4 | 2.22 |",
"| 5 | 3.33 |",
"+----+--------+",
];
assert_batches_sorted_eq!(expected, &result);
Ok(())
}
#[tokio::test]
async fn test_read_batches_empty() -> Result<()> {
let config = SessionConfig::new();
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionState::new_with_config_rt(config, runtime);
let ctx = SessionContext::new_with_state(state);

let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("number", DataType::Float32, false),
]));

let batches = vec![];
let df = ctx.read_batches(batches).unwrap();
df.clone().show().await.unwrap();
let result = df.collect().await?;
let expected = vec!["++", "++"];
assert_batches_sorted_eq!(expected, &result);
Ok(())
}

#[tokio::test]
async fn consecutive_projection_same_schema() -> Result<()> {
Expand Down

0 comments on commit 769dbfd

Please sign in to comment.