diff --git a/daft/__init__.py b/daft/__init__.py index add31fc9e2..be04c80f8f 100644 --- a/daft/__init__.py +++ b/daft/__init__.py @@ -72,7 +72,7 @@ def refresh_logger() -> None: from daft.dataframe import DataFrame from daft.logical.schema import Schema from daft.datatype import DataType, TimeUnit -from daft.expressions import Expression, col, lit, interval +from daft.expressions import Expression, col, lit, interval, coalesce from daft.io import ( DataCatalogTable, DataCatalogType, @@ -135,4 +135,5 @@ def refresh_logger() -> None: "sql", "sql_expr", "to_struct", + "coalesce", ] diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index a5e4ad0844..9fc430b663 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1176,6 +1176,7 @@ def minhash( seed: int = 1, hash_function: Literal["murmurhash3", "xxhash", "sha1"] = "murmurhash3", ) -> PyExpr: ... +def coalesce(exprs: list[PyExpr]) -> PyExpr: ... # ----- # SQL functions diff --git a/daft/expressions/__init__.py b/daft/expressions/__init__.py index 6e07ffe0f7..bc28ed1925 100644 --- a/daft/expressions/__init__.py +++ b/daft/expressions/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from .expressions import Expression, ExpressionsProjection, col, lit, interval +from .expressions import Expression, ExpressionsProjection, col, lit, interval, coalesce -__all__ = ["Expression", "ExpressionsProjection", "col", "lit", "interval"] +__all__ = ["Expression", "ExpressionsProjection", "col", "lit", "interval", "coalesce"] diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index baa698aaf4..45fd329250 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -184,6 +184,37 @@ def interval( return Expression._from_pyexpr(lit_value) +def coalesce(*args: Expression) -> Expression: + """Returns the first non-null value in a list of expressions. If all inputs are null, returns null. + + Example: + >>> import daft + >>> df = daft.from_pydict({"x": [1, None, 3], "y": [None, 2, None]}) + >>> df = df.with_column("first_valid", daft.coalesce(df["x"], df["y"])) + >>> df.show() + ╭───────┬───────┬─────────────╮ + │ x ┆ y ┆ first_valid │ + │ --- ┆ --- ┆ --- │ + │ Int64 ┆ Int64 ┆ Int64 │ + ╞═══════╪═══════╪═════════════╡ + │ 1 ┆ None ┆ 1 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ None ┆ 2 ┆ 2 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ 3 ┆ None ┆ 3 │ + ╰───────┴───────┴─────────────╯ + + (Showing first 3 of 3 rows) + + Args: + *args: Two or more expressions to coalesce + + Returns: + Expression: Expression containing first non-null value encountered when evaluating arguments in order + """ + return Expression._from_pyexpr(native.coalesce([arg._expr for arg in args])) + + class Expression: _expr: _PyExpr = None # type: ignore diff --git a/src/daft-core/src/array/from_iter.rs b/src/daft-core/src/array/from_iter.rs index 8e9f45fbd9..536cbe219f 100644 --- a/src/daft-core/src/array/from_iter.rs +++ b/src/daft-core/src/array/from_iter.rs @@ -89,7 +89,6 @@ impl BinaryArray { .unwrap() } } - impl FixedSizeBinaryArray { pub fn from_iter>( name: &str, diff --git a/src/daft-functions/src/coalesce.rs b/src/daft-functions/src/coalesce.rs new file mode 100644 index 0000000000..bc2376e980 --- /dev/null +++ b/src/daft-functions/src/coalesce.rs @@ -0,0 +1,297 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{BooleanArray, DaftLogical, Field, Schema}, + series::{IntoSeries, Series}, + utils::supertype::try_get_supertype, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Coalesce {} + +#[typetag::serde] +impl ScalarUDF for Coalesce { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "coalesce" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [] => Err(DaftError::SchemaMismatch( + "Expected at least 1 input args, got 0".to_string(), + )), + [input] => { + let input_field = input.to_field(schema)?; + Ok(input_field) + } + _ => { + let first_field = inputs[0].to_field(schema)?; + let mut output_dtype = first_field.dtype.clone(); + + for input in inputs { + let lhs = input.to_field(schema)?.dtype; + let rhs = &first_field.dtype; + output_dtype = try_get_supertype(&lhs, rhs)?; + + if try_get_supertype(&lhs, rhs).is_err() { + return Err(DaftError::SchemaMismatch(format!( + "All input fields must have the same data type. Got {lhs} and {rhs}" + ))); + } + } + Ok(Field::new(first_field.name, output_dtype)) + } + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs.len() { + 0 => Err(DaftError::ComputeError("No inputs provided".to_string())), + 1 => Ok(inputs[0].clone()), + _ => { + let name = inputs[0].name(); + let dtype = inputs[0].data_type(); + let len = inputs[0].len(); + // the first input is not null, so no work to do + if inputs[0].validity().is_none() { + return Ok(inputs[0].clone()); + } + + let mut current_value = Series::full_null(name, dtype, len); + let remainder = BooleanArray::from_values(name, vec![true; len].into_iter()); + let all_false = BooleanArray::from_values(name, vec![false; len].into_iter()); + let mut remainder = remainder.into_series(); + + for input in inputs { + let to_apply = remainder.and(&input.not_null()?)?; + current_value = input.if_else(¤t_value, &to_apply)?; + + remainder = remainder.and(&input.is_null()?)?; + + // exit early if all values are filled + if remainder.bool().unwrap() == &all_false { + break; + } + } + + Ok(current_value.rename(name)) + } + } + } +} + +#[must_use] +/// Coalesce returns the first non-null value in a list of expressions. +/// Returns the first non-null value from a sequence of expressions. +/// +/// # Arguments +/// * `inputs` - A vector of expressions to evaluate in order +pub fn coalesce(inputs: Vec) -> ExprRef { + ScalarFunction::new(Coalesce {}, inputs).into() +} + +#[cfg(test)] +mod tests { + use common_error::DaftError; + use daft_core::{ + prelude::{DataType, Field, FullNull, Int8Array, Schema, Utf8Array}, + series::{IntoSeries, Series}, + }; + use daft_dsl::{col, functions::ScalarUDF, lit, null_lit}; + + #[test] + fn test_coalesce_0() { + let s0 = Int8Array::from_iter( + Field::new("s0", DataType::Int8), + vec![None, None, Some(10), Some(11), None].into_iter(), + ) + .into_series(); + let s1 = Int8Array::from_iter( + Field::new("s1", DataType::Int8), + vec![None, Some(2), Some(3), None, None].into_iter(), + ) + .into_series(); + let s2 = Int8Array::from_iter( + Field::new("s2", DataType::Int8), + vec![None, Some(1), Some(4), Some(4), Some(10)].into_iter(), + ) + .into_series(); + + let coalesce = super::Coalesce {}; + let output = coalesce.evaluate(&[s0, s1, s2]).unwrap(); + let actual = output.i8().unwrap(); + let expected = Int8Array::from_iter( + Field::new("s0", DataType::Int8), + vec![None, Some(2), Some(10), Some(11), Some(10)].into_iter(), + ); + + assert_eq!(actual, &expected); + } + + #[test] + fn test_coalesce_1() { + let s0 = Int8Array::from_iter( + Field::new("s0", DataType::Int8), + vec![None, None, Some(10), Some(11), None].into_iter(), + ) + .into_series(); + + let s1 = Int8Array::from_iter( + Field::new("s1", DataType::Int8), + vec![None, Some(2), Some(3), None, None].into_iter(), + ) + .into_series(); + + let coalesce = super::Coalesce {}; + let output = coalesce.evaluate(&[s0, s1]).unwrap(); + let actual = output.i8().unwrap(); + let expected = Int8Array::from_iter( + Field::new("s0", DataType::Int8), + vec![None, Some(2), Some(10), Some(11), None].into_iter(), + ); + + assert_eq!(actual, &expected); + } + + #[test] + fn test_coalesce_no_args() { + let coalesce = super::Coalesce {}; + let output = coalesce.evaluate(&[]); + + assert!(output.is_err()); + } + + #[test] + fn test_coalesce_one_arg() { + let s0 = Int8Array::from_iter( + Field::new("s0", DataType::Int8), + vec![None, None, Some(10), Some(11), None].into_iter(), + ) + .into_series(); + + let coalesce = super::Coalesce {}; + let output = coalesce.evaluate(&[s0.clone()]).unwrap(); + // can't directly compare as null != null + let output = output.i8().unwrap(); + let s0 = s0.i8().unwrap(); + assert_eq!(output, s0); + } + + #[test] + fn test_coalesce_full_nulls() { + let s0 = Series::full_null("s0", &DataType::Utf8, 100); + let s1 = Series::full_null("s1", &DataType::Utf8, 100); + let s2 = Series::full_null("s2", &DataType::Utf8, 100); + + let coalesce = super::Coalesce {}; + let output = coalesce.evaluate(&[s0, s1, s2]).unwrap(); + let actual = output.utf8().unwrap(); + let expected = Utf8Array::full_null("s0", &DataType::Utf8, 100); + + assert_eq!(actual, &expected); + } + + #[test] + fn test_coalesce_with_mismatched_types() { + let s0 = Int8Array::from_iter( + Field::new("s0", DataType::Int8), + vec![None, None, Some(10), Some(11), None].into_iter(), + ) + .into_series(); + let s1 = Int8Array::from_iter( + Field::new("s1", DataType::Int8), + vec![None, Some(2), Some(3), None, None].into_iter(), + ) + .into_series(); + let s2 = Utf8Array::from_iter( + "s2", + vec![ + None, + Some("hello"), + Some("world"), + Some("hello"), + Some("world"), + ] + .into_iter(), + ) + .into_series(); + + let coalesce = super::Coalesce {}; + let output = coalesce.evaluate(&[s0, s1, s2]); + + let expected = Utf8Array::from_iter( + "s2", + vec![None, Some("2"), Some("10"), Some("11"), Some("world")].into_iter(), + ); + assert_eq!(output.unwrap().utf8().unwrap(), &expected); + } + + #[test] + fn test_to_field() { + let col_0 = null_lit().alias("s0"); + let fallback = lit(0); + + let schema = Schema::new(vec![ + Field::new("s0", DataType::Int32), + Field::new("s1", DataType::Int32), + ]) + .unwrap(); + let expected = Field::new("s0", DataType::Int32); + + let coalesce = super::Coalesce {}; + let output = coalesce.to_field(&[col_0, fallback], &schema).unwrap(); + assert_eq!(output, expected); + } + + #[test] + fn test_to_field_with_mismatched_types() { + let col_0 = col("s0"); + let col_1 = col("s1"); + let fallback = lit("not found"); + + let schema = Schema::new(vec![ + Field::new("s0", DataType::Int8), + Field::new("s1", DataType::Int8), + Field::new("s2", DataType::Utf8), + ]) + .unwrap(); + let expected = Field::new("s0", DataType::Utf8); + + let coalesce = super::Coalesce {}; + let output = coalesce + .to_field(&[col_0, col_1, fallback], &schema) + .unwrap(); + assert_eq!(output, expected); + } + + #[test] + fn test_to_field_with_incompatible_types() { + let col_0 = col("s0"); + let col_1 = col("s1"); + let col_2 = lit(1u32); + + let schema = Schema::new(vec![ + Field::new("s0", DataType::Date), + Field::new("s1", DataType::Boolean), + Field::new("s2", DataType::UInt32), + ]); + let expected = "could not determine supertype of Boolean and Date".to_string(); + let coalesce = super::Coalesce {}; + let DaftError::TypeError(e) = coalesce + .to_field(&[col_0, col_1, col_2], &schema.unwrap()) + .unwrap_err() + else { + panic!("Expected error") + }; + + assert_eq!(e, expected); + } +} diff --git a/src/daft-functions/src/lib.rs b/src/daft-functions/src/lib.rs index ba5ef18c97..20c17e358c 100644 --- a/src/daft-functions/src/lib.rs +++ b/src/daft-functions/src/lib.rs @@ -1,4 +1,5 @@ #![feature(async_closure)] +pub mod coalesce; pub mod count_matches; pub mod distance; pub mod float; diff --git a/src/daft-functions/src/python/coalesce.rs b/src/daft-functions/src/python/coalesce.rs new file mode 100644 index 0000000000..76e9f42ee9 --- /dev/null +++ b/src/daft-functions/src/python/coalesce.rs @@ -0,0 +1,7 @@ +use daft_dsl::python::PyExpr; +use pyo3::pyfunction; + +#[pyfunction] +pub fn coalesce(exprs: Vec) -> PyExpr { + crate::coalesce::coalesce(exprs.into_iter().map(|expr| expr.into()).collect()).into() +} diff --git a/src/daft-functions/src/python/mod.rs b/src/daft-functions/src/python/mod.rs index 7e067a2a65..14e3266373 100644 --- a/src/daft-functions/src/python/mod.rs +++ b/src/daft-functions/src/python/mod.rs @@ -12,6 +12,7 @@ macro_rules! simple_python_wrapper { }; } +mod coalesce; mod distance; mod float; mod image; @@ -35,6 +36,7 @@ pub fn register(parent: &Bound) -> PyResult<()> { }; } + add!(coalesce::coalesce); add!(distance::cosine_distance); add!(float::is_inf); diff --git a/src/daft-sql/src/functions.rs b/src/daft-sql/src/functions.rs index 19172a2e1e..00bddf92a3 100644 --- a/src/daft-sql/src/functions.rs +++ b/src/daft-sql/src/functions.rs @@ -10,9 +10,10 @@ use sqlparser::ast::{ use crate::{ error::{PlannerError, SQLPlannerResult}, modules::{ - hashing, SQLModule, SQLModuleAggs, SQLModuleConfig, SQLModuleFloat, SQLModuleImage, - SQLModuleJson, SQLModuleList, SQLModuleMap, SQLModuleNumeric, SQLModulePartitioning, - SQLModulePython, SQLModuleSketch, SQLModuleStructs, SQLModuleTemporal, SQLModuleUtf8, + coalesce::SQLCoalesce, hashing, SQLModule, SQLModuleAggs, SQLModuleConfig, SQLModuleFloat, + SQLModuleImage, SQLModuleJson, SQLModuleList, SQLModuleMap, SQLModuleNumeric, + SQLModulePartitioning, SQLModulePython, SQLModuleSketch, SQLModuleStructs, + SQLModuleTemporal, SQLModuleUtf8, }, planner::SQLPlanner, unsupported_sql_err, @@ -36,6 +37,7 @@ pub(crate) static SQL_FUNCTIONS: Lazy = Lazy::new(|| { functions.register::(); functions.register::(); functions.register::(); + functions.add_fn("coalesce", SQLCoalesce {}); functions }); diff --git a/src/daft-sql/src/modules/coalesce.rs b/src/daft-sql/src/modules/coalesce.rs new file mode 100644 index 0000000000..724355d5ba --- /dev/null +++ b/src/daft-sql/src/modules/coalesce.rs @@ -0,0 +1,26 @@ +use crate::functions::SQLFunction; + +pub struct SQLCoalesce {} + +impl SQLFunction for SQLCoalesce { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + let args = inputs + .iter() + .map(|arg| planner.plan_function_arg(arg)) + .collect::>>()?; + + Ok(daft_functions::coalesce::coalesce(args)) + } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::DOCSTRING.to_string() + } +} + +mod static_docs { + pub(super) const DOCSTRING: &str = "Coalesce the first non-null value from a list of inputs."; +} diff --git a/src/daft-sql/src/modules/mod.rs b/src/daft-sql/src/modules/mod.rs index ded8007e2d..30195dc52f 100644 --- a/src/daft-sql/src/modules/mod.rs +++ b/src/daft-sql/src/modules/mod.rs @@ -1,6 +1,7 @@ use crate::functions::SQLFunctions; pub mod aggs; +pub mod coalesce; pub mod config; pub mod float; pub mod hashing; diff --git a/tests/sql/test_exprs.py b/tests/sql/test_exprs.py index 438d96b837..56eb6d6994 100644 --- a/tests/sql/test_exprs.py +++ b/tests/sql/test_exprs.py @@ -224,3 +224,27 @@ def test_interval_comparison(date_values, ts_values, expected_intervals): ) assert expected_df == actual_sql == expected_intervals + + +def test_coalesce(): + df = daft.from_pydict( + { + "a": [None, None, 3, None], + "b": [None, 2, 4, None], + "c": [None, None, 5, 6], + } + ) + + expected = df.select(daft.coalesce(col("a"), col("b"), col("c")).alias("result")).to_pydict() + + catalog = SQLCatalog({"df": df}) + actual = daft.sql( + """ + SELECT + COALESCE(a, b, c) as result + FROM df + """, + catalog=catalog, + ).to_pydict() + + assert actual == expected