diff --git a/core/Cargo.lock b/core/Cargo.lock index e209e4a8d..3fb7b5f62 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -825,7 +825,7 @@ dependencies = [ [[package]] name = "datafusion" version = "36.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=111a940#111a940b297aa83839e4e2273f0e1a38e108b370" +source = "git+https://github.com/viirya/arrow-datafusion.git?rev=57b3be4#57b3be4297a47aa45094c16e37ddf0141d723bf0" dependencies = [ "ahash", "arrow", @@ -867,7 +867,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "36.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=111a940#111a940b297aa83839e4e2273f0e1a38e108b370" +source = "git+https://github.com/viirya/arrow-datafusion.git?rev=57b3be4#57b3be4297a47aa45094c16e37ddf0141d723bf0" dependencies = [ "ahash", "arrow", @@ -886,7 +886,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" version = "36.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=111a940#111a940b297aa83839e4e2273f0e1a38e108b370" +source = "git+https://github.com/viirya/arrow-datafusion.git?rev=57b3be4#57b3be4297a47aa45094c16e37ddf0141d723bf0" dependencies = [ "tokio", ] @@ -894,7 +894,7 @@ dependencies = [ [[package]] name = "datafusion-execution" version = "36.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=111a940#111a940b297aa83839e4e2273f0e1a38e108b370" +source = "git+https://github.com/viirya/arrow-datafusion.git?rev=57b3be4#57b3be4297a47aa45094c16e37ddf0141d723bf0" dependencies = [ "arrow", "chrono", @@ -914,7 +914,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "36.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=111a940#111a940b297aa83839e4e2273f0e1a38e108b370" +source = "git+https://github.com/viirya/arrow-datafusion.git?rev=57b3be4#57b3be4297a47aa45094c16e37ddf0141d723bf0" dependencies = [ "ahash", "arrow", @@ -930,7 +930,7 @@ dependencies = [ [[package]] name = "datafusion-functions" version = "36.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=111a940#111a940b297aa83839e4e2273f0e1a38e108b370" +source = "git+https://github.com/viirya/arrow-datafusion.git?rev=57b3be4#57b3be4297a47aa45094c16e37ddf0141d723bf0" dependencies = [ "arrow", "base64", @@ -954,7 +954,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "36.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=111a940#111a940b297aa83839e4e2273f0e1a38e108b370" +source = "git+https://github.com/viirya/arrow-datafusion.git?rev=57b3be4#57b3be4297a47aa45094c16e37ddf0141d723bf0" dependencies = [ "arrow", "async-trait", @@ -971,7 +971,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "36.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=111a940#111a940b297aa83839e4e2273f0e1a38e108b370" +source = "git+https://github.com/viirya/arrow-datafusion.git?rev=57b3be4#57b3be4297a47aa45094c16e37ddf0141d723bf0" dependencies = [ "ahash", "arrow", @@ -1005,7 +1005,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" version = "36.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=111a940#111a940b297aa83839e4e2273f0e1a38e108b370" +source = "git+https://github.com/viirya/arrow-datafusion.git?rev=57b3be4#57b3be4297a47aa45094c16e37ddf0141d723bf0" dependencies = [ "ahash", "arrow", @@ -1035,7 +1035,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "36.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=111a940#111a940b297aa83839e4e2273f0e1a38e108b370" +source = "git+https://github.com/viirya/arrow-datafusion.git?rev=57b3be4#57b3be4297a47aa45094c16e37ddf0141d723bf0" dependencies = [ "arrow", "arrow-array", diff --git a/core/Cargo.toml b/core/Cargo.toml index 880d18d19..5d1604952 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -66,10 +66,10 @@ itertools = "0.11.0" chrono = { version = "0.4", default-features = false, features = ["clock"] } chrono-tz = { version = "0.8" } paste = "1.0.14" -datafusion-common = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "111a940" } -datafusion = { default-features = false, git = "https://github.com/viirya/arrow-datafusion.git", rev = "111a940", features = ["unicode_expressions"] } -datafusion-functions = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "111a940" } -datafusion-physical-expr = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "111a940", default-features = false, features = ["unicode_expressions"] } +datafusion-common = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4" } +datafusion = { default-features = false, git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", features = ["unicode_expressions"] } +datafusion-functions = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4" } +datafusion-physical-expr = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "57b3be4", default-features = false, features = ["unicode_expressions"] } unicode-segmentation = "^1.10.1" once_cell = "1.18.0" regex = "1.9.6" diff --git a/core/src/execution/datafusion/operators/expand.rs b/core/src/execution/datafusion/operators/expand.rs index 5cf444b3b..ca3fdc1aa 100644 --- a/core/src/execution/datafusion/operators/expand.rs +++ b/core/src/execution/datafusion/operators/expand.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow_array::RecordBatch; +use arrow_array::{RecordBatch, RecordBatchOptions}; use arrow_schema::SchemaRef; use datafusion::{ execution::TaskContext, @@ -169,7 +169,9 @@ impl ExpandStream { Ok::<(), DataFusionError>(()) })?; - RecordBatch::try_new(self.schema.clone(), columns).map_err(|e| e.into()) + let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + RecordBatch::try_new_with_options(self.schema.clone(), columns, &options) + .map_err(|e| e.into()) } } diff --git a/core/src/execution/operators/copy.rs b/core/src/execution/operators/copy.rs index 292271f9e..96c244935 100644 --- a/core/src/execution/operators/copy.rs +++ b/core/src/execution/operators/copy.rs @@ -24,7 +24,7 @@ use std::{ use futures::{Stream, StreamExt}; -use arrow_array::{ArrayRef, RecordBatch}; +use arrow_array::{ArrayRef, RecordBatch, RecordBatchOptions}; use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion::{execution::TaskContext, physical_expr::*, physical_plan::*}; @@ -149,7 +149,10 @@ impl CopyStream { .iter() .map(|v| copy_or_cast_array(v)) .collect::, _>>()?; - RecordBatch::try_new(self.schema.clone(), vectors).map_err(|e| arrow_datafusion_err!(e)) + + let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + RecordBatch::try_new_with_options(self.schema.clone(), vectors, &options) + .map_err(|e| arrow_datafusion_err!(e)) } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 230ac36b0..b5ed5f457 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -40,6 +40,27 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ + test("lead/lag should return the default value if the offset row does not exist") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + checkSparkAnswer(sql(""" + |SELECT + | lag(123, 100, 321) OVER (ORDER BY id) as lag, + | lead(123, 100, 321) OVER (ORDER BY id) as lead + |FROM (SELECT 1 as id) tmp + """.stripMargin)) + + checkSparkAnswer(sql(""" + |SELECT + | lag(123, 100, a) OVER (ORDER BY id) as lag, + | lead(123, 100, a) OVER (ORDER BY id) as lead + |FROM (SELECT 1 as id, 2 as a) tmp + """.stripMargin)) + } + } + test("multiple column distinct count") { withSQLConf( CometConf.COMET_ENABLED.key -> "true",