From d00c00a340a830bcc26403b494e92b0def127665 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 9 Jul 2024 07:02:32 -0400 Subject: [PATCH 01/55] Expose missing functions to python --- src/common.rs | 1 + src/common/data_type.rs | 2 +- src/dataframe.rs | 12 ++++++++++++ src/expr.rs | 1 + src/functions.rs | 25 ++++++++++++++++++++++--- src/lib.rs | 2 ++ 6 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/common.rs b/src/common.rs index 094e70c0..453bf67a 100644 --- a/src/common.rs +++ b/src/common.rs @@ -27,6 +27,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/common/data_type.rs b/src/common/data_type.rs index 313318fc..3299a46f 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -764,7 +764,7 @@ pub enum SqlType { #[allow(non_camel_case_types)] #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(name = "PythonType", module = "datafusion.common")] +#[pyclass(name = "NullTreatment", module = "datafusion.common")] pub enum NullTreatment { IGNORE_NULLS, RESPECT_NULLS, diff --git a/src/dataframe.rs b/src/dataframe.rs index 9e36be2c..53e11234 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -320,6 +320,18 @@ impl PyDataFrame { Ok(Self::new(df)) } + #[pyo3(signature = (columns, preserve_nulls=true))] + fn unnest_columns(&self, columns: Vec, preserve_nulls: bool) -> PyResult { + let unnest_options = UnnestOptions { preserve_nulls }; + let cols = columns.iter().map(|s| s.as_ref()).collect::>(); + let df = self + .df + .as_ref() + .clone() + .unnest_columns_with_options(&cols, unnest_options)?; + Ok(Self::new(df)) + } + /// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema fn intersect(&self, py_df: PyDataFrame) -> PyResult { let new_df = self diff --git a/src/expr.rs b/src/expr.rs index dc1de669..aab0daa6 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -583,6 +583,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/functions.rs b/src/functions.rs index b39d98b3..42d1d058 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -232,6 +232,12 @@ fn concat_ws(sep: String, args: Vec) -> PyResult { Ok(functions::string::expr_fn::concat_ws(lit(sep), args).into()) } +#[pyfunction] +#[pyo3(signature = (values, regex, flags = None))] +fn regexp_like(values: PyExpr, regex: PyExpr, flags: Option) -> PyResult { + Ok(functions::expr_fn::regexp_like(values.expr, regex.expr, flags.map(|x| x.expr)).into()) +} + #[pyfunction] #[pyo3(signature = (values, regex, flags = None))] fn regexp_match(values: PyExpr, regex: PyExpr, flags: Option) -> PyResult { @@ -256,12 +262,12 @@ fn regexp_replace( } /// Creates a new Sort Expr #[pyfunction] -fn order_by(expr: PyExpr, asc: Option, nulls_first: Option) -> PyResult { +fn order_by(expr: PyExpr, asc: bool, nulls_first: bool) -> PyResult { Ok(PyExpr { expr: datafusion_expr::Expr::Sort(Sort { expr: Box::new(expr.expr), - asc: asc.unwrap_or(true), - nulls_first: nulls_first.unwrap_or(true), + asc, + nulls_first, }), }) } @@ -488,6 +494,7 @@ expr_fn!(chr, arg, "Returns the character with the given code."); expr_fn_vec!(coalesce); expr_fn!(cos, num); expr_fn!(cosh, num); +expr_fn!(cot, num); expr_fn!(degrees, num); expr_fn!(decode, input encoding); expr_fn!(encode, input encoding); @@ -499,6 +506,7 @@ expr_fn!(gcd, x y); expr_fn!(initcap, string, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters."); expr_fn!(isnan, num); expr_fn!(iszero, num); +expr_fn!(levenshtein, string1 string2); expr_fn!(lcm, x y); expr_fn!(left, string n, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters."); expr_fn!(ln, num); @@ -555,7 +563,9 @@ expr_fn!(sqrt, num); expr_fn!(starts_with, string prefix, "Returns true if string starts with prefix."); expr_fn!(strpos, string substring, "Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)"); expr_fn!(substr, string position); +expr_fn!(substr_index, string delimiter count); expr_fn!(substring, string position length); +expr_fn!(find_in_set, string string_list); expr_fn!(tan, num); expr_fn!(tanh, num); expr_fn!( @@ -568,6 +578,7 @@ expr_fn_vec!(to_timestamp); expr_fn_vec!(to_timestamp_millis); expr_fn_vec!(to_timestamp_micros); expr_fn_vec!(to_timestamp_seconds); +expr_fn_vec!(to_unixtime); expr_fn!(current_date); expr_fn!(current_time); expr_fn!(date_part, part date); @@ -575,6 +586,7 @@ expr_fn!(datepart, date_part, part date); expr_fn!(date_trunc, part date); expr_fn!(datetrunc, date_trunc, part date); expr_fn!(date_bin, stride source origin); +expr_fn!(make_date, year month day); expr_fn!(translate, string from to, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted."); expr_fn_vec!(trim, "Removes the longest string containing only characters in characters (a space by default) from the start, end, or both ends (BOTH is the default) of string."); @@ -712,6 +724,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(corr))?; m.add_wrapped(wrap_pyfunction!(cos))?; m.add_wrapped(wrap_pyfunction!(cosh))?; + m.add_wrapped(wrap_pyfunction!(cot))?; m.add_wrapped(wrap_pyfunction!(count))?; m.add_wrapped(wrap_pyfunction!(count_star))?; m.add_wrapped(wrap_pyfunction!(covar))?; @@ -725,6 +738,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(date_part))?; m.add_wrapped(wrap_pyfunction!(datetrunc))?; m.add_wrapped(wrap_pyfunction!(date_trunc))?; + m.add_wrapped(wrap_pyfunction!(make_date))?; m.add_wrapped(wrap_pyfunction!(digest))?; m.add_wrapped(wrap_pyfunction!(ends_with))?; m.add_wrapped(wrap_pyfunction!(exp))?; @@ -737,6 +751,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(initcap))?; m.add_wrapped(wrap_pyfunction!(isnan))?; m.add_wrapped(wrap_pyfunction!(iszero))?; + m.add_wrapped(wrap_pyfunction!(levenshtein))?; m.add_wrapped(wrap_pyfunction!(lcm))?; m.add_wrapped(wrap_pyfunction!(left))?; m.add_wrapped(wrap_pyfunction!(length))?; @@ -764,6 +779,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(pow))?; m.add_wrapped(wrap_pyfunction!(radians))?; m.add_wrapped(wrap_pyfunction!(random))?; + m.add_wrapped(wrap_pyfunction!(regexp_like))?; m.add_wrapped(wrap_pyfunction!(regexp_match))?; m.add_wrapped(wrap_pyfunction!(regexp_replace))?; m.add_wrapped(wrap_pyfunction!(repeat))?; @@ -789,7 +805,9 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(strpos))?; m.add_wrapped(wrap_pyfunction!(r#struct))?; // Use raw identifier since struct is a keyword m.add_wrapped(wrap_pyfunction!(substr))?; + m.add_wrapped(wrap_pyfunction!(substr_index))?; m.add_wrapped(wrap_pyfunction!(substring))?; + m.add_wrapped(wrap_pyfunction!(find_in_set))?; m.add_wrapped(wrap_pyfunction!(sum))?; m.add_wrapped(wrap_pyfunction!(tan))?; m.add_wrapped(wrap_pyfunction!(tanh))?; @@ -798,6 +816,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(to_timestamp_millis))?; m.add_wrapped(wrap_pyfunction!(to_timestamp_micros))?; m.add_wrapped(wrap_pyfunction!(to_timestamp_seconds))?; + m.add_wrapped(wrap_pyfunction!(to_unixtime))?; m.add_wrapped(wrap_pyfunction!(translate))?; m.add_wrapped(wrap_pyfunction!(trim))?; m.add_wrapped(wrap_pyfunction!(trunc))?; diff --git a/src/lib.rs b/src/lib.rs index 71c27e1a..357eaacd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,6 +92,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; // Register `common` as a submodule. Matching `datafusion-common` https://docs.rs/datafusion-common/latest/datafusion_common/ let common = PyModule::new_bound(py, "common")?; From 27e4f30e94c0f129c976ab54e788801ad4420055 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 9 Jul 2024 07:10:20 -0400 Subject: [PATCH 02/55] Initial commit for creating wrapper classes and functions for all user facing python features --- python/datafusion/__init__.py | 106 +- python/datafusion/catalog.py | 42 + python/datafusion/context.py | 1167 ++++++++++++++ python/datafusion/dataframe.py | 561 +++++++ python/datafusion/expr.py | 253 ++- python/datafusion/functions.py | 1728 ++++++++++++++++++++- python/datafusion/record_batch.py | 35 + python/datafusion/substrait.py | 153 +- python/datafusion/tests/conftest.py | 3 +- python/datafusion/tests/test_dataframe.py | 17 +- python/datafusion/tests/test_functions.py | 12 +- python/datafusion/tests/test_imports.py | 15 +- python/datafusion/tests/test_sql.py | 3 +- python/datafusion/tests/test_substrait.py | 10 +- python/datafusion/udf.py | 41 + 15 files changed, 4068 insertions(+), 78 deletions(-) create mode 100644 python/datafusion/catalog.py create mode 100644 python/datafusion/context.py create mode 100644 python/datafusion/dataframe.py create mode 100644 python/datafusion/record_batch.py create mode 100644 python/datafusion/udf.py diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 846b1a45..8c3e9de7 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -25,64 +25,67 @@ import pyarrow as pa -from ._internal import ( - AggregateUDF, - Config, - DataFrame, +from .context import ( SessionContext, SessionConfig, RuntimeConfig, - ScalarUDF, SQLOptions, ) +# The following imports are okay to remain as opaque to the user. +from ._internal import Config + +from .udf import ScalarUDF, AggregateUDF + from .common import ( DFSchema, ) +from .dataframe import DataFrame + from .expr import ( - Alias, - Analyze, + # Alias, + # Analyze, Expr, - Filter, - Limit, - Like, - ILike, - Projection, - SimilarTo, - ScalarVariable, - Sort, - TableScan, - Not, - IsNotNull, - IsTrue, - IsFalse, - IsUnknown, - IsNotTrue, - IsNotFalse, - IsNotUnknown, - Negative, - InList, - Exists, - Subquery, - InSubquery, - ScalarSubquery, - GroupingSet, - Placeholder, - Case, - Cast, - TryCast, - Between, - Explain, - CreateMemoryTable, - SubqueryAlias, - Extension, - CreateView, - Distinct, - DropTable, - Repartition, - Partitioning, - Window, + # Filter, + # Limit, + # Like, + # ILike, + # Projection, + # SimilarTo, + # ScalarVariable, + # Sort, + # TableScan, + # Not, + # IsNotNull, + # IsTrue, + # IsFalse, + # IsUnknown, + # IsNotTrue, + # IsNotFalse, + # IsNotUnknown, + # Negative, + # InList, + # Exists, + # Subquery, + # InSubquery, + # ScalarSubquery, + # GroupingSet, + # Placeholder, + # Case, + # Cast, + # TryCast, + # Between, + # Explain, + # CreateMemoryTable, + # SubqueryAlias, + # Extension, + # CreateView, + # Distinct, + # DropTable, + # Repartition, + # Partitioning, + # Window, WindowFrame, ) @@ -96,7 +99,6 @@ "SQLOptions", "RuntimeConfig", "Expr", - "AggregateUDF", "ScalarUDF", "Window", "WindowFrame", @@ -175,8 +177,6 @@ def column(value): def literal(value): - if not isinstance(value, pa.Scalar): - value = pa.scalar(value) return Expr.literal(value) @@ -200,7 +200,7 @@ def udf(func, input_types, return_type, volatility, name=None): ) -def udaf(accum, input_type, return_type, state_type, volatility, name=None): +def udaf(accum, input_types, return_type, state_type, volatility, name=None): """ Create a new User Defined Aggregate Function """ @@ -208,12 +208,12 @@ def udaf(accum, input_type, return_type, state_type, volatility, name=None): raise TypeError("`accum` must implement the abstract base class Accumulator") if name is None: name = accum.__qualname__.lower() - if isinstance(input_type, pa.lib.DataType): - input_type = [input_type] + if isinstance(input_types, pa.lib.DataType): + input_types = [input_types] return AggregateUDF( name=name, accumulator=accum, - input_type=input_type, + input_types=input_types, return_type=return_type, state_type=state_type, volatility=volatility, diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py new file mode 100644 index 00000000..0764e63f --- /dev/null +++ b/python/datafusion/catalog.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import datafusion._internal as df_internal + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import pyarrow + + +class Catalog: + def __init__(self, catalog: df_internal.Catalog) -> None: + self.catalog = catalog + + def names(self) -> list[str]: + return self.catalog.names() + + def database(self, name: str = "public") -> Database: + return Database(self.catalog.database(name)) + + +class Database: + def __init__(self, db: df_internal.Database) -> None: + self.db = db + + def names(self) -> set[str]: + return self.db.names() + + def table(self, name: str) -> Table: + return Table(self.db.table(name)) + + +class Table: + def __init__(self, table: df_internal.Table) -> None: + self.table = table + + def schema(self) -> pyarrow.Schema: + return self.table.schema() + + @property + def kind(self) -> str: + return self.table.kind() diff --git a/python/datafusion/context.py b/python/datafusion/context.py new file mode 100644 index 00000000..f34bbf03 --- /dev/null +++ b/python/datafusion/context.py @@ -0,0 +1,1167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from ._internal import SessionConfig as SessionConfigInternal +from ._internal import RuntimeConfig as RuntimeConfigInternal +from ._internal import SQLOptions as SQLOptionsInternal +from ._internal import SessionContext as SessionContextInternal +from ._internal import LogicalPlan, ExecutionPlan # TODO MAKE THIS A DEFINED CLASS + +from datafusion._internal import AggregateUDF +from datafusion.catalog import Catalog, Table +from datafusion.dataframe import DataFrame +from datafusion.expr import Expr +from datafusion.record_batch import RecordBatchStream +from datafusion.udf import ScalarUDF + +from typing import Any, TYPE_CHECKING +from typing_extensions import deprecated + +if TYPE_CHECKING: + import pyarrow + import pandas + import polars + + +class SessionConfig: + def __init__(self, config_options: dict[str, str] = {}) -> None: + """Create a new `SessionConfig` with the given configuration options. + + Parameters + ---------- + config_options : dict[str, str] + Configuration options. + """ + self.config_internal = SessionConfigInternal(config_options) + + def with_create_default_catalog_and_schema( + self, enabled: bool = True + ) -> SessionConfig: + """Control whether the default catalog and schema will be automatically created. + + Parameters + ---------- + enabled : bool + Whether the default catalog and schema will be automatically created. + + Returns + ------- + SessionConfig + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = ( + self.config_internal.with_create_default_catalog_and_schema(enabled) + ) + return self + + def with_default_catalog_and_schema( + self, catalog: str, schema: str + ) -> SessionConfig: + """Select a name for the default catalog and shcema. + + Parameters + ---------- + catalog : str + Catalog name. + schema : str + Schema name. + + Returns + ------- + SessionConfig + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_default_catalog_and_schema( + catalog, schema + ) + return self + + def with_information_schema(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the inclusion of `information_schema` virtual tables. + + Parameters + ---------- + enabled : bool + Whether to include `information_schema` virtual tables. + + Returns + ------- + SessionConfig + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_information_schema(enabled) + return self + + def with_batch_size(self, batch_size: int) -> SessionConfig: + """Customize batch size. + + Parameters + ---------- + batch_size : int + Batch size. + + Returns + ------- + SessionConfig + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_batch_size(batch_size) + return self + + def with_target_partitions(self, target_partitions: int) -> SessionConfig: + """Customize the number of target partitions for query execution. + + Increasing partitions can increase concurrency. + + Parameters + ---------- + target_partitions : int + Number of target partitions. + + Returns + ------- + SessionConfig + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_target_partitions( + target_partitions + ) + return self + + def with_repartition_aggregations(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the use of repartitioning for aggregations. + + Enabling this improves parallelism. + + Parameters + ---------- + enabled : bool + Whether to use repartitioning for aggregations. + + Returns + ------- + SessionConfig + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_repartition_aggregations( + enabled + ) + return self + + def with_repartition_joins(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the use of repartitioning for joins to improve parallelism. + + Parameters + ---------- + enabled : bool + Whether to use repartitioning for joins. + + Returns + ------- + SessionConfig + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_repartition_joins(enabled) + return self + + def with_repartition_windows(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the use of repartitioning for window functions to improve parallelism. + + Parameters + ---------- + enabled : bool + Whether to use repartitioning for window functions. + + Returns + ------- + SessionConfig + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_repartition_windows(enabled) + return self + + def with_repartition_sorts(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the use of repartitioning for window functions to improve parallelism. + + Parameters + ---------- + enabled : bool + Whether to use repartitioning for window functions. + + Returns + ------- + SessionConfig + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_repartition_sorts(enabled) + return self + + def with_repartition_file_scans(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the use of repartitioning for file scans. + + Parameters + ---------- + enabled : bool + Whether to use repartitioning for file scans. + + Returns + ------- + SessionConfig + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_repartition_file_scans(enabled) + return self + + def with_repartition_file_min_size(self, size: int) -> SessionConfig: + """Set minimum file range size for repartitioning scans. + + Parameters + ---------- + size : int + Minimum file range size. + + Returns + ------- + SessionConfig + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_repartition_file_min_size(size) + return self + + def with_parquet_pruning(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the use of pruning predicate for parquet readers to skip row groups. + + Parameters + ---------- + enabled : bool + Whether to use pruning predicate for parquet readers. + + Returns + ------- + SessionConfig + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_parquet_pruning(enabled) + return self + + def set(self, key: str, value: str) -> SessionConfig: + """Set a configuration option. + + Parameters + ---------- + key : str + Option key. + value : str + Option value. + + Returns + ------- + SessionConfig + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.set(key, value) + return self + + +class RuntimeConfig: + def __init__(self) -> None: + """Create a new `RuntimeConfig` with default values.""" + self.config_internal = RuntimeConfigInternal() + + def with_disk_manager_disabled(self) -> RuntimeConfig: + """Disable the disk manager, attempts to create temporary files will error. + + Returns + ------- + RuntimeConfig + A new `RuntimeConfig` object with the updated setting. + + Examples + -------- + >>> config = RuntimeConfig().with_disk_manager_disabled() + """ + self.config_internal = self.config_internal.with_disk_manager_disabled() + return self + + def with_disk_manager_os(self) -> RuntimeConfig: + """Use the operating system's temporary directory for disk manager. + + Returns + ------- + RuntimeConfig + A new `RuntimeConfig` object with the updated setting. + + Examples + -------- + >>> config = RuntimeConfig().with_disk_manager_os() + """ + self.config_internal = self.config_internal.with_disk_manager_os() + return self + + def with_disk_manager_specified(self, paths: list[str]) -> RuntimeConfig: + """Use the specified paths for the disk manager's temporary files. + + Parameters + ---------- + paths : list[str] + Paths to use for the disk manager's temporary files. + + Returns + ------- + RuntimeConfig + A new `RuntimeConfig` object with the updated setting. + + Examples + -------- + >>> config = RuntimeConfig().with_disk_manager_specified(["/tmp"]) + """ + self.config_internal = self.config_internal.with_disk_manager_specified(paths) + return self + + def with_unbounded_memory_pool(self) -> RuntimeConfig: + """Use an unbounded memory pool. + + Returns + ------- + RuntimeConfig + A new `RuntimeConfig` object with the updated setting. + + Examples + -------- + >>> config = RuntimeConfig().with_unbounded_memory_pool() + """ + self.config_internal = self.config_internal.with_unbounded_memory_pool() + return self + + def with_fair_spill_pool(self, size: int) -> RuntimeConfig: + """Use a fair spill pool with the specified size. + + This pool works best when you know beforehand the query has multiple spillable + operators that will likely all need to spill. Sometimes it will cause spills + even when there was sufficient memory (reserved for other operators) to avoid + doing so. + + ```text + ┌───────────────────────z──────────────────────z───────────────┐ + │ z z │ + │ z z │ + │ Spillable z Unspillable z Free │ + │ Memory z Memory z Memory │ + │ z z │ + │ z z │ + └───────────────────────z──────────────────────z───────────────┘ + ``` + + Parameters + ---------- + size : int + Size of the memory pool in bytes. + + Returns + ------- + RuntimeConfig + A new `RuntimeConfig` object with the updated setting. + + Examples + -------- + ```python + >>> config = RuntimeConfig().with_fair_spill_pool(1024) + ``` + """ + self.config_internal = self.config_internal.with_fair_spill_pool(size) + return self + + def with_greedy_memory_pool(self, size: int) -> RuntimeConfig: + """Use a greedy memory pool with the specified size. + + This pool works well for queries that do not need to spill or have a single + spillable operator. See `RuntimeConfig.with_fair_spill_pool` if there are + multiple spillable operators that all will spill. + + Parameters + ---------- + size : int + Size of the memory pool in bytes. + + Returns + ------- + RuntimeConfig + A new `RuntimeConfig` object with the updated setting. + + Examples + -------- + >>> config = RuntimeConfig().with_greedy_memory_pool(1024) + """ + self.config_internal = self.config_internal.with_greedy_memory_pool(size) + return self + + def with_temp_file_path(self, path: str) -> RuntimeConfig: + """Use the specified path to create any needed temporary files. + + Parameters + ---------- + path : str + Path to use for temporary files. + + Returns + ------- + RuntimeConfig + A new `RuntimeConfig` object with the updated setting. + + Examples + -------- + >>> config = RuntimeConfig().with_temp_file_path("/tmp") + """ + self.config_internal = self.config_internal.with_temp_file_path(path) + return self + + +class SQLOptions: + def __init__(self) -> None: + """Create a new `SQLOptions` with default values. + + The default values are: + - DDL commands are allowed + - DML commands are allowed + - Statements are allowed + """ + self.options_internal = SQLOptionsInternal() + + def with_allow_ddl(self, allow: bool = True) -> SQLOptions: + """Should DDL (Data Definition Language) commands be run? + + Examples of DDL commands include `CREATE TABLE` and `DROP TABLE`. + + Parameters + ---------- + allow : bool + Allow DDL commands to be run. + + Returns + ------- + SQLOptions + A new `SQLOptions` object with the updated setting. + + + Examples + -------- + >>> options = SQLOptions().with_allow_ddl(True) + """ + self.options_internal = self.options_internal.with_allow_ddl(allow) + return self + + def with_allow_dml(self, allow: bool = True) -> SQLOptions: + """Should DML (Data Manipulation Language) commands be run? + + Examples of DML commands include `INSERT INTO` and `DELETE`. + + Parameters + ---------- + allow : bool + Allow DML commands to be run. + + Returns + ------- + SQLOptions + A new `SQLOptions` object with the updated setting. + + + Examples + -------- + >>> options = SQLOptions().with_allow_dml(True) + """ + self.options_internal = self.options_internal.with_allow_dml(allow) + return self + + def with_allow_statements(self, allow: bool = True) -> SQLOptions: + """Should statements such as `SET VARIABLE` and `BEGIN TRANSACTION` be run? + + Parameters + ---------- + allow : bool + Allow statements to be run. + + Returns + ------- + SQLOptions + A new `SQLOptions` object with the updated setting. + + Examples + -------- + >>> options = SQLOptions().with_allow_statements(True) + """ + self.options_internal = self.options_internal.with_allow_statements(allow) + return self + + +class SessionContext: + def __init__( + self, config: SessionConfig | None = None, runtime: RuntimeConfig | None = None + ) -> None: + """Main interface for executing queries with DataFusion. + + Maintains the state of the connection between a user and an instance + of the connection between a user and an instance of the DataFusion + engine. + + Parameters + ---------- + config : SessionConfig | None + Session configuration options. + runtime : RuntimeConfig | None + Runtime configuration options. + + Examples + -------- + The following example demostrates how to use the context to execute + a query against a CSV data source using the `DataFrame` API: + + ```python + from datafusion import SessionContext + + ctx = SessionContext() + df = ctx.read_csv("data.csv") + ``` + """ + config = config.config_internal if config is not None else None + runtime = runtime.config_internal if config is not None else None + + self.ctx = SessionContextInternal(config, runtime) + + def register_object_store(self, schema: str, store: Any, host: str | None) -> None: + self.ctx.register_object_store(schema, store, host) + + def register_listing_table( + self, + name: str, + path: str, + table_partition_cols: list[tuple[str, str]] = [], + file_extension: str = ".parquet", + schema: pyarrow.Schema | None = None, + file_sort_order: list[list[Expr]] | None = None, + ) -> None: + if file_sort_order is not None: + file_sort_order = [[x.expr for x in xs] for xs in file_sort_order] + self.ctx.register_listing_table( + name, path, table_partition_cols, file_extension, schema, file_sort_order + ) + + def sql(self, query: str) -> DataFrame: + """Create a `DataFrame` from SQL query text. + + Note: This API implements DDL statements such as `CREATE TABLE` and + `CREATE VIEW` and DML statements such as `INSERT INTO` with in-memory + default implementation. See `SessionContext.sql_with_options`. + + Parameters + ---------- + query : str + SQL query text. + + Returns + ------- + DataFrame + DataFrame representation of the SQL query. + """ + return DataFrame(self.ctx.sql(query)) + + def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: + """Create a `DataFrame` from SQL query text, first validating that + the query is allowed by the provided options. + + Parameters + ---------- + query : str + SQL query text. + options : SQLOptions + SQL options. + + Returns + ------- + DataFrame + DataFrame representation of the SQL query. + """ + return DataFrame(self.ctx.sql_with_options(query, options.options_internal)) + + def create_dataframe( + self, + partitions: list[list[pyarrow.RecordBatch]], + name: str | None = None, + schema: pyarrow.Schema | None = None, + ) -> DataFrame: + return DataFrame(self.ctx.create_dataframe(partitions, name, schema)) + + def create_dataframe_from_logical_plan(self, plan: LogicalPlan) -> DataFrame: + """Create a `DataFrame` from an existing logical plan. + + Parameters + ---------- + plan : LogicalPlan + Logical plan. + + Returns + ------- + DataFrame + DataFrame representation of the logical plan. + """ + return DataFrame(self.ctx.create_dataframe_from_logical_plan(plan)) + + def from_pylist( + self, data: list[dict[str, Any]], name: str | None = None + ) -> DataFrame: + """Create a `DataFrame` from a list of dictionaries. + + Parameters + ---------- + data : list[dict[str, Any]] + List of dictionaries. + name : str | None + Name of the DataFrame. + + Returns + ------- + DataFrame + DataFrame representation of the list of dictionaries. + """ + return DataFrame(self.ctx.from_pylist(data, name)) + + def from_pydict( + self, data: dict[str, list[Any]], name: str | None = None + ) -> DataFrame: + """Create a `DataFrame` from a dictionary of lists. + + Parameters + ---------- + data : dict[str, list[Any]] + Dictionary of lists. + name : str | None + Name of the DataFrame. + + Returns + ------- + DataFrame + DataFrame representation of the dictionary of lists. + """ + return DataFrame(self.ctx.from_pydict(data, name)) + + def from_arrow_table( + self, data: pyarrow.Table, name: str | None = None + ) -> DataFrame: + """Create a `DataFrame` from an Arrow table. + + Parameters + ---------- + data : pyarrow.Table + Arrow table. + name : str | None + Name of the DataFrame. + + Returns + ------- + DataFrame + DataFrame representation of the Arrow table. + """ + return DataFrame(self.ctx.from_arrow_table(data, name)) + + def from_pandas(self, data: pandas.DataFrame, name: str | None = None) -> DataFrame: + """Create a `DataFrame` from a Pandas DataFrame. + + Parameters + ---------- + data : pandas.DataFrame + Pandas DataFrame. + name : str | None + Name of the DataFrame. + + Returns + ------- + DataFrame + DataFrame representation of the Pandas DataFrame. + """ + return DataFrame(self.ctx.from_pandas(data, name)) + + def from_polars(self, data: polars.DataFrame, name: str | None = None) -> DataFrame: + """Create a `DataFrame` from a Polars DataFrame. + + Parameters + ---------- + data : polars.DataFrame + Polars DataFrame. + name : str | None + Name of the DataFrame. + + Returns + ------- + DataFrame + DataFrame representation of the Polars DataFrame. + """ + return DataFrame(self.ctx.from_polars(data, name)) + + def register_table(self, name: str, table: pyarrow.Table) -> None: + self.ctx.register_table(name, table) + + def deregister_table(self, name: str) -> None: + self.ctx.deregister_table(name) + + def register_record_batches( + self, name: str, partitions: list[list[pyarrow.RecordBatch]] + ) -> None: + self.ctx.register_record_batches(name, partitions) + + def register_parquet( + self, + name: str, + path: str, + table_partition_cols: list[tuple[str, str]] = [], + parquet_pruning: bool = True, + file_extension: str = ".parquet", + skip_metadata: bool = True, + schema: pyarrow.Schema | None = None, + file_sort_order: list[list[Expr]] | None = None, + ) -> None: + """Register a Parquet file as a table. + + The registered table can be referenced from SQL statement executed against + this context. + + Parameters + ---------- + name : str + Name of the table to register. + path : str + Path to the Parquet file. + table_partition_cols : list[tuple[str, str]], optional + Partition columns, by default [] + parquet_pruning : bool, optional + Whether the parquet reader should use the predicate to prune row groups, by default True + file_extension : str, optional + File extension; only files with this extension are selected for data input, by default ".parquet" + skip_metadata : bool, optional + Whether the parquet reader should skip any metadata that may be in the file + schema. This can help avoid schema conflicts due to metadata. by default True + schema : pyarrow.Schema | None, optional + The data source schema, by default None + file_sort_order : list[list[Expr]] | None, optional + Sort order for the file, by default None + """ + self.ctx.register_parquet( + name, + path, + table_partition_cols, + parquet_pruning, + file_extension, + skip_metadata, + schema, + file_sort_order, + ) + + def register_csv( + self, + name: str, + path: str, + schema: pyarrow.Schema | None = None, + has_header: bool = True, + delimiter: str = ",", + schema_infer_max_records: int = 1000, + file_extension: str = ".csv", + file_compression_type: str | None = None, + ) -> None: + """Register a CSV file as a table. + + The registered table can be referenced from SQL statement executed against. + + Parameters + ---------- + name : str + Name of the table to register. + path : str + Path to the CSV file. + schema : pyarrow.Schema | None, optional + An optional schema representing the CSV file. If None, the CSV reader will try to infer it based on data in file, by default None + has_header : bool, optional + Whether the CSV file have a header. If schema inference is run on a file with no headers, default column names are created, by default True + delimiter : str, optional + An optional column delimiter, by default "," + schema_infer_max_records : int, optional + Maximum number of rows to read from CSV files for schema inference if needed, by default 1000 + file_extension : str, optional + File extension; only files with this extension are selected for data input, by default ".csv" + file_compression_type : str | None, optional + File compression type, by default None + """ + self.ctx.register_csv( + name, + path, + schema, + has_header, + delimiter, + schema_infer_max_records, + file_extension, + file_compression_type, + ) + + def register_json( + self, + name: str, + path: str, + schema: pyarrow.Schema | None = None, + schema_infer_max_records: int = 1000, + file_extension: str = ".json", + table_partition_cols: list[tuple[str, str]] = [], + file_compression_type: str | None = None, + ) -> None: + """Register a JSON file as a table. + + The registered table can be referenced from SQL statement executed against + this context. + + Parameters + ---------- + name : str + Name of the table to register. + path : str + Path to the JSON file. + schema : pyarrow.Schema | None, optional + The data source schema, by default None + schema_infer_max_records : int, optional + Maximum number of rows to read from JSON files for schema inference if needed, by default 1000 + file_extension : str, optional + File extension; only files with this extension are selected for data input, by default ".json" + table_partition_cols : list[tuple[str, str]], optional + Partition columns, by default [] + file_compression_type : str | None, optional + File compression type, by default None + """ + self.ctx.register_json( + name, + path, + schema, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, + ) + + def register_avro( + self, + name: str, + path: str, + schema: pyarrow.Schema | None = None, + file_extension: str = ".avro", + table_partition_cols: list[tuple[str, str]] = [], + ) -> None: + """Register an Avro file as a table. + + The registered table can be referenced from SQL statement executed against + this context. + + Parameters + ---------- + name : str + Name of the table to register. + path : str + Path to the Avro file. + schema : pyarrow.Schema | None, optional + The data source schema, by default None + file_extension : str, optional + File extension to select, by default ".avro" + table_partition_cols : list[tuple[str, str]], optional + Partition columns, by default [] + """ + self.ctx.register_avro(name, path, schema, file_extension, table_partition_cols) + + def register_dataset(self, name: str, dataset: pyarrow.dataset.Dataset) -> None: + """ + Register a `pyarrow.dataset.Dataset` as a table. + + Parameters + ---------- + name : str + Name of the table to register. + dataset : dataset.Dataset + PyArrow dataset. + """ + self.ctx.register_dataset(name, dataset) + + def register_udf(self, udf: ScalarUDF) -> None: + """Register a user-defined function (UDF) with the context. + + Parameters + ---------- + udf : ScalarUDF + User-defined function. + """ + self.ctx.register_udf(udf.udf) + + def register_udaf(self, udaf: AggregateUDF) -> None: + """Register a user-defined aggregation function (UDAF) with the context. + + Parameters + ---------- + udaf : AggregateUDF + User-defined aggregation function. + """ + self.ctx.register_udaf(udaf) + + def catalog(self, name: str = "datafusion") -> Catalog: + """Retrieve a catalog by name. + + Parameters + ---------- + name : str, optional + Name of the catalog to retrieve, by default "datafusion". + + Returns + ------- + Catalog + Catalog representation. + """ + return self.ctx.catalog(name) + + @deprecated( + "Use the catalog provider interface `SessionContext.catalog` to " + "examine available catalogs, schemas and tables" + ) + def tables(self) -> set[str]: + return self.ctx.tables() + + def table(self, name: str) -> DataFrame: + """Retrieve a `DataFrame` representing a previously registered table. + + Parameters + ---------- + name : str + Name of the table to retrieve. + + Returns + ------- + DataFrame + DataFrame representation of the table. + """ + return DataFrame(self.ctx.table(name)) + + def table_exist(self, name: str) -> bool: + """Return whether a table with the given name exists. + + Parameters + ---------- + name : str + Name of the table to check. + + Returns + ------- + bool + Whether a table with the given name exists. + """ + return self.ctx.table_exist(name) + + def empty_table(self) -> DataFrame: + """Create an empty `DataFrame`. + + Returns + ------- + DataFrame + An empty DataFrame. + """ + return DataFrame(self.ctx.empty_table()) + + def session_id(self) -> str: + """Retrun an id that uniquely identifies this `SessionContext`. + + Returns + ------- + str + Unique session identifier + """ + return self.ctx.session_id() + + def read_json( + self, + path: str, + schema: pyarrow.Schema | None = None, + schema_infer_max_records: int = 1000, + file_extension: str = ".json", + table_partition_cols: list[tuple[str, str]] = [], + file_compression_type: str | None = None, + ) -> DataFrame: + """Create a `DataFrame` for reading a line-delimited JSON data source. + + Parameters + ---------- + path : str + Path to the JSON file + schema : pyarrow.Schema | None, optional + The data source schema, by default None + schema_infer_max_records : int, optional + Maximum number of rows to read from JSON files for schema inference if needed, by default 1000 + file_extension : str, optional + File extension; only files with this extension are selected for data input, by default ".json" + table_partition_cols : list[tuple[str, str]], optional + Partition columns, by default [] + file_compression_type : str | None, optional + File compression type, by default None + + Returns + ------- + DataFrame + DataFrame representation of the read JSON files + """ + return DataFrame( + self.ctx.read_json( + path, + schema, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, + ) + ) + + def read_csv( + self, + path: str, + schema: pyarrow.Schema | None = None, + has_header: bool = True, + delimiter: str = ",", + schema_infer_max_records: int = 1000, + file_extension: str = ".csv", + table_partition_cols: list[tuple[str, str]] = [], + file_compression_type: str | None = None, + ) -> DataFrame: + """Create a `DataFrame` for reading a CSV data source. + + Parameters + ---------- + path : str + Path to the CSV file + schema : pyarrow.Schema | None, optional + An optional schema representing the CSV files. If None, the CSV reader will try to infer it based on data in file, by default None + has_header : bool, optional + Whether the CSV file have a header. If schema inference is run on a file with no headers, default column names are created, by default True + delimiter : str, optional + An optional column delimiter, by default "," + schema_infer_max_records : int, optional + Maximum number of rows to read from CSV files for schema inference if needed, by default 1000 + file_extension : str, optional + File extension; only files with this extension are selected for data input, by default ".csv" + table_partition_cols : list[tuple[str, str]], optional + Partition columns, by default [] + file_compression_type : str | None, optional + File compression type, by default None + + Returns + ------- + DataFrame + DataFrame representation of the read CSV files + """ + return DataFrame( + self.ctx.read_csv( + path, + schema, + has_header, + delimiter, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, + ) + ) + + def read_parquet( + self, + path: str, + table_partition_cols: list[tuple[str, str]] = [], + parquet_pruning: bool = True, + file_extension: str = ".parquet", + skip_metadata: bool = True, + schema: pyarrow.Schema | None = None, + file_sort_order: list[list[Expr]] | None = None, + ) -> DataFrame: + """Create a `DataFrame` for reading Parquet data source. + + Parameters + ---------- + path: str + Path to the Parquet file + table_partition_cols : list[tuple[str, str]], optional + Partition columns, by default [] + parquet_pruning : bool, optional + Whether the parquet reader should use the predicate to prune row groups, by default True + file_extension : str, optional + File extension; only files with this extension are selected for data input, by default ".parquet" + skip_metadata : bool, optional + Whether the parquet reader should skip any metadata that may be in the file + schema. This can help avoid schema conflicts due to metadata. by default True + schema : pyarrow.Schema | None, optional + An optional schema representing the parquet files. If None, the parquet + reader will try to infer it based on data in the file, by default None + file_sort_order : list[list[Expr]] | None, optional + Sort order for the file, by default None + + Returns + ------- + DataFrame + DataFrame representation of the read Parquet files + """ + return DataFrame( + self.ctx.read_parquet( + path, + table_partition_cols, + parquet_pruning, + file_extension, + skip_metadata, + schema, + file_sort_order, + ) + ) + + def read_avro( + self, + path: str, + schema: pyarrow.Schema | None = None, + file_partition_cols: list[tuple[str, str]] = [], + file_extension: str = ".avro", + ) -> DataFrame: + """Create a `DataFrame` for reading Avro data source. + + Parameters + ---------- + path : str + Path to the Avro file + schema : pyarrow.Schema | None, optional + The data source schema, by default None + file_partition_cols : list[tuple[str, str]], optional + Partition columns, by default [] + file_extension : str, optional + File extension to select, by default ".avro" + + Returns + ------- + DataFrame + DataFrame representation of the read Avro file + """ + return DataFrame( + self.ctx.read_avro(path, schema, file_partition_cols, file_extension) + ) + + def read_table(self, table: Table) -> DataFrame: + return DataFrame(self.ctx.read_table(table)) + + def execute(self, plan: ExecutionPlan, part: int) -> RecordBatchStream: + return RecordBatchStream(self.ctx.execute(plan, part)) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py new file mode 100644 index 00000000..55df376d --- /dev/null +++ b/python/datafusion/dataframe.py @@ -0,0 +1,561 @@ +from __future__ import annotations + +from typing import Any, List +from datafusion.record_batch import RecordBatchStream +from typing_extensions import deprecated +import pyarrow as pa +import pandas as pd +import polars as pl + +from datafusion._internal import DataFrame as DataFrameInternal +from datafusion.expr import Expr +from datafusion._internal import ( + LogicalPlan, + ExecutionPlan, +) # TODO make these first class python classes + + +class DataFrame: + def __init__(self, df: DataFrameInternal) -> None: + self.df = df + + def __getitem__(self, key: str | List[str]) -> DataFrame: + """Return a new `DataFrame` with the specified column or columns. + + Parameters + ---------- + key : Any + Column name or list of column names to select. + + Returns + ------- + DataFrame + DataFrame with the specified column or columns. + """ + return DataFrame(self.df.__getitem__(key)) + + def __repr__(self) -> str: + """Return a string representation of the DataFrame. + + Returns + ------- + str + String representation of the DataFrame. + """ + return self.df.__repr__() + + def describe(self) -> DataFrame: + """Return a new `DataFrame` that has statistics for a DataFrame. + + Only summarized numeric datatypes at the moments and returns nulls + for non-numeric datatypes. + + The output format is modeled after pandas. + + Returns + ------- + DataFrame + A summary DataFrame containing statistics. + """ + return DataFrame(self.df.describe()) + + def schema(self) -> pa.Schema: + """Return the `pyarrow.Schema` describing the output of this DataFrame. + + The output schema contains information on the name, data type, and + nullability for each column. + + Returns + ------- + pa.Schema + Describing schema of the DataFrame + """ + return self.df.schema() + + def select_columns(self, *args: str) -> DataFrame: + """Filter the DataFrame by columns. + + Returns + ------- + DataFrame + DataFrame only containing the specified columns. + """ + return DataFrame(self.df.select_columns(*args)) + + def select(self, *args: Expr) -> DataFrame: + """Project arbitrary expressions (like SQL SELECT expressions) into a new `DataFrame`. + + Returns + ------- + DataFrame + DataFrame after projection. It has one column for each expression. + """ + args = [arg.expr for arg in args] + return DataFrame(self.df.select(*args)) + + def filter(self, predicate: Expr) -> DataFrame: + """Return a DataFrame for which `predicate` evaluates to `True`. + + Rows for which `predicate` evaluates to `False` or `None` are filtered out. + + Parameters + ---------- + predicate : Expr + Predicate expression to filter the DataFrame. + + Returns + ------- + DataFrame + DataFrame after filtering. + """ + return DataFrame(self.df.filter(predicate.expr)) + + def with_column(self, name: str, expr: Expr) -> DataFrame: + """Add an additional column to the DataFrame. + + Parameters + ---------- + name : str + Name of the column to add. + expr : Expr + Expression to compute the column. + + Returns + ------- + DataFrame + DataFrame with the new column. + """ + return DataFrame(self.df.with_column(name, expr.expr)) + + def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame: + """Rename one column by applying a new projection. + + This is a no-op if the column to be renamed does not exist. + + The method supports case sensitive rename with wrapping column name + into one the following symbols (" or ' or `). + + Parameters + ---------- + old_name : str + Old column name. + new_name : str + New column name. + + Returns + ------- + DataFrame + DataFrame with the column renamed. + """ + return DataFrame(self.df.with_column_renamed(old_name, new_name)) + + def aggregate(self, group_by: list[Expr], aggs: list[Expr]) -> DataFrame: + """Return a new `DataFrame` that aggregates the rows of the current DataFrame. + + First optionally grouping by the given expressions. + + Parameters + ---------- + group_by : list[Expr] + List of expressions to group by. + aggs : list[Expr] + List of expressions to aggregate. + + Returns + ------- + DataFrame + DataFrame after aggregation. + """ + group_by = [e.expr for e in group_by] + aggs = [e.expr for e in aggs] + return DataFrame(self.df.aggregate(group_by, aggs)) + + def sort(self, *exprs: Expr) -> DataFrame: + """Sort the DataFrame by the specified sorting expressions. + + Note that any expression can be turned into a sort expression by + calling its `sort` method. + + Returns + ------- + DataFrame + DataFrame after sorting. + """ + exprs = [expr.expr for expr in exprs] + return DataFrame(self.df.sort(*exprs)) + + def limit(self, count: int, offset: int = 0) -> DataFrame: + """Return a new `DataFrame` with a limited number of rows. + + Parameters + ---------- + count : int + Number of rows to limit the DataFrame to. + offset : int, optional + Number of rows to skip, by default 0 + + Returns + ------- + DataFrame + DataFrame after limiting. + """ + return DataFrame(self.df.limit(count, offset)) + + def collect(self) -> list[pa.RecordBatch]: + """Execute this `DataFrame` and collect `pyarrow.RecordBatch`es into memory. + + Prior to calling `collect`, modifying a DataFrme simply updates a plan + (no actual computation is performed). Calling `collect` triggers the + computation. + + Returns + ------- + list[pa.RecordBatch] + List of `pyarrow.RecordBatch`es collected from the DataFrame. + """ + return self.df.collect() + + def cache(self) -> DataFrame: + """Cache the DataFrame as a memory table. + + Returns + ------- + DataFrame + Cached DataFrame. + """ + return DataFrame(self.df.cache()) + + def collect_partitioned(self) -> list[list[pa.RecordBatch]]: + """Execute this DataFrame and collect all results into a list of list of + `pyarrow.RecordBatch`es maintaining the input partitioning. + + Returns + ------- + list[list[pa.RecordBatch]] + List of list of `pyarrow.RecordBatch`es collected from the DataFrame. + """ + return self.df.collect_partitioned() + + def show(self, num: int = 20) -> None: + """Execute the DataFrame and print the result to the console. + + Parameters + ---------- + num : int, optional + Number of lines to show, by default 20 + """ + self.df.show(num) + + def distinct(self) -> DataFrame: + """Return a new `DataFrame` with all duplicated rows removed. + + Returns + ------- + DataFrame + DataFrame after removing duplicates. + """ + return DataFrame(self.df.distinct()) + + def join( + self, + right: DataFrame, + join_keys: tuple[list[str], list[str]], + how: str, + ) -> DataFrame: + """Join this `DataFrame` with another `DataFrame` using explicitly + specified columns. + + Parameters + ---------- + right : DataFrame + Other DataFrame to join with. + join_keys : tuple[list[str], list[str]] + Tuple of two lists of column names to join on. + how : str + Type of join to perform. Supported types are "inner", "left", "right", "full", "semi", "anti". + + Returns + ------- + DataFrame + DataFrame after join. + """ + return DataFrame(self.df.join(right.df, join_keys, how)) + + def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame: + """Return a DataFrame with the explanation of its plan so far. + + If `analyze` is specified, runs the plan and reports metrics. + + Parameters + ---------- + verbose : bool, optional + If `True`, more details will be included, by default False + analyze : bool, optional + If `True`, the plan will run and metrics reported, by default False + + Returns + ------- + DataFrame + DataFrame with the explanation of its plan. + """ + return DataFrame(self.df.explain(verbose, analyze)) + + def logical_plan(self) -> LogicalPlan: + """Return the unoptimized `LogicalPlan` that comprises this `DataFrame`. + + Returns + ------- + LogicalPlan + Unoptimized logical plan. + """ + return self.df.logical_plan() + + def optimized_logical_plan(self) -> LogicalPlan: + """Return the optimized `LogicalPlan` that comprises this `DataFrame`. + + Returns + ------- + LogicalPlan + Optimized logical plan. + """ + return self.df.optimized_logical_plan() + + def execution_plan(self) -> ExecutionPlan: + """Return the execution/physical plan that comprises this `DataFrame`. + + Returns + ------- + ExecutionPlan + Execution plan. + """ + return self.df.execution_plan() + + def repartition(self, num: int) -> DataFrame: + """Repartition a DataFrame into `num` partitions. + + The batches allocation uses a round-robin algorithm. + + Parameters + ---------- + num : int + Number of partitions to repartition the DataFrame into. + + Returns + ------- + DataFrame + Repartitioned DataFrame. + """ + return DataFrame(self.df.repartition(num)) + + def repartition_by_hash(self, *args: Expr, num: int) -> DataFrame: + """Repartition a DataFrame into `num` partitions using a hash partitioning scheme. + + Parameters + ---------- + num : int + Number of partitions to repartition the DataFrame into. + + Returns + ------- + DataFrame + Repartitioned DataFrame. + """ + args = [expr.expr for expr in args] + return DataFrame(self.df.repartition_by_hash(*args, num=num)) + + def union(self, other: DataFrame, distinct: bool = False) -> DataFrame: + """Calculate the union of two `DataFrame`s. + + The two `DataFrame`s must have exactly the same schema. + + Parameters + ---------- + other : DataFrame + DataFrame to union with. + distinct : bool, optional + If `True`, duplicate rows will be removed, by default False + + Returns + ------- + DataFrame + DataFrame after union. + """ + return DataFrame(self.df.union(other.df, distinct)) + + def union_distinct(self, other: DataFrame) -> DataFrame: + """Calculate the distinct union of two `DataFrame`s. + + The two `DataFrame`s must have exactly the same schema. + Any duplicate rows are discarded. + + Parameters + ---------- + other : DataFrame + DataFrame to union with. + + Returns + ------- + DataFrame + DataFrame after union. + """ + return DataFrame(self.df.union_distinct(other.df)) + + def intersect(self, other: DataFrame) -> DataFrame: + """Calculate the intersection of two `DataFrame`s. + + The two `DataFrame`s must have exactly the same schema. + + Parameters + ---------- + other : DataFrame + DataFrame to intersect with. + + Returns + ------- + DataFrame + DataFrame after intersection. + """ + return DataFrame(self.df.intersect(other.df)) + + def except_all(self, other: DataFrame) -> DataFrame: + """Calculate the exception of two `DataFrame`s. + + The two `DataFrame`s must have exactly the same schema. + + Parameters + ---------- + other : DataFrame + DataFrame to calculate exception with. + + Returns + ------- + DataFrame + DataFrame after exception. + """ + return DataFrame(self.df.except_all(other.df)) + + def write_csv(self, path: str) -> None: + """Execute the `DataFrame` and write the results to a CSV file. + + Parameters + ---------- + path : str + Path of the CSV file to write. + """ + self.df.write_csv(path) + + def write_parquet( + self, + path: str, + compression: str = "uncompressed", + compression_level: int | None = None, + ) -> None: + """Execute the `DataFrame` and write the results to a Parquet file. + + Parameters + ---------- + path : str + Path of the Parquet file to write. + compression : str, optional + Compression type to use, by default "uncompressed" + compression_level : int | None, optional + Compression level to use, by default None + """ + self.df.write_parquet(path, compression, compression_level) + + def write_json(self, path: str) -> None: + """Execute the `DataFrame` and write the results to a JSON file. + + Parameters + ---------- + path : str + Path of the JSON file to write. + """ + self.df.write_json(path) + + def to_arrow_table(self) -> pa.Table: + """Execute the `DataFrame` and convert it into an Arrow Table. + + Returns + ------- + pa.Table + Arrow Table. + """ + return self.df.to_arrow_table() + + def execute_stream(self) -> RecordBatchStream: + """ + TODO add descriptive text + """ + return RecordBatchStream(self.df.execute_stream()) + + def execute_stream_partitioned(self) -> list[RecordBatchStream]: + """ + TODO add descriptive text + """ + streams = self.df.execute_stream_partitioned() + return [RecordBatchStream(rbs) for rbs in streams] + + def to_pandas(self) -> pd.DataFrame: + """Execute the `DataFrame` and convert it into a Pandas DataFrame. + + Returns + ------- + pd.DataFrame + Pandas DataFrame. + """ + return self.df.to_pandas() + + def to_pylist(self) -> list[dict[str, Any]]: + """Execute the `DataFrame` and convert it into a list of dictionaries. + + Returns + ------- + list[dict[str, Any]] + List of dictionaries. + """ + return self.df.to_pylist() + + def to_pydict(self) -> dict[str, list[Any]]: + """Execute the `DataFrame` and convert it into a dictionary of lists. + + Returns + ------- + dict[str, list[Any]] + Dictionary of lists. + """ + return self.df.to_pydict() + + def to_polars(self) -> pl.DataFrame: + """Execute the `DataFrame` and convert it into a Polars DataFrame. + + Returns + ------- + pl.DataFrame + Polars DataFrame. + """ + return self.df.to_polars() + + def count(self) -> int: + """Return the total number of rows in this `DataFrame`. + + Note that this method will actually run a plan to calculate the + count, which may be slow for large or complicated DataFrames. + + Returns + ------- + int + Number of rows in the DataFrame. + """ + return self.df.count() + + @deprecated("Use :func:`unnest_columns` instead.") + def unnest_column(self, column: str, preserve_nulls: bool = True) -> DataFrame: + """ """ + return DataFrame(self.df.unnest_column(column, preserve_nulls=preserve_nulls)) + + def unnest_columns( + self, columns: list[str], preserve_nulls: bool = True + ) -> DataFrame: + """ """ + return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls)) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index e914b85d..93ebddba 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -15,9 +15,256 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from ._internal import expr +from ._internal import expr as expr_internal, LogicalPlan +from datafusion.common import RexType, DataTypeMap +from typing import Any +import pyarrow as pa +# The following are imported from the internal representation. We may choose to +# give these all proper wrappers, or to simply leave as is. These were added +# in order to support passing the `test_imports` unit test. +# Tim Saucer note: It is not clear to me what the use case is for exposing +# these definitions to the end user. -def __getattr__(name): - return getattr(expr, name) +Alias = expr_internal.Alias +Analyze = expr_internal.Analyze +Aggregate = expr_internal.Aggregate +AggregateFunction = expr_internal.AggregateFunction +Between = expr_internal.Between +BinaryExpr = expr_internal.BinaryExpr +Case = expr_internal.Case +Cast = expr_internal.Cast +Column = expr_internal.Column +CreateMemoryTable = expr_internal.CreateMemoryTable +CreateView = expr_internal.CreateView +CrossJoin = expr_internal.CrossJoin +Distinct = expr_internal.Distinct +DropTable = expr_internal.DropTable +Exists = expr_internal.Exists +Explain = expr_internal.Explain +Extension = expr_internal.Extension +Filter = expr_internal.Filter +GroupingSet = expr_internal.GroupingSet +Join = expr_internal.Join +ILike = expr_internal.ILike +InList = expr_internal.InList +InSubquery = expr_internal.InSubquery +IsFalse = expr_internal.IsFalse +IsNotTrue = expr_internal.IsNotTrue +IsTrue = expr_internal.IsTrue +IsUnknown = expr_internal.IsUnknown +IsNotFalse = expr_internal.IsNotFalse +IsNotNull = expr_internal.IsNotNull +IsNotUnknown = expr_internal.IsNotUnknown +JoinConstraint = expr_internal.JoinConstraint +JoinType = expr_internal.JoinType +Like = expr_internal.Like +Limit = expr_internal.Limit +Literal = expr_internal.Literal +Negative = expr_internal.Negative +Not = expr_internal.Not +Partitioning = expr_internal.Partitioning +Placeholder = expr_internal.Placeholder +Projection = expr_internal.Projection +Repartition = expr_internal.Repartition +ScalarSubquery = expr_internal.ScalarSubquery +ScalarVariable = expr_internal.ScalarVariable +SimilarTo = expr_internal.SimilarTo +Sort = expr_internal.Sort +Subquery = expr_internal.Subquery +SubqueryAlias = expr_internal.SubqueryAlias +TableScan = expr_internal.TableScan +TryCast = expr_internal.TryCast +Union = expr_internal.Union + + +class Expr: + def __init__(self, expr: expr_internal.Expr) -> None: + self.expr = expr + + def to_variant(self) -> Any: + return self.expr.to_variant() + + def display_name(self) -> str: + return self.expr.display_name() + + def canonical_name(self) -> str: + return self.expr.canonical_name() + + def variant_name(self) -> str: + return self.expr.variant_name() + + def __richcmp__(self, other: Expr, op: int) -> Expr: + return Expr(self.expr.__richcmp__(other, op)) + + def __repr__(self) -> str: + return self.expr.__repr__() + + def __add__(self, rhs: Expr) -> Expr: + return Expr(self.expr.__add__(rhs.expr)) + + def __sub__(self, rhs: Expr) -> Expr: + return Expr(self.expr.__sub__(rhs.expr)) + + def __truediv__(self, rhs: Expr) -> Expr: + return Expr(self.expr.__truediv__(rhs.expr)) + + def __mul__(self, rhs: Expr) -> Expr: + return Expr(self.expr.__mul__(rhs.expr)) + + def __mod__(self, rhs: Expr) -> Expr: + return Expr(self.expr.__mod__(rhs.expr)) + + def __and__(self, rhs: Expr) -> Expr: + return Expr(self.expr.__and__(rhs.expr)) + + def __or__(self, rhs: Expr) -> Expr: + return Expr(self.expr.__or__(rhs.expr)) + + def __invert__(self) -> Expr: + return Expr(self.expr.__invert__()) + + def __getitem__(self, key: str) -> Expr: + return Expr(self.expr.__getitem__(key)) + + def __eq__(self, rhs: Expr) -> Expr: + return Expr(self.expr.__eq__(rhs.expr)) + + def __ne__(self, rhs: Expr) -> Expr: + return Expr(self.expr.__eq__(rhs.expr)) + + def __ge__(self, rhs: Expr) -> Expr: + return Expr(self.expr.__ge__(rhs.expr)) + + def __gt__(self, rhs: Expr) -> Expr: + return Expr(self.expr.__gt__(rhs.expr)) + + def __le__(self, rhs: Expr) -> Expr: + return Expr(self.expr.__le__(rhs.expr)) + + def __lt__(self, rhs: Expr) -> Expr: + return Expr(self.expr.__lt__(rhs.expr)) + + @staticmethod + def literal(value: Any) -> Expr: + if not isinstance(value, pa.Scalar): + value = pa.scalar(value) + return Expr(expr_internal.Expr.literal(value)) + + @staticmethod + def column(value: str) -> Expr: + return Expr(expr_internal.Expr.column(value)) + + def alias(self, name: str) -> Expr: + return Expr(self.expr.alias(name)) + + def sort(self, ascending: bool = True, nulls_first: bool = True) -> Expr: + return Expr(self.expr.sort(ascending=ascending, nulls_first=nulls_first)) + + def is_null(self) -> Expr: + return Expr(self.expr.is_null()) + + def cast(self, to: pa.DataType[Any]) -> Expr: + return Expr(self.expr.cast(to)) + + def rex_type(self) -> RexType: + return self.expr.rex_type() + + def types(self) -> DataTypeMap: + return self.expr.types() + + def python_value(self) -> Any: + return self.expr.python_value() + + def rex_call_operands(self) -> list[Expr]: + return [Expr(e) for e in self.expr.rex_call_operands()] + + def rex_call_operator(self) -> str: + return self.expr.rex_call_operator() + + def column_name(self, plan: LogicalPlan) -> str: + return self.expr.column_name() + + +class WindowFrame: + def __init__( + self, units: str, start_bound: int | None, end_bound: int | None + ) -> None: + """ + :param units: Should be one of `rows`, `range`, or `groups` + :param start_bound: Sets the preceeding bound. Must be >= 0. If none, this will be set to unbounded. If unit type is `groups`, this parameter must be set. + :param end_bound: Sets the following bound. Must be >= 0. If none, this will be set to unbounded. If unit type is `groups`, this parameter must be set. + """ + self.window_frame = expr_internal.WindowFrame(units, start_bound, end_bound) + + def get_frame_units(self) -> str: + """ + Returns the window frame units for the bounds + """ + return self.window_frame.get_frame_units() + + def get_lower_bound(self) -> WindowFrameBound: + """ + Returns starting bound + """ + return WindowFrameBound(self.window_frame.get_lower_bound()) + + def get_upper_bound(self): + """ + Returns end bound + """ + return WindowFrameBound(self.window_frame.get_upper_bound()) + + +class WindowFrameBound: + def __init__(self, frame_bound: expr_internal.WindowFrameBound) -> None: + self.frame_bound = frame_bound + + def get_offset(self) -> int | None: + """ + Returns the offset of the window frame + """ + return self.frame_bound.get_offset() + + def is_current_row(self) -> bool: + """ + Returns if the frame bound is current row + """ + return self.frame_bound.is_current_row() + + def is_following(self) -> bool: + """ + Returns if the frame bound is following + """ + return self.frame_bound.is_following() + + def is_preceding(self) -> bool: + """ + Returns if the frame bound is preceding + """ + return self.frame_bound.is_preceding() + + def is_unbounded(self) -> bool: + """ + Returns if the frame bound is unbounded + """ + return self.frame_bound.is_unbounded() + + +class CaseBuilder: + def __init__(self, case_builder: expr_internal.CaseBuilder) -> None: + """ + :param case_builder: Internal object. This constructor is not expected to be used by the end user. Instead use :func:`case` to construct. + """ + self.case_builder = case_builder + + def when(self, when_expr: Expr, then_expr: Expr) -> CaseBuilder: + return CaseBuilder(self.case_builder.when(when_expr.expr, then_expr.expr)) + + def otherwise(self, else_expr: Expr) -> Expr: + return Expr(self.case_builder.otherwise(else_expr.expr)) + + def end(self) -> Expr: + return Expr(self.case_builder.end()) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 782ecba2..de2264dc 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -15,9 +15,1731 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from ._internal import functions +# from datafusion._internal.context import SessionContext +# from datafusion._internal.expr import Expr +# from datafusion._internal.expr.conditional_expr import CaseBuilder +# from datafusion._internal.expr.window import WindowFrame +from datafusion._internal import functions as f, common +from datafusion.expr import CaseBuilder, Expr, WindowFrame +from datafusion.context import SessionContext -def __getattr__(name): - return getattr(functions, name) + +def isnan(expr: Expr) -> Expr: + """ + Returns true if a given number is +NaN or -NaN otherwise returns false. + """ + return Expr(f.isnan(expr.expr)) + + +def nullif(expr1: Expr, expr2: Expr) -> Expr: + """ + Returns NULL if expr1 equals expr2; otherwise it returns expr1. This can be used to perform the inverse operation of the COALESCE expression. + """ + return Expr(f.nullif(expr1.expr, expr2.expr)) + + +def encode(input: Expr, encoding: Expr) -> Expr: + """ + Encode the `input`, using the `encoding`. encoding can be base64 or hex. + """ + return Expr(f.encode(input.expr, encoding.expr)) + + +def decode(input: Expr, encoding: Expr) -> Expr: + """ + Decode the `input`, using the `encoding`. encoding can be base64 or hex. + """ + return Expr(f.decode(input.expr, encoding.expr)) + + +def array_to_string(expr: Expr, delimiter: Expr) -> Expr: + """ + Converts each element to its text representation. + """ + return Expr(f.array_to_string(expr.expr, delimiter.expr)) + + +def array_join(expr: Expr, delimiter: Expr) -> Expr: + """ + Converts each element to its text representation. + This is an alias for :func:`array_to_string`. + """ + return array_to_string(expr, delimiter) + + +def list_to_string(expr: Expr, delimiter: Expr) -> Expr: + """ + Converts each element to its text representation. + This is an alias for :func:`array_to_string`. + """ + return array_to_string(expr, delimiter) + + +def list_join(expr: Expr, delimiter: Expr) -> Expr: + """ + Converts each element to its text representation. + This is an alias for :func:`array_to_string`. + """ + return array_to_string(expr, delimiter) + + +def in_list(arg: Expr, values: list[Expr], negated: bool = False) -> Expr: + """ + Returns whether the argument is contained within the list `values`. + """ + values = [v.expr for v in values] + return Expr(f.in_list(arg.expr, values, negated)) + + +def digest(value: Expr, method: Expr) -> Expr: + """ + Computes the binary hash of an expression using the specified algorithm. + Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, blake2b, and blake3. + """ + return Expr(f.digest(value.expr, method.expr)) + + +def concat(*args: Expr) -> Expr: + """ + Concatenates the text representations of all the arguments. NULL arguments are ignored. + """ + args = [arg.expr for arg in args] + return Expr(f.concat(*args)) + + +def concat_ws(separator: str, *args: Expr) -> Expr: + """ + Concatenates the list `args` with the separator. `NULL` arugments are ignored. `separator` should not be `NULL`. + """ + args = [arg.expr for arg in args] + return Expr(f.concat_ws(separator, *args)) + + +def order_by(expr: Expr, ascending: bool = True, nulls_first: bool = True) -> Expr: + """ + Creates a new sort expression. + """ + return Expr(f.order_by(expr.expr, ascending, nulls_first)) + + +def alias(expr: Expr, name: str) -> Expr: + """ + Creates an alias expression. + """ + return Expr(f.alias(expr.expr, name)) + + +def col(name: str) -> Expr: + """ + Creates a column reference expression. + """ + return Expr(f.col(name)) + + +def count_star() -> Expr: + """ + Create a COUNT(1) aggregate expression. + """ + return Expr(f.count_star()) + + +def case(expr: Expr) -> CaseBuilder: + """ + Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. + """ + return CaseBuilder(f.case(expr.expr)) + + +def window( + name: str, + args: list[Expr], + partition_by: list[Expr] | None = None, + order_by: list[Expr] | None = None, + window_frame: WindowFrame | None = None, + ctx: SessionContext | None = None, +) -> Expr: + """ + Creates a new Window function expression. + """ + args = [a.expr for a in args] + partition_by = [e.expr for e in partition_by] if partition_by is not None else None + order_by = [o.expr for o in order_by] if order_by is not None else None + window_frame = window_frame.window_frame if window_frame is not None else None + return Expr(f.window(name, args, partition_by, order_by, window_frame, ctx)) + + +# scalar functions +def abs(arg: Expr) -> Expr: + """ + Return the absolute value of a given number. + + Returns + ------- + Expr + A new expression representing the absolute value of the input expression. + """ + return Expr(f.abs(arg.expr)) + + +def acos(arg: Expr) -> Expr: + """ + Returns the arc cosine or inverse cosine of a number. + + Returns + ------- + Expr + A new expression representing the arc cosine of the input expression. + """ + return Expr(f.acos(arg.expr)) + + +def acosh(arg: Expr) -> Expr: + """ + Returns inverse hyperbolic cosine. + """ + return Expr(f.acosh(arg.expr)) + + +def ascii(arg: Expr) -> Expr: + """ + Returns the numeric code of the first character of the argument. + """ + return Expr(f.ascii(arg.expr)) + + +def asin(arg: Expr) -> Expr: + """ + Returns the arc sine or inverse sine of a number. + """ + return Expr(f.asin(arg.expr)) + + +def asinh(arg: Expr) -> Expr: + """ + Returns inverse hyperbolic sine. + """ + return Expr(f.asinh(arg.expr)) + + +def atan(arg: Expr) -> Expr: + """ + Returns inverse tangent of a number. + """ + return Expr(f.atan(arg.expr)) + + +def atanh(arg: Expr) -> Expr: + """ + Returns inverse hyperbolic tangent. + """ + return Expr(f.atanh(arg.expr)) + + +def atan2(y: Expr, x: Expr) -> Expr: + """ + Returns inverse tangent of a division given in the argument. + """ + return Expr(f.atan2(y.expr, x.expr)) + + +def bit_length(arg: Expr) -> Expr: + """ + Returns the number of bits in the string argument. + """ + return Expr(f.bit_length(arg.expr)) + + +def btrim(arg: Expr) -> Expr: + """ + Removes all characters, spaces by default, from both sides of a string. + """ + return Expr(f.btrim(arg.expr)) + + +def cbrt(arg: Expr) -> Expr: + """ + Returns the cube root of a number. + """ + return Expr(f.cbrt(arg.expr)) + + +def ceil(arg: Expr) -> Expr: + """ + Returns the nearest integer greater than or equal to argument. + """ + return Expr(f.ceil(arg.expr)) + + +def character_length(arg: Expr) -> Expr: + """ + Returns the number of characters in the argument. + """ + return Expr(f.character_length(arg.expr)) + + +def length(string: Expr) -> Expr: + """ + The number of characters in the `string` + """ + return Expr(f.length(string.expr)) + + +def char_length(string: Expr) -> Expr: + """ + The number of characters in the `string`. + """ + return Expr(f.char_length(string.expr)) + + +def chr(arg: Expr) -> Expr: + """ + Converts the Unicode code point to a UTF8 character. + """ + return Expr(f.chr(arg.expr)) + + +def coalesce(*args: Expr) -> Expr: + """ + Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL. + """ + args = [arg.expr for arg in args] + return Expr(f.coalesce(*args)) + + +def cos(arg: Expr) -> Expr: + """ + Returns the cosine of the argument. + """ + return Expr(f.cos(arg.expr)) + + +def cosh(arg: Expr) -> Expr: + """ + Returns the hyperbolic cosine of the argument. + """ + return Expr(f.cosh(arg.expr)) + + +def cot(arg: Expr) -> Expr: + """ + Returns the cotangent of the argument. + """ + return Expr(f.cot(arg.expr)) + + +def degrees(arg: Expr) -> Expr: + """ + Converts the argument from radians to degrees. + """ + return Expr(f.degrees(arg.expr)) + + +def ends_with(arg: Expr, suffix: Expr) -> Expr: + """ + Returns true if the `string` ends with the `suffix`, false otherwise. + """ + return Expr(f.ends_with(arg.expr, suffix.expr)) + + +def exp(arg: Expr) -> Expr: + """ + Returns the exponential of the arugment. + """ + return Expr(f.exp(arg.expr)) + + +def factorial(arg: Expr) -> Expr: + """ + Returns the factorial of the argument. + """ + return Expr(f.factorial(arg.expr)) + + +def find_in_set(string: Expr, string_list: Expr) -> Expr: + """ + Returns a value in the range of 1 to N if the string is in the string list `string_list` consisting of N substrings. + The string list is a string composed of substrings separated by `,` characters. + """ + return Expr(f.find_in_set(string.expr, string_list.expr)) + + +def floor(arg: Expr) -> Expr: + """ + Returns the nearest integer less than or equal to the argument. + """ + return Expr(f.floor(arg.expr)) + + +def gcd(x: Expr, y: Expr) -> Expr: + """ + Returns the greatest common divisor. + """ + return Expr(f.gcd(x.expr, y.expr)) + + +def initcap(string: Expr) -> Expr: + """ + Converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase. + """ + return Expr(f.initcap(string.expr)) + + +def instr(string: Expr, substring: Expr) -> Expr: + """ + Finds the position from where the `substring` matches the `string`. + This is an alias for :func:`strpos`. + """ + return strpos(string, substring) + + +def iszero(arg: Expr) -> Expr: + """ + Returns true if a given number is +0.0 or -0.0 otherwise returns false. + """ + return Expr(f.iszero(arg.expr)) + + +def lcm(x: Expr, y: Expr) -> Expr: + """ + Returns the least common multiple. + """ + return Expr(f.lcm(x.expr, y.expr)) + + +def left(string: Expr, n: Expr) -> Expr: + """ + Returns the first `n` characters in the `string`. + """ + return Expr(f.left(string.expr, n.expr)) + + +def levenshtein(string1: Expr, string2: Expr) -> Expr: + """ + Returns the Levenshtein distance between the two given strings + """ + return Expr(f.levenshtein(string1.expr, string2.expr)) + + +def ln(arg: Expr) -> Expr: + """ + Returns the natural logarithm (base e) of the argument. + """ + return Expr(f.ln(arg.expr)) + + +def log(base: Expr, num: Expr) -> Expr: + """ + Returns the logarithm of a number for a particular `base` + """ + return Expr(f.log(base.expr, num.expr)) + + +def log10(arg: Expr) -> Expr: + """ + Base 10 logarithm of the argument. + """ + return Expr(f.log10(arg.expr)) + + +def log2(arg: Expr) -> Expr: + """ + Base 2 logarithm of the argument. + """ + return Expr(f.log2(arg.expr)) + + +def lower(arg: Expr) -> Expr: + """ + Converts a string to lowercase. + """ + return Expr(f.lower(arg.expr)) + + +def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: + """ + Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). + """ + characters = characters if characters is not None else Expr.literal(" ") + return Expr(f.lpad(string.expr, count.expr, characters.expr)) + + +def ltrim(arg: Expr) -> Expr: + """ + Removes all characters, spaces by default, from the beginning of a string. + """ + return Expr(f.ltrim(arg.expr)) + + +def md5(arg: Expr) -> Expr: + """ + Computes an MD5 128-bit checksum for a string expression. + """ + return Expr(f.md5(arg.expr)) + + +def nanvl(x: Expr, y: Expr) -> Expr: + """ + Returns `x` if `x` is not `NaN`. Otherwise returns `y`. + """ + return Expr(f.nanvl(x.expr, y.expr)) + + +def octet_length(arg: Expr) -> Expr: + """ + Returns the number of bytes of a string. + """ + return Expr(f.octet_length(arg.expr)) + + +# TODO: `overlay` in datafusion needs to be updated from generic `args` definition, and then exposed in this repo. +# def overlay(string: Expr, substring: Expr, start: Expr, length: Expr | None = None) -> Expr: +# """ +# Replace the substring of string that starts at the `start`'th character and extends for `length` characters with new substring +# """ +# return Expr() + + +def pi() -> Expr: + """ + Returns an approximate value of π. + """ + return Expr(f.pi()) + + +def position(string: Expr, substring: Expr) -> Expr: + """ + Finds the position from where the `substring` matches the `string`. + This is an alias for :func:`strpos`. + """ + return strpos(string, substring) + + +def power(base: Expr, exponent: Expr) -> Expr: + """ + Returns `base` raised to the power of `exponent`. + """ + return Expr(f.power(base.expr, exponent.expr)) + + +def pow(base: Expr, exponent: Expr) -> Expr: + """ + Returns `base` raised to the power of `exponent`. + This is an alias of `power`. + """ + return power(base, exponent) + + +def radians(arg: Expr) -> Expr: + """ + Converts the argument from degrees to radians. + """ + return Expr(f.radians(arg.expr)) + + +def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: + """ + Tests a string using a regular expression returning true if at + least one match, false otherwise. + """ + if flags is not None: + flags = flags.expr + return Expr(f.regexp_like(string.expr, regex.expr, flags)) + + +def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: + """ + Returns an array with each element containing the leftmost-first + match of the corresponding index in `regex` to string in `string` + + If there is no match, the list element is NULL. + + If a match is found, and the pattern contains no capturing parenthesized subexpressions, + then the list element is a single-element [`GenericStringArray`] containing the substring + matching the whole pattern. + + If a match is found, and the pattern contains capturing parenthesized subexpressions, then the + list element is a [`GenericStringArray`] whose n'th element is the substring matching + the n'th capturing parenthesized subexpression of the pattern. + """ + + # TODO VALIDATE THIS IS CORRECT FOR DATAFRAME RESULTS + if flags is not None: + flags = flags.expr + return Expr(f.regexp_match(string.expr, regex.expr, flags)) + + +def regexp_replace( + string: Expr, pattern: Expr, replacement: Expr, flags: Expr | None = None +) -> Expr: + """ + Replaces substring(s) matching a PCRE-like regular expression. + + The full list of supported features and syntax can be found at + + + Supported flags with the addition of 'g' can be found at + + """ + if flags is not None: + flags = flags.expr + return Expr(f.regexp_replace(string.expr, pattern.expr, replacement.expr, flags)) + + +def repeat(string: Expr, n: Expr) -> Expr: + """ + Repeats the `string` to `n` times. + """ + return Expr(f.repeat(string.expr, n.expr)) + + +def replace(string: Expr, from_val: Expr, to_val: Expr) -> Expr: + """ + Replaces all occurrences of `from` with `to` in the `string`. + """ + return Expr(f.replace(string.expr, from_val.expr, to_val.expr)) + + +def reverse(arg: Expr) -> Expr: + """ + Reverse the string argument. + """ + return Expr(f.reverse(arg.expr)) + + +def right(string: Expr, n: Expr) -> Expr: + """ + Returns the last `n` characters in the `string`. + """ + return Expr(f.right(string.expr, n.expr)) + + +def round(arg: Expr) -> Expr: + """ + Round the argument to the nearest integer. + """ + return Expr(f.round(arg.expr)) + + +def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: + """ + Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. + """ + characters = characters if characters is not None else Expr.literal(" ") + return Expr(f.rpad(string.expr, count.expr, characters.expr)) + + +def rtrim(arg: Expr) -> Expr: + """ + Removes all characters, spaces by default, from the end of a string. + """ + return Expr(f.rtrim(arg.expr)) + + +def sha224(arg: Expr) -> Expr: + """ + Computes the SHA-224 hash of a binary string. + """ + return Expr(f.sha224(arg.expr)) + + +def sha256(arg: Expr) -> Expr: + """ + Computes the SHA-256 hash of a binary string. + """ + return Expr(f.sha256(arg.expr)) + + +def sha384(arg: Expr) -> Expr: + """ + Computes the SHA-384 hash of a binary string. + """ + return Expr(f.sha384(arg.expr)) + + +def sha512(arg: Expr) -> Expr: + """ + Computes the SHA-512 hash of a binary string. + """ + return Expr(f.sha512(arg.expr)) + + +def signum(arg: Expr) -> Expr: + """ + Returns the sign of the argument (-1, 0, +1). + """ + return Expr(f.signum(arg.expr)) + + +def sin(arg: Expr) -> Expr: + """ + Returns the sine of the argument. + """ + return Expr(f.sin(arg.expr)) + + +def sinh(arg: Expr) -> Expr: + """ + Returns the hyperbolic sine of the argument. + """ + return Expr(f.sinh(arg.expr)) + + +def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: + """ + Splits a string based on a delimiter and picks out the desired field based on the index. + """ + return Expr(f.split_part(string.expr, delimiter.expr, index.expr)) + + +def sqrt(arg: Expr) -> Expr: + """ + Returns the square root of the argument. + """ + return Expr(f.sqrt(arg.expr)) + + +def starts_with(string: Expr, prefix: Expr) -> Expr: + """ + Returns true if string starts with prefix. + """ + return Expr(f.starts_with(string.expr, prefix.expr)) + + +def strpos(string: Expr, substring: Expr) -> Expr: + """ + Finds the position from where the `substring` matches the `string`. + """ + return Expr(f.strpos(string.expr, substring.expr)) + + +def substr(string: Expr, position: Expr) -> Expr: + """ + Substring from the `position` to the end. + """ + return Expr(f.substr(string.expr, position.expr)) + + +def substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr: + """ + Returns the substring from `string` before `count` occurrences of `delimiter`. + """ + return Expr(f.substr_index(string.expr, delimiter.expr, count.expr)) + + +def substring(string: Expr, position: Expr, length: Expr) -> Expr: + """ + Substring from the `position` with `length` characters. + """ + return Expr(f.substring(string.expr, position.expr, length.expr)) + + +def tan(arg: Expr) -> Expr: + """ + Returns the tangent of the argument. + """ + return Expr(f.tan(arg.expr)) + + +def tanh(arg: Expr) -> Expr: + """ + Returns the hyperbolic tangent of the argument. + """ + return Expr(f.tanh(arg.expr)) + + +def to_hex(arg: Expr) -> Expr: + """ + Converts an integer to a hexadecimal string. + """ + return Expr(f.to_hex(arg.expr)) + + +def now() -> Expr: + """ + Returns the current timestamp in nanoseconds, using the same value for all instances of now() in same statement. + """ + return Expr(f.now()) + + +def to_timestamp(arg: Expr, *formatters: Expr) -> Expr: + """ + Converts a string and optional formats to a `Timestamp` in nanoseconds. + """ + # TODO Add a detailed description of how to use formatters. + if formatters is None: + return f.to_timestamp(arg.expr) + + formatters = [f.expr for f in formatters] + return Expr(f.to_timestamp(arg.expr, *formatters)) + + +def to_timestamp_millis(arg: Expr, *formatters: Expr) -> Expr: + """ + Converts a string and optional formats to a `Timestamp` in milliseconds. + See `to_timestamp` for a description on how to use formatters. + """ + return Expr(f.to_timestamp_millis(arg.expr, *formatters)) + + +def to_timestamp_micros(arg: Expr, *formatters: Expr) -> Expr: + """ + Converts a string and optional formats to a `Timestamp` in microseconds. + See `to_timestamp` for a description on how to use formatters. + """ + return Expr(f.to_timestamp_micros(arg.expr, *formatters)) + + +def to_timestamp_nanos(arg: Expr, *formatters: Expr) -> Expr: + """ + Converts a string and optional formats to a `Timestamp` in nanoseconds. + See `to_timestamp` for a description on how to use formatters. + """ + return Expr(f.to_timestamp_nanos(arg.expr, *formatters)) + + +def to_timestamp_seconds(arg: Expr, *formatters: Expr) -> Expr: + """ + Converts a string and optional formats to a `Timestamp` in seconds. + See `to_timestamp` for a description on how to use formatters. + """ + return Expr(f.to_timestamp_seconds(arg.expr, *formatters)) + + +def to_unixtime(string: Expr, *format_arguments: Expr) -> Expr: + """ + Converts a string and optional formats to a Unixtime. + """ + # TODO verify if the format arguments are the same as to_timestamp and update documentation appropriately. + args = [f.expr for f in format_arguments] + return Expr(f.to_unixtime(string.expr, *args)) + + +def current_date() -> Expr: + """ + Returns current UTC date as a Date32 value. + """ + return Expr(f.current_date()) + + +def current_time() -> Expr: + """ + Returns current UTC time as a Time64 value. + """ + return Expr(f.current_time()) + + +def datepart(part: Expr, date: Expr) -> Expr: + """ + Return a specified part of a date. + This is an alias for `date_part`. + """ + return date_part(part, date) + + +def date_part(part: Expr, date: Expr) -> Expr: + """ + Extracts a subfield from the date. + """ + return Expr(f.date_part(part.expr, date.expr)) + + +def date_trunc(part: Expr, date: Expr) -> Expr: + """ + Truncates the date to a specified level of precision. + """ + return Expr(f.date_trunc(part.expr, date.expr)) + + +def datetrunc(part: Expr, date: Expr) -> Expr: + """ + Truncates the date to a specified level of precision. + This is an alias for `date_trunc`. + """ + return date_trunc(part, date) + + +def date_bin(stride: Expr, source: Expr, origin: Expr) -> Expr: + """ + Coerces an arbitrary timestamp to the start of the nearest specified interval. + """ + return Expr(f.date_bin(stride.expr, source.expr, origin.expr)) + + +def make_date(year: Expr, month: Expr, day: Expr) -> Expr: + """ + Make a date from year, month and day component parts. + """ + return Expr(f.make_date(year.expr, month.expr, day.expr)) + + +def translate(string: Expr, from_val: Expr, to_val: Expr) -> Expr: + """ + Replaces the characters in `from_val` with the counterpart in `to_val`. + """ + return Expr(f.translate(string.expr, from_val.expr, to_val.expr)) + + +def trim(arg: Expr) -> Expr: + """ + Removes all characters, spaces by default, from both sides of a string. + """ + return Expr(f.trim(arg.expr)) + + +def trunc(num: Expr, precision: Expr | None = None) -> Expr: + """ + Truncate the number toward zero with optional precision. + """ + if precision is not None: + return Expr(f.trunc(num.expr, precision.expr)) + return Expr(f.trunc(num.expr)) + + +def upper(arg: Expr) -> Expr: + """ + Converts a string to uppercase. + """ + return Expr(f.upper(arg.expr)) + + +def make_array(*args: Expr) -> Expr: + """ + Returns an array using the specified input expressions. + """ + args = [arg.expr for arg in args] + return Expr(f.make_array(*args)) + + +def array(*args: Expr) -> Expr: + """ + Returns an array using the specified input expressions. + This is an alias for `make_array`. + """ + return make_array(args) + + +def range(start: Expr, stop: Expr, step: Expr) -> Expr: + """ + Create a list of values in the range between start and stop. + """ + return Expr(f.range(start.expr, stop.expr, step.expr)) + + +def uuid(arg: Expr) -> Expr: + """ + Returns uuid v4 as a string value. + """ + return Expr(f.uuid(arg.expr)) + + +def struct(*args: Expr) -> Expr: + """ + Returns a struct with the given arguments. + """ + args = [arg.expr for arg in args] + return Expr(f.struct(*args)) + + +def named_struct(name_pairs: list[(str, Expr)]) -> Expr: + """ + Returns a struct with the given names and arguments pairs + """ + name_pairs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs] + + # flatten + name_pairs = [x.expr for xs in name_pairs for x in xs] + return Expr(f.named_struct(*name_pairs)) + + +def from_unixtime(arg: Expr) -> Expr: + """ + Converts an integer to RFC3339 timestamp format string. + """ + return Expr(f.from_unixtime(arg.expr)) + + +def arrow_typeof(arg: Expr) -> Expr: + """ + Returns the Arrow type of the expression. + """ + return Expr(f.arrow_typeof(arg.expr)) + + +def random() -> Expr: + """ + Returns a random value in the range 0.0 <= x < 1.0 + """ + return Expr(f.random()) + + +def array_append(array: Expr, element: Expr) -> Expr: + """ + Appends an element to the end of an array. + """ + return Expr(f.array_append(array.expr, element.expr)) + + +def array_push_back(array: Expr, element: Expr) -> Expr: + """ + Appends an element to the end of an array. + This is an alias for `array_append`. + """ + return array_append(array, element) + + +def list_append(array: Expr, element: Expr) -> Expr: + """ + Appends an element to the end of an array. + This is an alias for `array_append`. + """ + return array_append(array, element) + + +def list_push_back(array: Expr, element: Expr) -> Expr: + """ + Appends an element to the end of an array. + This is an alias for `array_append`. + """ + return array_append(array, element) + + +def array_concat(*args: Expr) -> Expr: + """ + Concatenates the input arrays. + """ + args = [arg.expr for arg in args] + return Expr(f.array_concat(*args)) + + +def array_cat(*args: Expr) -> Expr: + """ + Concatenates the input arrays. + This is an alias for `array_concat`. + """ + return array_concat(*args) + + +def array_dims(array: Expr) -> Expr: + """ + Returns an array of the array's dimensions. + """ + return Expr(f.array_dims(array.expr)) + + +def array_distinct(array: Expr) -> Expr: + """ + Returns distinct values from the array after removing duplicates. + """ + return Expr(f.array_distinct(array.expr)) + + +def list_distinct(array: Expr) -> Expr: + """ + Returns distinct values from the array after removing duplicates. + This is an alias for `array_distinct`. + """ + return array_distinct(array) + + +def list_dims(array: Expr) -> Expr: + """ + Returns an array of the array's dimensions. + This is an alias for `array_dims`. + """ + return array_dims(array) + + +def array_element(array: Expr, n: Expr) -> Expr: + """ + Extracts the element with the index n from the array. + """ + return Expr(f.array_element(array.expr, n.expr)) + + +def array_extract(array: Expr, n: Expr) -> Expr: + """ + Extracts the element with the index n from the array. + This is an alias for `array_element`. + """ + return array_element(array, n) + + +def list_element(array: Expr, n: Expr) -> Expr: + """ + Extracts the element with the index n from the array. + This is an alias for `array_element`. + """ + return array_element(array, n) + + +def list_extract(array: Expr, n: Expr) -> Expr: + """ + Extracts the element with the index n from the array. + This is an alias for `array_element`. + """ + return array_element(array, n) + + +def array_length(array: Expr) -> Expr: + """ + Returns the length of the array. + """ + return Expr(f.array_length(array.expr)) + + +def list_length(array: Expr) -> Expr: + """ + Returns the length of the array. + This is an alias for `array_length`. + """ + return array_length(array) + + +def array_has(first_array: Expr, second_array: Expr) -> Expr: + """ + Returns true if the element appears in the first array, otherwise false. + """ + return Expr(f.array_has(first_array.expr, second_array.expr)) + + +def array_has_all(first_array: Expr, second_array: Expr) -> Expr: + """ + Returns true if each element of the second array appears in the first array. Otherwise, it returns false. + """ + return Expr(f.array_has_all(first_array.expr, second_array.expr)) + + +def array_has_any(first_array: Expr, second_array: Expr) -> Expr: + """ + Returns true if at least one element of the second array appears in the first array. Otherwise, it returns false. + """ + return Expr(f.array_has_any(first_array.expr, second_array.expr)) + + +def array_position(array: Expr, element: Expr, index: int | None = 1) -> Expr: + """ + Searches for an element in the array and returns the position of the first occurrence. + """ + return Expr(f.array_position(array.expr, element.expr, index)) + + +def array_indexof(array: Expr, element: Expr, index: int | None = 1) -> Expr: + """ + Searches for an element in the array and returns the position of the first occurrence. + This is an alias for `array_position`. + """ + return array_position(array, element, index) + + +def list_position(array: Expr, element: Expr, index: int | None = 1) -> Expr: + """ + Searches for an element in the array and returns the position of the first occurrence. + This is an alias for `array_position`. + """ + return array_position(array, element, index) + + +def list_indexof(array: Expr, element: Expr, index: int | None = 1) -> Expr: + """ + Searches for an element in the array and returns the position of the first occurrence. + This is an alias for `array_position`. + """ + return array_position(array, element, index) + + +def array_positions(array: Expr, element: Expr) -> Expr: + """ + Searches for an element in the array and returns all occurrences. + """ + return Expr(f.array_positions(array.expr, element.expr)) + + +def list_positions(array: Expr, element: Expr) -> Expr: + """ + Searches for an element in the array and returns all occurrences. + This is an alias for `array_positions`. + """ + return array_positions(array, element) + + +def array_ndims(array: Expr) -> Expr: + """ + Returns the number of dimensions of the array. + """ + return Expr(f.array_ndims(array.expr)) + + +def list_ndims(array: Expr) -> Expr: + """ + Returns the number of dimensions of the array. + This is an alias for `array_ndims`. + """ + return array_ndims(array) + + +def array_prepend(element: Expr, array: Expr) -> Expr: + """ + Prepends an element to the beginning of an array. + """ + return Expr(f.array_prepend(element.expr, array.expr)) + + +def array_push_front(element: Expr, array: Expr) -> Expr: + """ + Prepends an element to the beginning of an array. + This is an alias for `array_prepend`. + """ + return array_prepend(element, array) + + +def list_prepend(element: Expr, array: Expr) -> Expr: + """ + Prepends an element to the beginning of an array. + This is an alias for `array_prepend`. + """ + return array_prepend(element, array) + + +def list_push_front(element: Expr, array: Expr) -> Expr: + """ + Prepends an element to the beginning of an array. + This is an alias for `array_prepend`. + """ + return array_prepend(element, array) + + +def array_pop_back(array: Expr) -> Expr: + """ + Returns the array without the last element. + """ + return Expr(f.array_pop_back(array.expr)) + + +def array_pop_front(array: Expr) -> Expr: + """ + Returns the array without the first element. + """ + return Expr(f.array_pop_front(array.expr)) + + +def array_remove(array: Expr, element: Expr) -> Expr: + """ + Removes the first element from the array equal to the given value. + """ + return Expr(f.array_remove(array.expr, element.expr)) + + +def list_remove(array: Expr, element: Expr) -> Expr: + """ + Removes the first element from the array equal to the given value. + This is an alias for `array_remove`. + """ + return array_remove(array, element) + + +def array_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: + """ + Removes the first `max` elements from the array equal to the given value. + """ + return Expr(f.array_remove_n(array.expr, element.expr, max.expr)) + + +def list_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: + """ + Removes the first `max` elements from the array equal to the given value. + This is an alias for `array_remove_n`. + """ + return array_remove_n(array, element, max) + + +def array_remove_all(array: Expr, element: Expr) -> Expr: + """ + Removes all elements from the array equal to the given value. + """ + return Expr(f.array_remove_all(array.expr, element.expr)) + + +def list_remove_all(array: Expr, element: Expr) -> Expr: + """ + Removes all elements from the array equal to the given value. + This is an alias for `array_remove_all`. + """ + return array_remove_all(array, element) + + +def array_repeat(element: Expr, count: Expr) -> Expr: + """ + Returns an array containing `element` `count` times. + """ + return Expr(f.array_repeat(element.expr, count.expr)) + + +def array_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: + """ + Replaces the first occurrence of the specified element with another specified element. + """ + return Expr(f.array_replace(array.expr, from_val.expr, to_val.expr)) + + +def list_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: + """ + Replaces the first occurrence of the specified element with another specified element. + This is an alias for `array_replace`. + """ + return array_replace(array, from_val, to_val) + + +def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: + """ + Replaces the first `max` occurrences of the specified element with another specified element. + """ + return Expr(f.array_replace_n(array.expr, from_val.expr, to_val.expr, max.expr)) + + +def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: + """ + Replaces the first `max` occurrences of the specified element with another specified element. + This is an alias for `array_replace_n`. + """ + return array_replace_n(array, from_val, to_val, max) + + +def array_replace_all(array: Expr, from_val: Expr, to_val: Expr) -> Expr: + """ + Replaces all occurrences of the specified element with another specified element. + """ + return Expr(f.array_replace_all(array.expr, from_val.expr, to_val.expr)) + + +def list_replace_all(array: Expr, from_val: Expr, to_val: Expr) -> Expr: + """ + Replaces all occurrences of the specified element with another specified element. + This is an alias for `array_replace_all`. + """ + return array_replace_all(array, from_val, to_val) + + +def array_slice( + array: Expr, begin: Expr, end: Expr, stride: Expr | None = None +) -> Expr: + """ + Returns a slice of the array. + """ + if stride is not None: + stride = stride.expr + return Expr(f.array_slice(array.expr, begin.expr, end.expr, stride)) + + +def list_slice(array: Expr, begin: Expr, end: Expr, stride: Expr | None = None) -> Expr: + """ + Returns a slice of the array. + This is an alias for `array_slice`. + """ + return array_slice(array, begin, end, stride) + + +def array_intersect(array1: Expr, array2: Expr) -> Expr: + """ + Returns an array of the elements in the intersection of array1 and array2. + """ + return Expr(f.array_intersect(array1.expr, array2.expr)) + + +def list_intersect(array1: Expr, array2: Expr) -> Expr: + """ + Returns an array of the elements in the intersection of array1 and array2. + This is an alias for `array_intersect`. + """ + return array_intersect(array1, array2) + + +def array_union(array1: Expr, array2: Expr) -> Expr: + """ + Returns an array of the elements in the union of array1 and array2 without duplicates. + """ + return Expr(f.array_union(array1.expr, array2.expr)) + + +def list_union(array1: Expr, array2: Expr) -> Expr: + """ + Returns an array of the elements in the union of array1 and array2 without duplicates. + This is an alias for `array_union`. + """ + return array_union(array1, array2) + + +def array_except(array1: Expr, array2: Expr) -> Expr: + """ + Returns an array of the elements that appear in `array1` but not in the `array2`. + """ + return Expr(f.array_except(array1.expr, array2.expr)) + + +def list_except(array1: Expr, array2: Expr) -> Expr: + """ + Returns an array of the elements that appear in `array1` but not in the `array2`. + This is an alias for `array_except`. + """ + return array_except(array1, array2) + + +def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: + """ + Returns an array with the specified size filled. If `size` is greater than the `array` length, the additional entries will be filled with the given `value`. + """ + return Expr(f.array_resize(array.expr, size.expr, value.expr)) + + +def list_resize(array: Expr, size: Expr, value: Expr) -> Expr: + """ + Returns an array with the specified size filled. If `size` is greater than the `array` length, the additional entries will be filled with the given `value`. + This is an alias for `array_resize`. + """ + return array_resize(array, size, value) + + +def flatten(array: Expr) -> Expr: + """ + Flattens an array of arrays into a single array. + """ + return Expr(f.flatten(array.expr)) + + +# aggregate functions +def approx_distinct(arg: Expr) -> Expr: + """ + Returns the approximate number of distinct values. + """ + return Expr(f.approx_distinct(arg.expr, distinct=True)) + + +def approx_median(arg: Expr, distinct: bool = False) -> Expr: + """ + Returns the approximate median value. + """ + return Expr(f.approx_median(arg.expr, distinct=distinct)) + + +def approx_percentile_cont( + arg: Expr, + percentile: Expr, + num_centroids: int | None = None, + distinct: bool = False, +) -> Expr: + """ + Returns the value that is approximately at a given percentile of a distribution of values. + """ + # TODO validate that these parameters are passed properly + if num_centroids is None: + return Expr( + f.approx_percentile_cont(arg.expr, percentile.expr, distinct=distinct) + ) + + return Expr( + f.approx_percentile_cont( + arg.expr, percentile.expr, num_centroids, distinct=distinct + ) + ) + + +def approx_percentile_cont_with_weight( + arg: Expr, weight: Expr, percentile: Expr, distinct: bool = False +) -> Expr: + """ + Returns the value that is approximately at a given percentile of a distribution of values with associated weights. + """ + # TODO validate that these parameters are passed properly + return Expr( + f.approx_percentile_cont_with_weight( + arg.expr, weight.expr, percentile.expr, distinct=distinct + ) + ) + + +def array_agg(arg: Expr, distinct: bool = False) -> Expr: + """ + Aggregate values into an array. + """ + return Expr(f.array_agg(arg.expr, distinct=distinct)) + + +def avg(arg: Expr, distinct: bool = False) -> Expr: + """ + Returns the average value. + """ + return Expr(f.avg(arg.expr, distinct=distinct)) + + +def corr(value1: Expr, value2: Expr, distinct: bool = False) -> Expr: + """ + Returns the correlation coefficient between `value1` and `value2`. + """ + return Expr(f.corr(value1.expr, value2.expr, distinct=distinct)) + + +def count(args: Expr | list[Expr] | None = None, distinct: bool = False) -> Expr: + """ + Returns the number of rows that match the given arguments. + """ + if isinstance(args, list): + args = [arg.expr for arg in args] + elif isinstance(args, Expr): + args = [args.expr] + return Expr(f.count(*args, distinct=distinct)) + + +def covar(y: Expr, x: Expr) -> Expr: + """ + Computes the sample covariance. + This is an alias for `covar_samp`. + """ + return Expr(f.covar(y.expr, x.expr)) + + +def covar_pop(y: Expr, x: Expr) -> Expr: + """ + Computes the population covariance. + """ + return Expr(f.covar_pop(y.expr, x.expr)) + + +def covar_samp(y: Expr, x: Expr) -> Expr: + """ + Computes the sample covariance. + """ + return Expr(f.covar_samp(y.expr, x.expr)) + + +def grouping(arg: Expr, distinct: bool = False) -> Expr: + """ + Returns 1 if the value of the argument in the returned row is a null value. + """ + return Expr(f.grouping([arg.expr], distinct=distinct)) + + +def max(arg: Expr, distinct: bool = False) -> Expr: + """ + Returns the maximum value of the arugment. + """ + return Expr(f.max(arg.expr, distinct=distinct)) + + +def mean(arg: Expr, distinct: bool = False) -> Expr: + """ + Returns the average (mean) value of the argument. + This is an alias for `avg`. + """ + return avg(arg, distinct) + + +def median(arg: Expr) -> Expr: + """ + Computes the median of a set of numbers. + """ + return Expr(f.median(arg.expr)) + + +def min(arg: Expr, distinct: bool = False) -> Expr: + """ + Returns the minimum value of the argument. + """ + return Expr(f.min(arg.expr, distinct=distinct)) + + +def sum(arg: Expr) -> Expr: + """ + Computes the sum of a set of numbers. + """ + return Expr(f.sum(arg.expr)) + + +def stddev(arg: Expr, distinct: bool = False) -> Expr: + """ + Computes the standard deviation of the argument. + """ + return Expr(f.stddev(arg.expr, distinct=distinct)) + + +def stddev_pop(arg: Expr, distinct: bool = False) -> Expr: + """ + Computes the population standard deviation of the argument. + """ + return Expr(f.stddev_pop(arg.expr, distinct=distinct)) + + +def stddev_samp(arg: Expr, distinct: bool = False) -> Expr: + """ + Computes the sample standard deviation of the argument. + This is an alias for `stddev`. + """ + return stddev(arg, distinct) + + +def var(arg: Expr) -> Expr: + """ + Computes the sample variance of the argument. + This is an alias for `var_samp`. + """ + return var_samp(arg) + + +def var_pop(arg: Expr, distinct: bool = False) -> Expr: + """ + Computes the population variance of the argument. + """ + return Expr(f.var_pop(arg.expr, distinct=distinct)) + + +def var_samp(arg: Expr) -> Expr: + """ + Computes the sample variance of the argument. + """ + return Expr(f.var_samp(arg.expr)) + + +def regr_avgx(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """ + Computes the average of the independent variable `x` for non-null pairs of the inputs. + """ + return Expr(f.regr_avgx[y.expr, x.expr], distinct) + + +def regr_avgy(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """ + Computes the average of the dependent variable `y` for non-null pairs of the inputs. + """ + return Expr(f.regr_avgy[y.expr, x.expr], distinct) + + +def regr_count(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """ + Counts the number of input rows in which both expressions are not null. + """ + return Expr(f.regr_count[y.expr, x.expr], distinct) + + +def regr_intercept(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """ + Computes the intercept from the linear regression. + """ + return Expr(f.regr_intercept[y.expr, x.expr], distinct) + + +def regr_r2(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """ + Computes the R-squared value from linear regression. + """ + return Expr(f.regr_r2[y.expr, x.expr], distinct) + + +def regr_slope(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """ + Computes the slope from linear regression. + """ + return Expr(f.regr_slope[y.expr, x.expr], distinct) + + +def regr_sxx(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """ + Computes the sum of squares of the independent variable `x`. + """ + return Expr(f.regr_sxx[y.expr, x.expr], distinct) + + +def regr_sxy(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """ + Computes the sum of products of pairs of numbers + """ + return Expr(f.regr_sxy[y.expr, x.expr], distinct) + + +def regr_syy(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """ + Computes the sum of squares of the dependent variable `y`. + """ + return Expr(f.regr_syy[y.expr, x.expr], distinct) + + +def first_value( + arg: Expr, + distinct: bool = False, + filter: bool = None, + order_by: Expr | None = None, + null_treatment: common.NullTreatment | None = None, +) -> Expr: + """ + Returns the first value in a group of values. + """ + return Expr( + f.first_value( + arg.expr, + distinct=distinct, + filter=filter, + order_by=order_by, + null_treatment=null_treatment, + ) + ) + + +def last_value( + arg: Expr, + distinct: bool = False, + filter: bool = None, + order_by: Expr | None = None, + null_treatment: common.NullTreatment | None = None, +) -> Expr: + """ + Returns the last value in a group of values. + """ + return Expr( + f.last_value( + arg.expr, + distinct=distinct, + filter=filter, + order_by=order_by, + null_treatment=null_treatment, + ) + ) + + +def bit_and(*args: Expr, distinct: bool = False) -> Expr: + """ + Computes the bitwise AND of the argument. + """ + args = [arg.expr for arg in args] + return Expr(f.bit_and(*args, distinct=distinct)) + + +def bit_or(*args: Expr, distinct: bool = False) -> Expr: + """ + Computes the bitwise OR of the argument. + """ + args = [arg.expr for arg in args] + return Expr(f.bit_or(*args, distinct=distinct)) + + +def bit_xor(*args: Expr, distinct: bool = False) -> Expr: + """ + Computes the bitwise XOR of the argument. + """ + args = [arg.expr for arg in args] + return Expr(f.bit_xor(*args, distinct=distinct)) + + +def bool_and(*args: Expr, distinct: bool = False) -> Expr: + """ + Computes the boolean AND of the arugment. + """ + args = [arg.expr for arg in args] + return Expr(f.bool_and(*args, distinct=distinct)) + + +def bool_or(*args: Expr, distinct: bool = False) -> Expr: + """ + Computes the boolean OR of the arguement. + """ + args = [arg.expr for arg in args] + return Expr(f.bool_or(*args, distinct=distinct)) diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py new file mode 100644 index 00000000..eb7f07f4 --- /dev/null +++ b/python/datafusion/record_batch.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import pyarrow + import datafusion._internal as df_internal + + +class RecordBatch: + def __init__(self, record_batch: df_internal.RecordBatch) -> None: + self.record_batch = record_batch + + def to_pyarrow(self) -> pyarrow.RecordBatch: + return self.record_batch.to_pyarrow() + + +class RecordBatchStream: + def __init__(self, record_batch_stream: df_internal.RecordBatchStream) -> None: + self.rbs = record_batch_stream + + def next(self) -> RecordBatch | None: + try: + next_batch = next(self) + except StopIteration: + return None + + return next_batch + + def __next__(self) -> RecordBatch | None: + next_batch = next(self.rbs) + return RecordBatch(next_batch) if next_batch is not None else None + + def __iter__(self) -> RecordBatchStream: + return self diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index eff809a0..cc17b2a9 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -15,9 +15,156 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations -from ._internal import substrait +from ._internal import substrait as substrait_internal +from typing import TYPE_CHECKING -def __getattr__(name): - return getattr(substrait, name) +if TYPE_CHECKING: + from datafusion.context import SessionContext + from datafusion._internal import LogicalPlan + + +class plan: + def __init__(self, plan: substrait_internal.plan) -> None: + self.plan_internal = plan + + def encode(self) -> bytes: + """Encode the plan to bytes. + + Returns + ------- + bytes + Encoded plan. + """ + return self.plan_internal.encode() + + +class serde: + @staticmethod + def serialize(sql: str, ctx: SessionContext, path: str) -> None: + """Serialize a SQL query to a Substrait plan and write it to a file. + + Parameters + ---------- + sql : str + SQL query to serialize. + ctx : SessionContext + SessionContext to use. + path : str + Path to write the Substrait plan to. + """ + return substrait_internal.serde.serialize(sql, ctx.ctx, path) + + @staticmethod + def serialize_to_plan(sql: str, ctx: SessionContext) -> plan: + """Serialize a SQL query to a Substrait plan. + + Parameters + ---------- + sql : str + SQL query to serialize. + ctx : SessionContext + SessionContext to use. + + Returns + ------- + plan + Substrait plan. + """ + return plan(substrait_internal.serde.serialize_to_plan(sql, ctx.ctx)) + + @staticmethod + def serialize_bytes(sql: str, ctx: SessionContext) -> bytes: + """Serialize a SQL query to a Substrait plan as bytes. + + Parameters + ---------- + sql : str + SQL query to serialize. + ctx : SessionContext + SessionContext to use. + + Returns + ------- + bytes + Substrait plan as bytes. + """ + return substrait_internal.serde.serialize_bytes(sql, ctx.ctx) + + @staticmethod + def deserialize(path: str) -> plan: + """Deserialize a Substrait plan from a file. + + Parameters + ---------- + path : str + Path to read the Substrait plan from. + + Returns + ------- + plan + Substrait plan. + """ + return plan(substrait_internal.serde.deserialize(path)) + + @staticmethod + def deserialize_bytes(proto_bytes: bytes) -> plan: + """Deserialize a Substrait plan from bytes. + + Parameters + ---------- + proto_bytes : bytes + Bytes to read the Substrait plan from. + + Returns + ------- + plan + Substrait plan. + """ + return plan(substrait_internal.serde.deserialize_bytes(proto_bytes)) + + +class producer: + @staticmethod + def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> plan: + """Convert a DataFusion LogicalPlan to a Substrait plan. + + Parameters + ---------- + plan : LogicalPlan + LogicalPlan to convert. + ctx : SessionContext + SessionContext to use. + + Returns + ------- + plan + Substrait plan. + """ + return plan( + substrait_internal.producer.to_substrait_plan(logical_plan, ctx.ctx) + ) + + +class consumer: + @staticmethod + def from_substrait_plan(ctx: SessionContext, plan: plan) -> LogicalPlan: + """Convert a Substrait plan to a DataFusion LogicalPlan. + + Parameters + ---------- + ctx : SessionContext + SessionContext to use. + plan : plan + Substrait plan to convert. + + Returns + ------- + LogicalPlan + LogicalPlan. + """ + return substrait_internal.consumer.from_substrait_plan( + ctx.ctx, plan.plan_internal + ) diff --git a/python/datafusion/tests/conftest.py b/python/datafusion/tests/conftest.py index a4eec41e..1cc07e50 100644 --- a/python/datafusion/tests/conftest.py +++ b/python/datafusion/tests/conftest.py @@ -18,6 +18,7 @@ import pytest from datafusion import SessionContext import pyarrow as pa +from pyarrow.csv import write_csv @pytest.fixture @@ -37,7 +38,7 @@ def database(ctx, tmp_path): ], names=["int", "str", "float"], ) - pa.csv.write_csv(table, path) + write_csv(table, path) ctx.register_csv("csv", path) ctx.register_csv("csv1", str(path)) diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 2f6a818e..5f26063e 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -17,6 +17,7 @@ import os import pyarrow as pa +from pyarrow.csv import write_csv import pyarrow.parquet as pq import pytest @@ -379,7 +380,7 @@ def test_get_dataframe(tmp_path): ], names=["int", "str", "float"], ) - pa.csv.write_csv(table, path) + write_csv(table, path) ctx.register_csv("csv", path) @@ -659,6 +660,8 @@ def test_to_arrow_table(df): def test_execute_stream(df): stream = df.execute_stream() + for s in stream: + print(type(s)) assert all(batch is not None for batch in stream) assert not list(stream) # after one iteration the generator must be exhausted @@ -795,3 +798,15 @@ def test_write_compressed_parquet_missing_compression_level(df, tmp_path, compre with pytest.raises(ValueError): df.write_parquet(str(path), compression=compression) + + +# ctx = SessionContext() + +# # create a RecordBatch and a new DataFrame from it +# batch = pa.RecordBatch.from_arrays( +# [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])], +# names=["a", "b", "c"], +# ) + +# df = ctx.create_dataframe([[batch]]) +# test_execute_stream(df) diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 449f706c..85fed622 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -50,16 +50,16 @@ def df(): return ctx.create_dataframe([[batch]]) +# TODO Update documentation of PR to indicate this is a user facing change to how named_struct is called def test_named_struct(df): df = df.with_column( "d", f.named_struct( - literal("a"), - column("a"), - literal("b"), - column("b"), - literal("c"), - column("c"), + [ + ("a", column("a")), + ("b", column("b")), + ("c", column("c")), + ] ), ) diff --git a/python/datafusion/tests/test_imports.py b/python/datafusion/tests/test_imports.py index bd4e7c31..3d324fb6 100644 --- a/python/datafusion/tests/test_imports.py +++ b/python/datafusion/tests/test_imports.py @@ -94,13 +94,24 @@ def test_datafusion_python_version(): def test_class_module_is_datafusion(): + # context for klass in [ SessionContext, + ]: + assert klass.__module__ == "datafusion.context" + + # dataframe + for klass in [ DataFrame, - ScalarUDF, + ]: + assert klass.__module__ == "datafusion.dataframe" + + # udf + for klass in [ AggregateUDF, + ScalarUDF, ]: - assert klass.__module__ == "datafusion" + assert klass.__module__ == "datafusion.udf" # expressions for klass in [Expr, Column, Literal, BinaryExpr, AggregateFunction]: diff --git a/python/datafusion/tests/test_sql.py b/python/datafusion/tests/test_sql.py index 8ec2ffb1..ec0e4c57 100644 --- a/python/datafusion/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -19,6 +19,7 @@ import numpy as np import pyarrow as pa +from pyarrow.csv import write_csv import pyarrow.dataset as ds import pytest from datafusion.object_store import LocalFileSystem @@ -45,7 +46,7 @@ def test_register_csv(ctx, tmp_path): ], names=["int", "str", "float"], ) - pa.csv.write_csv(table, path) + write_csv(table, path) with open(path, "rb") as csv_file: with gzip.open(gzip_path, "wb") as gzipped_file: diff --git a/python/datafusion/tests/test_substrait.py b/python/datafusion/tests/test_substrait.py index 62f6413a..260db5eb 100644 --- a/python/datafusion/tests/test_substrait.py +++ b/python/datafusion/tests/test_substrait.py @@ -38,14 +38,14 @@ def test_substrait_serialization(ctx): assert ctx.tables() == {"t"} # For now just make sure the method calls blow up - substrait_plan = ss.substrait.serde.serialize_to_plan("SELECT * FROM t", ctx) + substrait_plan = ss.serde.serialize_to_plan("SELECT * FROM t", ctx) substrait_bytes = substrait_plan.encode() assert isinstance(substrait_bytes, bytes) - substrait_bytes = ss.substrait.serde.serialize_bytes("SELECT * FROM t", ctx) - substrait_plan = ss.substrait.serde.deserialize_bytes(substrait_bytes) - logical_plan = ss.substrait.consumer.from_substrait_plan(ctx, substrait_plan) + substrait_bytes = ss.serde.serialize_bytes("SELECT * FROM t", ctx) + substrait_plan = ss.serde.deserialize_bytes(substrait_bytes) + logical_plan = ss.consumer.from_substrait_plan(ctx, substrait_plan) # demonstrate how to create a DataFrame from a deserialized logical plan df = ctx.create_dataframe_from_logical_plan(logical_plan) - substrait_plan = ss.substrait.producer.to_substrait_plan(df.logical_plan(), ctx) + substrait_plan = ss.producer.to_substrait_plan(df.logical_plan(), ctx) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py new file mode 100644 index 00000000..7e713573 --- /dev/null +++ b/python/datafusion/udf.py @@ -0,0 +1,41 @@ +import datafusion._internal as df_internal +from datafusion.expr import Expr +import pyarrow +from typing import Callable + + +class ScalarUDF: + def __init__( + self, + name: str | None, + func: Callable, + input_types: list[pyarrow.DataType], + return_type: pyarrow.DataType, + volatility: str, + ) -> None: + self.udf = df_internal.ScalarUDF( + name, func, input_types, return_type, volatility + ) + + def __call__(self, *args: Expr) -> Expr: + args = [arg.expr for arg in args] + return Expr(self.udf.__call__(*args)) + + +class AggregateUDF: + def __init__( + self, + name: str | None, + accumulator: Callable, + input_types: list[pyarrow.DataType], + return_type: pyarrow.DataType, + state_type: list[pyarrow.DataType], + volatility: str, + ) -> None: + self.udf = df_internal.AggregateUDF( + name, accumulator, input_types, return_type, state_type, volatility + ) + + def __call__(self, *args: Expr) -> Expr: + args = [arg.expr for arg in args] + return Expr(self.udf.__call__(*args)) From a3429ab60aa245bccdf2e7889bf34c835fde1d6a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 9 Jul 2024 07:11:02 -0400 Subject: [PATCH 03/55] Remove extra level of python path that is no longer required --- examples/substrait.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/substrait.py b/examples/substrait.py index 23cd7464..7a268a78 100644 --- a/examples/substrait.py +++ b/examples/substrait.py @@ -18,6 +18,7 @@ from datafusion import SessionContext from datafusion import substrait as ss +# TODO add user changing interface note to PR that datafusion.substrait.substrait is simplified to datafusion.substrait # Create a DataFusion context ctx = SessionContext() @@ -25,9 +26,7 @@ # Register table with context ctx.register_csv("aggregate_test_data", "./testing/data/csv/aggregate_test_100.csv") -substrait_plan = ss.substrait.serde.serialize_to_plan( - "SELECT * FROM aggregate_test_data", ctx -) +substrait_plan = ss.serde.serialize_to_plan("SELECT * FROM aggregate_test_data", ctx) # type(substrait_plan) -> # Encode it to bytes @@ -38,17 +37,15 @@ # Alternative serialization approaches # type(substrait_bytes) -> , at this point the bytes can be distributed to file, network, etc safely # where they could subsequently be deserialized on the receiving end. -substrait_bytes = ss.substrait.serde.serialize_bytes( - "SELECT * FROM aggregate_test_data", ctx -) +substrait_bytes = ss.serde.serialize_bytes("SELECT * FROM aggregate_test_data", ctx) # Imagine here bytes would be read from network, file, etc ... for example brevity this is omitted and variable is simply reused # type(substrait_plan) -> -substrait_plan = ss.substrait.serde.deserialize_bytes(substrait_bytes) +substrait_plan = ss.serde.deserialize_bytes(substrait_bytes) # type(df_logical_plan) -> -df_logical_plan = ss.substrait.consumer.from_substrait_plan(ctx, substrait_plan) +df_logical_plan = ss.consumer.from_substrait_plan(ctx, substrait_plan) # Back to Substrait Plan just for demonstration purposes # type(substrait_plan) -> -substrait_plan = ss.substrait.producer.to_substrait_plan(df_logical_plan) +substrait_plan = ss.producer.to_substrait_plan(df_logical_plan) From 7937963a15e63d2699b00ad1c19575ef02320a10 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 9 Jul 2024 08:08:15 -0400 Subject: [PATCH 04/55] Move import to only happen for type checking for hints --- python/datafusion/dataframe.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 55df376d..559e2348 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -1,11 +1,13 @@ from __future__ import annotations -from typing import Any, List +from typing import Any, List, TYPE_CHECKING from datafusion.record_batch import RecordBatchStream from typing_extensions import deprecated -import pyarrow as pa -import pandas as pd -import polars as pl + +if TYPE_CHECKING: + import pyarrow as pa + import pandas as pd + import polars as pl from datafusion._internal import DataFrame as DataFrameInternal from datafusion.expr import Expr From 1f4c82941b0359b5aa26df75d73de8a8d6d380c1 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 9 Jul 2024 08:09:03 -0400 Subject: [PATCH 05/55] Comment out classes from __all__ in the top level that are not currently exposed. --- python/datafusion/__init__.py | 88 +++++++++++++++++------------------ 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 8c3e9de7..8ba95260 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -100,54 +100,54 @@ "RuntimeConfig", "Expr", "ScalarUDF", - "Window", + # "Window", "WindowFrame", "column", "literal", - "TableScan", - "Projection", + # "TableScan", + # "Projection", "DFSchema", - "DFField", - "Analyze", - "Sort", - "Limit", - "Filter", - "Like", - "ILike", - "SimilarTo", - "ScalarVariable", - "Alias", - "Not", - "IsNotNull", - "IsTrue", - "IsFalse", - "IsUnknown", - "IsNotTrue", - "IsNotFalse", - "IsNotUnknown", - "Negative", - "ScalarFunction", - "BuiltinScalarFunction", - "InList", - "Exists", - "Subquery", - "InSubquery", - "ScalarSubquery", - "GroupingSet", - "Placeholder", - "Case", - "Cast", - "TryCast", - "Between", - "Explain", - "SubqueryAlias", - "Extension", - "CreateMemoryTable", - "CreateView", - "Distinct", - "DropTable", - "Repartition", - "Partitioning", + # "DFField", + # "Analyze", + # "Sort", + # "Limit", + # "Filter", + # "Like", + # "ILike", + # "SimilarTo", + # "ScalarVariable", + # "Alias", + # "Not", + # "IsNotNull", + # "IsTrue", + # "IsFalse", + # "IsUnknown", + # "IsNotTrue", + # "IsNotFalse", + # "IsNotUnknown", + # "Negative", + # "ScalarFunction", + # "BuiltinScalarFunction", + # "InList", + # "Exists", + # "Subquery", + # "InSubquery", + # "ScalarSubquery", + # "GroupingSet", + # "Placeholder", + # "Case", + # "Cast", + # "TryCast", + # "Between", + # "Explain", + # "SubqueryAlias", + # "Extension", + # "CreateMemoryTable", + # "CreateView", + # "Distinct", + # "DropTable", + # "Repartition", + # "Partitioning", ] From d7f5f68d46e9e054461dcc07ec7cde8ac659cee1 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 9 Jul 2024 08:12:11 -0400 Subject: [PATCH 06/55] Add license comments --- python/datafusion/catalog.py | 17 +++++++++++++++++ python/datafusion/dataframe.py | 17 +++++++++++++++++ python/datafusion/record_batch.py | 17 +++++++++++++++++ python/datafusion/udf.py | 23 +++++++++++++++++++++-- 4 files changed, 72 insertions(+), 2 deletions(-) diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 0764e63f..1379b692 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from __future__ import annotations import datafusion._internal as df_internal diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 559e2348..00748341 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from __future__ import annotations from typing import Any, List, TYPE_CHECKING diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py index eb7f07f4..f26458e8 100644 --- a/python/datafusion/record_batch.py +++ b/python/datafusion/record_batch.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from __future__ import annotations from typing import TYPE_CHECKING diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 7e713573..acd70773 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -1,7 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import datafusion._internal as df_internal from datafusion.expr import Expr -import pyarrow -from typing import Callable +from typing import Callable, TYPE_CHECKING + +if TYPE_CHECKING: + import pyarrow class ScalarUDF: From 79bb196357304126fdd157eb71db472e80a7e7e2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 9 Jul 2024 08:33:09 -0400 Subject: [PATCH 07/55] Add missing import --- python/datafusion/udf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index acd70773..4a9aebf9 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import datafusion._internal as df_internal from datafusion.expr import Expr from typing import Callable, TYPE_CHECKING From 685a257bc0807608d87c3206ee6829d7e35edf8e Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 9 Jul 2024 17:01:51 -0400 Subject: [PATCH 08/55] Functions now only has one level of depth --- docs/source/api/functions.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/api/functions.rst b/docs/source/api/functions.rst index 958606df..6f10d826 100644 --- a/docs/source/api/functions.rst +++ b/docs/source/api/functions.rst @@ -24,4 +24,4 @@ Functions .. autosummary:: :toctree: ../generated/ - functions.functions + functions From 45ee5ab69d100412751b09e0dab38bdffdd7b8c9 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 9 Jul 2024 17:04:00 -0400 Subject: [PATCH 09/55] Applying google docstring formatting --- python/datafusion/functions.py | 881 ++++++++++----------------------- 1 file changed, 273 insertions(+), 608 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index de2264dc..75bf08a5 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -28,128 +28,102 @@ def isnan(expr: Expr) -> Expr: - """ - Returns true if a given number is +NaN or -NaN otherwise returns false. - """ + """Returns true if a given number is +NaN or -NaN otherwise returns false.""" return Expr(f.isnan(expr.expr)) def nullif(expr1: Expr, expr2: Expr) -> Expr: - """ - Returns NULL if expr1 equals expr2; otherwise it returns expr1. This can be used to perform the inverse operation of the COALESCE expression. - """ + """Returns NULL if expr1 equals expr2; otherwise it returns expr1. This can be used to perform the inverse operation of the COALESCE expression.""" return Expr(f.nullif(expr1.expr, expr2.expr)) def encode(input: Expr, encoding: Expr) -> Expr: - """ - Encode the `input`, using the `encoding`. encoding can be base64 or hex. - """ + """Encode the `input`, using the `encoding`. encoding can be base64 or hex.""" return Expr(f.encode(input.expr, encoding.expr)) def decode(input: Expr, encoding: Expr) -> Expr: - """ - Decode the `input`, using the `encoding`. encoding can be base64 or hex. - """ + """Decode the `input`, using the `encoding`. encoding can be base64 or hex.""" return Expr(f.decode(input.expr, encoding.expr)) def array_to_string(expr: Expr, delimiter: Expr) -> Expr: - """ - Converts each element to its text representation. - """ + """Converts each element to its text representation.""" return Expr(f.array_to_string(expr.expr, delimiter.expr)) def array_join(expr: Expr, delimiter: Expr) -> Expr: - """ - Converts each element to its text representation. + """Converts each element to its text representation. + This is an alias for :func:`array_to_string`. """ return array_to_string(expr, delimiter) def list_to_string(expr: Expr, delimiter: Expr) -> Expr: - """ - Converts each element to its text representation. + """Converts each element to its text representation. + This is an alias for :func:`array_to_string`. """ return array_to_string(expr, delimiter) def list_join(expr: Expr, delimiter: Expr) -> Expr: - """ - Converts each element to its text representation. + """Converts each element to its text representation. + This is an alias for :func:`array_to_string`. """ return array_to_string(expr, delimiter) def in_list(arg: Expr, values: list[Expr], negated: bool = False) -> Expr: - """ - Returns whether the argument is contained within the list `values`. - """ + """Returns whether the argument is contained within the list `values`.""" values = [v.expr for v in values] return Expr(f.in_list(arg.expr, values, negated)) def digest(value: Expr, method: Expr) -> Expr: - """ - Computes the binary hash of an expression using the specified algorithm. + """Computes the binary hash of an expression using the specified algorithm. + Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, blake2b, and blake3. """ return Expr(f.digest(value.expr, method.expr)) def concat(*args: Expr) -> Expr: - """ - Concatenates the text representations of all the arguments. NULL arguments are ignored. - """ + """Concatenates the text representations of all the arguments. NULL arguments are ignored.""" args = [arg.expr for arg in args] return Expr(f.concat(*args)) def concat_ws(separator: str, *args: Expr) -> Expr: - """ - Concatenates the list `args` with the separator. `NULL` arugments are ignored. `separator` should not be `NULL`. - """ + """Concatenates the list `args` with the separator. `NULL` arugments are ignored. `separator` should not be `NULL`.""" args = [arg.expr for arg in args] return Expr(f.concat_ws(separator, *args)) def order_by(expr: Expr, ascending: bool = True, nulls_first: bool = True) -> Expr: - """ - Creates a new sort expression. - """ + """Creates a new sort expression.""" return Expr(f.order_by(expr.expr, ascending, nulls_first)) def alias(expr: Expr, name: str) -> Expr: - """ - Creates an alias expression. - """ + """Creates an alias expression.""" return Expr(f.alias(expr.expr, name)) def col(name: str) -> Expr: - """ - Creates a column reference expression. - """ + """Creates a column reference expression.""" return Expr(f.col(name)) def count_star() -> Expr: - """ - Create a COUNT(1) aggregate expression. - """ + """Create a COUNT(1) aggregate expression.""" return Expr(f.count_star()) def case(expr: Expr) -> CaseBuilder: - """ - Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. - """ + """Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.""" return CaseBuilder(f.case(expr.expr)) @@ -161,9 +135,7 @@ def window( window_frame: WindowFrame | None = None, ctx: SessionContext | None = None, ) -> Expr: - """ - Creates a new Window function expression. - """ + """Creates a new Window function expression.""" args = [a.expr for a in args] partition_by = [e.expr for e in partition_by] if partition_by is not None else None order_by = [o.expr for o in order_by] if order_by is not None else None @@ -173,11 +145,10 @@ def window( # scalar functions def abs(arg: Expr) -> Expr: - """ - Return the absolute value of a given number. + """Return the absolute value of a given number. - Returns - ------- + Returns: + -------- Expr A new expression representing the absolute value of the input expression. """ @@ -185,11 +156,10 @@ def abs(arg: Expr) -> Expr: def acos(arg: Expr) -> Expr: - """ - Returns the arc cosine or inverse cosine of a number. + """Returns the arc cosine or inverse cosine of a number. - Returns - ------- + Returns: + -------- Expr A new expression representing the arc cosine of the input expression. """ @@ -197,300 +167,220 @@ def acos(arg: Expr) -> Expr: def acosh(arg: Expr) -> Expr: - """ - Returns inverse hyperbolic cosine. - """ + """Returns inverse hyperbolic cosine.""" return Expr(f.acosh(arg.expr)) def ascii(arg: Expr) -> Expr: - """ - Returns the numeric code of the first character of the argument. - """ + """Returns the numeric code of the first character of the argument.""" return Expr(f.ascii(arg.expr)) def asin(arg: Expr) -> Expr: - """ - Returns the arc sine or inverse sine of a number. - """ + """Returns the arc sine or inverse sine of a number.""" return Expr(f.asin(arg.expr)) def asinh(arg: Expr) -> Expr: - """ - Returns inverse hyperbolic sine. - """ + """Returns inverse hyperbolic sine.""" return Expr(f.asinh(arg.expr)) def atan(arg: Expr) -> Expr: - """ - Returns inverse tangent of a number. - """ + """Returns inverse tangent of a number.""" return Expr(f.atan(arg.expr)) def atanh(arg: Expr) -> Expr: - """ - Returns inverse hyperbolic tangent. - """ + """Returns inverse hyperbolic tangent.""" return Expr(f.atanh(arg.expr)) def atan2(y: Expr, x: Expr) -> Expr: - """ - Returns inverse tangent of a division given in the argument. - """ + """Returns inverse tangent of a division given in the argument.""" return Expr(f.atan2(y.expr, x.expr)) def bit_length(arg: Expr) -> Expr: - """ - Returns the number of bits in the string argument. - """ + """Returns the number of bits in the string argument.""" return Expr(f.bit_length(arg.expr)) def btrim(arg: Expr) -> Expr: - """ - Removes all characters, spaces by default, from both sides of a string. - """ + """Removes all characters, spaces by default, from both sides of a string.""" return Expr(f.btrim(arg.expr)) def cbrt(arg: Expr) -> Expr: - """ - Returns the cube root of a number. - """ + """Returns the cube root of a number.""" return Expr(f.cbrt(arg.expr)) def ceil(arg: Expr) -> Expr: - """ - Returns the nearest integer greater than or equal to argument. - """ + """Returns the nearest integer greater than or equal to argument.""" return Expr(f.ceil(arg.expr)) def character_length(arg: Expr) -> Expr: - """ - Returns the number of characters in the argument. - """ + """Returns the number of characters in the argument.""" return Expr(f.character_length(arg.expr)) def length(string: Expr) -> Expr: - """ - The number of characters in the `string` - """ + """The number of characters in the `string`.""" return Expr(f.length(string.expr)) def char_length(string: Expr) -> Expr: - """ - The number of characters in the `string`. - """ + """The number of characters in the `string`.""" return Expr(f.char_length(string.expr)) def chr(arg: Expr) -> Expr: - """ - Converts the Unicode code point to a UTF8 character. - """ + """Converts the Unicode code point to a UTF8 character.""" return Expr(f.chr(arg.expr)) def coalesce(*args: Expr) -> Expr: - """ - Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL. - """ + """Returns the value of the first expr in `args` which is not NULL.""" args = [arg.expr for arg in args] return Expr(f.coalesce(*args)) def cos(arg: Expr) -> Expr: - """ - Returns the cosine of the argument. - """ + """Returns the cosine of the argument.""" return Expr(f.cos(arg.expr)) def cosh(arg: Expr) -> Expr: - """ - Returns the hyperbolic cosine of the argument. - """ + """Returns the hyperbolic cosine of the argument.""" return Expr(f.cosh(arg.expr)) def cot(arg: Expr) -> Expr: - """ - Returns the cotangent of the argument. - """ + """Returns the cotangent of the argument.""" return Expr(f.cot(arg.expr)) def degrees(arg: Expr) -> Expr: - """ - Converts the argument from radians to degrees. - """ + """Converts the argument from radians to degrees.""" return Expr(f.degrees(arg.expr)) def ends_with(arg: Expr, suffix: Expr) -> Expr: - """ - Returns true if the `string` ends with the `suffix`, false otherwise. - """ + """Returns true if the `string` ends with the `suffix`, false otherwise.""" return Expr(f.ends_with(arg.expr, suffix.expr)) def exp(arg: Expr) -> Expr: - """ - Returns the exponential of the arugment. - """ + """Returns the exponential of the arugment.""" return Expr(f.exp(arg.expr)) def factorial(arg: Expr) -> Expr: - """ - Returns the factorial of the argument. - """ + """Returns the factorial of the argument.""" return Expr(f.factorial(arg.expr)) def find_in_set(string: Expr, string_list: Expr) -> Expr: - """ - Returns a value in the range of 1 to N if the string is in the string list `string_list` consisting of N substrings. + """Returns a value in the range of 1 to N if the string is in the string list `string_list` consisting of N substrings. + The string list is a string composed of substrings separated by `,` characters. """ return Expr(f.find_in_set(string.expr, string_list.expr)) def floor(arg: Expr) -> Expr: - """ - Returns the nearest integer less than or equal to the argument. - """ + """Returns the nearest integer less than or equal to the argument.""" return Expr(f.floor(arg.expr)) def gcd(x: Expr, y: Expr) -> Expr: - """ - Returns the greatest common divisor. - """ + """Returns the greatest common divisor.""" return Expr(f.gcd(x.expr, y.expr)) def initcap(string: Expr) -> Expr: - """ - Converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase. - """ + """Converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase.""" return Expr(f.initcap(string.expr)) def instr(string: Expr, substring: Expr) -> Expr: - """ - Finds the position from where the `substring` matches the `string`. + """Finds the position from where the `substring` matches the `string`. + This is an alias for :func:`strpos`. """ return strpos(string, substring) def iszero(arg: Expr) -> Expr: - """ - Returns true if a given number is +0.0 or -0.0 otherwise returns false. - """ + """Returns true if a given number is +0.0 or -0.0 otherwise returns false.""" return Expr(f.iszero(arg.expr)) def lcm(x: Expr, y: Expr) -> Expr: - """ - Returns the least common multiple. - """ + """Returns the least common multiple.""" return Expr(f.lcm(x.expr, y.expr)) def left(string: Expr, n: Expr) -> Expr: - """ - Returns the first `n` characters in the `string`. - """ + """Returns the first `n` characters in the `string`.""" return Expr(f.left(string.expr, n.expr)) def levenshtein(string1: Expr, string2: Expr) -> Expr: - """ - Returns the Levenshtein distance between the two given strings - """ + """Returns the Levenshtein distance between the two given strings.""" return Expr(f.levenshtein(string1.expr, string2.expr)) def ln(arg: Expr) -> Expr: - """ - Returns the natural logarithm (base e) of the argument. - """ + """Returns the natural logarithm (base e) of the argument.""" return Expr(f.ln(arg.expr)) def log(base: Expr, num: Expr) -> Expr: - """ - Returns the logarithm of a number for a particular `base` - """ + """Returns the logarithm of a number for a particular `base`.""" return Expr(f.log(base.expr, num.expr)) def log10(arg: Expr) -> Expr: - """ - Base 10 logarithm of the argument. - """ + """Base 10 logarithm of the argument.""" return Expr(f.log10(arg.expr)) def log2(arg: Expr) -> Expr: - """ - Base 2 logarithm of the argument. - """ + """Base 2 logarithm of the argument.""" return Expr(f.log2(arg.expr)) def lower(arg: Expr) -> Expr: - """ - Converts a string to lowercase. - """ + """Converts a string to lowercase.""" return Expr(f.lower(arg.expr)) def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: - """ - Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). - """ + """Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right).""" characters = characters if characters is not None else Expr.literal(" ") return Expr(f.lpad(string.expr, count.expr, characters.expr)) def ltrim(arg: Expr) -> Expr: - """ - Removes all characters, spaces by default, from the beginning of a string. - """ + """Removes all characters, spaces by default, from the beginning of a string.""" return Expr(f.ltrim(arg.expr)) def md5(arg: Expr) -> Expr: - """ - Computes an MD5 128-bit checksum for a string expression. - """ + """Computes an MD5 128-bit checksum for a string expression.""" return Expr(f.md5(arg.expr)) def nanvl(x: Expr, y: Expr) -> Expr: - """ - Returns `x` if `x` is not `NaN`. Otherwise returns `y`. - """ + """Returns `x` if `x` is not `NaN`. Otherwise returns `y`.""" return Expr(f.nanvl(x.expr, y.expr)) def octet_length(arg: Expr) -> Expr: - """ - Returns the number of bytes of a string. - """ + """Returns the number of bytes of a string.""" return Expr(f.octet_length(arg.expr)) @@ -503,56 +393,45 @@ def octet_length(arg: Expr) -> Expr: def pi() -> Expr: - """ - Returns an approximate value of π. - """ + """Returns an approximate value of π.""" return Expr(f.pi()) def position(string: Expr, substring: Expr) -> Expr: - """ - Finds the position from where the `substring` matches the `string`. + """Finds the position from where the `substring` matches the `string`. + This is an alias for :func:`strpos`. """ return strpos(string, substring) def power(base: Expr, exponent: Expr) -> Expr: - """ - Returns `base` raised to the power of `exponent`. - """ + """Returns `base` raised to the power of `exponent`.""" return Expr(f.power(base.expr, exponent.expr)) def pow(base: Expr, exponent: Expr) -> Expr: - """ - Returns `base` raised to the power of `exponent`. + """Returns `base` raised to the power of `exponent`. + This is an alias of `power`. """ return power(base, exponent) def radians(arg: Expr) -> Expr: - """ - Converts the argument from degrees to radians. - """ + """Converts the argument from degrees to radians.""" return Expr(f.radians(arg.expr)) def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: - """ - Tests a string using a regular expression returning true if at - least one match, false otherwise. - """ + """Tests a string using a regular expression returning true if at least one match, false otherwise.""" if flags is not None: flags = flags.expr return Expr(f.regexp_like(string.expr, regex.expr, flags)) def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: - """ - Returns an array with each element containing the leftmost-first - match of the corresponding index in `regex` to string in `string` + """Returns an array with each element containing the leftmost-first match of the corresponding index in `regex` to string in `string`. If there is no match, the list element is NULL. @@ -564,7 +443,6 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: list element is a [`GenericStringArray`] whose n'th element is the substring matching the n'th capturing parenthesized subexpression of the pattern. """ - # TODO VALIDATE THIS IS CORRECT FOR DATAFRAME RESULTS if flags is not None: flags = flags.expr @@ -574,8 +452,7 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: def regexp_replace( string: Expr, pattern: Expr, replacement: Expr, flags: Expr | None = None ) -> Expr: - """ - Replaces substring(s) matching a PCRE-like regular expression. + """Replaces substring(s) matching a PCRE-like regular expression. The full list of supported features and syntax can be found at @@ -589,185 +466,136 @@ def regexp_replace( def repeat(string: Expr, n: Expr) -> Expr: - """ - Repeats the `string` to `n` times. - """ + """Repeats the `string` to `n` times.""" return Expr(f.repeat(string.expr, n.expr)) def replace(string: Expr, from_val: Expr, to_val: Expr) -> Expr: - """ - Replaces all occurrences of `from` with `to` in the `string`. - """ + """Replaces all occurrences of `from` with `to` in the `string`.""" return Expr(f.replace(string.expr, from_val.expr, to_val.expr)) def reverse(arg: Expr) -> Expr: - """ - Reverse the string argument. - """ + """Reverse the string argument.""" return Expr(f.reverse(arg.expr)) def right(string: Expr, n: Expr) -> Expr: - """ - Returns the last `n` characters in the `string`. - """ + """Returns the last `n` characters in the `string`.""" return Expr(f.right(string.expr, n.expr)) def round(arg: Expr) -> Expr: - """ - Round the argument to the nearest integer. - """ + """Round the argument to the nearest integer.""" return Expr(f.round(arg.expr)) def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: - """ - Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. - """ + """Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.""" characters = characters if characters is not None else Expr.literal(" ") return Expr(f.rpad(string.expr, count.expr, characters.expr)) def rtrim(arg: Expr) -> Expr: - """ - Removes all characters, spaces by default, from the end of a string. - """ + """Removes all characters, spaces by default, from the end of a string.""" return Expr(f.rtrim(arg.expr)) def sha224(arg: Expr) -> Expr: - """ - Computes the SHA-224 hash of a binary string. - """ + """Computes the SHA-224 hash of a binary string.""" return Expr(f.sha224(arg.expr)) def sha256(arg: Expr) -> Expr: - """ - Computes the SHA-256 hash of a binary string. - """ + """Computes the SHA-256 hash of a binary string.""" return Expr(f.sha256(arg.expr)) def sha384(arg: Expr) -> Expr: - """ - Computes the SHA-384 hash of a binary string. - """ + """Computes the SHA-384 hash of a binary string.""" return Expr(f.sha384(arg.expr)) def sha512(arg: Expr) -> Expr: - """ - Computes the SHA-512 hash of a binary string. - """ + """Computes the SHA-512 hash of a binary string.""" return Expr(f.sha512(arg.expr)) def signum(arg: Expr) -> Expr: - """ - Returns the sign of the argument (-1, 0, +1). - """ + """Returns the sign of the argument (-1, 0, +1).""" return Expr(f.signum(arg.expr)) def sin(arg: Expr) -> Expr: - """ - Returns the sine of the argument. - """ + """Returns the sine of the argument.""" return Expr(f.sin(arg.expr)) def sinh(arg: Expr) -> Expr: - """ - Returns the hyperbolic sine of the argument. - """ + """Returns the hyperbolic sine of the argument.""" return Expr(f.sinh(arg.expr)) def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: - """ - Splits a string based on a delimiter and picks out the desired field based on the index. - """ + """Splits a string based on a delimiter and picks out the desired field based on the index.""" return Expr(f.split_part(string.expr, delimiter.expr, index.expr)) def sqrt(arg: Expr) -> Expr: - """ - Returns the square root of the argument. - """ + """Returns the square root of the argument.""" return Expr(f.sqrt(arg.expr)) def starts_with(string: Expr, prefix: Expr) -> Expr: - """ - Returns true if string starts with prefix. - """ + """Returns true if string starts with prefix.""" return Expr(f.starts_with(string.expr, prefix.expr)) def strpos(string: Expr, substring: Expr) -> Expr: - """ - Finds the position from where the `substring` matches the `string`. - """ + """Finds the position from where the `substring` matches the `string`.""" return Expr(f.strpos(string.expr, substring.expr)) def substr(string: Expr, position: Expr) -> Expr: - """ - Substring from the `position` to the end. - """ + """Substring from the `position` to the end.""" return Expr(f.substr(string.expr, position.expr)) def substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr: - """ - Returns the substring from `string` before `count` occurrences of `delimiter`. - """ + """Returns the substring from `string` before `count` occurrences of `delimiter`.""" return Expr(f.substr_index(string.expr, delimiter.expr, count.expr)) def substring(string: Expr, position: Expr, length: Expr) -> Expr: - """ - Substring from the `position` with `length` characters. - """ + """Substring from the `position` with `length` characters.""" return Expr(f.substring(string.expr, position.expr, length.expr)) def tan(arg: Expr) -> Expr: - """ - Returns the tangent of the argument. - """ + """Returns the tangent of the argument.""" return Expr(f.tan(arg.expr)) def tanh(arg: Expr) -> Expr: - """ - Returns the hyperbolic tangent of the argument. - """ + """Returns the hyperbolic tangent of the argument.""" return Expr(f.tanh(arg.expr)) def to_hex(arg: Expr) -> Expr: - """ - Converts an integer to a hexadecimal string. - """ + """Converts an integer to a hexadecimal string.""" return Expr(f.to_hex(arg.expr)) def now() -> Expr: - """ - Returns the current timestamp in nanoseconds, using the same value for all instances of now() in same statement. + """Returns the current timestamp in nanoseconds. + + This will use the same value for all instances of now() in same statement. """ return Expr(f.now()) def to_timestamp(arg: Expr, *formatters: Expr) -> Expr: - """ - Converts a string and optional formats to a `Timestamp` in nanoseconds. - """ + """Converts a string and optional formats to a `Timestamp` in nanoseconds.""" # TODO Add a detailed description of how to use formatters. if formatters is None: return f.to_timestamp(arg.expr) @@ -777,176 +605,144 @@ def to_timestamp(arg: Expr, *formatters: Expr) -> Expr: def to_timestamp_millis(arg: Expr, *formatters: Expr) -> Expr: - """ - Converts a string and optional formats to a `Timestamp` in milliseconds. + """Converts a string and optional formats to a `Timestamp` in milliseconds. + See `to_timestamp` for a description on how to use formatters. """ return Expr(f.to_timestamp_millis(arg.expr, *formatters)) def to_timestamp_micros(arg: Expr, *formatters: Expr) -> Expr: - """ - Converts a string and optional formats to a `Timestamp` in microseconds. + """Converts a string and optional formats to a `Timestamp` in microseconds. + See `to_timestamp` for a description on how to use formatters. """ return Expr(f.to_timestamp_micros(arg.expr, *formatters)) def to_timestamp_nanos(arg: Expr, *formatters: Expr) -> Expr: - """ - Converts a string and optional formats to a `Timestamp` in nanoseconds. + """Converts a string and optional formats to a `Timestamp` in nanoseconds. + See `to_timestamp` for a description on how to use formatters. """ return Expr(f.to_timestamp_nanos(arg.expr, *formatters)) def to_timestamp_seconds(arg: Expr, *formatters: Expr) -> Expr: - """ - Converts a string and optional formats to a `Timestamp` in seconds. + """Converts a string and optional formats to a `Timestamp` in seconds. + See `to_timestamp` for a description on how to use formatters. """ return Expr(f.to_timestamp_seconds(arg.expr, *formatters)) def to_unixtime(string: Expr, *format_arguments: Expr) -> Expr: - """ - Converts a string and optional formats to a Unixtime. - """ + """Converts a string and optional formats to a Unixtime.""" # TODO verify if the format arguments are the same as to_timestamp and update documentation appropriately. args = [f.expr for f in format_arguments] return Expr(f.to_unixtime(string.expr, *args)) def current_date() -> Expr: - """ - Returns current UTC date as a Date32 value. - """ + """Returns current UTC date as a Date32 value.""" return Expr(f.current_date()) def current_time() -> Expr: - """ - Returns current UTC time as a Time64 value. - """ + """Returns current UTC time as a Time64 value.""" return Expr(f.current_time()) def datepart(part: Expr, date: Expr) -> Expr: - """ - Return a specified part of a date. + """Return a specified part of a date. + This is an alias for `date_part`. """ return date_part(part, date) def date_part(part: Expr, date: Expr) -> Expr: - """ - Extracts a subfield from the date. - """ + """Extracts a subfield from the date.""" return Expr(f.date_part(part.expr, date.expr)) def date_trunc(part: Expr, date: Expr) -> Expr: - """ - Truncates the date to a specified level of precision. - """ + """Truncates the date to a specified level of precision.""" return Expr(f.date_trunc(part.expr, date.expr)) def datetrunc(part: Expr, date: Expr) -> Expr: - """ - Truncates the date to a specified level of precision. + """Truncates the date to a specified level of precision. + This is an alias for `date_trunc`. """ return date_trunc(part, date) def date_bin(stride: Expr, source: Expr, origin: Expr) -> Expr: - """ - Coerces an arbitrary timestamp to the start of the nearest specified interval. - """ + """Coerces an arbitrary timestamp to the start of the nearest specified interval.""" return Expr(f.date_bin(stride.expr, source.expr, origin.expr)) def make_date(year: Expr, month: Expr, day: Expr) -> Expr: - """ - Make a date from year, month and day component parts. - """ + """Make a date from year, month and day component parts.""" return Expr(f.make_date(year.expr, month.expr, day.expr)) def translate(string: Expr, from_val: Expr, to_val: Expr) -> Expr: - """ - Replaces the characters in `from_val` with the counterpart in `to_val`. - """ + """Replaces the characters in `from_val` with the counterpart in `to_val`.""" return Expr(f.translate(string.expr, from_val.expr, to_val.expr)) def trim(arg: Expr) -> Expr: - """ - Removes all characters, spaces by default, from both sides of a string. - """ + """Removes all characters, spaces by default, from both sides of a string.""" return Expr(f.trim(arg.expr)) def trunc(num: Expr, precision: Expr | None = None) -> Expr: - """ - Truncate the number toward zero with optional precision. - """ + """Truncate the number toward zero with optional precision.""" if precision is not None: return Expr(f.trunc(num.expr, precision.expr)) return Expr(f.trunc(num.expr)) def upper(arg: Expr) -> Expr: - """ - Converts a string to uppercase. - """ + """Converts a string to uppercase.""" return Expr(f.upper(arg.expr)) def make_array(*args: Expr) -> Expr: - """ - Returns an array using the specified input expressions. - """ + """Returns an array using the specified input expressions.""" args = [arg.expr for arg in args] return Expr(f.make_array(*args)) def array(*args: Expr) -> Expr: - """ - Returns an array using the specified input expressions. + """Returns an array using the specified input expressions. + This is an alias for `make_array`. """ return make_array(args) def range(start: Expr, stop: Expr, step: Expr) -> Expr: - """ - Create a list of values in the range between start and stop. - """ + """Create a list of values in the range between start and stop.""" return Expr(f.range(start.expr, stop.expr, step.expr)) def uuid(arg: Expr) -> Expr: - """ - Returns uuid v4 as a string value. - """ + """Returns uuid v4 as a string value.""" return Expr(f.uuid(arg.expr)) def struct(*args: Expr) -> Expr: - """ - Returns a struct with the given arguments. - """ + """Returns a struct with the given arguments.""" args = [arg.expr for arg in args] return Expr(f.struct(*args)) def named_struct(name_pairs: list[(str, Expr)]) -> Expr: - """ - Returns a struct with the given names and arguments pairs - """ + """Returns a struct with the given names and arguments pairs.""" name_pairs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs] # flatten @@ -955,368 +751,318 @@ def named_struct(name_pairs: list[(str, Expr)]) -> Expr: def from_unixtime(arg: Expr) -> Expr: - """ - Converts an integer to RFC3339 timestamp format string. - """ + """Converts an integer to RFC3339 timestamp format string.""" return Expr(f.from_unixtime(arg.expr)) def arrow_typeof(arg: Expr) -> Expr: - """ - Returns the Arrow type of the expression. - """ + """Returns the Arrow type of the expression.""" return Expr(f.arrow_typeof(arg.expr)) def random() -> Expr: - """ - Returns a random value in the range 0.0 <= x < 1.0 - """ + """Returns a random value in the range `0.0 <= x < 1.0`.""" return Expr(f.random()) def array_append(array: Expr, element: Expr) -> Expr: - """ - Appends an element to the end of an array. - """ + """Appends an element to the end of an array.""" return Expr(f.array_append(array.expr, element.expr)) def array_push_back(array: Expr, element: Expr) -> Expr: - """ - Appends an element to the end of an array. + """Appends an element to the end of an array. + This is an alias for `array_append`. """ return array_append(array, element) def list_append(array: Expr, element: Expr) -> Expr: - """ - Appends an element to the end of an array. + """Appends an element to the end of an array. + This is an alias for `array_append`. """ return array_append(array, element) def list_push_back(array: Expr, element: Expr) -> Expr: - """ - Appends an element to the end of an array. + """Appends an element to the end of an array. + This is an alias for `array_append`. """ return array_append(array, element) def array_concat(*args: Expr) -> Expr: - """ - Concatenates the input arrays. - """ + """Concatenates the input arrays.""" args = [arg.expr for arg in args] return Expr(f.array_concat(*args)) def array_cat(*args: Expr) -> Expr: - """ - Concatenates the input arrays. + """Concatenates the input arrays. + This is an alias for `array_concat`. """ return array_concat(*args) def array_dims(array: Expr) -> Expr: - """ - Returns an array of the array's dimensions. - """ + """Returns an array of the array's dimensions.""" return Expr(f.array_dims(array.expr)) def array_distinct(array: Expr) -> Expr: - """ - Returns distinct values from the array after removing duplicates. - """ + """Returns distinct values from the array after removing duplicates.""" return Expr(f.array_distinct(array.expr)) def list_distinct(array: Expr) -> Expr: - """ - Returns distinct values from the array after removing duplicates. + """Returns distinct values from the array after removing duplicates. + This is an alias for `array_distinct`. """ return array_distinct(array) def list_dims(array: Expr) -> Expr: - """ - Returns an array of the array's dimensions. + """Returns an array of the array's dimensions. + This is an alias for `array_dims`. """ return array_dims(array) def array_element(array: Expr, n: Expr) -> Expr: - """ - Extracts the element with the index n from the array. - """ + """Extracts the element with the index n from the array.""" return Expr(f.array_element(array.expr, n.expr)) def array_extract(array: Expr, n: Expr) -> Expr: - """ - Extracts the element with the index n from the array. + """Extracts the element with the index n from the array. + This is an alias for `array_element`. """ return array_element(array, n) def list_element(array: Expr, n: Expr) -> Expr: - """ - Extracts the element with the index n from the array. + """Extracts the element with the index n from the array. + This is an alias for `array_element`. """ return array_element(array, n) def list_extract(array: Expr, n: Expr) -> Expr: - """ - Extracts the element with the index n from the array. + """Extracts the element with the index n from the array. + This is an alias for `array_element`. """ return array_element(array, n) def array_length(array: Expr) -> Expr: - """ - Returns the length of the array. - """ + """Returns the length of the array.""" return Expr(f.array_length(array.expr)) def list_length(array: Expr) -> Expr: - """ - Returns the length of the array. + """Returns the length of the array. + This is an alias for `array_length`. """ return array_length(array) def array_has(first_array: Expr, second_array: Expr) -> Expr: - """ - Returns true if the element appears in the first array, otherwise false. - """ + """Returns true if the element appears in the first array, otherwise false.""" return Expr(f.array_has(first_array.expr, second_array.expr)) def array_has_all(first_array: Expr, second_array: Expr) -> Expr: - """ - Returns true if each element of the second array appears in the first array. Otherwise, it returns false. - """ + """Returns true if each element of the second array appears in the first array. Otherwise, it returns false.""" return Expr(f.array_has_all(first_array.expr, second_array.expr)) def array_has_any(first_array: Expr, second_array: Expr) -> Expr: - """ - Returns true if at least one element of the second array appears in the first array. Otherwise, it returns false. - """ + """Returns true if at least one element of the second array appears in the first array. Otherwise, it returns false.""" return Expr(f.array_has_any(first_array.expr, second_array.expr)) def array_position(array: Expr, element: Expr, index: int | None = 1) -> Expr: - """ - Searches for an element in the array and returns the position of the first occurrence. - """ + """Searches for an element in the array and returns the position of the first occurrence.""" return Expr(f.array_position(array.expr, element.expr, index)) def array_indexof(array: Expr, element: Expr, index: int | None = 1) -> Expr: - """ - Searches for an element in the array and returns the position of the first occurrence. + """Searches for an element in the array and returns the position of the first occurrence. + This is an alias for `array_position`. """ return array_position(array, element, index) def list_position(array: Expr, element: Expr, index: int | None = 1) -> Expr: - """ - Searches for an element in the array and returns the position of the first occurrence. + """Searches for an element in the array and returns the position of the first occurrence. + This is an alias for `array_position`. """ return array_position(array, element, index) def list_indexof(array: Expr, element: Expr, index: int | None = 1) -> Expr: - """ - Searches for an element in the array and returns the position of the first occurrence. + """Searches for an element in the array and returns the position of the first occurrence. + This is an alias for `array_position`. """ return array_position(array, element, index) def array_positions(array: Expr, element: Expr) -> Expr: - """ - Searches for an element in the array and returns all occurrences. - """ + """Searches for an element in the array and returns all occurrences.""" return Expr(f.array_positions(array.expr, element.expr)) def list_positions(array: Expr, element: Expr) -> Expr: - """ - Searches for an element in the array and returns all occurrences. + """Searches for an element in the array and returns all occurrences. + This is an alias for `array_positions`. """ return array_positions(array, element) def array_ndims(array: Expr) -> Expr: - """ - Returns the number of dimensions of the array. - """ + """Returns the number of dimensions of the array.""" return Expr(f.array_ndims(array.expr)) def list_ndims(array: Expr) -> Expr: - """ - Returns the number of dimensions of the array. + """Returns the number of dimensions of the array. + This is an alias for `array_ndims`. """ return array_ndims(array) def array_prepend(element: Expr, array: Expr) -> Expr: - """ - Prepends an element to the beginning of an array. - """ + """Prepends an element to the beginning of an array.""" return Expr(f.array_prepend(element.expr, array.expr)) def array_push_front(element: Expr, array: Expr) -> Expr: - """ - Prepends an element to the beginning of an array. + """Prepends an element to the beginning of an array. + This is an alias for `array_prepend`. """ return array_prepend(element, array) def list_prepend(element: Expr, array: Expr) -> Expr: - """ - Prepends an element to the beginning of an array. + """Prepends an element to the beginning of an array. + This is an alias for `array_prepend`. """ return array_prepend(element, array) def list_push_front(element: Expr, array: Expr) -> Expr: - """ - Prepends an element to the beginning of an array. + """Prepends an element to the beginning of an array. + This is an alias for `array_prepend`. """ return array_prepend(element, array) def array_pop_back(array: Expr) -> Expr: - """ - Returns the array without the last element. - """ + """Returns the array without the last element.""" return Expr(f.array_pop_back(array.expr)) def array_pop_front(array: Expr) -> Expr: - """ - Returns the array without the first element. - """ + """Returns the array without the first element.""" return Expr(f.array_pop_front(array.expr)) def array_remove(array: Expr, element: Expr) -> Expr: - """ - Removes the first element from the array equal to the given value. - """ + """Removes the first element from the array equal to the given value.""" return Expr(f.array_remove(array.expr, element.expr)) def list_remove(array: Expr, element: Expr) -> Expr: - """ - Removes the first element from the array equal to the given value. + """Removes the first element from the array equal to the given value. + This is an alias for `array_remove`. """ return array_remove(array, element) def array_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: - """ - Removes the first `max` elements from the array equal to the given value. - """ + """Removes the first `max` elements from the array equal to the given value.""" return Expr(f.array_remove_n(array.expr, element.expr, max.expr)) def list_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: - """ - Removes the first `max` elements from the array equal to the given value. + """Removes the first `max` elements from the array equal to the given value. + This is an alias for `array_remove_n`. """ return array_remove_n(array, element, max) def array_remove_all(array: Expr, element: Expr) -> Expr: - """ - Removes all elements from the array equal to the given value. - """ + """Removes all elements from the array equal to the given value.""" return Expr(f.array_remove_all(array.expr, element.expr)) def list_remove_all(array: Expr, element: Expr) -> Expr: - """ - Removes all elements from the array equal to the given value. + """Removes all elements from the array equal to the given value. + This is an alias for `array_remove_all`. """ return array_remove_all(array, element) def array_repeat(element: Expr, count: Expr) -> Expr: - """ - Returns an array containing `element` `count` times. - """ + """Returns an array containing `element` `count` times.""" return Expr(f.array_repeat(element.expr, count.expr)) def array_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: - """ - Replaces the first occurrence of the specified element with another specified element. - """ + """Replaces the first occurrence of the specified element with another specified element.""" return Expr(f.array_replace(array.expr, from_val.expr, to_val.expr)) def list_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: - """ - Replaces the first occurrence of the specified element with another specified element. + """Replaces the first occurrence of the specified element with another specified element. + This is an alias for `array_replace`. """ return array_replace(array, from_val, to_val) def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: - """ - Replaces the first `max` occurrences of the specified element with another specified element. - """ + """Replaces the first `max` occurrences of the specified element with another specified element.""" return Expr(f.array_replace_n(array.expr, from_val.expr, to_val.expr, max.expr)) def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: - """ - Replaces the first `max` occurrences of the specified element with another specified element. + """Replaces the first `max` occurrences of the specified element with another specified element. + This is an alias for `array_replace_n`. """ return array_replace_n(array, from_val, to_val, max) def array_replace_all(array: Expr, from_val: Expr, to_val: Expr) -> Expr: - """ - Replaces all occurrences of the specified element with another specified element. - """ + """Replaces all occurrences of the specified element with another specified element.""" return Expr(f.array_replace_all(array.expr, from_val.expr, to_val.expr)) def list_replace_all(array: Expr, from_val: Expr, to_val: Expr) -> Expr: - """ - Replaces all occurrences of the specified element with another specified element. + """Replaces all occurrences of the specified element with another specified element. + This is an alias for `array_replace_all`. """ return array_replace_all(array, from_val, to_val) @@ -1325,101 +1071,86 @@ def list_replace_all(array: Expr, from_val: Expr, to_val: Expr) -> Expr: def array_slice( array: Expr, begin: Expr, end: Expr, stride: Expr | None = None ) -> Expr: - """ - Returns a slice of the array. - """ + """Returns a slice of the array.""" if stride is not None: stride = stride.expr return Expr(f.array_slice(array.expr, begin.expr, end.expr, stride)) def list_slice(array: Expr, begin: Expr, end: Expr, stride: Expr | None = None) -> Expr: - """ - Returns a slice of the array. + """Returns a slice of the array. + This is an alias for `array_slice`. """ return array_slice(array, begin, end, stride) def array_intersect(array1: Expr, array2: Expr) -> Expr: - """ - Returns an array of the elements in the intersection of array1 and array2. - """ + """Returns an array of the elements in the intersection of array1 and array2.""" return Expr(f.array_intersect(array1.expr, array2.expr)) def list_intersect(array1: Expr, array2: Expr) -> Expr: - """ - Returns an array of the elements in the intersection of array1 and array2. + """Returns an array of the elements in the intersection of `array1` and `array2`. + This is an alias for `array_intersect`. """ return array_intersect(array1, array2) def array_union(array1: Expr, array2: Expr) -> Expr: - """ - Returns an array of the elements in the union of array1 and array2 without duplicates. - """ + """Returns an array of the elements in the union of array1 and array2 without duplicates.""" return Expr(f.array_union(array1.expr, array2.expr)) def list_union(array1: Expr, array2: Expr) -> Expr: - """ - Returns an array of the elements in the union of array1 and array2 without duplicates. + """Returns an array of the elements in the union of array1 and array2 without duplicates. + This is an alias for `array_union`. """ return array_union(array1, array2) def array_except(array1: Expr, array2: Expr) -> Expr: - """ - Returns an array of the elements that appear in `array1` but not in the `array2`. - """ + """Returns an array of the elements that appear in `array1` but not in the `array2`.""" return Expr(f.array_except(array1.expr, array2.expr)) def list_except(array1: Expr, array2: Expr) -> Expr: - """ - Returns an array of the elements that appear in `array1` but not in the `array2`. + """Returns an array of the elements that appear in `array1` but not in the `array2`. + This is an alias for `array_except`. """ return array_except(array1, array2) def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: - """ - Returns an array with the specified size filled. If `size` is greater than the `array` length, the additional entries will be filled with the given `value`. - """ + """Returns an array with the specified size filled. If `size` is greater than the `array` length, the additional entries will be filled with the given `value`.""" return Expr(f.array_resize(array.expr, size.expr, value.expr)) def list_resize(array: Expr, size: Expr, value: Expr) -> Expr: - """ - Returns an array with the specified size filled. If `size` is greater than the `array` length, the additional entries will be filled with the given `value`. + """Returns an array with the specified size filled. + + If `size` is greater than the `array` length, the additional entries will be filled with the given `value`. This is an alias for `array_resize`. """ return array_resize(array, size, value) def flatten(array: Expr) -> Expr: - """ - Flattens an array of arrays into a single array. - """ + """Flattens an array of arrays into a single array.""" return Expr(f.flatten(array.expr)) # aggregate functions def approx_distinct(arg: Expr) -> Expr: - """ - Returns the approximate number of distinct values. - """ + """Returns the approximate number of distinct values.""" return Expr(f.approx_distinct(arg.expr, distinct=True)) def approx_median(arg: Expr, distinct: bool = False) -> Expr: - """ - Returns the approximate median value. - """ + """Returns the approximate median value.""" return Expr(f.approx_median(arg.expr, distinct=distinct)) @@ -1429,9 +1160,7 @@ def approx_percentile_cont( num_centroids: int | None = None, distinct: bool = False, ) -> Expr: - """ - Returns the value that is approximately at a given percentile of a distribution of values. - """ + """Returns the value that is approximately at a given percentile of a distribution of values.""" # TODO validate that these parameters are passed properly if num_centroids is None: return Expr( @@ -1448,9 +1177,7 @@ def approx_percentile_cont( def approx_percentile_cont_with_weight( arg: Expr, weight: Expr, percentile: Expr, distinct: bool = False ) -> Expr: - """ - Returns the value that is approximately at a given percentile of a distribution of values with associated weights. - """ + """Returns the value that is approximately at a given percentile of a distribution of values with associated weights.""" # TODO validate that these parameters are passed properly return Expr( f.approx_percentile_cont_with_weight( @@ -1460,30 +1187,22 @@ def approx_percentile_cont_with_weight( def array_agg(arg: Expr, distinct: bool = False) -> Expr: - """ - Aggregate values into an array. - """ + """Aggregate values into an array.""" return Expr(f.array_agg(arg.expr, distinct=distinct)) def avg(arg: Expr, distinct: bool = False) -> Expr: - """ - Returns the average value. - """ + """Returns the average value.""" return Expr(f.avg(arg.expr, distinct=distinct)) def corr(value1: Expr, value2: Expr, distinct: bool = False) -> Expr: - """ - Returns the correlation coefficient between `value1` and `value2`. - """ + """Returns the correlation coefficient between `value1` and `value2`.""" return Expr(f.corr(value1.expr, value2.expr, distinct=distinct)) def count(args: Expr | list[Expr] | None = None, distinct: bool = False) -> Expr: - """ - Returns the number of rows that match the given arguments. - """ + """Returns the number of rows that match the given arguments.""" if isinstance(args, list): args = [arg.expr for arg in args] elif isinstance(args, Expr): @@ -1492,174 +1211,134 @@ def count(args: Expr | list[Expr] | None = None, distinct: bool = False) -> Expr def covar(y: Expr, x: Expr) -> Expr: - """ - Computes the sample covariance. + """Computes the sample covariance. + This is an alias for `covar_samp`. """ return Expr(f.covar(y.expr, x.expr)) def covar_pop(y: Expr, x: Expr) -> Expr: - """ - Computes the population covariance. - """ + """Computes the population covariance.""" return Expr(f.covar_pop(y.expr, x.expr)) def covar_samp(y: Expr, x: Expr) -> Expr: - """ - Computes the sample covariance. - """ + """Computes the sample covariance.""" return Expr(f.covar_samp(y.expr, x.expr)) def grouping(arg: Expr, distinct: bool = False) -> Expr: - """ - Returns 1 if the value of the argument in the returned row is a null value. - """ + """Returns 1 if the value of the argument in the returned row is a null value.""" return Expr(f.grouping([arg.expr], distinct=distinct)) def max(arg: Expr, distinct: bool = False) -> Expr: - """ - Returns the maximum value of the arugment. - """ + """Returns the maximum value of the arugment.""" return Expr(f.max(arg.expr, distinct=distinct)) def mean(arg: Expr, distinct: bool = False) -> Expr: - """ - Returns the average (mean) value of the argument. + """Returns the average (mean) value of the argument. + This is an alias for `avg`. """ return avg(arg, distinct) def median(arg: Expr) -> Expr: - """ - Computes the median of a set of numbers. - """ + """Computes the median of a set of numbers.""" return Expr(f.median(arg.expr)) def min(arg: Expr, distinct: bool = False) -> Expr: - """ - Returns the minimum value of the argument. - """ + """Returns the minimum value of the argument.""" return Expr(f.min(arg.expr, distinct=distinct)) def sum(arg: Expr) -> Expr: - """ - Computes the sum of a set of numbers. - """ + """Computes the sum of a set of numbers.""" return Expr(f.sum(arg.expr)) def stddev(arg: Expr, distinct: bool = False) -> Expr: - """ - Computes the standard deviation of the argument. - """ + """Computes the standard deviation of the argument.""" return Expr(f.stddev(arg.expr, distinct=distinct)) def stddev_pop(arg: Expr, distinct: bool = False) -> Expr: - """ - Computes the population standard deviation of the argument. - """ + """Computes the population standard deviation of the argument.""" return Expr(f.stddev_pop(arg.expr, distinct=distinct)) def stddev_samp(arg: Expr, distinct: bool = False) -> Expr: - """ - Computes the sample standard deviation of the argument. + """Computes the sample standard deviation of the argument. + This is an alias for `stddev`. """ return stddev(arg, distinct) def var(arg: Expr) -> Expr: - """ - Computes the sample variance of the argument. + """Computes the sample variance of the argument. + This is an alias for `var_samp`. """ return var_samp(arg) def var_pop(arg: Expr, distinct: bool = False) -> Expr: - """ - Computes the population variance of the argument. - """ + """Computes the population variance of the argument.""" return Expr(f.var_pop(arg.expr, distinct=distinct)) def var_samp(arg: Expr) -> Expr: - """ - Computes the sample variance of the argument. - """ + """Computes the sample variance of the argument.""" return Expr(f.var_samp(arg.expr)) def regr_avgx(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """ - Computes the average of the independent variable `x` for non-null pairs of the inputs. - """ + """Computes the average of the independent variable `x` for non-null pairs of the inputs.""" return Expr(f.regr_avgx[y.expr, x.expr], distinct) def regr_avgy(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """ - Computes the average of the dependent variable `y` for non-null pairs of the inputs. - """ + """Computes the average of the dependent variable `y` for non-null pairs of the inputs.""" return Expr(f.regr_avgy[y.expr, x.expr], distinct) def regr_count(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """ - Counts the number of input rows in which both expressions are not null. - """ + """Counts the number of input rows in which both expressions are not null.""" return Expr(f.regr_count[y.expr, x.expr], distinct) def regr_intercept(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """ - Computes the intercept from the linear regression. - """ + """Computes the intercept from the linear regression.""" return Expr(f.regr_intercept[y.expr, x.expr], distinct) def regr_r2(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """ - Computes the R-squared value from linear regression. - """ + """Computes the R-squared value from linear regression.""" return Expr(f.regr_r2[y.expr, x.expr], distinct) def regr_slope(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """ - Computes the slope from linear regression. - """ + """Computes the slope from linear regression.""" return Expr(f.regr_slope[y.expr, x.expr], distinct) def regr_sxx(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """ - Computes the sum of squares of the independent variable `x`. - """ + """Computes the sum of squares of the independent variable `x`.""" return Expr(f.regr_sxx[y.expr, x.expr], distinct) def regr_sxy(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """ - Computes the sum of products of pairs of numbers - """ + """Computes the sum of products of pairs of numbers.""" return Expr(f.regr_sxy[y.expr, x.expr], distinct) def regr_syy(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """ - Computes the sum of squares of the dependent variable `y`. - """ + """Computes the sum of squares of the dependent variable `y`.""" return Expr(f.regr_syy[y.expr, x.expr], distinct) @@ -1670,9 +1349,7 @@ def first_value( order_by: Expr | None = None, null_treatment: common.NullTreatment | None = None, ) -> Expr: - """ - Returns the first value in a group of values. - """ + """Returns the first value in a group of values.""" return Expr( f.first_value( arg.expr, @@ -1691,9 +1368,7 @@ def last_value( order_by: Expr | None = None, null_treatment: common.NullTreatment | None = None, ) -> Expr: - """ - Returns the last value in a group of values. - """ + """Returns the last value in a group of values.""" return Expr( f.last_value( arg.expr, @@ -1706,40 +1381,30 @@ def last_value( def bit_and(*args: Expr, distinct: bool = False) -> Expr: - """ - Computes the bitwise AND of the argument. - """ + """Computes the bitwise AND of the argument.""" args = [arg.expr for arg in args] return Expr(f.bit_and(*args, distinct=distinct)) def bit_or(*args: Expr, distinct: bool = False) -> Expr: - """ - Computes the bitwise OR of the argument. - """ + """Computes the bitwise OR of the argument.""" args = [arg.expr for arg in args] return Expr(f.bit_or(*args, distinct=distinct)) def bit_xor(*args: Expr, distinct: bool = False) -> Expr: - """ - Computes the bitwise XOR of the argument. - """ + """Computes the bitwise XOR of the argument.""" args = [arg.expr for arg in args] return Expr(f.bit_xor(*args, distinct=distinct)) def bool_and(*args: Expr, distinct: bool = False) -> Expr: - """ - Computes the boolean AND of the arugment. - """ + """Computes the boolean AND of the arugment.""" args = [arg.expr for arg in args] return Expr(f.bool_and(*args, distinct=distinct)) def bool_or(*args: Expr, distinct: bool = False) -> Expr: - """ - Computes the boolean OR of the arguement. - """ + """Computes the boolean OR of the arguement.""" args = [arg.expr for arg in args] return Expr(f.bool_or(*args, distinct=distinct)) From b8239e738281fddac47d0cf1fae0cc3c79b0ddd6 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 10 Jul 2024 08:59:01 -0400 Subject: [PATCH 10/55] Addressing PR request to add google formatted docstrings --- python/datafusion/__init__.py | 18 +-- python/datafusion/catalog.py | 17 +++ python/datafusion/common.py | 2 +- python/datafusion/context.py | 134 ++++++++++++--------- python/datafusion/dataframe.py | 97 ++++++++-------- python/datafusion/expr.py | 135 +++++++++++++++++----- python/datafusion/functions.py | 1 + python/datafusion/input/__init__.py | 5 + python/datafusion/input/base.py | 11 +- python/datafusion/input/location.py | 9 +- python/datafusion/object_store.py | 2 +- python/datafusion/record_batch.py | 15 +++ python/datafusion/substrait.py | 33 ++++-- python/datafusion/tests/test_functions.py | 4 +- python/datafusion/tests/test_sql.py | 4 +- python/datafusion/tests/test_udaf.py | 4 +- python/datafusion/udf.py | 22 ++++ 17 files changed, 351 insertions(+), 162 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 8ba95260..2a69f58a 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -15,6 +15,12 @@ # specific language governing permissions and limitations # under the License. +"""DataFusion python package. + +This is a Python library that binds to Apache Arrow in-memory query engine DataFusion. +See https://datafusion.apache.org/python/index.html for more information. +""" + from abc import ABCMeta, abstractmethod from typing import List @@ -169,7 +175,8 @@ def evaluate(self) -> pa.Scalar: pass -def column(value): +def column(value: str): + """Create a column expression.""" return Expr.column(value) @@ -177,6 +184,7 @@ def column(value): def literal(value): + """Create a literal expression.""" return Expr.literal(value) @@ -184,9 +192,7 @@ def literal(value): def udf(func, input_types, return_type, volatility, name=None): - """ - Create a new User Defined Function - """ + """Create a new User Defined Function.""" if not callable(func): raise TypeError("`func` argument must be callable") if name is None: @@ -201,9 +207,7 @@ def udf(func, input_types, return_type, volatility, name=None): def udaf(accum, input_types, return_type, state_type, volatility, name=None): - """ - Create a new User Defined Aggregate Function - """ + """Create a new User Defined Aggregate Function.""" if not issubclass(accum, Accumulator): raise TypeError("`accum` must implement the abstract base class Accumulator") if name is None: diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 1379b692..cec0be76 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +"""Data catalog providers.""" + from __future__ import annotations import datafusion._internal as df_internal @@ -26,34 +28,49 @@ class Catalog: + """DataFusion data catalog.""" + def __init__(self, catalog: df_internal.Catalog) -> None: + """This constructor is not typically called by the end user.""" self.catalog = catalog def names(self) -> list[str]: + """Returns the list of databases in this catalog.""" return self.catalog.names() def database(self, name: str = "public") -> Database: + """Returns the database with the given `name` from this catalog.""" return Database(self.catalog.database(name)) class Database: + """DataFusion Database.""" + def __init__(self, db: df_internal.Database) -> None: + """This constructor is not typically called by the end user.""" self.db = db def names(self) -> set[str]: + """Returns the list of all tables in this database.""" return self.db.names() def table(self, name: str) -> Table: + """Return the table with the given `name` from this database.""" return Table(self.db.table(name)) class Table: + """DataFusion table.""" + def __init__(self, table: df_internal.Table) -> None: + """This constructor is not typically called by the end user.""" self.table = table def schema(self) -> pyarrow.Schema: + """Returns the schema associated with this table.""" return self.table.schema() @property def kind(self) -> str: + """Returns the kind of table.""" return self.table.kind() diff --git a/python/datafusion/common.py b/python/datafusion/common.py index dd56640a..2351845b 100644 --- a/python/datafusion/common.py +++ b/python/datafusion/common.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Common data types used throughout the DataFusion project.""" from ._internal import common diff --git a/python/datafusion/context.py b/python/datafusion/context.py index f34bbf03..40462a53 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +"""Session Context and it's associated configuration.""" + from __future__ import annotations from ._internal import SessionConfig as SessionConfigInternal @@ -40,6 +42,8 @@ class SessionConfig: + """Session configuration options.""" + def __init__(self, config_options: dict[str, str] = {}) -> None: """Create a new `SessionConfig` with the given configuration options. @@ -60,7 +64,7 @@ def with_create_default_catalog_and_schema( enabled : bool Whether the default catalog and schema will be automatically created. - Returns + Returns: ------- SessionConfig A new `SessionConfig` object with the updated setting. @@ -82,7 +86,7 @@ def with_default_catalog_and_schema( schema : str Schema name. - Returns + Returns: ------- SessionConfig A new `SessionConfig` object with the updated setting. @@ -100,7 +104,7 @@ def with_information_schema(self, enabled: bool = True) -> SessionConfig: enabled : bool Whether to include `information_schema` virtual tables. - Returns + Returns: ------- SessionConfig A new `SessionConfig` object with the updated setting. @@ -116,7 +120,7 @@ def with_batch_size(self, batch_size: int) -> SessionConfig: batch_size : int Batch size. - Returns + Returns: ------- SessionConfig A new `SessionConfig` object with the updated setting. @@ -134,7 +138,7 @@ def with_target_partitions(self, target_partitions: int) -> SessionConfig: target_partitions : int Number of target partitions. - Returns + Returns: ------- SessionConfig A new `SessionConfig` object with the updated setting. @@ -154,7 +158,7 @@ def with_repartition_aggregations(self, enabled: bool = True) -> SessionConfig: enabled : bool Whether to use repartitioning for aggregations. - Returns + Returns: ------- SessionConfig A new `SessionConfig` object with the updated setting. @@ -172,7 +176,7 @@ def with_repartition_joins(self, enabled: bool = True) -> SessionConfig: enabled : bool Whether to use repartitioning for joins. - Returns + Returns: ------- SessionConfig A new `SessionConfig` object with the updated setting. @@ -188,7 +192,7 @@ def with_repartition_windows(self, enabled: bool = True) -> SessionConfig: enabled : bool Whether to use repartitioning for window functions. - Returns + Returns: ------- SessionConfig A new `SessionConfig` object with the updated setting. @@ -204,7 +208,7 @@ def with_repartition_sorts(self, enabled: bool = True) -> SessionConfig: enabled : bool Whether to use repartitioning for window functions. - Returns + Returns: ------- SessionConfig A new `SessionConfig` object with the updated setting. @@ -220,7 +224,7 @@ def with_repartition_file_scans(self, enabled: bool = True) -> SessionConfig: enabled : bool Whether to use repartitioning for file scans. - Returns + Returns: ------- SessionConfig A new `SessionConfig` object with the updated setting. @@ -236,7 +240,7 @@ def with_repartition_file_min_size(self, size: int) -> SessionConfig: size : int Minimum file range size. - Returns + Returns: ------- SessionConfig A new `SessionConfig` object with the updated setting. @@ -252,7 +256,7 @@ def with_parquet_pruning(self, enabled: bool = True) -> SessionConfig: enabled : bool Whether to use pruning predicate for parquet readers. - Returns + Returns: ------- SessionConfig A new `SessionConfig` object with the updated setting. @@ -270,7 +274,7 @@ def set(self, key: str, value: str) -> SessionConfig: value : str Option value. - Returns + Returns: ------- SessionConfig A new `SessionConfig` object with the updated setting. @@ -280,6 +284,8 @@ def set(self, key: str, value: str) -> SessionConfig: class RuntimeConfig: + """Runtime configuration options.""" + def __init__(self) -> None: """Create a new `RuntimeConfig` with default values.""" self.config_internal = RuntimeConfigInternal() @@ -287,12 +293,12 @@ def __init__(self) -> None: def with_disk_manager_disabled(self) -> RuntimeConfig: """Disable the disk manager, attempts to create temporary files will error. - Returns + Returns: ------- RuntimeConfig A new `RuntimeConfig` object with the updated setting. - Examples + Examples: -------- >>> config = RuntimeConfig().with_disk_manager_disabled() """ @@ -302,12 +308,12 @@ def with_disk_manager_disabled(self) -> RuntimeConfig: def with_disk_manager_os(self) -> RuntimeConfig: """Use the operating system's temporary directory for disk manager. - Returns + Returns: ------- RuntimeConfig A new `RuntimeConfig` object with the updated setting. - Examples + Examples: -------- >>> config = RuntimeConfig().with_disk_manager_os() """ @@ -322,12 +328,12 @@ def with_disk_manager_specified(self, paths: list[str]) -> RuntimeConfig: paths : list[str] Paths to use for the disk manager's temporary files. - Returns + Returns: ------- RuntimeConfig A new `RuntimeConfig` object with the updated setting. - Examples + Examples: -------- >>> config = RuntimeConfig().with_disk_manager_specified(["/tmp"]) """ @@ -337,12 +343,12 @@ def with_disk_manager_specified(self, paths: list[str]) -> RuntimeConfig: def with_unbounded_memory_pool(self) -> RuntimeConfig: """Use an unbounded memory pool. - Returns + Returns: ------- RuntimeConfig A new `RuntimeConfig` object with the updated setting. - Examples + Examples: -------- >>> config = RuntimeConfig().with_unbounded_memory_pool() """ @@ -373,12 +379,12 @@ def with_fair_spill_pool(self, size: int) -> RuntimeConfig: size : int Size of the memory pool in bytes. - Returns + Returns: ------- RuntimeConfig A new `RuntimeConfig` object with the updated setting. - Examples + Examples: -------- ```python >>> config = RuntimeConfig().with_fair_spill_pool(1024) @@ -399,12 +405,12 @@ def with_greedy_memory_pool(self, size: int) -> RuntimeConfig: size : int Size of the memory pool in bytes. - Returns + Returns: ------- RuntimeConfig A new `RuntimeConfig` object with the updated setting. - Examples + Examples: -------- >>> config = RuntimeConfig().with_greedy_memory_pool(1024) """ @@ -419,12 +425,12 @@ def with_temp_file_path(self, path: str) -> RuntimeConfig: path : str Path to use for temporary files. - Returns + Returns: ------- RuntimeConfig A new `RuntimeConfig` object with the updated setting. - Examples + Examples: -------- >>> config = RuntimeConfig().with_temp_file_path("/tmp") """ @@ -433,6 +439,8 @@ def with_temp_file_path(self, path: str) -> RuntimeConfig: class SQLOptions: + """Options to be used when performing SQL queries on the ``SessionContext``.""" + def __init__(self) -> None: """Create a new `SQLOptions` with default values. @@ -453,13 +461,13 @@ def with_allow_ddl(self, allow: bool = True) -> SQLOptions: allow : bool Allow DDL commands to be run. - Returns + Returns: ------- SQLOptions A new `SQLOptions` object with the updated setting. - Examples + Examples: -------- >>> options = SQLOptions().with_allow_ddl(True) """ @@ -476,13 +484,13 @@ def with_allow_dml(self, allow: bool = True) -> SQLOptions: allow : bool Allow DML commands to be run. - Returns + Returns: ------- SQLOptions A new `SQLOptions` object with the updated setting. - Examples + Examples: -------- >>> options = SQLOptions().with_allow_dml(True) """ @@ -497,12 +505,12 @@ def with_allow_statements(self, allow: bool = True) -> SQLOptions: allow : bool Allow statements to be run. - Returns + Returns: ------- SQLOptions A new `SQLOptions` object with the updated setting. - Examples + Examples: -------- >>> options = SQLOptions().with_allow_statements(True) """ @@ -511,6 +519,11 @@ def with_allow_statements(self, allow: bool = True) -> SQLOptions: class SessionContext: + """This is the main interface for executing queries and creating DataFrames. + + See https://datafusion.apache.org/python/user-guide/basics.html for additional information. + """ + def __init__( self, config: SessionConfig | None = None, runtime: RuntimeConfig | None = None ) -> None: @@ -527,7 +540,7 @@ def __init__( runtime : RuntimeConfig | None Runtime configuration options. - Examples + Examples: -------- The following example demostrates how to use the context to execute a query against a CSV data source using the `DataFrame` API: @@ -545,6 +558,7 @@ def __init__( self.ctx = SessionContextInternal(config, runtime) def register_object_store(self, schema: str, store: Any, host: str | None) -> None: + """Add a new object store into the session.""" self.ctx.register_object_store(schema, store, host) def register_listing_table( @@ -556,6 +570,7 @@ def register_listing_table( schema: pyarrow.Schema | None = None, file_sort_order: list[list[Expr]] | None = None, ) -> None: + """Registers a Table that can assemble multiple files from locations in an ``ObjectStore`` instance into a single table.""" if file_sort_order is not None: file_sort_order = [[x.expr for x in xs] for xs in file_sort_order] self.ctx.register_listing_table( @@ -574,7 +589,7 @@ def sql(self, query: str) -> DataFrame: query : str SQL query text. - Returns + Returns: ------- DataFrame DataFrame representation of the SQL query. @@ -582,8 +597,7 @@ def sql(self, query: str) -> DataFrame: return DataFrame(self.ctx.sql(query)) def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: - """Create a `DataFrame` from SQL query text, first validating that - the query is allowed by the provided options. + """Create a `DataFrame` from SQL query text, first validating that the query is allowed by the provided options. Parameters ---------- @@ -592,7 +606,7 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: options : SQLOptions SQL options. - Returns + Returns: ------- DataFrame DataFrame representation of the SQL query. @@ -605,6 +619,7 @@ def create_dataframe( name: str | None = None, schema: pyarrow.Schema | None = None, ) -> DataFrame: + """Create and return a dataframe using the provided partitions.""" return DataFrame(self.ctx.create_dataframe(partitions, name, schema)) def create_dataframe_from_logical_plan(self, plan: LogicalPlan) -> DataFrame: @@ -615,7 +630,7 @@ def create_dataframe_from_logical_plan(self, plan: LogicalPlan) -> DataFrame: plan : LogicalPlan Logical plan. - Returns + Returns: ------- DataFrame DataFrame representation of the logical plan. @@ -634,7 +649,7 @@ def from_pylist( name : str | None Name of the DataFrame. - Returns + Returns: ------- DataFrame DataFrame representation of the list of dictionaries. @@ -653,7 +668,7 @@ def from_pydict( name : str | None Name of the DataFrame. - Returns + Returns: ------- DataFrame DataFrame representation of the dictionary of lists. @@ -672,7 +687,7 @@ def from_arrow_table( name : str | None Name of the DataFrame. - Returns + Returns: ------- DataFrame DataFrame representation of the Arrow table. @@ -689,7 +704,7 @@ def from_pandas(self, data: pandas.DataFrame, name: str | None = None) -> DataFr name : str | None Name of the DataFrame. - Returns + Returns: ------- DataFrame DataFrame representation of the Pandas DataFrame. @@ -706,7 +721,7 @@ def from_polars(self, data: polars.DataFrame, name: str | None = None) -> DataFr name : str | None Name of the DataFrame. - Returns + Returns: ------- DataFrame DataFrame representation of the Polars DataFrame. @@ -714,14 +729,17 @@ def from_polars(self, data: polars.DataFrame, name: str | None = None) -> DataFr return DataFrame(self.ctx.from_polars(data, name)) def register_table(self, name: str, table: pyarrow.Table) -> None: + """Register a table with the given name into the session.""" self.ctx.register_table(name, table) def deregister_table(self, name: str) -> None: + """Remove a table from the session.""" self.ctx.deregister_table(name) def register_record_batches( self, name: str, partitions: list[list[pyarrow.RecordBatch]] ) -> None: + """Convert the provided partitions into a table and register it into the session using the given name.""" self.ctx.register_record_batches(name, partitions) def register_parquet( @@ -887,8 +905,7 @@ def register_avro( self.ctx.register_avro(name, path, schema, file_extension, table_partition_cols) def register_dataset(self, name: str, dataset: pyarrow.dataset.Dataset) -> None: - """ - Register a `pyarrow.dataset.Dataset` as a table. + """Register a `pyarrow.dataset.Dataset` as a table. Parameters ---------- @@ -927,7 +944,7 @@ def catalog(self, name: str = "datafusion") -> Catalog: name : str, optional Name of the catalog to retrieve, by default "datafusion". - Returns + Returns: ------- Catalog Catalog representation. @@ -939,6 +956,7 @@ def catalog(self, name: str = "datafusion") -> Catalog: "examine available catalogs, schemas and tables" ) def tables(self) -> set[str]: + """Deprecated.""" return self.ctx.tables() def table(self, name: str) -> DataFrame: @@ -949,7 +967,7 @@ def table(self, name: str) -> DataFrame: name : str Name of the table to retrieve. - Returns + Returns: ------- DataFrame DataFrame representation of the table. @@ -964,7 +982,7 @@ def table_exist(self, name: str) -> bool: name : str Name of the table to check. - Returns + Returns: ------- bool Whether a table with the given name exists. @@ -974,7 +992,7 @@ def table_exist(self, name: str) -> bool: def empty_table(self) -> DataFrame: """Create an empty `DataFrame`. - Returns + Returns: ------- DataFrame An empty DataFrame. @@ -984,7 +1002,7 @@ def empty_table(self) -> DataFrame: def session_id(self) -> str: """Retrun an id that uniquely identifies this `SessionContext`. - Returns + Returns: ------- str Unique session identifier @@ -1017,7 +1035,7 @@ def read_json( file_compression_type : str | None, optional File compression type, by default None - Returns + Returns: ------- DataFrame DataFrame representation of the read JSON files @@ -1065,7 +1083,7 @@ def read_csv( file_compression_type : str | None, optional File compression type, by default None - Returns + Returns: ------- DataFrame DataFrame representation of the read CSV files @@ -1114,7 +1132,7 @@ def read_parquet( file_sort_order : list[list[Expr]] | None, optional Sort order for the file, by default None - Returns + Returns: ------- DataFrame DataFrame representation of the read Parquet files @@ -1151,7 +1169,7 @@ def read_avro( file_extension : str, optional File extension to select, by default ".avro" - Returns + Returns: ------- DataFrame DataFrame representation of the read Avro file @@ -1161,7 +1179,9 @@ def read_avro( ) def read_table(self, table: Table) -> DataFrame: + """Creates a ``DataFrame`` for a ``Table`` such as a ``ListingTable``.""" return DataFrame(self.ctx.read_table(table)) - def execute(self, plan: ExecutionPlan, part: int) -> RecordBatchStream: - return RecordBatchStream(self.ctx.execute(plan, part)) + def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: + """Execute the `plan` and return the results.""" + return RecordBatchStream(self.ctx.execute(plan, partitions)) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 00748341..3028d5c5 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -14,6 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""DataFrame is one of the core concepts in DataFusion. + +See https://datafusion.apache.org/python/user-guide/basics.html for more information. +""" from __future__ import annotations @@ -35,7 +39,16 @@ class DataFrame: + """Two dimensional representation of data represented as rows and columns in a table. + + See https://datafusion.apache.org/python/user-guide/basics.html for more information. + """ + def __init__(self, df: DataFrameInternal) -> None: + """This constructor is not to be used by the end user. + + See ``SessionContext`` for methods to create DataFrames. + """ self.df = df def __getitem__(self, key: str | List[str]) -> DataFrame: @@ -46,7 +59,7 @@ def __getitem__(self, key: str | List[str]) -> DataFrame: key : Any Column name or list of column names to select. - Returns + Returns: ------- DataFrame DataFrame with the specified column or columns. @@ -56,7 +69,7 @@ def __getitem__(self, key: str | List[str]) -> DataFrame: def __repr__(self) -> str: """Return a string representation of the DataFrame. - Returns + Returns: ------- str String representation of the DataFrame. @@ -71,7 +84,7 @@ def describe(self) -> DataFrame: The output format is modeled after pandas. - Returns + Returns: ------- DataFrame A summary DataFrame containing statistics. @@ -84,7 +97,7 @@ def schema(self) -> pa.Schema: The output schema contains information on the name, data type, and nullability for each column. - Returns + Returns: ------- pa.Schema Describing schema of the DataFrame @@ -94,7 +107,7 @@ def schema(self) -> pa.Schema: def select_columns(self, *args: str) -> DataFrame: """Filter the DataFrame by columns. - Returns + Returns: ------- DataFrame DataFrame only containing the specified columns. @@ -104,7 +117,7 @@ def select_columns(self, *args: str) -> DataFrame: def select(self, *args: Expr) -> DataFrame: """Project arbitrary expressions (like SQL SELECT expressions) into a new `DataFrame`. - Returns + Returns: ------- DataFrame DataFrame after projection. It has one column for each expression. @@ -122,7 +135,7 @@ def filter(self, predicate: Expr) -> DataFrame: predicate : Expr Predicate expression to filter the DataFrame. - Returns + Returns: ------- DataFrame DataFrame after filtering. @@ -139,7 +152,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame: expr : Expr Expression to compute the column. - Returns + Returns: ------- DataFrame DataFrame with the new column. @@ -161,7 +174,7 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame: new_name : str New column name. - Returns + Returns: ------- DataFrame DataFrame with the column renamed. @@ -180,7 +193,7 @@ def aggregate(self, group_by: list[Expr], aggs: list[Expr]) -> DataFrame: aggs : list[Expr] List of expressions to aggregate. - Returns + Returns: ------- DataFrame DataFrame after aggregation. @@ -195,7 +208,7 @@ def sort(self, *exprs: Expr) -> DataFrame: Note that any expression can be turned into a sort expression by calling its `sort` method. - Returns + Returns: ------- DataFrame DataFrame after sorting. @@ -213,7 +226,7 @@ def limit(self, count: int, offset: int = 0) -> DataFrame: offset : int, optional Number of rows to skip, by default 0 - Returns + Returns: ------- DataFrame DataFrame after limiting. @@ -227,7 +240,7 @@ def collect(self) -> list[pa.RecordBatch]: (no actual computation is performed). Calling `collect` triggers the computation. - Returns + Returns: ------- list[pa.RecordBatch] List of `pyarrow.RecordBatch`es collected from the DataFrame. @@ -237,7 +250,7 @@ def collect(self) -> list[pa.RecordBatch]: def cache(self) -> DataFrame: """Cache the DataFrame as a memory table. - Returns + Returns: ------- DataFrame Cached DataFrame. @@ -245,10 +258,9 @@ def cache(self) -> DataFrame: return DataFrame(self.df.cache()) def collect_partitioned(self) -> list[list[pa.RecordBatch]]: - """Execute this DataFrame and collect all results into a list of list of - `pyarrow.RecordBatch`es maintaining the input partitioning. + """Execute this DataFrame and collect all results into a list of list of `pyarrow.RecordBatch`es maintaining the input partitioning. - Returns + Returns: ------- list[list[pa.RecordBatch]] List of list of `pyarrow.RecordBatch`es collected from the DataFrame. @@ -268,7 +280,7 @@ def show(self, num: int = 20) -> None: def distinct(self) -> DataFrame: """Return a new `DataFrame` with all duplicated rows removed. - Returns + Returns: ------- DataFrame DataFrame after removing duplicates. @@ -281,8 +293,7 @@ def join( join_keys: tuple[list[str], list[str]], how: str, ) -> DataFrame: - """Join this `DataFrame` with another `DataFrame` using explicitly - specified columns. + """Join this `DataFrame` with another `DataFrame` using explicitly specified columns. Parameters ---------- @@ -293,7 +304,7 @@ def join( how : str Type of join to perform. Supported types are "inner", "left", "right", "full", "semi", "anti". - Returns + Returns: ------- DataFrame DataFrame after join. @@ -312,7 +323,7 @@ def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame: analyze : bool, optional If `True`, the plan will run and metrics reported, by default False - Returns + Returns: ------- DataFrame DataFrame with the explanation of its plan. @@ -322,7 +333,7 @@ def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame: def logical_plan(self) -> LogicalPlan: """Return the unoptimized `LogicalPlan` that comprises this `DataFrame`. - Returns + Returns: ------- LogicalPlan Unoptimized logical plan. @@ -332,7 +343,7 @@ def logical_plan(self) -> LogicalPlan: def optimized_logical_plan(self) -> LogicalPlan: """Return the optimized `LogicalPlan` that comprises this `DataFrame`. - Returns + Returns: ------- LogicalPlan Optimized logical plan. @@ -342,7 +353,7 @@ def optimized_logical_plan(self) -> LogicalPlan: def execution_plan(self) -> ExecutionPlan: """Return the execution/physical plan that comprises this `DataFrame`. - Returns + Returns: ------- ExecutionPlan Execution plan. @@ -359,7 +370,7 @@ def repartition(self, num: int) -> DataFrame: num : int Number of partitions to repartition the DataFrame into. - Returns + Returns: ------- DataFrame Repartitioned DataFrame. @@ -374,7 +385,7 @@ def repartition_by_hash(self, *args: Expr, num: int) -> DataFrame: num : int Number of partitions to repartition the DataFrame into. - Returns + Returns: ------- DataFrame Repartitioned DataFrame. @@ -394,7 +405,7 @@ def union(self, other: DataFrame, distinct: bool = False) -> DataFrame: distinct : bool, optional If `True`, duplicate rows will be removed, by default False - Returns + Returns: ------- DataFrame DataFrame after union. @@ -412,7 +423,7 @@ def union_distinct(self, other: DataFrame) -> DataFrame: other : DataFrame DataFrame to union with. - Returns + Returns: ------- DataFrame DataFrame after union. @@ -429,7 +440,7 @@ def intersect(self, other: DataFrame) -> DataFrame: other : DataFrame DataFrame to intersect with. - Returns + Returns: ------- DataFrame DataFrame after intersection. @@ -446,7 +457,7 @@ def except_all(self, other: DataFrame) -> DataFrame: other : DataFrame DataFrame to calculate exception with. - Returns + Returns: ------- DataFrame DataFrame after exception. @@ -495,7 +506,7 @@ def write_json(self, path: str) -> None: def to_arrow_table(self) -> pa.Table: """Execute the `DataFrame` and convert it into an Arrow Table. - Returns + Returns: ------- pa.Table Arrow Table. @@ -503,22 +514,18 @@ def to_arrow_table(self) -> pa.Table: return self.df.to_arrow_table() def execute_stream(self) -> RecordBatchStream: - """ - TODO add descriptive text - """ + """Executes this DataFrame and returns a stream over a single partition.""" return RecordBatchStream(self.df.execute_stream()) def execute_stream_partitioned(self) -> list[RecordBatchStream]: - """ - TODO add descriptive text - """ + """Executes this DataFrame and returns a stream for each partition.""" streams = self.df.execute_stream_partitioned() return [RecordBatchStream(rbs) for rbs in streams] def to_pandas(self) -> pd.DataFrame: """Execute the `DataFrame` and convert it into a Pandas DataFrame. - Returns + Returns: ------- pd.DataFrame Pandas DataFrame. @@ -528,7 +535,7 @@ def to_pandas(self) -> pd.DataFrame: def to_pylist(self) -> list[dict[str, Any]]: """Execute the `DataFrame` and convert it into a list of dictionaries. - Returns + Returns: ------- list[dict[str, Any]] List of dictionaries. @@ -538,7 +545,7 @@ def to_pylist(self) -> list[dict[str, Any]]: def to_pydict(self) -> dict[str, list[Any]]: """Execute the `DataFrame` and convert it into a dictionary of lists. - Returns + Returns: ------- dict[str, list[Any]] Dictionary of lists. @@ -548,7 +555,7 @@ def to_pydict(self) -> dict[str, list[Any]]: def to_polars(self) -> pl.DataFrame: """Execute the `DataFrame` and convert it into a Polars DataFrame. - Returns + Returns: ------- pl.DataFrame Polars DataFrame. @@ -561,7 +568,7 @@ def count(self) -> int: Note that this method will actually run a plan to calculate the count, which may be slow for large or complicated DataFrames. - Returns + Returns: ------- int Number of rows in the DataFrame. @@ -570,11 +577,11 @@ def count(self) -> int: @deprecated("Use :func:`unnest_columns` instead.") def unnest_column(self, column: str, preserve_nulls: bool = True) -> DataFrame: - """ """ + """See ``unnest_columns``.""" return DataFrame(self.df.unnest_column(column, preserve_nulls=preserve_nulls)) def unnest_columns( self, columns: list[str], preserve_nulls: bool = True ) -> DataFrame: - """ """ + """Expand columns of arrays into a single row per array element.""" return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls)) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 93ebddba..35a84791 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -15,6 +15,11 @@ # specific language governing permissions and limitations # under the License. +"""This module supports expressions, one of the core concepts in DataFusion. + +See ``Expr`` for more details. +""" + from __future__ import annotations from ._internal import expr as expr_internal, LogicalPlan @@ -81,118 +86,184 @@ class Expr: + """Expression object. + + Expressions are one of the core concepts in DataFusion. See + [the online help](https://datafusion.apache.org/python/user-guide/common-operations/expressions.html) + for more information. + """ + def __init__(self, expr: expr_internal.Expr) -> None: + """This constructor should not be called by the end user.""" self.expr = expr def to_variant(self) -> Any: + """Convert this expression into a python object if possible.""" return self.expr.to_variant() def display_name(self) -> str: + """Returns the name of this expression as it should appear in a schema. + + This name will not include any CAST expressions. + """ return self.expr.display_name() def canonical_name(self) -> str: + """Returns a full and complete string representation of this expression.""" return self.expr.canonical_name() def variant_name(self) -> str: + """Returns the name of the Expr variant. + + Ex: ``IsNotNull``, ``Literal``, ``BinaryExpr``, etc + """ return self.expr.variant_name() def __richcmp__(self, other: Expr, op: int) -> Expr: + """Comparison operator.""" return Expr(self.expr.__richcmp__(other, op)) def __repr__(self) -> str: + """Generate a string representation of this expression.""" return self.expr.__repr__() def __add__(self, rhs: Expr) -> Expr: + """Addition operator.""" return Expr(self.expr.__add__(rhs.expr)) def __sub__(self, rhs: Expr) -> Expr: + """Subtraction operator.""" return Expr(self.expr.__sub__(rhs.expr)) def __truediv__(self, rhs: Expr) -> Expr: + """Division operator.""" return Expr(self.expr.__truediv__(rhs.expr)) def __mul__(self, rhs: Expr) -> Expr: + """Multiplication operator.""" return Expr(self.expr.__mul__(rhs.expr)) def __mod__(self, rhs: Expr) -> Expr: + """Modulo operator (%).""" return Expr(self.expr.__mod__(rhs.expr)) def __and__(self, rhs: Expr) -> Expr: + """Logical AND.""" return Expr(self.expr.__and__(rhs.expr)) def __or__(self, rhs: Expr) -> Expr: + """Logical OR.""" return Expr(self.expr.__or__(rhs.expr)) def __invert__(self) -> Expr: + """Binary not (~).""" return Expr(self.expr.__invert__()) def __getitem__(self, key: str) -> Expr: + """For struct data types, return the field indicated by ``key``.""" return Expr(self.expr.__getitem__(key)) def __eq__(self, rhs: Expr) -> Expr: + """Equal to.""" return Expr(self.expr.__eq__(rhs.expr)) def __ne__(self, rhs: Expr) -> Expr: + """Not equal to.""" return Expr(self.expr.__eq__(rhs.expr)) def __ge__(self, rhs: Expr) -> Expr: + """Greater than or equal to.""" return Expr(self.expr.__ge__(rhs.expr)) def __gt__(self, rhs: Expr) -> Expr: + """Greater than.""" return Expr(self.expr.__gt__(rhs.expr)) def __le__(self, rhs: Expr) -> Expr: + """Less than or equal to.""" return Expr(self.expr.__le__(rhs.expr)) def __lt__(self, rhs: Expr) -> Expr: + """Less than.""" return Expr(self.expr.__lt__(rhs.expr)) @staticmethod def literal(value: Any) -> Expr: + """Creates a new expression representing a scalar value. + + `value` must be a valid PyArrow scalar value or easily castable to one. + """ if not isinstance(value, pa.Scalar): value = pa.scalar(value) return Expr(expr_internal.Expr.literal(value)) @staticmethod def column(value: str) -> Expr: + """Creates a new expression representing a column in a ``DataFrame``.""" return Expr(expr_internal.Expr.column(value)) def alias(self, name: str) -> Expr: + """Assign a name to the expression.""" return Expr(self.expr.alias(name)) def sort(self, ascending: bool = True, nulls_first: bool = True) -> Expr: + """Creates a sort ``Expr`` from an existing ``Expr``.""" return Expr(self.expr.sort(ascending=ascending, nulls_first=nulls_first)) def is_null(self) -> Expr: + """Returns ``True`` if this expression is null.""" return Expr(self.expr.is_null()) def cast(self, to: pa.DataType[Any]) -> Expr: + """Cast to a new data type.""" return Expr(self.expr.cast(to)) def rex_type(self) -> RexType: + """Return the Rex Type of this expression. + + A Rex (Row Expression) specifies a single row of data.That specification + could include user defined functions or types. RexType identifies the row + as one of the possible valid ``RexType``(s). + """ return self.expr.rex_type() def types(self) -> DataTypeMap: + """Return the ``DataTypeMap`` which represents the PythonType, Arrow DataType, and SqlType Enum which this expression represents.""" return self.expr.types() def python_value(self) -> Any: + """Extracts the Expr value into a PyObject that can be shared with Python. + + This is only valid for literal expressions. + """ return self.expr.python_value() def rex_call_operands(self) -> list[Expr]: + """Return the operands of the expression based on it's variant type. + + Row expressions, Rex(s), operate on the concept of operands. Different variants of Expressions, Expr(s), + store those operands in different datastructures. This function examines the Expr variant and returns + the operands to the calling logic. + """ return [Expr(e) for e in self.expr.rex_call_operands()] def rex_call_operator(self) -> str: + """Extracts the operator associated with a row expression type ``Call``.""" return self.expr.rex_call_operator() def column_name(self, plan: LogicalPlan) -> str: - return self.expr.column_name() + """Compute the output column name based on the provided logical plan.""" + return self.expr.column_name(plan) class WindowFrame: + """Defines a window frame for performing window operations.""" + def __init__( self, units: str, start_bound: int | None, end_bound: int | None ) -> None: - """ + """Construct a window frame using the given parameters. + :param units: Should be one of `rows`, `range`, or `groups` :param start_bound: Sets the preceeding bound. Must be >= 0. If none, this will be set to unbounded. If unit type is `groups`, this parameter must be set. :param end_bound: Sets the following bound. Must be >= 0. If none, this will be set to unbounded. If unit type is `groups`, this parameter must be set. @@ -200,71 +271,81 @@ def __init__( self.window_frame = expr_internal.WindowFrame(units, start_bound, end_bound) def get_frame_units(self) -> str: - """ - Returns the window frame units for the bounds - """ + """Returns the window frame units for the bounds.""" return self.window_frame.get_frame_units() def get_lower_bound(self) -> WindowFrameBound: - """ - Returns starting bound - """ + """Returns starting bound.""" return WindowFrameBound(self.window_frame.get_lower_bound()) def get_upper_bound(self): - """ - Returns end bound - """ + """Returns end bound.""" return WindowFrameBound(self.window_frame.get_upper_bound()) class WindowFrameBound: + """Defines a single window frame bound. + + ```WindowFrame`` typically requires a start and end bound. + """ + def __init__(self, frame_bound: expr_internal.WindowFrameBound) -> None: + """Constructs a window frame bound.""" self.frame_bound = frame_bound def get_offset(self) -> int | None: - """ - Returns the offset of the window frame - """ + """Returns the offset of the window frame.""" return self.frame_bound.get_offset() def is_current_row(self) -> bool: - """ - Returns if the frame bound is current row - """ + """Returns if the frame bound is current row.""" return self.frame_bound.is_current_row() def is_following(self) -> bool: - """ - Returns if the frame bound is following - """ + """Returns if the frame bound is following.""" return self.frame_bound.is_following() def is_preceding(self) -> bool: - """ - Returns if the frame bound is preceding - """ + """Returns if the frame bound is preceding.""" return self.frame_bound.is_preceding() def is_unbounded(self) -> bool: - """ - Returns if the frame bound is unbounded - """ + """Returns if the frame bound is unbounded.""" return self.frame_bound.is_unbounded() class CaseBuilder: + """Builder class for constructing case statements. + + An example usage would be as follows: + + ```python + import datafusion.functions as f + from datafusion import lit, col + df.select(f.case(col("column_a").when(lit(1), lit("One")).when(lit(2), lit("Two")).otherwise(lit("Unknown"))) + ``` + """ + def __init__(self, case_builder: expr_internal.CaseBuilder) -> None: - """ + """Constructs a case builder. + + This is not typically called by the end user directly. See ``datafusion.functions.case`` instead. + :param case_builder: Internal object. This constructor is not expected to be used by the end user. Instead use :func:`case` to construct. """ self.case_builder = case_builder def when(self, when_expr: Expr, then_expr: Expr) -> CaseBuilder: + """Add a case to match against.""" return CaseBuilder(self.case_builder.when(when_expr.expr, then_expr.expr)) def otherwise(self, else_expr: Expr) -> Expr: + """Set a default value for the case statement.""" return Expr(self.case_builder.otherwise(else_expr.expr)) def end(self) -> Expr: + """Finish building a case statement. + + Any non-matching cases will end in a `null` value. + """ return Expr(self.case_builder.end()) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 75bf08a5..36a8aba8 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""This module contains the user functions for operating on ``Expr``.""" from __future__ import annotations diff --git a/python/datafusion/input/__init__.py b/python/datafusion/input/__init__.py index 27e39b8c..f85ce21f 100644 --- a/python/datafusion/input/__init__.py +++ b/python/datafusion/input/__init__.py @@ -15,6 +15,11 @@ # specific language governing permissions and limitations # under the License. +"""This package provides for input sources. + +The primary class used within DataFusion is ``LocationInputPlugin``. +""" + from .location import LocationInputPlugin __all__ = [ diff --git a/python/datafusion/input/base.py b/python/datafusion/input/base.py index efcaf769..b91e0a1e 100644 --- a/python/datafusion/input/base.py +++ b/python/datafusion/input/base.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +"""This module provides ``BaseInputSource`` which a user can extend to provide a custom input source.""" + from abc import ABC, abstractmethod from typing import Any @@ -22,18 +24,19 @@ class BaseInputSource(ABC): - """ - If a consuming library would like to provider their own InputSource - this is the class they should extend to write their own. Once - completed the Plugin InputSource can be registered with the + """If a consuming library would like to provider their own InputSource this is the class they should extend to write their own. + + Once completed the Plugin InputSource can be registered with the SessionContext to ensure that it will be used in order to obtain the SqlTable information from the custom datasource. """ @abstractmethod def is_correct_input(self, input_item: Any, table_name: str, **kwargs) -> bool: + """Returns `True` if the input is valid.""" pass @abstractmethod def build_table(self, input_item: Any, table_name: str, **kwarg) -> SqlTable: + """Create a table from the input source.""" pass diff --git a/python/datafusion/input/location.py b/python/datafusion/input/location.py index 16e632d1..7454829d 100644 --- a/python/datafusion/input/location.py +++ b/python/datafusion/input/location.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +"""This module provides ``LocationInputPlugin`` which is the default input source for DataFusion.""" + import os import glob from typing import Any @@ -24,12 +26,10 @@ class LocationInputPlugin(BaseInputSource): - """ - Input Plugin for everything, which can be read - in from a file (on disk, remote etc.) - """ + """Input Plugin for everything, which can be read in from a file (on disk, remote etc.).""" def is_correct_input(self, input_item: Any, table_name: str, **kwargs): + """Returns `True` if the input is valid.""" return isinstance(input_item, str) def build_table( @@ -38,6 +38,7 @@ def build_table( table_name: str, **kwargs, ) -> SqlTable: + """Create a table from the input source.""" _, extension = os.path.splitext(input_file) format = extension.lstrip(".").lower() num_rows = 0 # Total number of rows in the file. Used for statistics diff --git a/python/datafusion/object_store.py b/python/datafusion/object_store.py index 70ecbd2b..06db9a25 100644 --- a/python/datafusion/object_store.py +++ b/python/datafusion/object_store.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""This module contains functionality for operating with different types of object stores.""" from ._internal import object_store diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py index f26458e8..4b7416d9 100644 --- a/python/datafusion/record_batch.py +++ b/python/datafusion/record_batch.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +"""This module provides the classes for handling record batches, which are typically the result of dataframe `execute_stream` operations.""" + from __future__ import annotations from typing import TYPE_CHECKING @@ -25,18 +27,29 @@ class RecordBatch: + """This class is essentially a wrapper for ``pyarrow.RecordBatch``.""" + def __init__(self, record_batch: df_internal.RecordBatch) -> None: + """This constructor is generally not called by the end user. + + See the ``RecordBatchStream`` iterator for generating this class. + """ self.record_batch = record_batch def to_pyarrow(self) -> pyarrow.RecordBatch: + """Convert to pyarrow ``RecordBatch``.""" return self.record_batch.to_pyarrow() class RecordBatchStream: + """This class represents a stream of record batches, typically as the result of a DataFrame `execute_stream` operation.""" + def __init__(self, record_batch_stream: df_internal.RecordBatchStream) -> None: + """This constructor is typically not called by the end user.""" self.rbs = record_batch_stream def next(self) -> RecordBatch | None: + """See ``__next__`` for the iterator function.""" try: next_batch = next(self) except StopIteration: @@ -45,8 +58,10 @@ def next(self) -> RecordBatch | None: return next_batch def __next__(self) -> RecordBatch | None: + """Iterator function.""" next_batch = next(self.rbs) return RecordBatch(next_batch) if next_batch is not None else None def __iter__(self) -> RecordBatchStream: + """Iterator function.""" return self diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index cc17b2a9..6390662e 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -15,6 +15,12 @@ # specific language governing permissions and limitations # under the License. +"""This module provides support for using substrait with datafusion. + +For additional information about substrait, see https://substrait.io/ for more information +about substrait. +""" + from __future__ import annotations from ._internal import substrait as substrait_internal @@ -27,13 +33,20 @@ class plan: + """A class representing an encodable substrait plan.""" + def __init__(self, plan: substrait_internal.plan) -> None: + """Create a substrait plan. + + The user should not have to call this constructor directly. Rather, it should be created + via ``serde`` or ``producer`` classes in this module. + """ self.plan_internal = plan def encode(self) -> bytes: """Encode the plan to bytes. - Returns + Returns: ------- bytes Encoded plan. @@ -42,6 +55,8 @@ def encode(self) -> bytes: class serde: + """Provides the serialization and deserialization required to convert to and from a Substrait plan.""" + @staticmethod def serialize(sql: str, ctx: SessionContext, path: str) -> None: """Serialize a SQL query to a Substrait plan and write it to a file. @@ -68,7 +83,7 @@ def serialize_to_plan(sql: str, ctx: SessionContext) -> plan: ctx : SessionContext SessionContext to use. - Returns + Returns: ------- plan Substrait plan. @@ -86,7 +101,7 @@ def serialize_bytes(sql: str, ctx: SessionContext) -> bytes: ctx : SessionContext SessionContext to use. - Returns + Returns: ------- bytes Substrait plan as bytes. @@ -102,7 +117,7 @@ def deserialize(path: str) -> plan: path : str Path to read the Substrait plan from. - Returns + Returns: ------- plan Substrait plan. @@ -118,7 +133,7 @@ def deserialize_bytes(proto_bytes: bytes) -> plan: proto_bytes : bytes Bytes to read the Substrait plan from. - Returns + Returns: ------- plan Substrait plan. @@ -127,6 +142,8 @@ def deserialize_bytes(proto_bytes: bytes) -> plan: class producer: + """Generates substrait plans from a logical plan.""" + @staticmethod def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> plan: """Convert a DataFusion LogicalPlan to a Substrait plan. @@ -138,7 +155,7 @@ def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> plan: ctx : SessionContext SessionContext to use. - Returns + Returns: ------- plan Substrait plan. @@ -149,6 +166,8 @@ def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> plan: class consumer: + """Generates a logical plan from a substrait plan.""" + @staticmethod def from_substrait_plan(ctx: SessionContext, plan: plan) -> LogicalPlan: """Convert a Substrait plan to a DataFusion LogicalPlan. @@ -160,7 +179,7 @@ def from_substrait_plan(ctx: SessionContext, plan: plan) -> LogicalPlan: plan : plan Substrait plan to convert. - Returns + Returns: ------- LogicalPlan LogicalPlan. diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 85fed622..2e601c28 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -97,9 +97,7 @@ def test_literal(df): def test_lit_arith(df): - """ - Test literals with arithmetic operations - """ + """Test literals with arithmetic operations""" df = df.select(literal(1) + column("b"), f.concat(column("a"), literal("!"))) result = df.collect() assert len(result) == 1 diff --git a/python/datafusion/tests/test_sql.py b/python/datafusion/tests/test_sql.py index ec0e4c57..f9eb588d 100644 --- a/python/datafusion/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -281,9 +281,7 @@ def test_execute(ctx, tmp_path): def test_cast(ctx, tmp_path): - """ - Verify that we can cast - """ + """Verify that we can cast""" path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) ctx.register_parquet("t", path) diff --git a/python/datafusion/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py index c2b29d19..81194927 100644 --- a/python/datafusion/tests/test_udaf.py +++ b/python/datafusion/tests/test_udaf.py @@ -25,9 +25,7 @@ class Summarize(Accumulator): - """ - Interface of a user-defined accumulation. - """ + """Interface of a user-defined accumulation.""" def __init__(self): self._sum = pa.scalar(0.0) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 4a9aebf9..90cd593b 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +"""This module provides the user defined functions for evaluation of dataframes.""" + from __future__ import annotations import datafusion._internal as df_internal @@ -26,6 +28,11 @@ class ScalarUDF: + """Class for performing scalar user defined functions (UDF). + + Scalar UDFs operate on a row by row basis. See also ``AggregateUDF`` for operating on a group of rows. + """ + def __init__( self, name: str | None, @@ -34,16 +41,26 @@ def __init__( return_type: pyarrow.DataType, volatility: str, ) -> None: + """Instantiate a scalar user defined function (UDF).""" self.udf = df_internal.ScalarUDF( name, func, input_types, return_type, volatility ) def __call__(self, *args: Expr) -> Expr: + """Execute the UDF. + + This function is not typically called by an end user. These calls will occur during the evaluation of the dataframe. + """ args = [arg.expr for arg in args] return Expr(self.udf.__call__(*args)) class AggregateUDF: + """Class for performing scalar user defined functions (UDF). + + Aggregate UDFs operate on a group of rows and return a single value. See also ``ScalarUDF`` for operating on a row by row basis. + """ + def __init__( self, name: str | None, @@ -53,10 +70,15 @@ def __init__( state_type: list[pyarrow.DataType], volatility: str, ) -> None: + """Instantiate a user defined aggregate function (UDAF).""" self.udf = df_internal.AggregateUDF( name, accumulator, input_types, return_type, state_type, volatility ) def __call__(self, *args: Expr) -> Expr: + """Execute the UDAF. + + This function is not typically called by an end user. These calls will occur during the evaluation of the dataframe. + """ args = [arg.expr for arg in args] return Expr(self.udf.__call__(*args)) From 4c8073e07a7c3b0342f0e23f56912c2bcfaf762d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 10 Jul 2024 08:59:24 -0400 Subject: [PATCH 11/55] Small docstring for ruff --- docs/source/conf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index c0da8b2c..2946efe3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +"""Documenation generation.""" + # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full From 411c91ce904a48e5b5f97a2fe7cd5b51566d6e30 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 10 Jul 2024 08:59:39 -0400 Subject: [PATCH 12/55] Linting --- examples/tpch/_tests.py | 30 ++++++++++++++----- examples/tpch/convert_data_to_parquet.py | 3 +- examples/tpch/q08_market_share.py | 4 ++- .../tpch/q09_product_type_profit_measure.py | 4 ++- examples/tpch/q13_customer_distribution.py | 4 ++- examples/tpch/q14_promotion_effect.py | 4 ++- .../tpch/q16_part_supplier_relationship.py | 3 +- examples/tpch/q17_small_quantity_order.py | 8 ++++- examples/tpch/q20_potential_part_promotion.py | 4 ++- examples/tpch/q22_global_sales_opportunity.py | 4 ++- examples/tpch/util.py | 7 +++-- 11 files changed, 55 insertions(+), 20 deletions(-) diff --git a/examples/tpch/_tests.py b/examples/tpch/_tests.py index 3f973d9f..cc201a31 100644 --- a/examples/tpch/_tests.py +++ b/examples/tpch/_tests.py @@ -21,6 +21,7 @@ from datafusion import col, lit, functions as F from util import get_answer_file + def df_selection(col_name, col_type): if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type): return F.round(col(col_name), lit(2)).alias(col_name) @@ -29,6 +30,7 @@ def df_selection(col_name, col_type): else: return col(col_name) + def load_schema(col_name, col_type): if col_type == pa.int64() or col_type == pa.int32(): return col_name, pa.string() @@ -36,7 +38,8 @@ def load_schema(col_name, col_type): return col_name, pa.float64() else: return col_name, col_type - + + def expected_selection(col_name, col_type): if col_type == pa.int64() or col_type == pa.int32(): return F.trim(col(col_name)).cast(col_type).alias(col_name) @@ -45,20 +48,23 @@ def expected_selection(col_name, col_type): else: return col(col_name) + def selections_and_schema(original_schema): - columns = [ (c, original_schema.field(c).type) for c in original_schema.names ] + columns = [(c, original_schema.field(c).type) for c in original_schema.names] - df_selections = [ df_selection(c, t) for (c, t) in columns] - expected_schema = [ load_schema(c, t) for (c, t) in columns] - expected_selections = [ expected_selection(c, t) for (c, t) in columns] + df_selections = [df_selection(c, t) for (c, t) in columns] + expected_schema = [load_schema(c, t) for (c, t) in columns] + expected_selections = [expected_selection(c, t) for (c, t) in columns] return (df_selections, expected_schema, expected_selections) + def check_q17(df): raw_value = float(df.collect()[0]["avg_yearly"][0].as_py()) value = round(raw_value, 2) assert abs(value - 348406.05) < 0.001 + @pytest.mark.parametrize( ("query_code", "answer_file"), [ @@ -73,7 +79,8 @@ def check_q17(df): ("q09_product_type_profit_measure", "q9"), ("q10_returned_item_reporting", "q10"), pytest.param( - "q11_important_stock_identification", "q11", + "q11_important_stock_identification", + "q11", ), ("q12_ship_mode_order_priority", "q12"), ("q13_customer_distribution", "q13"), @@ -97,13 +104,20 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str): if answer_file == "q17": return check_q17(df) - (df_selections, expected_schema, expected_selections) = selections_and_schema(df.schema()) + (df_selections, expected_schema, expected_selections) = selections_and_schema( + df.schema() + ) df = df.select(*df_selections) read_schema = pa.schema(expected_schema) - df_expected = module.ctx.read_csv(get_answer_file(answer_file), schema=read_schema, delimiter="|", file_extension=".out") + df_expected = module.ctx.read_csv( + get_answer_file(answer_file), + schema=read_schema, + delimiter="|", + file_extension=".out", + ) df_expected = df_expected.select(*expected_selections) diff --git a/examples/tpch/convert_data_to_parquet.py b/examples/tpch/convert_data_to_parquet.py index d81ec290..a8091a70 100644 --- a/examples/tpch/convert_data_to_parquet.py +++ b/examples/tpch/convert_data_to_parquet.py @@ -117,7 +117,6 @@ curr_dir = os.path.dirname(os.path.abspath(__file__)) for filename, curr_schema in all_schemas.items(): - # For convenience, go ahead and convert the schema column names to lowercase curr_schema = [(s[0].lower(), s[1]) for s in curr_schema] @@ -125,7 +124,7 @@ # in to handle the trailing | in the file output_cols = [r[0] for r in curr_schema] - curr_schema = [ pyarrow.field(r[0], r[1], nullable=False) for r in curr_schema] + curr_schema = [pyarrow.field(r[0], r[1], nullable=False) for r in curr_schema] # Trailing | requires extra field for in processing curr_schema.append(("some_null", pyarrow.null())) diff --git a/examples/tpch/q08_market_share.py b/examples/tpch/q08_market_share.py index d13a71df..cd6bc1fa 100644 --- a/examples/tpch/q08_market_share.py +++ b/examples/tpch/q08_market_share.py @@ -47,7 +47,9 @@ ctx = SessionContext() -df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_type") +df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns( + "p_partkey", "p_type" +) df_supplier = ctx.read_parquet(get_data_path("supplier.parquet")).select_columns( "s_suppkey", "s_nationkey" ) diff --git a/examples/tpch/q09_product_type_profit_measure.py b/examples/tpch/q09_product_type_profit_measure.py index 29ffceed..b4a7369f 100644 --- a/examples/tpch/q09_product_type_profit_measure.py +++ b/examples/tpch/q09_product_type_profit_measure.py @@ -39,7 +39,9 @@ ctx = SessionContext() -df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_name") +df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns( + "p_partkey", "p_name" +) df_supplier = ctx.read_parquet(get_data_path("supplier.parquet")).select_columns( "s_suppkey", "s_nationkey" ) diff --git a/examples/tpch/q13_customer_distribution.py b/examples/tpch/q13_customer_distribution.py index 2b6e7e20..bc0a5bd1 100644 --- a/examples/tpch/q13_customer_distribution.py +++ b/examples/tpch/q13_customer_distribution.py @@ -41,7 +41,9 @@ df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns( "o_custkey", "o_comment" ) -df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns("c_custkey") +df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns( + "c_custkey" +) # Use a regex to remove special cases df_orders = df_orders.filter( diff --git a/examples/tpch/q14_promotion_effect.py b/examples/tpch/q14_promotion_effect.py index 75fa363a..8cb1e4c5 100644 --- a/examples/tpch/q14_promotion_effect.py +++ b/examples/tpch/q14_promotion_effect.py @@ -44,7 +44,9 @@ df_lineitem = ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns( "l_partkey", "l_shipdate", "l_extendedprice", "l_discount" ) -df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_type") +df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns( + "p_partkey", "p_type" +) # Check part type begins with PROMO diff --git a/examples/tpch/q16_part_supplier_relationship.py b/examples/tpch/q16_part_supplier_relationship.py index 0db2d1b8..fdcb5b4d 100644 --- a/examples/tpch/q16_part_supplier_relationship.py +++ b/examples/tpch/q16_part_supplier_relationship.py @@ -62,7 +62,8 @@ # Select the parts we are interested in df_part = df_part.filter(col("p_brand") != lit(BRAND)) df_part = df_part.filter( - F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) != lit(TYPE_TO_IGNORE) + F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) + != lit(TYPE_TO_IGNORE) ) # Python conversion of integer to literal casts it to int64 but the data for diff --git a/examples/tpch/q17_small_quantity_order.py b/examples/tpch/q17_small_quantity_order.py index 5880e7ed..e0ee8bb9 100644 --- a/examples/tpch/q17_small_quantity_order.py +++ b/examples/tpch/q17_small_quantity_order.py @@ -56,7 +56,13 @@ # Find the average quantity window_frame = WindowFrame("rows", None, None) df = df.with_column( - "avg_quantity", F.window("avg", [col("l_quantity")], window_frame=window_frame, partition_by=[col("l_partkey")]) + "avg_quantity", + F.window( + "avg", + [col("l_quantity")], + window_frame=window_frame, + partition_by=[col("l_partkey")], + ), ) df = df.filter(col("l_quantity") < lit(0.2) * col("avg_quantity")) diff --git a/examples/tpch/q20_potential_part_promotion.py b/examples/tpch/q20_potential_part_promotion.py index 85e7226f..05a26745 100644 --- a/examples/tpch/q20_potential_part_promotion.py +++ b/examples/tpch/q20_potential_part_promotion.py @@ -40,7 +40,9 @@ ctx = SessionContext() -df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_name") +df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns( + "p_partkey", "p_name" +) df_lineitem = ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns( "l_shipdate", "l_partkey", "l_suppkey", "l_quantity" ) diff --git a/examples/tpch/q22_global_sales_opportunity.py b/examples/tpch/q22_global_sales_opportunity.py index dfde19cb..622c1429 100644 --- a/examples/tpch/q22_global_sales_opportunity.py +++ b/examples/tpch/q22_global_sales_opportunity.py @@ -38,7 +38,9 @@ df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns( "c_phone", "c_acctbal", "c_custkey" ) -df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns("o_custkey") +df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns( + "o_custkey" +) # The nation code is a two digit number, but we need to convert it to a string literal nation_codes = F.make_array(*[lit(str(n)) for n in NATION_CODES]) diff --git a/examples/tpch/util.py b/examples/tpch/util.py index 191fa609..7e3d659d 100644 --- a/examples/tpch/util.py +++ b/examples/tpch/util.py @@ -20,14 +20,17 @@ """ import os -from pathlib import Path + def get_data_path(filename: str) -> str: path = os.path.dirname(os.path.abspath(__file__)) return os.path.join(path, "data", filename) + def get_answer_file(answer_file: str) -> str: path = os.path.dirname(os.path.abspath(__file__)) - return os.path.join(path, "../../benchmarks/tpch/data/answers", f"{answer_file}.out") + return os.path.join( + path, "../../benchmarks/tpch/data/answers", f"{answer_file}.out" + ) From 610adda10f0503edeba8fcbd44aca44e64279aa7 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 10 Jul 2024 09:00:47 -0400 Subject: [PATCH 13/55] Add docstring format checking to pre-commit stage --- pyproject.toml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b706065a..8f21dc48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,3 +64,18 @@ exclude = [".github/**", "ci/**", ".asf.yaml"] # Require Cargo.lock is up to date locked = true features = ["substrait"] + +# Enable docstring linting using the google style guide +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "D"] +ignore = ["D417"] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +# Disable docstring checking for these directories +[tool.ruff.lint.per-file-ignores] +"python/datafusion/tests/*" = ["D"] +"examples/*" = ["D"] +"dev/*" = ["D"] +"benchmarks/*" = ["D", "F"] From 265aeb7b33f6438e6412efac888e3e61302f67a6 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 11 Jul 2024 07:12:24 -0400 Subject: [PATCH 14/55] Set explicit return types on UDFs --- python/datafusion/udf.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 90cd593b..109282cf 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -21,11 +21,13 @@ import datafusion._internal as df_internal from datafusion.expr import Expr -from typing import Callable, TYPE_CHECKING +from typing import Callable, TYPE_CHECKING, TypeVar if TYPE_CHECKING: import pyarrow + _R = TypeVar("_R", bound=pyarrow.DataType) + class ScalarUDF: """Class for performing scalar user defined functions (UDF). @@ -36,9 +38,9 @@ class ScalarUDF: def __init__( self, name: str | None, - func: Callable, + func: Callable[..., _R], input_types: list[pyarrow.DataType], - return_type: pyarrow.DataType, + return_type: _R, volatility: str, ) -> None: """Instantiate a scalar user defined function (UDF).""" @@ -64,9 +66,9 @@ class AggregateUDF: def __init__( self, name: str | None, - accumulator: Callable, + accumulator: Callable[..., _R], input_types: list[pyarrow.DataType], - return_type: pyarrow.DataType, + return_type: _R, state_type: list[pyarrow.DataType], volatility: str, ) -> None: From 02564dea39009b8738dddfb4242ff499838419f9 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 12 Jul 2024 08:13:20 -0400 Subject: [PATCH 15/55] Add options of passing either a path or a string --- python/datafusion/context.py | 54 +++++++++++++++++----------- python/datafusion/dataframe.py | 13 +++---- python/datafusion/substrait.py | 66 +++++++++++++++++++++++++--------- 3 files changed, 90 insertions(+), 43 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 40462a53..9ec80e8a 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -39,6 +39,7 @@ import pyarrow import pandas import polars + import pathlib class SessionConfig: @@ -320,7 +321,9 @@ def with_disk_manager_os(self) -> RuntimeConfig: self.config_internal = self.config_internal.with_disk_manager_os() return self - def with_disk_manager_specified(self, paths: list[str]) -> RuntimeConfig: + def with_disk_manager_specified( + self, paths: list[str] | list[pathlib.Path] + ) -> RuntimeConfig: """Use the specified paths for the disk manager's temporary files. Parameters @@ -337,6 +340,7 @@ def with_disk_manager_specified(self, paths: list[str]) -> RuntimeConfig: -------- >>> config = RuntimeConfig().with_disk_manager_specified(["/tmp"]) """ + paths = [str(p) for p in paths] self.config_internal = self.config_internal.with_disk_manager_specified(paths) return self @@ -417,7 +421,7 @@ def with_greedy_memory_pool(self, size: int) -> RuntimeConfig: self.config_internal = self.config_internal.with_greedy_memory_pool(size) return self - def with_temp_file_path(self, path: str) -> RuntimeConfig: + def with_temp_file_path(self, path: str | pathlib.Path) -> RuntimeConfig: """Use the specified path to create any needed temporary files. Parameters @@ -434,7 +438,7 @@ def with_temp_file_path(self, path: str) -> RuntimeConfig: -------- >>> config = RuntimeConfig().with_temp_file_path("/tmp") """ - self.config_internal = self.config_internal.with_temp_file_path(path) + self.config_internal = self.config_internal.with_temp_file_path(str(path)) return self @@ -564,7 +568,7 @@ def register_object_store(self, schema: str, store: Any, host: str | None) -> No def register_listing_table( self, name: str, - path: str, + path: str | pathlib.Path, table_partition_cols: list[tuple[str, str]] = [], file_extension: str = ".parquet", schema: pyarrow.Schema | None = None, @@ -573,8 +577,14 @@ def register_listing_table( """Registers a Table that can assemble multiple files from locations in an ``ObjectStore`` instance into a single table.""" if file_sort_order is not None: file_sort_order = [[x.expr for x in xs] for xs in file_sort_order] + # TODO add unit test for pathlib path self.ctx.register_listing_table( - name, path, table_partition_cols, file_extension, schema, file_sort_order + name, + str(path), + table_partition_cols, + file_extension, + schema, + file_sort_order, ) def sql(self, query: str) -> DataFrame: @@ -745,7 +755,7 @@ def register_record_batches( def register_parquet( self, name: str, - path: str, + path: str | pathlib.Path, table_partition_cols: list[tuple[str, str]] = [], parquet_pruning: bool = True, file_extension: str = ".parquet", @@ -780,7 +790,7 @@ def register_parquet( """ self.ctx.register_parquet( name, - path, + str(path), table_partition_cols, parquet_pruning, file_extension, @@ -792,7 +802,7 @@ def register_parquet( def register_csv( self, name: str, - path: str, + path: str | pathlib.Path, schema: pyarrow.Schema | None = None, has_header: bool = True, delimiter: str = ",", @@ -825,7 +835,7 @@ def register_csv( """ self.ctx.register_csv( name, - path, + str(path), schema, has_header, delimiter, @@ -837,7 +847,7 @@ def register_csv( def register_json( self, name: str, - path: str, + path: str | pathlib.Path, schema: pyarrow.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", @@ -868,7 +878,7 @@ def register_json( """ self.ctx.register_json( name, - path, + str(path), schema, schema_infer_max_records, file_extension, @@ -879,7 +889,7 @@ def register_json( def register_avro( self, name: str, - path: str, + path: str | pathlib.Path, schema: pyarrow.Schema | None = None, file_extension: str = ".avro", table_partition_cols: list[tuple[str, str]] = [], @@ -902,7 +912,9 @@ def register_avro( table_partition_cols : list[tuple[str, str]], optional Partition columns, by default [] """ - self.ctx.register_avro(name, path, schema, file_extension, table_partition_cols) + self.ctx.register_avro( + name, str(path), schema, file_extension, table_partition_cols + ) def register_dataset(self, name: str, dataset: pyarrow.dataset.Dataset) -> None: """Register a `pyarrow.dataset.Dataset` as a table. @@ -1011,7 +1023,7 @@ def session_id(self) -> str: def read_json( self, - path: str, + path: str | pathlib.Path, schema: pyarrow.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", @@ -1042,7 +1054,7 @@ def read_json( """ return DataFrame( self.ctx.read_json( - path, + str(path), schema, schema_infer_max_records, file_extension, @@ -1053,7 +1065,7 @@ def read_json( def read_csv( self, - path: str, + path: str | pathlib.Path, schema: pyarrow.Schema | None = None, has_header: bool = True, delimiter: str = ",", @@ -1090,7 +1102,7 @@ def read_csv( """ return DataFrame( self.ctx.read_csv( - path, + str(path), schema, has_header, delimiter, @@ -1103,7 +1115,7 @@ def read_csv( def read_parquet( self, - path: str, + path: str | pathlib.Path, table_partition_cols: list[tuple[str, str]] = [], parquet_pruning: bool = True, file_extension: str = ".parquet", @@ -1139,7 +1151,7 @@ def read_parquet( """ return DataFrame( self.ctx.read_parquet( - path, + str(path), table_partition_cols, parquet_pruning, file_extension, @@ -1151,7 +1163,7 @@ def read_parquet( def read_avro( self, - path: str, + path: str | pathlib.Path, schema: pyarrow.Schema | None = None, file_partition_cols: list[tuple[str, str]] = [], file_extension: str = ".avro", @@ -1175,7 +1187,7 @@ def read_avro( DataFrame representation of the read Avro file """ return DataFrame( - self.ctx.read_avro(path, schema, file_partition_cols, file_extension) + self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) def read_table(self, table: Table) -> DataFrame: diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 3028d5c5..d3212578 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -29,6 +29,7 @@ import pyarrow as pa import pandas as pd import polars as pl + import pathlib from datafusion._internal import DataFrame as DataFrameInternal from datafusion.expr import Expr @@ -464,7 +465,7 @@ def except_all(self, other: DataFrame) -> DataFrame: """ return DataFrame(self.df.except_all(other.df)) - def write_csv(self, path: str) -> None: + def write_csv(self, path: str | pathlib.Path) -> None: """Execute the `DataFrame` and write the results to a CSV file. Parameters @@ -472,11 +473,11 @@ def write_csv(self, path: str) -> None: path : str Path of the CSV file to write. """ - self.df.write_csv(path) + self.df.write_csv(str(path)) def write_parquet( self, - path: str, + path: str | pathlib.Path, compression: str = "uncompressed", compression_level: int | None = None, ) -> None: @@ -491,9 +492,9 @@ def write_parquet( compression_level : int | None, optional Compression level to use, by default None """ - self.df.write_parquet(path, compression, compression_level) + self.df.write_parquet(str(path), compression, compression_level) - def write_json(self, path: str) -> None: + def write_json(self, path: str | pathlib.Path) -> None: """Execute the `DataFrame` and write the results to a JSON file. Parameters @@ -501,7 +502,7 @@ def write_json(self, path: str) -> None: path : str Path of the JSON file to write. """ - self.df.write_json(path) + self.df.write_json(str(path)) def to_arrow_table(self) -> pa.Table: """Execute the `DataFrame` and convert it into an Arrow Table. diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index 6390662e..dea5fb1c 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -26,20 +26,22 @@ from ._internal import substrait as substrait_internal from typing import TYPE_CHECKING +from typing_extensions import deprecated +import pathlib if TYPE_CHECKING: from datafusion.context import SessionContext from datafusion._internal import LogicalPlan -class plan: +class Plan: """A class representing an encodable substrait plan.""" def __init__(self, plan: substrait_internal.plan) -> None: """Create a substrait plan. The user should not have to call this constructor directly. Rather, it should be created - via ``serde`` or ``producer`` classes in this module. + via `Serde` or `Producer` classes in this module. """ self.plan_internal = plan @@ -54,11 +56,18 @@ def encode(self) -> bytes: return self.plan_internal.encode() -class serde: +@deprecated("Use `Plan` instead.") +class plan(Plan): + """See `Plan`.""" + + pass + + +class Serde: """Provides the serialization and deserialization required to convert to and from a Substrait plan.""" @staticmethod - def serialize(sql: str, ctx: SessionContext, path: str) -> None: + def serialize(sql: str, ctx: SessionContext, path: str | pathlib.Path) -> None: """Serialize a SQL query to a Substrait plan and write it to a file. Parameters @@ -69,11 +78,13 @@ def serialize(sql: str, ctx: SessionContext, path: str) -> None: SessionContext to use. path : str Path to write the Substrait plan to. + + TODO add unit test on passing in as path instead of str """ - return substrait_internal.serde.serialize(sql, ctx.ctx, path) + return substrait_internal.serde.serialize(sql, ctx.ctx, str(path)) @staticmethod - def serialize_to_plan(sql: str, ctx: SessionContext) -> plan: + def serialize_to_plan(sql: str, ctx: SessionContext) -> Plan: """Serialize a SQL query to a Substrait plan. Parameters @@ -88,7 +99,7 @@ def serialize_to_plan(sql: str, ctx: SessionContext) -> plan: plan Substrait plan. """ - return plan(substrait_internal.serde.serialize_to_plan(sql, ctx.ctx)) + return Plan(substrait_internal.serde.serialize_to_plan(sql, ctx.ctx)) @staticmethod def serialize_bytes(sql: str, ctx: SessionContext) -> bytes: @@ -109,7 +120,7 @@ def serialize_bytes(sql: str, ctx: SessionContext) -> bytes: return substrait_internal.serde.serialize_bytes(sql, ctx.ctx) @staticmethod - def deserialize(path: str) -> plan: + def deserialize(path: str | pathlib.Path) -> Plan: """Deserialize a Substrait plan from a file. Parameters @@ -121,11 +132,13 @@ def deserialize(path: str) -> plan: ------- plan Substrait plan. + + TODO add unit test for passing in as path """ - return plan(substrait_internal.serde.deserialize(path)) + return Plan(substrait_internal.serde.deserialize(path)) @staticmethod - def deserialize_bytes(proto_bytes: bytes) -> plan: + def deserialize_bytes(proto_bytes: bytes) -> Plan: """Deserialize a Substrait plan from bytes. Parameters @@ -138,14 +151,21 @@ def deserialize_bytes(proto_bytes: bytes) -> plan: plan Substrait plan. """ - return plan(substrait_internal.serde.deserialize_bytes(proto_bytes)) + return Plan(substrait_internal.serde.deserialize_bytes(proto_bytes)) + + +@deprecated("Use `Serde` instead.") +class serde(Serde): + """See `Serde` instead.""" + pass -class producer: + +class Producer: """Generates substrait plans from a logical plan.""" @staticmethod - def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> plan: + def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> Plan: """Convert a DataFusion LogicalPlan to a Substrait plan. Parameters @@ -160,16 +180,23 @@ def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> plan: plan Substrait plan. """ - return plan( + return Plan( substrait_internal.producer.to_substrait_plan(logical_plan, ctx.ctx) ) -class consumer: +@deprecated("Use `Producer` instead.") +class producer(Producer): + """Use `Producer` instead.""" + + pass + + +class Consumer: """Generates a logical plan from a substrait plan.""" @staticmethod - def from_substrait_plan(ctx: SessionContext, plan: plan) -> LogicalPlan: + def from_substrait_plan(ctx: SessionContext, plan: Plan) -> LogicalPlan: """Convert a Substrait plan to a DataFusion LogicalPlan. Parameters @@ -187,3 +214,10 @@ def from_substrait_plan(ctx: SessionContext, plan: plan) -> LogicalPlan: return substrait_internal.consumer.from_substrait_plan( ctx.ctx, plan.plan_internal ) + + +@deprecated("Use `Consumer` instead.") +class consumer(Consumer): + """Use `Consumer` instead.""" + + pass From e0e55a8c2b6906bfe7891a535b327535b91c0037 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 12 Jul 2024 08:13:43 -0400 Subject: [PATCH 16/55] Switch to google docstring style --- python/datafusion/expr.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 35a84791..01a7149e 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -264,9 +264,10 @@ def __init__( ) -> None: """Construct a window frame using the given parameters. - :param units: Should be one of `rows`, `range`, or `groups` - :param start_bound: Sets the preceeding bound. Must be >= 0. If none, this will be set to unbounded. If unit type is `groups`, this parameter must be set. - :param end_bound: Sets the following bound. Must be >= 0. If none, this will be set to unbounded. If unit type is `groups`, this parameter must be set. + Args: + units: Should be one of `rows`, `range`, or `groups`. + start_bound: Sets the preceeding bound. Must be >= 0. If none, this will be set to unbounded. If unit type is `groups`, this parameter must be set. + end_bound: Sets the following bound. Must be >= 0. If none, this will be set to unbounded. If unit type is `groups`, this parameter must be set. """ self.window_frame = expr_internal.WindowFrame(units, start_bound, end_bound) @@ -330,8 +331,6 @@ def __init__(self, case_builder: expr_internal.CaseBuilder) -> None: """Constructs a case builder. This is not typically called by the end user directly. See ``datafusion.functions.case`` instead. - - :param case_builder: Internal object. This constructor is not expected to be used by the end user. Instead use :func:`case` to construct. """ self.case_builder = case_builder From dcd5211dedc5ea4e87136e2f9e3976b54604d471 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 12 Jul 2024 09:07:00 -0400 Subject: [PATCH 17/55] Update unit tests to include registering via path or string --- python/datafusion/context.py | 3 +- python/datafusion/tests/test_sql.py | 54 ++++++++++++++++++++++------- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 9ec80e8a..fb83bc9d 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -574,10 +574,9 @@ def register_listing_table( schema: pyarrow.Schema | None = None, file_sort_order: list[list[Expr]] | None = None, ) -> None: - """Registers a Table that can assemble multiple files from locations in an ``ObjectStore`` instance into a single table.""" + """Registers a Table that can assemble multiple files from locations in an `ObjectStore` instance into a single table.""" if file_sort_order is not None: file_sort_order = [[x.expr for x in xs] for xs in file_sort_order] - # TODO add unit test for pathlib path self.ctx.register_listing_table( name, str(path), diff --git a/python/datafusion/tests/test_sql.py b/python/datafusion/tests/test_sql.py index f9eb588d..d85f380e 100644 --- a/python/datafusion/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -77,7 +77,13 @@ def test_register_csv(ctx, tmp_path): ) ctx.register_csv("csv3", path, schema=alternative_schema) - assert ctx.tables() == {"csv", "csv1", "csv2", "csv3", "csv_gzip"} + assert ctx.catalog().database().names() == { + "csv", + "csv1", + "csv2", + "csv3", + "csv_gzip", + } for table in ["csv", "csv1", "csv2", "csv_gzip"]: result = ctx.sql(f"SELECT COUNT(int) AS cnt FROM {table}").collect() @@ -101,14 +107,16 @@ def test_register_csv(ctx, tmp_path): def test_register_parquet(ctx, tmp_path): path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) ctx.register_parquet("t", path) - assert ctx.tables() == {"t"} + ctx.register_parquet("t1", str(path)) + assert ctx.catalog().database().names() == {"t", "t1"} result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() result = pa.Table.from_batches(result) assert result.to_pydict() == {"cnt": [100]} -def test_register_parquet_partitioned(ctx, tmp_path): +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_register_parquet_partitioned(ctx, tmp_path, path_to_str): dir_root = tmp_path / "dataset_parquet_partitioned" dir_root.mkdir(exist_ok=False) (dir_root / "grp=a").mkdir(exist_ok=False) @@ -125,14 +133,16 @@ def test_register_parquet_partitioned(ctx, tmp_path): pa.parquet.write_table(table.slice(0, 3), dir_root / "grp=a/file.parquet") pa.parquet.write_table(table.slice(3, 4), dir_root / "grp=b/file.parquet") + dir_root = str(dir_root) if path_to_str else dir_root + ctx.register_parquet( "datapp", - str(dir_root), + dir_root, table_partition_cols=[("grp", "string")], parquet_pruning=True, file_extension=".parquet", ) - assert ctx.tables() == {"datapp"} + assert ctx.catalog().database().names() == {"datapp"} result = ctx.sql("SELECT grp, COUNT(*) AS cnt FROM datapp GROUP BY grp").collect() result = pa.Table.from_batches(result) @@ -141,12 +151,14 @@ def test_register_parquet_partitioned(ctx, tmp_path): assert dict(zip(rd["grp"], rd["cnt"])) == {"a": 3, "b": 1} -def test_register_dataset(ctx, tmp_path): +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_register_dataset(ctx, tmp_path, path_to_str): path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) + path = str(path) if path_to_str else path dataset = ds.dataset(path, format="parquet") ctx.register_dataset("t", dataset) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() result = pa.Table.from_batches(result) @@ -175,6 +187,12 @@ def test_register_json(ctx, tmp_path): file_extension="gz", file_compression_type="gzip", ) + ctx.register_json( + "json_gzip1", + str(gzip_path), + file_extension="gz", + file_compression_type="gzip", + ) alternative_schema = pa.schema( [ @@ -185,7 +203,14 @@ def test_register_json(ctx, tmp_path): ) ctx.register_json("json3", path, schema=alternative_schema) - assert ctx.tables() == {"json", "json1", "json2", "json3", "json_gzip"} + assert ctx.catalog().database().names() == { + "json", + "json1", + "json2", + "json3", + "json_gzip", + "json_gzip1", + } for table in ["json", "json1", "json2", "json_gzip"]: result = ctx.sql(f'SELECT COUNT("B") AS cnt FROM {table}').collect() @@ -235,7 +260,7 @@ def test_execute(ctx, tmp_path): path = helpers.write_parquet(tmp_path / "a.parquet", pa.array(data)) ctx.register_parquet("t", path) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} # count result = ctx.sql("SELECT COUNT(a) AS cnt FROM t WHERE a IS NOT NULL").collect() @@ -378,7 +403,10 @@ def test_simple_select(ctx, tmp_path, arr): @pytest.mark.parametrize("file_sort_order", (None, [[col("int").sort(True, True)]])) @pytest.mark.parametrize("pass_schema", (True, False)) -def test_register_listing_table(ctx, tmp_path, pass_schema, file_sort_order): +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_register_listing_table( + ctx, tmp_path, pass_schema, file_sort_order, path_to_str +): dir_root = tmp_path / "dataset_parquet_partitioned" dir_root.mkdir(exist_ok=False) (dir_root / "grp=a/date_id=20201005").mkdir(exist_ok=False, parents=True) @@ -403,16 +431,18 @@ def test_register_listing_table(ctx, tmp_path, pass_schema, file_sort_order): table.slice(5, 10), dir_root / "grp=b/date_id=20201005/file.parquet" ) + dir_root = f"file://{dir_root}/" if path_to_str else dir_root + ctx.register_object_store("file://local", LocalFileSystem(), None) ctx.register_listing_table( "my_table", - f"file://{dir_root}/", + dir_root, table_partition_cols=[("grp", "string"), ("date_id", "int")], file_extension=".parquet", schema=table.schema if pass_schema else None, file_sort_order=file_sort_order, ) - assert ctx.tables() == {"my_table"} + assert ctx.catalog().database().names() == {"my_table"} result = ctx.sql( "SELECT grp, COUNT(*) AS count FROM my_table GROUP BY grp" From 1063cffc1005518f1d0973e6c523a25ab76f5ef2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 12 Jul 2024 09:11:53 -0400 Subject: [PATCH 18/55] Add py.typed file --- python/datafusion/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/datafusion/py.typed diff --git a/python/datafusion/py.typed b/python/datafusion/py.typed new file mode 100644 index 00000000..e69de29b From 5ba2017cfa33501cb584180132b218680087eb27 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 12 Jul 2024 21:07:26 -0400 Subject: [PATCH 19/55] Resolve deprecation warnings in unit tests --- python/datafusion/dataframe.py | 5 ++- python/datafusion/tests/test_context.py | 38 +++++++++++++---------- python/datafusion/tests/test_dataframe.py | 4 +-- python/datafusion/tests/test_substrait.py | 2 +- 4 files changed, 27 insertions(+), 22 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index d3212578..90225f23 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -581,8 +581,7 @@ def unnest_column(self, column: str, preserve_nulls: bool = True) -> DataFrame: """See ``unnest_columns``.""" return DataFrame(self.df.unnest_column(column, preserve_nulls=preserve_nulls)) - def unnest_columns( - self, columns: list[str], preserve_nulls: bool = True - ) -> DataFrame: + def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFrame: """Expand columns of arrays into a single row per array element.""" + columns = [c for c in columns] return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls)) diff --git a/python/datafusion/tests/test_context.py b/python/datafusion/tests/test_context.py index abc324db..2d5e21b7 100644 --- a/python/datafusion/tests/test_context.py +++ b/python/datafusion/tests/test_context.py @@ -17,6 +17,7 @@ import gzip import os import datetime as dt +import pathlib import pyarrow as pa import pyarrow.dataset as ds @@ -68,7 +69,7 @@ def test_register_record_batches(ctx): ctx.register_record_batches("t", [[batch]]) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} result = ctx.sql("SELECT a+b, a-b FROM t").collect() @@ -84,7 +85,7 @@ def test_create_dataframe_registers_unique_table_name(ctx): ) df = ctx.create_dataframe([[batch]]) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 @@ -104,7 +105,7 @@ def test_create_dataframe_registers_with_defined_table_name(ctx): ) df = ctx.create_dataframe([[batch]], name="tbl") - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 @@ -118,7 +119,7 @@ def test_from_arrow_table(ctx): # convert to DataFrame df = ctx.from_arrow_table(table) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 @@ -134,7 +135,7 @@ def test_from_arrow_table_with_name(ctx): # convert to DataFrame with optional name df = ctx.from_arrow_table(table, name="tbl") - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert tables[0] == "tbl" @@ -147,7 +148,7 @@ def test_from_arrow_table_empty(ctx): # convert to DataFrame df = ctx.from_arrow_table(table) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 @@ -162,7 +163,7 @@ def test_from_arrow_table_empty_no_schema(ctx): # convert to DataFrame df = ctx.from_arrow_table(table) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 @@ -180,7 +181,7 @@ def test_from_pylist(ctx): ] df = ctx.from_pylist(data) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 @@ -194,7 +195,7 @@ def test_from_pydict(ctx): data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = ctx.from_pydict(data) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 @@ -210,7 +211,7 @@ def test_from_pandas(ctx): pandas_df = pd.DataFrame(data) df = ctx.from_pandas(pandas_df) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 @@ -226,7 +227,7 @@ def test_from_polars(ctx): polars_df = pd.DataFrame(data) df = ctx.from_polars(polars_df) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 @@ -273,7 +274,7 @@ def test_register_dataset(ctx): dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} result = ctx.sql("SELECT a+b, a-b FROM t").collect() @@ -290,7 +291,7 @@ def test_dataset_filter(ctx, capfd): dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} df = ctx.sql("SELECT a+b, a-b FROM t WHERE a BETWEEN 2 and 3 AND b > 5") # Make sure the filter was pushed down in Physical Plan @@ -370,7 +371,7 @@ def test_dataset_filter_nested_data(ctx): dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} df = ctx.table("t") @@ -468,8 +469,13 @@ def test_read_csv_compressed(ctx, tmp_path): def test_read_parquet(ctx): - csv_df = ctx.read_parquet(path="parquet/data/alltypes_plain.parquet") - csv_df.show() + parquet_df = ctx.read_parquet(path="parquet/data/alltypes_plain.parquet") + parquet_df.show() + assert parquet_df is not None + + path = pathlib.Path.cwd() / "parquet/data/alltypes_plain.parquet" + parquet_df = ctx.read_parquet(path=path) + assert parquet_df is not None def test_read_avro(ctx): diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 5f26063e..7a7f0994 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -176,7 +176,7 @@ def test_with_column_renamed(df): def test_unnest(nested_df): - nested_df = nested_df.unnest_column("a") + nested_df = nested_df.unnest_columns("a") # execute and collect the first (and only) batch result = nested_df.collect()[0] @@ -186,7 +186,7 @@ def test_unnest(nested_df): def test_unnest_without_nulls(nested_df): - nested_df = nested_df.unnest_column("a", preserve_nulls=False) + nested_df = nested_df.unnest_columns("a", preserve_nulls=False) # execute and collect the first (and only) batch result = nested_df.collect()[0] diff --git a/python/datafusion/tests/test_substrait.py b/python/datafusion/tests/test_substrait.py index 260db5eb..b0ba1b7e 100644 --- a/python/datafusion/tests/test_substrait.py +++ b/python/datafusion/tests/test_substrait.py @@ -35,7 +35,7 @@ def test_substrait_serialization(ctx): ctx.register_record_batches("t", [[batch]]) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} # For now just make sure the method calls blow up substrait_plan = ss.serde.serialize_to_plan("SELECT * FROM t", ctx) From 438afa0d12d0ffedb1391ff41705476391c2c639 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 12 Jul 2024 21:09:31 -0400 Subject: [PATCH 20/55] Add path to unit test --- python/datafusion/tests/test_context.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/datafusion/tests/test_context.py b/python/datafusion/tests/test_context.py index 2d5e21b7..a81f0e10 100644 --- a/python/datafusion/tests/test_context.py +++ b/python/datafusion/tests/test_context.py @@ -479,8 +479,13 @@ def test_read_parquet(ctx): def test_read_avro(ctx): - csv_df = ctx.read_avro(path="testing/data/avro/alltypes_plain.avro") - csv_df.show() + avro_df = ctx.read_avro(path="testing/data/avro/alltypes_plain.avro") + avro_df.show() + assert avro_df is not None + + path = pathlib.Path.cwd() / "testing/data/avro/alltypes_plain.avro" + avro_df = ctx.read_avro(path=path) + assert avro_df is not None def test_create_sql_options(): From 837e3b2c0ab8006e1bcaa659f8ce6071121fe38b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 06:06:21 -0400 Subject: [PATCH 21/55] Expose an option in write_csv to include header and add unit test --- python/datafusion/dataframe.py | 4 ++-- python/datafusion/tests/test_dataframe.py | 13 +++++++++++++ src/dataframe.rs | 17 +++++++++++------ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 90225f23..ac10b5ba 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -465,7 +465,7 @@ def except_all(self, other: DataFrame) -> DataFrame: """ return DataFrame(self.df.except_all(other.df)) - def write_csv(self, path: str | pathlib.Path) -> None: + def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None: """Execute the `DataFrame` and write the results to a CSV file. Parameters @@ -473,7 +473,7 @@ def write_csv(self, path: str | pathlib.Path) -> None: path : str Path of the CSV file to write. """ - self.df.write_csv(str(path)) + self.df.write_csv(str(path), with_header) def write_parquet( self, diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 7a7f0994..f624529d 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -739,6 +739,19 @@ def test_describe(df): } +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_write_csv(ctx, df, tmp_path, path_to_str): + path = str(tmp_path) if path_to_str else tmp_path + + df.write_csv(path, with_header=True) + + ctx.register_csv("csv", path) + result = ctx.table("csv").to_pydict() + expected = df.to_pydict() + + assert result == expected + + def test_write_parquet(df, tmp_path): path = tmp_path diff --git a/src/dataframe.rs b/src/dataframe.rs index 53e11234..4db59d4f 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::arrow::util::pretty; -use datafusion::config::TableParquetOptions; +use datafusion::config::{CsvOptions, TableParquetOptions}; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::execution::SendableRecordBatchStream; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; @@ -349,13 +349,18 @@ impl PyDataFrame { } /// Write a `DataFrame` to a CSV file. - fn write_csv(&self, path: &str, py: Python) -> PyResult<()> { + fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyResult<()> { + let csv_options = CsvOptions { + has_header: Some(with_header), + ..Default::default() + }; wait_for_future( py, - self.df - .as_ref() - .clone() - .write_csv(path, DataFrameWriteOptions::new(), None), + self.df.as_ref().clone().write_csv( + path, + DataFrameWriteOptions::new(), + Some(csv_options), + ), )?; Ok(()) } From 6e75eeef7958ef3e6cc0e59d4335c9393f8901b2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 06:07:45 -0400 Subject: [PATCH 22/55] Update write_parquet unit test to include paths or strings --- python/datafusion/tests/test_dataframe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index f624529d..2b38d0d1 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -752,8 +752,9 @@ def test_write_csv(ctx, df, tmp_path, path_to_str): assert result == expected -def test_write_parquet(df, tmp_path): - path = tmp_path +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_write_parquet(df, tmp_path, path_to_str): + path = str(tmp_path) if path_to_str else tmp_path df.write_parquet(str(path)) result = pq.read_table(str(path)).to_pydict() From 2ebe2e5cc68da400958bb23e4b6dbcc50dbb181a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 06:09:04 -0400 Subject: [PATCH 23/55] Add unit test for write_json --- python/datafusion/tests/test_dataframe.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 2b38d0d1..cd919796 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -752,6 +752,19 @@ def test_write_csv(ctx, df, tmp_path, path_to_str): assert result == expected +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_write_json(ctx, df, tmp_path, path_to_str): + path = str(tmp_path) if path_to_str else tmp_path + + df.write_json(path) + + ctx.register_json("json", path) + result = ctx.table("json").to_pydict() + expected = df.to_pydict() + + assert result == expected + + @pytest.mark.parametrize("path_to_str", (True, False)) def test_write_parquet(df, tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path From dad0d26f51c6790402cbfb47411f9811ec547bb3 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 07:38:43 -0400 Subject: [PATCH 24/55] Add unit test for substrait serialization to a file --- examples/substrait.py | 10 +++---- python/datafusion/substrait.py | 6 ++-- python/datafusion/tests/test_substrait.py | 36 +++++++++++++++++++---- src/substrait.rs | 2 +- 4 files changed, 39 insertions(+), 15 deletions(-) diff --git a/examples/substrait.py b/examples/substrait.py index 7a268a78..66f8a30d 100644 --- a/examples/substrait.py +++ b/examples/substrait.py @@ -26,7 +26,7 @@ # Register table with context ctx.register_csv("aggregate_test_data", "./testing/data/csv/aggregate_test_100.csv") -substrait_plan = ss.serde.serialize_to_plan("SELECT * FROM aggregate_test_data", ctx) +substrait_plan = ss.Serde.serialize_to_plan("SELECT * FROM aggregate_test_data", ctx) # type(substrait_plan) -> # Encode it to bytes @@ -37,15 +37,15 @@ # Alternative serialization approaches # type(substrait_bytes) -> , at this point the bytes can be distributed to file, network, etc safely # where they could subsequently be deserialized on the receiving end. -substrait_bytes = ss.serde.serialize_bytes("SELECT * FROM aggregate_test_data", ctx) +substrait_bytes = ss.Serde.serialize_bytes("SELECT * FROM aggregate_test_data", ctx) # Imagine here bytes would be read from network, file, etc ... for example brevity this is omitted and variable is simply reused # type(substrait_plan) -> -substrait_plan = ss.serde.deserialize_bytes(substrait_bytes) +substrait_plan = ss.Serde.deserialize_bytes(substrait_bytes) # type(df_logical_plan) -> -df_logical_plan = ss.consumer.from_substrait_plan(ctx, substrait_plan) +df_logical_plan = ss.Consumer.from_substrait_plan(ctx, substrait_plan) # Back to Substrait Plan just for demonstration purposes # type(substrait_plan) -> -substrait_plan = ss.producer.to_substrait_plan(df_logical_plan) +substrait_plan = ss.Producer.to_substrait_plan(df_logical_plan) diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index dea5fb1c..a6010ca2 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -37,7 +37,7 @@ class Plan: """A class representing an encodable substrait plan.""" - def __init__(self, plan: substrait_internal.plan) -> None: + def __init__(self, plan: substrait_internal.Plan) -> None: """Create a substrait plan. The user should not have to call this constructor directly. Rather, it should be created @@ -78,8 +78,6 @@ def serialize(sql: str, ctx: SessionContext, path: str | pathlib.Path) -> None: SessionContext to use. path : str Path to write the Substrait plan to. - - TODO add unit test on passing in as path instead of str """ return substrait_internal.serde.serialize(sql, ctx.ctx, str(path)) @@ -135,7 +133,7 @@ def deserialize(path: str | pathlib.Path) -> Plan: TODO add unit test for passing in as path """ - return Plan(substrait_internal.serde.deserialize(path)) + return Plan(substrait_internal.serde.deserialize(str(path))) @staticmethod def deserialize_bytes(proto_bytes: bytes) -> Plan: diff --git a/python/datafusion/tests/test_substrait.py b/python/datafusion/tests/test_substrait.py index b0ba1b7e..2071c8f3 100644 --- a/python/datafusion/tests/test_substrait.py +++ b/python/datafusion/tests/test_substrait.py @@ -38,14 +38,40 @@ def test_substrait_serialization(ctx): assert ctx.catalog().database().names() == {"t"} # For now just make sure the method calls blow up - substrait_plan = ss.serde.serialize_to_plan("SELECT * FROM t", ctx) + substrait_plan = ss.Serde.serialize_to_plan("SELECT * FROM t", ctx) substrait_bytes = substrait_plan.encode() assert isinstance(substrait_bytes, bytes) - substrait_bytes = ss.serde.serialize_bytes("SELECT * FROM t", ctx) - substrait_plan = ss.serde.deserialize_bytes(substrait_bytes) - logical_plan = ss.consumer.from_substrait_plan(ctx, substrait_plan) + substrait_bytes = ss.Serde.serialize_bytes("SELECT * FROM t", ctx) + substrait_plan = ss.Serde.deserialize_bytes(substrait_bytes) + logical_plan = ss.Consumer.from_substrait_plan(ctx, substrait_plan) # demonstrate how to create a DataFrame from a deserialized logical plan df = ctx.create_dataframe_from_logical_plan(logical_plan) - substrait_plan = ss.producer.to_substrait_plan(df.logical_plan(), ctx) + substrait_plan = ss.Producer.to_substrait_plan(df.logical_plan(), ctx) + + +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_substrait_file_serialization(ctx, tmp_path, path_to_str): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + + ctx.register_record_batches("t", [[batch]]) + + assert ctx.catalog().database().names() == {"t"} + + path = tmp_path / "substrait_plan" + path = str(path) if path_to_str else path + + sql_command = "SELECT * FROM T" + ss.Serde.serialize(sql_command, ctx, path) + + expected_plan = ss.Serde.serialize_to_plan(sql_command, ctx) + actual_plan = ss.Serde.deserialize(path) + + expected_logical_plan = ss.Consumer.from_substrait_plan(ctx, expected_plan) + expected_actual_plan = ss.Consumer.from_substrait_plan(ctx, actual_plan) + + assert str(expected_logical_plan) == str(expected_actual_plan) diff --git a/src/substrait.rs b/src/substrait.rs index 1e9e16c7..60a52380 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -27,7 +27,7 @@ use datafusion_substrait::serializer; use datafusion_substrait::substrait::proto::Plan; use prost::Message; -#[pyclass(name = "plan", module = "datafusion.substrait", subclass)] +#[pyclass(name = "Plan", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] pub struct PyPlan { pub plan: Plan, From ae569ff3f90f8696e6ed899049bcfb6d7c8dc619 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 07:51:20 -0400 Subject: [PATCH 25/55] Add unit tests for runtime config --- python/datafusion/context.py | 4 +--- python/datafusion/tests/test_context.py | 30 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index fb83bc9d..65d83f05 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -321,9 +321,7 @@ def with_disk_manager_os(self) -> RuntimeConfig: self.config_internal = self.config_internal.with_disk_manager_os() return self - def with_disk_manager_specified( - self, paths: list[str] | list[pathlib.Path] - ) -> RuntimeConfig: + def with_disk_manager_specified(self, *paths: str | pathlib.Path) -> RuntimeConfig: """Use the specified paths for the disk manager's temporary files. Parameters diff --git a/python/datafusion/tests/test_context.py b/python/datafusion/tests/test_context.py index a81f0e10..fb60360d 100644 --- a/python/datafusion/tests/test_context.py +++ b/python/datafusion/tests/test_context.py @@ -38,6 +38,36 @@ def test_create_context_no_args(): SessionContext() +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_runtime_configs(tmp_path, path_to_str): + path1 = tmp_path / "dir1" + path2 = tmp_path / "dir2" + + path1 = str(path1) if path_to_str else path1 + path2 = str(path2) if path_to_str else path2 + + runtime = RuntimeConfig().with_disk_manager_specified(path1, path2) + config = SessionConfig().with_default_catalog_and_schema("foo", "bar") + ctx = SessionContext(config, runtime) + assert ctx is not None + + db = ctx.catalog("foo").database("bar") + assert db is not None + + +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_temporary_files(tmp_path, path_to_str): + path = str(tmp_path) if path_to_str else tmp_path + + runtime = RuntimeConfig().with_temp_file_path(path) + config = SessionConfig().with_default_catalog_and_schema("foo", "bar") + ctx = SessionContext(config, runtime) + assert ctx is not None + + db = ctx.catalog("foo").database("bar") + assert db is not None + + def test_create_context_with_all_valid_args(): runtime = RuntimeConfig().with_disk_manager_os().with_fair_spill_pool(10000000) config = ( From 4f973af0c75e7f2a837d41e1dd3e53c13af55515 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 07:53:56 -0400 Subject: [PATCH 26/55] Setting return type to typing_extensions.Self per PR recommendation --- python/datafusion/record_batch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py index 4b7416d9..b0549255 100644 --- a/python/datafusion/record_batch.py +++ b/python/datafusion/record_batch.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: import pyarrow import datafusion._internal as df_internal + import typing_extensions class RecordBatch: @@ -62,6 +63,6 @@ def __next__(self) -> RecordBatch | None: next_batch = next(self.rbs) return RecordBatch(next_batch) if next_batch is not None else None - def __iter__(self) -> RecordBatchStream: + def __iter__(self) -> typing_extensions.Self: """Iterator function.""" return self From f2ed822772d5b8b56015b5d170b129b31e2bcccf Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 07:54:59 -0400 Subject: [PATCH 27/55] Correcting __next__ to not return None since it will raise an exception instead. --- python/datafusion/record_batch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py index b0549255..ddb90178 100644 --- a/python/datafusion/record_batch.py +++ b/python/datafusion/record_batch.py @@ -58,10 +58,10 @@ def next(self) -> RecordBatch | None: return next_batch - def __next__(self) -> RecordBatch | None: + def __next__(self) -> RecordBatch: """Iterator function.""" next_batch = next(self.rbs) - return RecordBatch(next_batch) if next_batch is not None else None + return RecordBatch(next_batch) def __iter__(self) -> typing_extensions.Self: """Iterator function.""" From c2ee65dfe32db5f9e6835fbb7b2aa5677249b0b1 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 08:07:18 -0400 Subject: [PATCH 28/55] Add optiona parameter of decimal places to round and add unit test --- python/datafusion/functions.py | 7 ++-- python/datafusion/tests/test_functions.py | 42 ++++++++++++----------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 36a8aba8..087680c7 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -486,9 +486,12 @@ def right(string: Expr, n: Expr) -> Expr: return Expr(f.right(string.expr, n.expr)) -def round(arg: Expr) -> Expr: +def round(value: Expr, decimal_places: Expr | None = None) -> Expr: """Round the argument to the nearest integer.""" - return Expr(f.round(arg.expr)) + if decimal_places is None: + return Expr(f.round(value.expr)) + + return Expr(f.round(value.expr, decimal_places.expr)) def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 2e601c28..77297114 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -138,6 +138,7 @@ def test_math_functions(): f.power(col_v, literal(pa.scalar(3))), f.pow(col_v, literal(pa.scalar(4))), f.round(col_v), + f.round(col_v, literal(pa.scalar(3))), f.sqrt(col_v), f.signum(col_v), f.trunc(col_v), @@ -181,29 +182,30 @@ def test_math_functions(): np.testing.assert_array_almost_equal(result.column(15), np.power(values, 3)) np.testing.assert_array_almost_equal(result.column(16), np.power(values, 4)) np.testing.assert_array_almost_equal(result.column(17), np.round(values)) - np.testing.assert_array_almost_equal(result.column(18), np.sqrt(values)) - np.testing.assert_array_almost_equal(result.column(19), np.sign(values)) - np.testing.assert_array_almost_equal(result.column(20), np.trunc(values)) - np.testing.assert_array_almost_equal(result.column(21), np.arcsinh(values)) - np.testing.assert_array_almost_equal(result.column(22), np.arccosh(values)) - np.testing.assert_array_almost_equal(result.column(23), np.arctanh(values)) - np.testing.assert_array_almost_equal(result.column(24), np.cbrt(values)) - np.testing.assert_array_almost_equal(result.column(25), np.cosh(values)) - np.testing.assert_array_almost_equal(result.column(26), np.degrees(values)) - np.testing.assert_array_almost_equal(result.column(27), np.gcd(9, 3)) - np.testing.assert_array_almost_equal(result.column(28), np.lcm(6, 4)) + np.testing.assert_array_almost_equal(result.column(18), np.round(values, 3)) + np.testing.assert_array_almost_equal(result.column(19), np.sqrt(values)) + np.testing.assert_array_almost_equal(result.column(20), np.sign(values)) + np.testing.assert_array_almost_equal(result.column(21), np.trunc(values)) + np.testing.assert_array_almost_equal(result.column(22), np.arcsinh(values)) + np.testing.assert_array_almost_equal(result.column(23), np.arccosh(values)) + np.testing.assert_array_almost_equal(result.column(24), np.arctanh(values)) + np.testing.assert_array_almost_equal(result.column(25), np.cbrt(values)) + np.testing.assert_array_almost_equal(result.column(26), np.cosh(values)) + np.testing.assert_array_almost_equal(result.column(27), np.degrees(values)) + np.testing.assert_array_almost_equal(result.column(28), np.gcd(9, 3)) + np.testing.assert_array_almost_equal(result.column(29), np.lcm(6, 4)) np.testing.assert_array_almost_equal( - result.column(29), np.where(np.isnan(na_values), 5, na_values) + result.column(30), np.where(np.isnan(na_values), 5, na_values) ) - np.testing.assert_array_almost_equal(result.column(30), np.pi) - np.testing.assert_array_almost_equal(result.column(31), np.radians(values)) - np.testing.assert_array_almost_equal(result.column(32), np.sinh(values)) - np.testing.assert_array_almost_equal(result.column(33), np.tanh(values)) - np.testing.assert_array_almost_equal(result.column(34), math.factorial(6)) - np.testing.assert_array_almost_equal(result.column(35), np.isnan(na_values)) - np.testing.assert_array_almost_equal(result.column(36), na_values == 0) + np.testing.assert_array_almost_equal(result.column(31), np.pi) + np.testing.assert_array_almost_equal(result.column(32), np.radians(values)) + np.testing.assert_array_almost_equal(result.column(33), np.sinh(values)) + np.testing.assert_array_almost_equal(result.column(34), np.tanh(values)) + np.testing.assert_array_almost_equal(result.column(35), math.factorial(6)) + np.testing.assert_array_almost_equal(result.column(36), np.isnan(na_values)) + np.testing.assert_array_almost_equal(result.column(37), na_values == 0) np.testing.assert_array_almost_equal( - result.column(37), np.emath.logn(3, values + 1.0) + result.column(38), np.emath.logn(3, values + 1.0) ) From 835e37481409ec4793473b62b8dea9d4b6b8db4d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 08:21:45 -0400 Subject: [PATCH 29/55] Improve docstrings --- python/datafusion/functions.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 087680c7..d6886f49 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -486,16 +486,22 @@ def right(string: Expr, n: Expr) -> Expr: return Expr(f.right(string.expr, n.expr)) -def round(value: Expr, decimal_places: Expr | None = None) -> Expr: - """Round the argument to the nearest integer.""" - if decimal_places is None: - return Expr(f.round(value.expr)) +def round(value: Expr, decimal_places: Expr = Expr.literal(0)) -> Expr: + """Round the argument to the nearest integer. + If the optional ``decimal_places`` is specified, round to the nearest number of + decimal places. You can specify a negative number of decimal places. For example + `round(lit(125.2345), lit(-2))` would yield a value of `100.0`. + """ return Expr(f.round(value.expr, decimal_places.expr)) def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: - """Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.""" + """Add right padding to a string. + + Extends the string to length length by appending the characters fill (a space + by default). If the string is already longer than length then it is truncated. + """ characters = characters if characters is not None else Expr.literal(" ") return Expr(f.rpad(string.expr, count.expr, characters.expr)) From 08b83ac01781705a18baeef74a186e05ae333684 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 08:22:26 -0400 Subject: [PATCH 30/55] Set default to None instead of empty dict --- python/datafusion/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 65d83f05..a3fbc581 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -45,7 +45,7 @@ class SessionConfig: """Session configuration options.""" - def __init__(self, config_options: dict[str, str] = {}) -> None: + def __init__(self, config_options: dict[str, str] | None = None) -> None: """Create a new `SessionConfig` with the given configuration options. Parameters From 2ccd5ad6c120ccd335913354b63d56260ae13274 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 08:43:30 -0400 Subject: [PATCH 31/55] User request to allow passing multiple arguments to filter() --- python/datafusion/dataframe.py | 12 ++++++++---- python/datafusion/expr.py | 2 +- python/datafusion/tests/test_dataframe.py | 16 ++++++++++++++-- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index ac10b5ba..79d00fe5 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -126,22 +126,26 @@ def select(self, *args: Expr) -> DataFrame: args = [arg.expr for arg in args] return DataFrame(self.df.select(*args)) - def filter(self, predicate: Expr) -> DataFrame: + def filter(self, *predicates: Expr) -> DataFrame: """Return a DataFrame for which `predicate` evaluates to `True`. Rows for which `predicate` evaluates to `False` or `None` are filtered out. Parameters ---------- - predicate : Expr - Predicate expression to filter the DataFrame. + predicates : Predicate expression(s) to filter the DataFrame. If more than one + is provided, these predicates will be combined as a logical AND. If more complex + logic is required, see logical operations in `datafusion.functions`. Returns: ------- DataFrame DataFrame after filtering. """ - return DataFrame(self.df.filter(predicate.expr)) + df = self.df + for p in predicates: + df = df.filter(p.expr) + return DataFrame(df) def with_column(self, name: str, expr: Expr) -> DataFrame: """Add an additional column to the DataFrame. diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 01a7149e..6132994a 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -169,7 +169,7 @@ def __eq__(self, rhs: Expr) -> Expr: def __ne__(self, rhs: Expr) -> Expr: """Not equal to.""" - return Expr(self.expr.__eq__(rhs.expr)) + return Expr(self.expr.__ne__(rhs.expr)) def __ge__(self, rhs: Expr) -> Expr: """Greater than or equal to.""" diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index cd919796..82f27ad8 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -108,17 +108,29 @@ def test_select_columns(df): def test_filter(df): - df = df.filter(column("a") > literal(2)).select( + df1 = df.filter(column("a") > literal(2)).select( column("a") + column("b"), column("a") - column("b"), ) # execute and collect the first (and only) batch - result = df.collect()[0] + result = df1.collect()[0] assert result.column(0) == pa.array([9]) assert result.column(1) == pa.array([-3]) + df.show() + # verify that if there is no filter applied, internal dataframe is unchanged + df2 = df.filter() + assert df.df == df2.df + + df3 = df.filter(column("a") > literal(1), column("b") != literal(6)) + result = df3.collect()[0] + + assert result.column(0) == pa.array([2]) + assert result.column(1) == pa.array([5]) + assert result.column(2) == pa.array([5]) + def test_sort(df): df = df.sort(column("b").sort(ascending=False)) From 13be857399c4c13109b95ef93a5882849e2ab020 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 09:14:08 -0400 Subject: [PATCH 32/55] Enhance Expr comparison operators to accept any python value and attempt to convert it to a literal --- python/datafusion/expr.py | 103 +++++++++++++++++++++------ python/datafusion/tests/test_expr.py | 25 ++++++- 2 files changed, 105 insertions(+), 23 deletions(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 6132994a..0739c6cf 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -127,32 +127,61 @@ def __repr__(self) -> str: """Generate a string representation of this expression.""" return self.expr.__repr__() - def __add__(self, rhs: Expr) -> Expr: - """Addition operator.""" + def __add__(self, rhs: Any) -> Expr: + """Addition operator. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) return Expr(self.expr.__add__(rhs.expr)) - def __sub__(self, rhs: Expr) -> Expr: - """Subtraction operator.""" + def __sub__(self, rhs: Any) -> Expr: + """Subtraction operator. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) return Expr(self.expr.__sub__(rhs.expr)) - def __truediv__(self, rhs: Expr) -> Expr: - """Division operator.""" + def __truediv__(self, rhs: Any) -> Expr: + """Division operator. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) return Expr(self.expr.__truediv__(rhs.expr)) - def __mul__(self, rhs: Expr) -> Expr: - """Multiplication operator.""" + def __mul__(self, rhs: Any) -> Expr: + """Multiplication operator. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) return Expr(self.expr.__mul__(rhs.expr)) - def __mod__(self, rhs: Expr) -> Expr: - """Modulo operator (%).""" + def __mod__(self, rhs: Any) -> Expr: + """Modulo operator (%). + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) return Expr(self.expr.__mod__(rhs.expr)) def __and__(self, rhs: Expr) -> Expr: """Logical AND.""" + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) return Expr(self.expr.__and__(rhs.expr)) def __or__(self, rhs: Expr) -> Expr: """Logical OR.""" + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) return Expr(self.expr.__or__(rhs.expr)) def __invert__(self) -> Expr: @@ -163,28 +192,58 @@ def __getitem__(self, key: str) -> Expr: """For struct data types, return the field indicated by ``key``.""" return Expr(self.expr.__getitem__(key)) - def __eq__(self, rhs: Expr) -> Expr: - """Equal to.""" + def __eq__(self, rhs: Any) -> Expr: + """Equal to. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) return Expr(self.expr.__eq__(rhs.expr)) - def __ne__(self, rhs: Expr) -> Expr: - """Not equal to.""" + def __ne__(self, rhs: Any) -> Expr: + """Not equal to. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) return Expr(self.expr.__ne__(rhs.expr)) - def __ge__(self, rhs: Expr) -> Expr: - """Greater than or equal to.""" + def __ge__(self, rhs: Any) -> Expr: + """Greater than or equal to. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) return Expr(self.expr.__ge__(rhs.expr)) - def __gt__(self, rhs: Expr) -> Expr: - """Greater than.""" + def __gt__(self, rhs: Any) -> Expr: + """Greater than. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) return Expr(self.expr.__gt__(rhs.expr)) - def __le__(self, rhs: Expr) -> Expr: - """Less than or equal to.""" + def __le__(self, rhs: Any) -> Expr: + """Less than or equal to. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) return Expr(self.expr.__le__(rhs.expr)) - def __lt__(self, rhs: Expr) -> Expr: - """Less than.""" + def __lt__(self, rhs: Any) -> Expr: + """Less than. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) return Expr(self.expr.__lt__(rhs.expr)) @staticmethod diff --git a/python/datafusion/tests/test_expr.py b/python/datafusion/tests/test_expr.py index 73f7d087..c9f0e98d 100644 --- a/python/datafusion/tests/test_expr.py +++ b/python/datafusion/tests/test_expr.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from datafusion import SessionContext +from datafusion import SessionContext, col from datafusion.expr import Column, Literal, BinaryExpr, AggregateFunction from datafusion.expr import ( Projection, @@ -25,6 +25,7 @@ Sort, TableScan, ) +import pyarrow import pytest @@ -116,3 +117,25 @@ def test_sort(test_ctx): plan = plan.to_variant() assert isinstance(plan, Sort) + + +def test_relational_expr(test_ctx): + ctx = SessionContext() + + batch = pyarrow.RecordBatch.from_arrays( + [pyarrow.array([1, 2, 3]), pyarrow.array(["alpha", "beta", "gamma"])], + names=["a", "b"], + ) + df = ctx.create_dataframe([[batch]], name="batch_array") + + assert df.filter(col("a") == 1).count() == 1 + assert df.filter(col("a") != 1).count() == 2 + assert df.filter(col("a") >= 1).count() == 3 + assert df.filter(col("a") > 1).count() == 2 + assert df.filter(col("a") <= 3).count() == 3 + assert df.filter(col("a") < 3).count() == 2 + + assert df.filter(col("b") == "beta").count() == 1 + assert df.filter(col("b") != "beta").count() == 2 + + assert df.filter(col("a") == "beta").count() == 0 From 8f1bb6568bb418c54011a9b8432de23c7f5ea982 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 10:19:47 -0400 Subject: [PATCH 33/55] Expose overlay and add unit test --- python/datafusion/functions.py | 13 +++++++------ python/datafusion/tests/test_functions.py | 2 ++ src/functions.rs | 2 ++ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index d6886f49..6afb47fc 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -385,13 +385,14 @@ def octet_length(arg: Expr) -> Expr: return Expr(f.octet_length(arg.expr)) -# TODO: `overlay` in datafusion needs to be updated from generic `args` definition, and then exposed in this repo. -# def overlay(string: Expr, substring: Expr, start: Expr, length: Expr | None = None) -> Expr: -# """ -# Replace the substring of string that starts at the `start`'th character and extends for `length` characters with new substring -# """ -# return Expr() +def overlay(string: Expr, substring: Expr, start: Expr, length: Expr | None = None) -> Expr: + """Replace a substring with a new substring. + Replace the substring of string that starts at the `start`'th character and extends for `length` characters with new substring. + """ + if length is None: + return Expr(f.overlay(string.expr, substring.expr, start.expr)) + return Expr(f.overlay(string.expr, substring.expr, start.expr, length.expr)) def pi() -> Expr: """Returns an approximate value of π.""" diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 77297114..37cf58c6 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -591,6 +591,7 @@ def test_string_functions(df): f.trim(column("c")), f.upper(column("c")), f.ends_with(column("a"), literal("llo")), + f.overlay(column("a"), literal("--"), literal(2)), ) result = df.collect() assert len(result) == 1 @@ -632,6 +633,7 @@ def test_string_functions(df): assert result.column(26) == pa.array(["hello", "world", "!"]) assert result.column(27) == pa.array(["HELLO ", " WORLD ", " !"]) assert result.column(28) == pa.array([True, False, False]) + assert result.column(29) == pa.array(["H--lo", "W--ld", "--"]) def test_hash_functions(df): diff --git a/src/functions.rs b/src/functions.rs index 42d1d058..d2f3c7ed 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -528,6 +528,7 @@ expr_fn!( ); expr_fn!(nullif, arg_1 arg_2); expr_fn!(octet_length, args, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces."); +expr_fn_vec!(overlay); expr_fn!(pi); expr_fn!(power, base exponent); expr_fn!(pow, power, base exponent); @@ -774,6 +775,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(nullif))?; m.add_wrapped(wrap_pyfunction!(octet_length))?; m.add_wrapped(wrap_pyfunction!(order_by))?; + m.add_wrapped(wrap_pyfunction!(overlay))?; m.add_wrapped(wrap_pyfunction!(pi))?; m.add_wrapped(wrap_pyfunction!(power))?; m.add_wrapped(wrap_pyfunction!(pow))?; From 75e129a129ecd1cab0fc3a517ab435a89227a7df Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 10:20:12 -0400 Subject: [PATCH 34/55] Allow select() to take either str for column names or a full expr --- python/datafusion/dataframe.py | 8 +++++--- python/datafusion/functions.py | 5 ++++- python/datafusion/tests/test_dataframe.py | 10 ++++++++++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 79d00fe5..178bfad1 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -113,9 +113,9 @@ def select_columns(self, *args: str) -> DataFrame: DataFrame DataFrame only containing the specified columns. """ - return DataFrame(self.df.select_columns(*args)) + return self.select(*args) - def select(self, *args: Expr) -> DataFrame: + def select(self, *args: Expr | str) -> DataFrame: """Project arbitrary expressions (like SQL SELECT expressions) into a new `DataFrame`. Returns: @@ -123,7 +123,9 @@ def select(self, *args: Expr) -> DataFrame: DataFrame DataFrame after projection. It has one column for each expression. """ - args = [arg.expr for arg in args] + args = [ + arg.expr if isinstance(arg, Expr) else Expr.column(arg).expr for arg in args + ] return DataFrame(self.df.select(*args)) def filter(self, *predicates: Expr) -> DataFrame: diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 6afb47fc..9fda512e 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -385,7 +385,9 @@ def octet_length(arg: Expr) -> Expr: return Expr(f.octet_length(arg.expr)) -def overlay(string: Expr, substring: Expr, start: Expr, length: Expr | None = None) -> Expr: +def overlay( + string: Expr, substring: Expr, start: Expr, length: Expr | None = None +) -> Expr: """Replace a substring with a new substring. Replace the substring of string that starts at the `start`'th character and extends for `length` characters with new substring. @@ -394,6 +396,7 @@ def overlay(string: Expr, substring: Expr, start: Expr, length: Expr | None = No return Expr(f.overlay(string.expr, substring.expr, start.expr)) return Expr(f.overlay(string.expr, substring.expr, start.expr, length.expr)) + def pi() -> Expr: """Returns an approximate value of π.""" return Expr(f.pi()) diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 82f27ad8..f5db9fdb 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -97,6 +97,16 @@ def test_select(df): assert result.column(1) == pa.array([-3, -3, -3]) +def test_select_mixed_expr_string(df): + df = df.select_columns(column("b"), "a") + + # execute and collect the first (and only) batch + result = df.collect()[0] + + assert result.column(0) == pa.array([4, 5, 6]) + assert result.column(1) == pa.array([1, 2, 3]) + + def test_select_columns(df): df = df.select_columns("b", "a") From f2b15e0c46c9be49378adebfad0781dd7ade7d5f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 10:29:53 -0400 Subject: [PATCH 35/55] Update comments on regexp and add unit tests --- python/datafusion/functions.py | 14 +------------- python/datafusion/tests/test_functions.py | 7 +++++++ 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 9fda512e..561e4aa1 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -436,19 +436,7 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: - """Returns an array with each element containing the leftmost-first match of the corresponding index in `regex` to string in `string`. - - If there is no match, the list element is NULL. - - If a match is found, and the pattern contains no capturing parenthesized subexpressions, - then the list element is a single-element [`GenericStringArray`] containing the substring - matching the whole pattern. - - If a match is found, and the pattern contains capturing parenthesized subexpressions, then the - list element is a [`GenericStringArray`] whose n'th element is the substring matching - the n'th capturing parenthesized subexpression of the pattern. - """ - # TODO VALIDATE THIS IS CORRECT FOR DATAFRAME RESULTS + """Returns an array with each element containing the leftmost-first match of the corresponding index in `regex` to string in `string`.""" if flags is not None: flags = flags.expr return Expr(f.regexp_match(string.expr, regex.expr, flags)) diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 37cf58c6..c3a59698 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -592,7 +592,11 @@ def test_string_functions(df): f.upper(column("c")), f.ends_with(column("a"), literal("llo")), f.overlay(column("a"), literal("--"), literal(2)), + f.regexp_like(column("a"), literal("(ell|orl)")), + f.regexp_match(column("a"), literal("(ell|orl)")), + f.regexp_replace(column("a"), literal("(ell|orl)"), literal("-")), ) + result = df.collect() assert len(result) == 1 result = result[0] @@ -634,6 +638,9 @@ def test_string_functions(df): assert result.column(27) == pa.array(["HELLO ", " WORLD ", " !"]) assert result.column(28) == pa.array([True, False, False]) assert result.column(29) == pa.array(["H--lo", "W--ld", "--"]) + assert result.column(30) == pa.array([True, True, False]) + assert result.column(31) == pa.array([["ell"], ["orl"], None]) + assert result.column(32) == pa.array(["H-o", "W-d", "!"]) def test_hash_functions(df): From b76d1052345ddb19cea651314601eb6655ec3284 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 13 Jul 2024 10:43:08 -0400 Subject: [PATCH 36/55] Remove TODO markings no longer applicable --- python/datafusion/context.py | 2 +- python/datafusion/dataframe.py | 2 +- python/datafusion/functions.py | 11 ++++++----- python/datafusion/substrait.py | 2 -- python/datafusion/tests/test_functions.py | 1 - 5 files changed, 8 insertions(+), 10 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index a3fbc581..7dfbeaf1 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -23,7 +23,7 @@ from ._internal import RuntimeConfig as RuntimeConfigInternal from ._internal import SQLOptions as SQLOptionsInternal from ._internal import SessionContext as SessionContextInternal -from ._internal import LogicalPlan, ExecutionPlan # TODO MAKE THIS A DEFINED CLASS +from ._internal import LogicalPlan, ExecutionPlan from datafusion._internal import AggregateUDF from datafusion.catalog import Catalog, Table diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 178bfad1..99d9b035 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -36,7 +36,7 @@ from datafusion._internal import ( LogicalPlan, ExecutionPlan, -) # TODO make these first class python classes +) class DataFrame: diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 561e4aa1..cd726f4c 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -597,8 +597,12 @@ def now() -> Expr: def to_timestamp(arg: Expr, *formatters: Expr) -> Expr: - """Converts a string and optional formats to a `Timestamp` in nanoseconds.""" - # TODO Add a detailed description of how to use formatters. + """Converts a string and optional formats to a `Timestamp` in nanoseconds. + + For usage of ``formatters`` see the rust chrono package ``strftime`` package. + + [Documentation here.](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) + """ if formatters is None: return f.to_timestamp(arg.expr) @@ -640,7 +644,6 @@ def to_timestamp_seconds(arg: Expr, *formatters: Expr) -> Expr: def to_unixtime(string: Expr, *format_arguments: Expr) -> Expr: """Converts a string and optional formats to a Unixtime.""" - # TODO verify if the format arguments are the same as to_timestamp and update documentation appropriately. args = [f.expr for f in format_arguments] return Expr(f.to_unixtime(string.expr, *args)) @@ -1163,7 +1166,6 @@ def approx_percentile_cont( distinct: bool = False, ) -> Expr: """Returns the value that is approximately at a given percentile of a distribution of values.""" - # TODO validate that these parameters are passed properly if num_centroids is None: return Expr( f.approx_percentile_cont(arg.expr, percentile.expr, distinct=distinct) @@ -1180,7 +1182,6 @@ def approx_percentile_cont_with_weight( arg: Expr, weight: Expr, percentile: Expr, distinct: bool = False ) -> Expr: """Returns the value that is approximately at a given percentile of a distribution of values with associated weights.""" - # TODO validate that these parameters are passed properly return Expr( f.approx_percentile_cont_with_weight( arg.expr, weight.expr, percentile.expr, distinct=distinct diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index a6010ca2..f3e5f59b 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -130,8 +130,6 @@ def deserialize(path: str | pathlib.Path) -> Plan: ------- plan Substrait plan. - - TODO add unit test for passing in as path """ return Plan(substrait_internal.serde.deserialize(str(path))) diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index c3a59698..2384b6ab 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -50,7 +50,6 @@ def df(): return ctx.create_dataframe([[batch]]) -# TODO Update documentation of PR to indicate this is a user facing change to how named_struct is called def test_named_struct(df): df = df.with_column( "d", From 6e87d73f606fe056be9fd03e7111ae836163f747 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 14 Jul 2024 08:14:06 -0400 Subject: [PATCH 37/55] Update udf documentation --- python/datafusion/__init__.py | 144 +------------------------- python/datafusion/udf.py | 183 +++++++++++++++++++++++++++++++--- 2 files changed, 175 insertions(+), 152 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 2a69f58a..6fd1a887 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -21,16 +21,11 @@ See https://datafusion.apache.org/python/index.html for more information. """ -from abc import ABCMeta, abstractmethod -from typing import List - try: import importlib.metadata as importlib_metadata except ImportError: import importlib_metadata -import pyarrow as pa - from .context import ( SessionContext, SessionConfig, @@ -41,7 +36,7 @@ # The following imports are okay to remain as opaque to the user. from ._internal import Config -from .udf import ScalarUDF, AggregateUDF +from .udf import ScalarUDF, AggregateUDF, Accumulator from .common import ( DFSchema, @@ -50,54 +45,14 @@ from .dataframe import DataFrame from .expr import ( - # Alias, - # Analyze, Expr, - # Filter, - # Limit, - # Like, - # ILike, - # Projection, - # SimilarTo, - # ScalarVariable, - # Sort, - # TableScan, - # Not, - # IsNotNull, - # IsTrue, - # IsFalse, - # IsUnknown, - # IsNotTrue, - # IsNotFalse, - # IsNotUnknown, - # Negative, - # InList, - # Exists, - # Subquery, - # InSubquery, - # ScalarSubquery, - # GroupingSet, - # Placeholder, - # Case, - # Cast, - # TryCast, - # Between, - # Explain, - # CreateMemoryTable, - # SubqueryAlias, - # Extension, - # CreateView, - # Distinct, - # DropTable, - # Repartition, - # Partitioning, - # Window, WindowFrame, ) __version__ = importlib_metadata.version(__name__) __all__ = [ + "Accumulator", "Config", "DataFrame", "SessionContext", @@ -106,75 +61,13 @@ "RuntimeConfig", "Expr", "ScalarUDF", - # "Window", "WindowFrame", "column", "literal", - # "TableScan", - # "Projection", "DFSchema", - # "DFField", - # "Analyze", - # "Sort", - # "Limit", - # "Filter", - # "Like", - # "ILike", - # "SimilarTo", - # "ScalarVariable", - # "Alias", - # "Not", - # "IsNotNull", - # "IsTrue", - # "IsFalse", - # "IsUnknown", - # "IsNotTrue", - # "IsNotFalse", - # "IsNotUnknown", - # "Negative", - # "ScalarFunction", - # "BuiltinScalarFunction", - # "InList", - # "Exists", - # "Subquery", - # "InSubquery", - # "ScalarSubquery", - # "GroupingSet", - # "Placeholder", - # "Case", - # "Cast", - # "TryCast", - # "Between", - # "Explain", - # "SubqueryAlias", - # "Extension", - # "CreateMemoryTable", - # "CreateView", - # "Distinct", - # "DropTable", - # "Repartition", - # "Partitioning", ] -class Accumulator(metaclass=ABCMeta): - @abstractmethod - def state(self) -> List[pa.Scalar]: - pass - - @abstractmethod - def update(self, values: pa.Array) -> None: - pass - - @abstractmethod - def merge(self, states: pa.Array) -> None: - pass - - @abstractmethod - def evaluate(self) -> pa.Scalar: - pass - - def column(value: str): """Create a column expression.""" return Expr.column(value) @@ -190,35 +83,6 @@ def literal(value): lit = literal +udf = ScalarUDF.udf -def udf(func, input_types, return_type, volatility, name=None): - """Create a new User Defined Function.""" - if not callable(func): - raise TypeError("`func` argument must be callable") - if name is None: - name = func.__qualname__.lower() - return ScalarUDF( - name=name, - func=func, - input_types=input_types, - return_type=return_type, - volatility=volatility, - ) - - -def udaf(accum, input_types, return_type, state_type, volatility, name=None): - """Create a new User Defined Aggregate Function.""" - if not issubclass(accum, Accumulator): - raise TypeError("`accum` must implement the abstract base class Accumulator") - if name is None: - name = accum.__qualname__.lower() - if isinstance(input_types, pa.lib.DataType): - input_types = [input_types] - return AggregateUDF( - name=name, - accumulator=accum, - input_types=input_types, - return_type=return_type, - state_type=state_type, - volatility=volatility, - ) +udaf = AggregateUDF.udaf diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 109282cf..688a6433 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -22,17 +22,64 @@ import datafusion._internal as df_internal from datafusion.expr import Expr from typing import Callable, TYPE_CHECKING, TypeVar +from abc import ABCMeta, abstractmethod +from typing import List +from enum import Enum +import pyarrow if TYPE_CHECKING: - import pyarrow - _R = TypeVar("_R", bound=pyarrow.DataType) +class Volatility(Enum): + """Defines how stable or volatile a function is. + + When setting the volatility of a function, you can either pass this + enumeration or a `str`. The `str` equivalent is the lower case value of the + name (`"immutable"`, `"stable"`, or `"volatile"`). + """ + + Immutable = 1 + """An immutable function will always return the same output when given the + same input. + + DataFusion will attempt to inline immutable functions during planning. + """ + + Stable = 2 + """ + Returns the same value for a given input within a single queries. + + A stable function may return different values given the same input across + different queries but must return the same value for a given input within a + query. An example of this is the `Now` function. DataFusion will attempt to + inline `Stable` functions during planning, when possible. For query + `select col1, now() from t1`, it might take a while to execute but `now()` + column will be the same for each output row, which is evaluated during + planning. + """ + + Volatile = 3 + """A volatile function may change the return value from evaluation to + evaluation. + + Multiple invocations of a volatile function may return different results + when used in the same query. An example of this is the random() function. + DataFusion can not evaluate such functions during planning. In the query + `select col1, random() from t1`, `random()` function will be evaluated + for each output row, resulting in a unique random value for each row. + """ + + def __str__(self): + """Returns the string equivalent.""" + return self.name.lower() + + class ScalarUDF: """Class for performing scalar user defined functions (UDF). - Scalar UDFs operate on a row by row basis. See also ``AggregateUDF`` for operating on a group of rows. + Scalar UDFs operate on a row by row basis. See also ``AggregateUDF`` for + operating on a group of rows. """ def __init__( @@ -41,46 +88,158 @@ def __init__( func: Callable[..., _R], input_types: list[pyarrow.DataType], return_type: _R, - volatility: str, + volatility: Volatility | str, ) -> None: """Instantiate a scalar user defined function (UDF).""" self.udf = df_internal.ScalarUDF( - name, func, input_types, return_type, volatility + name, func, input_types, return_type, str(volatility) ) def __call__(self, *args: Expr) -> Expr: """Execute the UDF. - This function is not typically called by an end user. These calls will occur during the evaluation of the dataframe. + This function is not typically called by an end user. These calls will + occur during the evaluation of the dataframe. """ args = [arg.expr for arg in args] return Expr(self.udf.__call__(*args)) + @staticmethod + def udf( + func: Callable[..., _R], + input_types: list[pyarrow.DataType], + return_type: _R, + volatility: Volatility | str, + name: str | None = None, + ) -> ScalarUDF: + """Create a new User Defined Function. + + Args: + func: A callable python function. + input_types: The data types of the arguments to `func`. This list + must be of the same length as the number of arguments. + return_type: The data type of the return value from the python + function. + volatility: See ~`Volatility` for allowed values. + name: A descriptive name for the function. + + Returns: + A user defined aggregate function, which can be used in either data + aggregation or window function calls. + """ + if not callable(func): + raise TypeError("`func` argument must be callable") + if name is None: + name = func.__qualname__.lower() + return ScalarUDF( + name=name, + func=func, + input_types=input_types, + return_type=return_type, + volatility=volatility, + ) + + +class Accumulator(metaclass=ABCMeta): + """Defines how an `AggregateUDF` accumulates values during an evaluation.""" + + @abstractmethod + def state(self) -> List[pyarrow.Scalar]: + """Return the current state.""" + pass + + @abstractmethod + def update(self, values: pyarrow.Array) -> None: + """Evalute an array of values and update state.""" + pass + + @abstractmethod + def merge(self, states: pyarrow.Array) -> None: + """Merge a set of states.""" + pass + + @abstractmethod + def evaluate(self) -> pyarrow.Scalar: + """Return the resultant value.""" + pass + + +if TYPE_CHECKING: + _A = TypeVar("_A", bound=(Callable[..., _R], Accumulator)) + class AggregateUDF: """Class for performing scalar user defined functions (UDF). - Aggregate UDFs operate on a group of rows and return a single value. See also ``ScalarUDF`` for operating on a row by row basis. + Aggregate UDFs operate on a group of rows and return a single value. See + also ``ScalarUDF`` for operating on a row by row basis. """ def __init__( self, name: str | None, - accumulator: Callable[..., _R], + accumulator: _A, input_types: list[pyarrow.DataType], return_type: _R, state_type: list[pyarrow.DataType], - volatility: str, + volatility: Volatility | str, ) -> None: - """Instantiate a user defined aggregate function (UDAF).""" + """Instantiate a user defined aggregate function (UDAF). + + See ~`Aggregate::udaf` for a convenience function and arugment + descriptions. + """ self.udf = df_internal.AggregateUDF( - name, accumulator, input_types, return_type, state_type, volatility + name, accumulator, input_types, return_type, state_type, str(volatility) ) def __call__(self, *args: Expr) -> Expr: """Execute the UDAF. - This function is not typically called by an end user. These calls will occur during the evaluation of the dataframe. + This function is not typically called by an end user. These calls will + occur during the evaluation of the dataframe. """ args = [arg.expr for arg in args] return Expr(self.udf.__call__(*args)) + + @staticmethod + def udaf( + accum: _A, + input_types: list[pyarrow.DataType], + return_type: _R, + state_type: list[pyarrow.DataType], + volatility: Volatility | str, + name: str | None = None, + ) -> AggregateUDF: + """Create a new User Defined Aggregate Function. + + The accumulator function must be callable and implement `Accumulator`. + + Args: + accum: The accumulator python function. + input_types: The data types of the arguments to `accum`. + return_type: The data type of the return value. + state_type: The data types of the intermediate accumulation. + volatility: See `Volatility` for allowed values. + name: A descriptive name for the function. + + Returns: + A user defined aggregate function, which can be used in either data + aggregation or window function calls. + """ + if not issubclass(accum, Accumulator): + raise TypeError( + "`accum` must implement the abstract base class Accumulator" + ) + if name is None: + name = accum.__qualname__.lower() + if isinstance(input_types, pyarrow.lib.DataType): + input_types = [input_types] + return AggregateUDF( + name=name, + accumulator=accum, + input_types=input_types, + return_type=return_type, + state_type=state_type, + volatility=volatility, + ) From 39f18cb125810286183ec7e913e7359bba6e9b03 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 14 Jul 2024 09:25:25 -0400 Subject: [PATCH 38/55] Docstring formatting --- python/datafusion/context.py | 761 +++++++++++++---------------------- 1 file changed, 289 insertions(+), 472 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 7dfbeaf1..da304d47 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -48,26 +48,21 @@ class SessionConfig: def __init__(self, config_options: dict[str, str] | None = None) -> None: """Create a new `SessionConfig` with the given configuration options. - Parameters - ---------- - config_options : dict[str, str] - Configuration options. + Args: + config_options: Configuration options. """ self.config_internal = SessionConfigInternal(config_options) def with_create_default_catalog_and_schema( self, enabled: bool = True ) -> SessionConfig: - """Control whether the default catalog and schema will be automatically created. + """Control if the default catalog and schema will be automatically created. - Parameters - ---------- - enabled : bool - Whether the default catalog and schema will be automatically created. + Args: + enabled: Whether the default catalog and schema will be + automatically created. Returns: - ------- - SessionConfig A new `SessionConfig` object with the updated setting. """ self.config_internal = ( @@ -80,16 +75,11 @@ def with_default_catalog_and_schema( ) -> SessionConfig: """Select a name for the default catalog and shcema. - Parameters - ---------- - catalog : str - Catalog name. - schema : str - Schema name. + Args: + catalog: Catalog name. + schema: Schema name. Returns: - ------- - SessionConfig A new `SessionConfig` object with the updated setting. """ self.config_internal = self.config_internal.with_default_catalog_and_schema( @@ -100,14 +90,10 @@ def with_default_catalog_and_schema( def with_information_schema(self, enabled: bool = True) -> SessionConfig: """Enable or disable the inclusion of `information_schema` virtual tables. - Parameters - ---------- - enabled : bool - Whether to include `information_schema` virtual tables. + Args: + enabled: Whether to include `information_schema` virtual tables. Returns: - ------- - SessionConfig A new `SessionConfig` object with the updated setting. """ self.config_internal = self.config_internal.with_information_schema(enabled) @@ -116,14 +102,10 @@ def with_information_schema(self, enabled: bool = True) -> SessionConfig: def with_batch_size(self, batch_size: int) -> SessionConfig: """Customize batch size. - Parameters - ---------- - batch_size : int - Batch size. + Args: + batch_size: Batch size. Returns: - ------- - SessionConfig A new `SessionConfig` object with the updated setting. """ self.config_internal = self.config_internal.with_batch_size(batch_size) @@ -134,14 +116,10 @@ def with_target_partitions(self, target_partitions: int) -> SessionConfig: Increasing partitions can increase concurrency. - Parameters - ---------- - target_partitions : int - Number of target partitions. + Args: + target_partitions: Number of target partitions. Returns: - ------- - SessionConfig A new `SessionConfig` object with the updated setting. """ self.config_internal = self.config_internal.with_target_partitions( @@ -154,14 +132,10 @@ def with_repartition_aggregations(self, enabled: bool = True) -> SessionConfig: Enabling this improves parallelism. - Parameters - ---------- - enabled : bool - Whether to use repartitioning for aggregations. + Args: + enabled: Whether to use repartitioning for aggregations. Returns: - ------- - SessionConfig A new `SessionConfig` object with the updated setting. """ self.config_internal = self.config_internal.with_repartition_aggregations( @@ -172,46 +146,38 @@ def with_repartition_aggregations(self, enabled: bool = True) -> SessionConfig: def with_repartition_joins(self, enabled: bool = True) -> SessionConfig: """Enable or disable the use of repartitioning for joins to improve parallelism. - Parameters - ---------- - enabled : bool - Whether to use repartitioning for joins. + Args: + enabled: Whether to use repartitioning for joins. Returns: - ------- - SessionConfig A new `SessionConfig` object with the updated setting. """ self.config_internal = self.config_internal.with_repartition_joins(enabled) return self def with_repartition_windows(self, enabled: bool = True) -> SessionConfig: - """Enable or disable the use of repartitioning for window functions to improve parallelism. + """Enable or disable the use of repartitioning for window functions. - Parameters - ---------- - enabled : bool - Whether to use repartitioning for window functions. + This may improve parallelism. + + Args: + enabled: Whether to use repartitioning for window functions. Returns: - ------- - SessionConfig A new `SessionConfig` object with the updated setting. """ self.config_internal = self.config_internal.with_repartition_windows(enabled) return self def with_repartition_sorts(self, enabled: bool = True) -> SessionConfig: - """Enable or disable the use of repartitioning for window functions to improve parallelism. + """Enable or disable the use of repartitioning for window functions. + + This may improve parallelism. - Parameters - ---------- - enabled : bool - Whether to use repartitioning for window functions. + Args: + enabled: Whether to use repartitioning for window functions. Returns: - ------- - SessionConfig A new `SessionConfig` object with the updated setting. """ self.config_internal = self.config_internal.with_repartition_sorts(enabled) @@ -220,14 +186,10 @@ def with_repartition_sorts(self, enabled: bool = True) -> SessionConfig: def with_repartition_file_scans(self, enabled: bool = True) -> SessionConfig: """Enable or disable the use of repartitioning for file scans. - Parameters - ---------- - enabled : bool - Whether to use repartitioning for file scans. + Args: + enabled: Whether to use repartitioning for file scans. Returns: - ------- - SessionConfig A new `SessionConfig` object with the updated setting. """ self.config_internal = self.config_internal.with_repartition_file_scans(enabled) @@ -236,30 +198,24 @@ def with_repartition_file_scans(self, enabled: bool = True) -> SessionConfig: def with_repartition_file_min_size(self, size: int) -> SessionConfig: """Set minimum file range size for repartitioning scans. - Parameters - ---------- - size : int - Minimum file range size. + Args: + size: Minimum file range size. Returns: - ------- - SessionConfig A new `SessionConfig` object with the updated setting. """ self.config_internal = self.config_internal.with_repartition_file_min_size(size) return self def with_parquet_pruning(self, enabled: bool = True) -> SessionConfig: - """Enable or disable the use of pruning predicate for parquet readers to skip row groups. + """Enable or disable the use of pruning predicate for parquet readers. + + Pruning predicates will enable the reader to skip row groups. - Parameters - ---------- - enabled : bool - Whether to use pruning predicate for parquet readers. + Args: + enabled: Whether to use pruning predicate for parquet readers. Returns: - ------- - SessionConfig A new `SessionConfig` object with the updated setting. """ self.config_internal = self.config_internal.with_parquet_pruning(enabled) @@ -268,16 +224,11 @@ def with_parquet_pruning(self, enabled: bool = True) -> SessionConfig: def set(self, key: str, value: str) -> SessionConfig: """Set a configuration option. - Parameters - ---------- - key : str - Option key. - value : str - Option value. + Args: + key: Option key. + value: Option value. Returns: - ------- - SessionConfig A new `SessionConfig` object with the updated setting. """ self.config_internal = self.config_internal.set(key, value) @@ -295,13 +246,7 @@ def with_disk_manager_disabled(self) -> RuntimeConfig: """Disable the disk manager, attempts to create temporary files will error. Returns: - ------- - RuntimeConfig A new `RuntimeConfig` object with the updated setting. - - Examples: - -------- - >>> config = RuntimeConfig().with_disk_manager_disabled() """ self.config_internal = self.config_internal.with_disk_manager_disabled() return self @@ -310,13 +255,7 @@ def with_disk_manager_os(self) -> RuntimeConfig: """Use the operating system's temporary directory for disk manager. Returns: - ------- - RuntimeConfig A new `RuntimeConfig` object with the updated setting. - - Examples: - -------- - >>> config = RuntimeConfig().with_disk_manager_os() """ self.config_internal = self.config_internal.with_disk_manager_os() return self @@ -324,19 +263,11 @@ def with_disk_manager_os(self) -> RuntimeConfig: def with_disk_manager_specified(self, *paths: str | pathlib.Path) -> RuntimeConfig: """Use the specified paths for the disk manager's temporary files. - Parameters - ---------- - paths : list[str] - Paths to use for the disk manager's temporary files. + Args: + paths: Paths to use for the disk manager's temporary files. Returns: - ------- - RuntimeConfig A new `RuntimeConfig` object with the updated setting. - - Examples: - -------- - >>> config = RuntimeConfig().with_disk_manager_specified(["/tmp"]) """ paths = [str(p) for p in paths] self.config_internal = self.config_internal.with_disk_manager_specified(paths) @@ -346,13 +277,7 @@ def with_unbounded_memory_pool(self) -> RuntimeConfig: """Use an unbounded memory pool. Returns: - ------- - RuntimeConfig A new `RuntimeConfig` object with the updated setting. - - Examples: - -------- - >>> config = RuntimeConfig().with_unbounded_memory_pool() """ self.config_internal = self.config_internal.with_unbounded_memory_pool() return self @@ -376,20 +301,15 @@ def with_fair_spill_pool(self, size: int) -> RuntimeConfig: └───────────────────────z──────────────────────z───────────────┘ ``` - Parameters - ---------- - size : int - Size of the memory pool in bytes. + Args: + size: Size of the memory pool in bytes. Returns: - ------- - RuntimeConfig A new `RuntimeConfig` object with the updated setting. - Examples: - -------- + Examples usage: ```python - >>> config = RuntimeConfig().with_fair_spill_pool(1024) + config = RuntimeConfig().with_fair_spill_pool(1024) ``` """ self.config_internal = self.config_internal.with_fair_spill_pool(size) @@ -402,19 +322,16 @@ def with_greedy_memory_pool(self, size: int) -> RuntimeConfig: spillable operator. See `RuntimeConfig.with_fair_spill_pool` if there are multiple spillable operators that all will spill. - Parameters - ---------- - size : int - Size of the memory pool in bytes. + Args: + size: Size of the memory pool in bytes. Returns: - ------- - RuntimeConfig A new `RuntimeConfig` object with the updated setting. - Examples: - -------- - >>> config = RuntimeConfig().with_greedy_memory_pool(1024) + Example usage: + ```python + config = RuntimeConfig().with_greedy_memory_pool(1024) + ``` """ self.config_internal = self.config_internal.with_greedy_memory_pool(size) return self @@ -422,19 +339,16 @@ def with_greedy_memory_pool(self, size: int) -> RuntimeConfig: def with_temp_file_path(self, path: str | pathlib.Path) -> RuntimeConfig: """Use the specified path to create any needed temporary files. - Parameters - ---------- - path : str - Path to use for temporary files. + Args: + path: Path to use for temporary files. Returns: - ------- - RuntimeConfig A new `RuntimeConfig` object with the updated setting. - Examples: - -------- - >>> config = RuntimeConfig().with_temp_file_path("/tmp") + Example usage: + ```python + config = RuntimeConfig().with_temp_file_path("/tmp") + ``` """ self.config_internal = self.config_internal.with_temp_file_path(str(path)) return self @@ -458,20 +372,16 @@ def with_allow_ddl(self, allow: bool = True) -> SQLOptions: Examples of DDL commands include `CREATE TABLE` and `DROP TABLE`. - Parameters - ---------- - allow : bool - Allow DDL commands to be run. + Args: + allow: Allow DDL commands to be run. Returns: - ------- - SQLOptions A new `SQLOptions` object with the updated setting. - - Examples: - -------- - >>> options = SQLOptions().with_allow_ddl(True) + Example usage: + ```python + options = SQLOptions().with_allow_ddl(True) + ``` """ self.options_internal = self.options_internal.with_allow_ddl(allow) return self @@ -481,20 +391,16 @@ def with_allow_dml(self, allow: bool = True) -> SQLOptions: Examples of DML commands include `INSERT INTO` and `DELETE`. - Parameters - ---------- - allow : bool - Allow DML commands to be run. + Args: + allow: Allow DML commands to be run. Returns: - ------- - SQLOptions A new `SQLOptions` object with the updated setting. - - Examples: - -------- - >>> options = SQLOptions().with_allow_dml(True) + Example usage: + ```python + options = SQLOptions().with_allow_dml(True) + ``` """ self.options_internal = self.options_internal.with_allow_dml(allow) return self @@ -502,19 +408,16 @@ def with_allow_dml(self, allow: bool = True) -> SQLOptions: def with_allow_statements(self, allow: bool = True) -> SQLOptions: """Should statements such as `SET VARIABLE` and `BEGIN TRANSACTION` be run? - Parameters - ---------- - allow : bool - Allow statements to be run. + Args: + allow: Allow statements to be run. Returns: - ------- - SQLOptions A new `SQLOptions` object with the updated setting. - Examples: - -------- - >>> options = SQLOptions().with_allow_statements(True) + Example usage: + ```python + options = SQLOptions().with_allow_statements(True) + ``` """ self.options_internal = self.options_internal.with_allow_statements(allow) return self @@ -523,7 +426,8 @@ def with_allow_statements(self, allow: bool = True) -> SQLOptions: class SessionContext: """This is the main interface for executing queries and creating DataFrames. - See https://datafusion.apache.org/python/user-guide/basics.html for additional information. + See https://datafusion.apache.org/python/user-guide/basics.html for + additional information. """ def __init__( @@ -535,15 +439,12 @@ def __init__( of the connection between a user and an instance of the DataFusion engine. - Parameters - ---------- - config : SessionConfig | None - Session configuration options. - runtime : RuntimeConfig | None - Runtime configuration options. + Args: + config: Session configuration options. + runtime: Runtime configuration options. + + Example usage: - Examples: - -------- The following example demostrates how to use the context to execute a query against a CSV data source using the `DataFrame` API: @@ -560,19 +461,39 @@ def __init__( self.ctx = SessionContextInternal(config, runtime) def register_object_store(self, schema: str, store: Any, host: str | None) -> None: - """Add a new object store into the session.""" + """Add a new object store into the session. + + Args: + schema: The data source schema. + store: The `ObjectStore` to register. + host: URL for the host. + """ self.ctx.register_object_store(schema, store, host) def register_listing_table( self, name: str, path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] = [], + table_partition_cols: list[tuple[str, str]] | None = None, file_extension: str = ".parquet", schema: pyarrow.Schema | None = None, file_sort_order: list[list[Expr]] | None = None, ) -> None: - """Registers a Table that can assemble multiple files from locations in an `ObjectStore` instance into a single table.""" + """Register multiple files as a single table. + + Registers a `Table` that can assemble multiple files from locations in + an `ObjectStore` instance. + + Args: + name: Name of the resultant table. + path: Path to the file to register. + table_partition_cols: Partition columns. + file_extension: File extension of the provided table. + schema: The data source schema. + file_sort_order: Sort order for the file. + """ + if table_partition_cols is None: + table_partition_cols = [] if file_sort_order is not None: file_sort_order = [[x.expr for x in xs] for xs in file_sort_order] self.ctx.register_listing_table( @@ -584,41 +505,38 @@ def register_listing_table( file_sort_order, ) - def sql(self, query: str) -> DataFrame: + def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: """Create a `DataFrame` from SQL query text. Note: This API implements DDL statements such as `CREATE TABLE` and `CREATE VIEW` and DML statements such as `INSERT INTO` with in-memory default implementation. See `SessionContext.sql_with_options`. - Parameters - ---------- - query : str - SQL query text. + Args: + query: SQL query text. + options: If provided, the query will be validated against these options. Returns: - ------- - DataFrame DataFrame representation of the SQL query. """ - return DataFrame(self.ctx.sql(query)) + if options is None: + return DataFrame(self.ctx.sql(query)) + return DataFrame(self.ctx.sql_with_options(query, options.options_internal)) def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: - """Create a `DataFrame` from SQL query text, first validating that the query is allowed by the provided options. + """Create a `DataFrame` from SQL query text. - Parameters - ---------- - query : str - SQL query text. - options : SQLOptions - SQL options. + This function will first validating that the query is allowed by the + provided options. + + Args: + query: SQL query text. + options: SQL options. Returns: - ------- - DataFrame DataFrame representation of the SQL query. """ - return DataFrame(self.ctx.sql_with_options(query, options.options_internal)) + return self.sql(query, options) def create_dataframe( self, @@ -626,20 +544,25 @@ def create_dataframe( name: str | None = None, schema: pyarrow.Schema | None = None, ) -> DataFrame: - """Create and return a dataframe using the provided partitions.""" + """Create and return a dataframe using the provided partitions. + + Args: + partitions: `RecordBatch` partitions to register. + name: Resultant dataframe name. + schema: Schema for the partitions. + + Returns: + DataFrame representation of the SQL query. + """ return DataFrame(self.ctx.create_dataframe(partitions, name, schema)) def create_dataframe_from_logical_plan(self, plan: LogicalPlan) -> DataFrame: """Create a `DataFrame` from an existing logical plan. - Parameters - ---------- - plan : LogicalPlan - Logical plan. + Args: + plan: Logical plan. Returns: - ------- - DataFrame DataFrame representation of the logical plan. """ return DataFrame(self.ctx.create_dataframe_from_logical_plan(plan)) @@ -649,16 +572,11 @@ def from_pylist( ) -> DataFrame: """Create a `DataFrame` from a list of dictionaries. - Parameters - ---------- - data : list[dict[str, Any]] - List of dictionaries. - name : str | None - Name of the DataFrame. + Args: + data: List of dictionaries. + name: Name of the DataFrame. Returns: - ------- - DataFrame DataFrame representation of the list of dictionaries. """ return DataFrame(self.ctx.from_pylist(data, name)) @@ -668,16 +586,11 @@ def from_pydict( ) -> DataFrame: """Create a `DataFrame` from a dictionary of lists. - Parameters - ---------- - data : dict[str, list[Any]] - Dictionary of lists. - name : str | None - Name of the DataFrame. + Args: + data: Dictionary of lists. + name: Name of the DataFrame. Returns: - ------- - DataFrame DataFrame representation of the dictionary of lists. """ return DataFrame(self.ctx.from_pydict(data, name)) @@ -687,16 +600,11 @@ def from_arrow_table( ) -> DataFrame: """Create a `DataFrame` from an Arrow table. - Parameters - ---------- - data : pyarrow.Table - Arrow table. - name : str | None - Name of the DataFrame. + Args: + data: Arrow table. + name: Name of the DataFrame. Returns: - ------- - DataFrame DataFrame representation of the Arrow table. """ return DataFrame(self.ctx.from_arrow_table(data, name)) @@ -704,16 +612,11 @@ def from_arrow_table( def from_pandas(self, data: pandas.DataFrame, name: str | None = None) -> DataFrame: """Create a `DataFrame` from a Pandas DataFrame. - Parameters - ---------- - data : pandas.DataFrame - Pandas DataFrame. - name : str | None - Name of the DataFrame. + Args: + data: Pandas DataFrame. + name: Name of the DataFrame. Returns: - ------- - DataFrame DataFrame representation of the Pandas DataFrame. """ return DataFrame(self.ctx.from_pandas(data, name)) @@ -721,22 +624,22 @@ def from_pandas(self, data: pandas.DataFrame, name: str | None = None) -> DataFr def from_polars(self, data: polars.DataFrame, name: str | None = None) -> DataFrame: """Create a `DataFrame` from a Polars DataFrame. - Parameters - ---------- - data : polars.DataFrame - Polars DataFrame. - name : str | None - Name of the DataFrame. + Args: + data: Polars DataFrame. + name: Name of the DataFrame. Returns: - ------- - DataFrame DataFrame representation of the Polars DataFrame. """ return DataFrame(self.ctx.from_polars(data, name)) def register_table(self, name: str, table: pyarrow.Table) -> None: - """Register a table with the given name into the session.""" + """Register a table with the given name into the session. + + Args: + name: Name of the resultant table. + table: PyArrow table to add to the session context. + """ self.ctx.register_table(name, table) def deregister_table(self, name: str) -> None: @@ -746,14 +649,22 @@ def deregister_table(self, name: str) -> None: def register_record_batches( self, name: str, partitions: list[list[pyarrow.RecordBatch]] ) -> None: - """Convert the provided partitions into a table and register it into the session using the given name.""" + """Register record batches as a table. + + This function will convert the provided partitions into a table and + register it into the session using the given name. + + Args: + name: Name of the resultant table. + partitions: Record batches to register as a table. + """ self.ctx.register_record_batches(name, partitions) def register_parquet( self, name: str, path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] = [], + table_partition_cols: list[tuple[str, str]] | None = None, parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, @@ -762,29 +673,25 @@ def register_parquet( ) -> None: """Register a Parquet file as a table. - The registered table can be referenced from SQL statement executed against - this context. - - Parameters - ---------- - name : str - Name of the table to register. - path : str - Path to the Parquet file. - table_partition_cols : list[tuple[str, str]], optional - Partition columns, by default [] - parquet_pruning : bool, optional - Whether the parquet reader should use the predicate to prune row groups, by default True - file_extension : str, optional - File extension; only files with this extension are selected for data input, by default ".parquet" - skip_metadata : bool, optional - Whether the parquet reader should skip any metadata that may be in the file - schema. This can help avoid schema conflicts due to metadata. by default True - schema : pyarrow.Schema | None, optional - The data source schema, by default None - file_sort_order : list[list[Expr]] | None, optional - Sort order for the file, by default None - """ + The registered table can be referenced from SQL statement executed + against this context. + + Args: + name: Name of the table to register. + path: Path to the Parquet file. + table_partition_cols: Partition columns. + parquet_pruning: Whether the parquet reader should use the + predicate to prune row groups. + file_extension: File extension; only files with this extension are + selected for data input. + skip_metadata: Whether the parquet reader should skip any metadata + that may be in the file schema. This can help avoid schema + conflicts due to metadata. + schema: The data source schema. + file_sort_order: Sort order for the file. + """ + if table_partition_cols is None: + table_partition_cols = [] self.ctx.register_parquet( name, str(path), @@ -811,24 +718,20 @@ def register_csv( The registered table can be referenced from SQL statement executed against. - Parameters - ---------- - name : str - Name of the table to register. - path : str - Path to the CSV file. - schema : pyarrow.Schema | None, optional - An optional schema representing the CSV file. If None, the CSV reader will try to infer it based on data in file, by default None - has_header : bool, optional - Whether the CSV file have a header. If schema inference is run on a file with no headers, default column names are created, by default True - delimiter : str, optional - An optional column delimiter, by default "," - schema_infer_max_records : int, optional - Maximum number of rows to read from CSV files for schema inference if needed, by default 1000 - file_extension : str, optional - File extension; only files with this extension are selected for data input, by default ".csv" - file_compression_type : str | None, optional - File compression type, by default None + Args: + name: Name of the table to register. + path: Path to the CSV file. + schema: An optional schema representing the CSV file. If None, the + CSV reader will try to infer it based on data in file. + has_header: Whether the CSV file have a header. If schema inference + is run on a file with no headers, default column names are + created. + delimiter: An optional column delimiter. + schema_infer_max_records: Maximum number of rows to read from CSV + files for schema inference if needed. + file_extension: File extension; only files with this extension are + selected for data input. + file_compression_type: File compression type. """ self.ctx.register_csv( name, @@ -848,31 +751,27 @@ def register_json( schema: pyarrow.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", - table_partition_cols: list[tuple[str, str]] = [], + table_partition_cols: list[tuple[str, str]] | None = None, file_compression_type: str | None = None, ) -> None: """Register a JSON file as a table. - The registered table can be referenced from SQL statement executed against - this context. - - Parameters - ---------- - name : str - Name of the table to register. - path : str - Path to the JSON file. - schema : pyarrow.Schema | None, optional - The data source schema, by default None - schema_infer_max_records : int, optional - Maximum number of rows to read from JSON files for schema inference if needed, by default 1000 - file_extension : str, optional - File extension; only files with this extension are selected for data input, by default ".json" - table_partition_cols : list[tuple[str, str]], optional - Partition columns, by default [] - file_compression_type : str | None, optional - File compression type, by default None - """ + The registered table can be referenced from SQL statement executed + against this context. + + Args: + name: Name of the table to register. + path: Path to the JSON file. + schema: The data source schema. + schema_infer_max_records: Maximum number of rows to read from JSON + files for schema inference if needed. + file_extension: File extension; only files with this extension are + selected for data input. + table_partition_cols: Partition columns. + file_compression_type: File compression type. + """ + if table_partition_cols is None: + table_partition_cols = [] self.ctx.register_json( name, str(path), @@ -889,26 +788,22 @@ def register_avro( path: str | pathlib.Path, schema: pyarrow.Schema | None = None, file_extension: str = ".avro", - table_partition_cols: list[tuple[str, str]] = [], + table_partition_cols: list[tuple[str, str]] | None = None, ) -> None: """Register an Avro file as a table. The registered table can be referenced from SQL statement executed against this context. - Parameters - ---------- - name : str - Name of the table to register. - path : str - Path to the Avro file. - schema : pyarrow.Schema | None, optional - The data source schema, by default None - file_extension : str, optional - File extension to select, by default ".avro" - table_partition_cols : list[tuple[str, str]], optional - Partition columns, by default [] + Args: + name: Name of the table to register. + path: Path to the Avro file. + schema: The data source schema. + file_extension: File extension to select. + table_partition_cols: Partition columns. """ + if table_partition_cols is None: + table_partition_cols = [] self.ctx.register_avro( name, str(path), schema, file_extension, table_partition_cols ) @@ -916,48 +811,22 @@ def register_avro( def register_dataset(self, name: str, dataset: pyarrow.dataset.Dataset) -> None: """Register a `pyarrow.dataset.Dataset` as a table. - Parameters - ---------- - name : str - Name of the table to register. - dataset : dataset.Dataset - PyArrow dataset. + Args: + name: Name of the table to register. + dataset: PyArrow dataset. """ self.ctx.register_dataset(name, dataset) def register_udf(self, udf: ScalarUDF) -> None: - """Register a user-defined function (UDF) with the context. - - Parameters - ---------- - udf : ScalarUDF - User-defined function. - """ + """Register a user-defined function (UDF) with the context.""" self.ctx.register_udf(udf.udf) def register_udaf(self, udaf: AggregateUDF) -> None: - """Register a user-defined aggregation function (UDAF) with the context. - - Parameters - ---------- - udaf : AggregateUDF - User-defined aggregation function. - """ + """Register a user-defined aggregation function (UDAF) with the context.""" self.ctx.register_udaf(udaf) def catalog(self, name: str = "datafusion") -> Catalog: - """Retrieve a catalog by name. - - Parameters - ---------- - name : str, optional - Name of the catalog to retrieve, by default "datafusion". - - Returns: - ------- - Catalog - Catalog representation. - """ + """Retrieve a catalog by name.""" return self.ctx.catalog(name) @deprecated( @@ -969,53 +838,19 @@ def tables(self) -> set[str]: return self.ctx.tables() def table(self, name: str) -> DataFrame: - """Retrieve a `DataFrame` representing a previously registered table. - - Parameters - ---------- - name : str - Name of the table to retrieve. - - Returns: - ------- - DataFrame - DataFrame representation of the table. - """ + """Retrieve a `DataFrame` representing a previously registered table.""" return DataFrame(self.ctx.table(name)) def table_exist(self, name: str) -> bool: - """Return whether a table with the given name exists. - - Parameters - ---------- - name : str - Name of the table to check. - - Returns: - ------- - bool - Whether a table with the given name exists. - """ + """Return whether a table with the given name exists.""" return self.ctx.table_exist(name) def empty_table(self) -> DataFrame: - """Create an empty `DataFrame`. - - Returns: - ------- - DataFrame - An empty DataFrame. - """ + """Create an empty `DataFrame`.""" return DataFrame(self.ctx.empty_table()) def session_id(self) -> str: - """Retrun an id that uniquely identifies this `SessionContext`. - - Returns: - ------- - str - Unique session identifier - """ + """Retrun an id that uniquely identifies this `SessionContext`.""" return self.ctx.session_id() def read_json( @@ -1024,31 +859,26 @@ def read_json( schema: pyarrow.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", - table_partition_cols: list[tuple[str, str]] = [], + table_partition_cols: list[tuple[str, str]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Create a `DataFrame` for reading a line-delimited JSON data source. - Parameters - ---------- - path : str - Path to the JSON file - schema : pyarrow.Schema | None, optional - The data source schema, by default None - schema_infer_max_records : int, optional - Maximum number of rows to read from JSON files for schema inference if needed, by default 1000 - file_extension : str, optional - File extension; only files with this extension are selected for data input, by default ".json" - table_partition_cols : list[tuple[str, str]], optional - Partition columns, by default [] - file_compression_type : str | None, optional - File compression type, by default None + Args: + path: Path to the JSON file. + schema: The data source schema. + schema_infer_max_records: Maximum number of rows to read from JSON + files for schema inference if needed. + file_extension: File extension; only files with this extension are + selected for data input. + table_partition_cols: Partition columns. + file_compression_type: File compression type. Returns: - ------- - DataFrame - DataFrame representation of the read JSON files + DataFrame representation of the read JSON files. """ + if table_partition_cols is None: + table_partition_cols = [] return DataFrame( self.ctx.read_json( str(path), @@ -1068,35 +898,31 @@ def read_csv( delimiter: str = ",", schema_infer_max_records: int = 1000, file_extension: str = ".csv", - table_partition_cols: list[tuple[str, str]] = [], + table_partition_cols: list[tuple[str, str]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Create a `DataFrame` for reading a CSV data source. - Parameters - ---------- - path : str - Path to the CSV file - schema : pyarrow.Schema | None, optional - An optional schema representing the CSV files. If None, the CSV reader will try to infer it based on data in file, by default None - has_header : bool, optional - Whether the CSV file have a header. If schema inference is run on a file with no headers, default column names are created, by default True - delimiter : str, optional - An optional column delimiter, by default "," - schema_infer_max_records : int, optional - Maximum number of rows to read from CSV files for schema inference if needed, by default 1000 - file_extension : str, optional - File extension; only files with this extension are selected for data input, by default ".csv" - table_partition_cols : list[tuple[str, str]], optional - Partition columns, by default [] - file_compression_type : str | None, optional - File compression type, by default None + Args: + path: Path to the CSV file + schema: An optional schema representing the CSV files. If None, the + CSV reader will try to infer it based on data in file. + has_header: Whether the CSV file have a header. If schema inference + is run on a file with no headers, default column names are + created. + delimiter: An optional column delimiter. + schema_infer_max_records: Maximum number of rows to read from CSV + files for schema inference if needed. + file_extension: File extension; only files with this extension are + selected for data input. + table_partition_cols: Partition columns. + file_compression_type: File compression type. Returns: - ------- - DataFrame DataFrame representation of the read CSV files """ + if table_partition_cols is None: + table_partition_cols = [] return DataFrame( self.ctx.read_csv( str(path), @@ -1113,7 +939,7 @@ def read_csv( def read_parquet( self, path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] = [], + table_partition_cols: list[tuple[str, str]] | None = None, parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, @@ -1122,30 +948,26 @@ def read_parquet( ) -> DataFrame: """Create a `DataFrame` for reading Parquet data source. - Parameters - ---------- - path: str - Path to the Parquet file - table_partition_cols : list[tuple[str, str]], optional - Partition columns, by default [] - parquet_pruning : bool, optional - Whether the parquet reader should use the predicate to prune row groups, by default True - file_extension : str, optional - File extension; only files with this extension are selected for data input, by default ".parquet" - skip_metadata : bool, optional - Whether the parquet reader should skip any metadata that may be in the file - schema. This can help avoid schema conflicts due to metadata. by default True - schema : pyarrow.Schema | None, optional - An optional schema representing the parquet files. If None, the parquet - reader will try to infer it based on data in the file, by default None - file_sort_order : list[list[Expr]] | None, optional - Sort order for the file, by default None + Args: + path: Path to the Parquet file. + table_partition_cols: Partition columns. + parquet_pruning: Whether the parquet reader should use the predicate + to prune row groups. + file_extension: File extension; only files with this extension are + selected for data input. + skip_metadata: Whether the parquet reader should skip any metadata + that may be in the file schema. This can help avoid schema + conflicts due to metadata. + schema: An optional schema representing the parquet files. If None, + the parquet reader will try to infer it based on data in the + file. + file_sort_order: Sort order for the file. Returns: - ------- - DataFrame DataFrame representation of the read Parquet files """ + if table_partition_cols is None: + table_partition_cols = [] return DataFrame( self.ctx.read_parquet( str(path), @@ -1162,27 +984,22 @@ def read_avro( self, path: str | pathlib.Path, schema: pyarrow.Schema | None = None, - file_partition_cols: list[tuple[str, str]] = [], + file_partition_cols: list[tuple[str, str]] | None = None, file_extension: str = ".avro", ) -> DataFrame: """Create a `DataFrame` for reading Avro data source. - Parameters - ---------- - path : str - Path to the Avro file - schema : pyarrow.Schema | None, optional - The data source schema, by default None - file_partition_cols : list[tuple[str, str]], optional - Partition columns, by default [] - file_extension : str, optional - File extension to select, by default ".avro" + Args: + path: Path to the Avro file. + schema: The data source schema. + file_partition_cols: Partition columns. + file_extension: File extension to select. Returns: - ------- - DataFrame DataFrame representation of the read Avro file """ + if file_partition_cols is None: + file_partition_cols = [] return DataFrame( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) From 94650b546d129292c0a42372f591bbe5d3234ca9 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 14 Jul 2024 10:07:52 -0400 Subject: [PATCH 39/55] Updating docstring formatting --- python/datafusion/dataframe.py | 299 +++++++++++++-------------------- 1 file changed, 117 insertions(+), 182 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 99d9b035..6b815dad 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -16,7 +16,8 @@ # under the License. """DataFrame is one of the core concepts in DataFusion. -See https://datafusion.apache.org/python/user-guide/basics.html for more information. +See https://datafusion.apache.org/python/user-guide/basics.html for more +information. """ from __future__ import annotations @@ -40,9 +41,10 @@ class DataFrame: - """Two dimensional representation of data represented as rows and columns in a table. + """Two dimensional table representation of data. - See https://datafusion.apache.org/python/user-guide/basics.html for more information. + See https://datafusion.apache.org/python/user-guide/basics.html for more + information. """ def __init__(self, df: DataFrameInternal) -> None: @@ -55,14 +57,10 @@ def __init__(self, df: DataFrameInternal) -> None: def __getitem__(self, key: str | List[str]) -> DataFrame: """Return a new `DataFrame` with the specified column or columns. - Parameters - ---------- - key : Any - Column name or list of column names to select. + Args: + key: Column name or list of column names to select. Returns: - ------- - DataFrame DataFrame with the specified column or columns. """ return DataFrame(self.df.__getitem__(key)) @@ -71,8 +69,6 @@ def __repr__(self) -> str: """Return a string representation of the DataFrame. Returns: - ------- - str String representation of the DataFrame. """ return self.df.__repr__() @@ -86,8 +82,6 @@ def describe(self) -> DataFrame: The output format is modeled after pandas. Returns: - ------- - DataFrame A summary DataFrame containing statistics. """ return DataFrame(self.df.describe()) @@ -99,8 +93,6 @@ def schema(self) -> pa.Schema: nullability for each column. Returns: - ------- - pa.Schema Describing schema of the DataFrame """ return self.df.schema() @@ -109,39 +101,48 @@ def select_columns(self, *args: str) -> DataFrame: """Filter the DataFrame by columns. Returns: - ------- - DataFrame DataFrame only containing the specified columns. """ return self.select(*args) - def select(self, *args: Expr | str) -> DataFrame: - """Project arbitrary expressions (like SQL SELECT expressions) into a new `DataFrame`. + def select(self, *exprs: Expr | str) -> DataFrame: + """Project arbitrary expressions into a new `DataFrame`. + + Args: + exprs: Either column names or `Expr` to select. Returns: - ------- - DataFrame DataFrame after projection. It has one column for each expression. + + Example usage: + + The following example will return 3 columns from the original dataframe. + The first two columns will be the original column `a` and `b` since the + string "a" is assumed to refer to column selection. Also a duplicate of + column `a` will be returned with the column name `alternate_a`. + + ```python + df = df.select("a", col("b"), col("a").alias("alternate_a")) + ``` """ - args = [ - arg.expr if isinstance(arg, Expr) else Expr.column(arg).expr for arg in args + exprs = [ + arg.expr if isinstance(arg, Expr) else Expr.column(arg).expr + for arg in exprs ] - return DataFrame(self.df.select(*args)) + return DataFrame(self.df.select(*exprs)) def filter(self, *predicates: Expr) -> DataFrame: """Return a DataFrame for which `predicate` evaluates to `True`. - Rows for which `predicate` evaluates to `False` or `None` are filtered out. + Rows for which `predicate` evaluates to `False` or `None` are filtered + out. If more than one predicate is provided, these predicates will be + combined as a logical AND. If more complex logic is required, see the + logical operations in `datafusion.functions`. - Parameters - ---------- - predicates : Predicate expression(s) to filter the DataFrame. If more than one - is provided, these predicates will be combined as a logical AND. If more complex - logic is required, see logical operations in `datafusion.functions`. + Args: + predicates: Predicate expression(s) to filter the DataFrame. Returns: - ------- - DataFrame DataFrame after filtering. """ df = self.df @@ -152,16 +153,11 @@ def filter(self, *predicates: Expr) -> DataFrame: def with_column(self, name: str, expr: Expr) -> DataFrame: """Add an additional column to the DataFrame. - Parameters - ---------- - name : str - Name of the column to add. - expr : Expr - Expression to compute the column. + Args: + name: Name of the column to add. + expr: Expression to compute the column. Returns: - ------- - DataFrame DataFrame with the new column. """ return DataFrame(self.df.with_column(name, expr.expr)) @@ -174,35 +170,23 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame: The method supports case sensitive rename with wrapping column name into one the following symbols (" or ' or `). - Parameters - ---------- - old_name : str - Old column name. - new_name : str - New column name. + Args: + old_name: Old column name. + new_name: New column name. Returns: - ------- - DataFrame DataFrame with the column renamed. """ return DataFrame(self.df.with_column_renamed(old_name, new_name)) def aggregate(self, group_by: list[Expr], aggs: list[Expr]) -> DataFrame: - """Return a new `DataFrame` that aggregates the rows of the current DataFrame. - - First optionally grouping by the given expressions. + """Aggregates the rows of the current DataFrame. - Parameters - ---------- - group_by : list[Expr] - List of expressions to group by. - aggs : list[Expr] - List of expressions to aggregate. + Args: + group_by: List of expressions to group by. + aggs: List of expressions to aggregate. Returns: - ------- - DataFrame DataFrame after aggregation. """ group_by = [e.expr for e in group_by] @@ -215,9 +199,10 @@ def sort(self, *exprs: Expr) -> DataFrame: Note that any expression can be turned into a sort expression by calling its `sort` method. + Args: + exprs: Sort expressions, applied in order. + Returns: - ------- - DataFrame DataFrame after sorting. """ exprs = [expr.expr for expr in exprs] @@ -226,30 +211,23 @@ def sort(self, *exprs: Expr) -> DataFrame: def limit(self, count: int, offset: int = 0) -> DataFrame: """Return a new `DataFrame` with a limited number of rows. - Parameters - ---------- - count : int - Number of rows to limit the DataFrame to. - offset : int, optional - Number of rows to skip, by default 0 + Args: + count: Number of rows to limit the DataFrame to. + offset: Number of rows to skip. Returns: - ------- - DataFrame DataFrame after limiting. """ return DataFrame(self.df.limit(count, offset)) def collect(self) -> list[pa.RecordBatch]: - """Execute this `DataFrame` and collect `pyarrow.RecordBatch`es into memory. + """Execute this `DataFrame` and collect results into memory. Prior to calling `collect`, modifying a DataFrme simply updates a plan (no actual computation is performed). Calling `collect` triggers the computation. Returns: - ------- - list[pa.RecordBatch] List of `pyarrow.RecordBatch`es collected from the DataFrame. """ return self.df.collect() @@ -258,29 +236,27 @@ def cache(self) -> DataFrame: """Cache the DataFrame as a memory table. Returns: - ------- - DataFrame Cached DataFrame. """ return DataFrame(self.df.cache()) def collect_partitioned(self) -> list[list[pa.RecordBatch]]: - """Execute this DataFrame and collect all results into a list of list of `pyarrow.RecordBatch`es maintaining the input partitioning. + """Execute this DataFrame and collect all partitioned results. + + This operation returns `pyarrow.RecordBatch`es maintaining the input + partitioning. Returns: - ------- - list[list[pa.RecordBatch]] - List of list of `pyarrow.RecordBatch`es collected from the DataFrame. + List of list of `pyarrow.RecordBatch`es collected from the + DataFrame. """ return self.df.collect_partitioned() def show(self, num: int = 20) -> None: """Execute the DataFrame and print the result to the console. - Parameters - ---------- - num : int, optional - Number of lines to show, by default 20 + Args: + num: Number of lines to show. """ self.df.show(num) @@ -288,8 +264,6 @@ def distinct(self) -> DataFrame: """Return a new `DataFrame` with all duplicated rows removed. Returns: - ------- - DataFrame DataFrame after removing duplicates. """ return DataFrame(self.df.distinct()) @@ -300,20 +274,18 @@ def join( join_keys: tuple[list[str], list[str]], how: str, ) -> DataFrame: - """Join this `DataFrame` with another `DataFrame` using explicitly specified columns. + """Join this `DataFrame` with another `DataFrame`. + + Join keys are a pair of lists of column names in the left and right + dataframes, respectively. These lists must have the same length. - Parameters - ---------- - right : DataFrame - Other DataFrame to join with. - join_keys : tuple[list[str], list[str]] - Tuple of two lists of column names to join on. - how : str - Type of join to perform. Supported types are "inner", "left", "right", "full", "semi", "anti". + Args: + right: Other DataFrame to join with. + join_keys: Tuple of two lists of column names to join on. + how: Type of join to perform. Supported types are "inner", "left", + "right", "full", "semi", "anti". Returns: - ------- - DataFrame DataFrame after join. """ return DataFrame(self.df.join(right.df, join_keys, how)) @@ -323,16 +295,11 @@ def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame: If `analyze` is specified, runs the plan and reports metrics. - Parameters - ---------- - verbose : bool, optional - If `True`, more details will be included, by default False - analyze : bool, optional - If `True`, the plan will run and metrics reported, by default False + Args: + verbose: If `True`, more details will be included. + analyze: If `True`, the plan will run and metrics reported. Returns: - ------- - DataFrame DataFrame with the explanation of its plan. """ return DataFrame(self.df.explain(verbose, analyze)) @@ -341,8 +308,6 @@ def logical_plan(self) -> LogicalPlan: """Return the unoptimized `LogicalPlan` that comprises this `DataFrame`. Returns: - ------- - LogicalPlan Unoptimized logical plan. """ return self.df.logical_plan() @@ -351,8 +316,6 @@ def optimized_logical_plan(self) -> LogicalPlan: """Return the optimized `LogicalPlan` that comprises this `DataFrame`. Returns: - ------- - LogicalPlan Optimized logical plan. """ return self.df.optimized_logical_plan() @@ -361,8 +324,6 @@ def execution_plan(self) -> ExecutionPlan: """Return the execution/physical plan that comprises this `DataFrame`. Returns: - ------- - ExecutionPlan Execution plan. """ return self.df.execution_plan() @@ -372,49 +333,37 @@ def repartition(self, num: int) -> DataFrame: The batches allocation uses a round-robin algorithm. - Parameters - ---------- - num : int - Number of partitions to repartition the DataFrame into. + Args: + num: Number of partitions to repartition the DataFrame into. Returns: - ------- - DataFrame Repartitioned DataFrame. """ return DataFrame(self.df.repartition(num)) - def repartition_by_hash(self, *args: Expr, num: int) -> DataFrame: - """Repartition a DataFrame into `num` partitions using a hash partitioning scheme. + def repartition_by_hash(self, *exprs: Expr, num: int) -> DataFrame: + """Repartition a DataFrame using a hash partitioning scheme. - Parameters - ---------- - num : int - Number of partitions to repartition the DataFrame into. + Args: + exprs: Expressions to evaluate and perform hashing on. + num: Number of partitions to repartition the DataFrame into. Returns: - ------- - DataFrame Repartitioned DataFrame. """ - args = [expr.expr for expr in args] - return DataFrame(self.df.repartition_by_hash(*args, num=num)) + exprs = [expr.expr for expr in exprs] + return DataFrame(self.df.repartition_by_hash(*exprs, num=num)) def union(self, other: DataFrame, distinct: bool = False) -> DataFrame: """Calculate the union of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema. - Parameters - ---------- - other : DataFrame - DataFrame to union with. - distinct : bool, optional - If `True`, duplicate rows will be removed, by default False + Args: + other: DataFrame to union with. + distinct: If `True`, duplicate rows will be removed. Returns: - ------- - DataFrame DataFrame after union. """ return DataFrame(self.df.union(other.df, distinct)) @@ -425,14 +374,10 @@ def union_distinct(self, other: DataFrame) -> DataFrame: The two `DataFrame`s must have exactly the same schema. Any duplicate rows are discarded. - Parameters - ---------- - other : DataFrame - DataFrame to union with. + Args: + other: DataFrame to union with. Returns: - ------- - DataFrame DataFrame after union. """ return DataFrame(self.df.union_distinct(other.df)) @@ -442,14 +387,10 @@ def intersect(self, other: DataFrame) -> DataFrame: The two `DataFrame`s must have exactly the same schema. - Parameters - ---------- - other : DataFrame - DataFrame to intersect with. + Args: + other: DataFrame to intersect with. Returns: - ------- - DataFrame DataFrame after intersection. """ return DataFrame(self.df.intersect(other.df)) @@ -459,14 +400,10 @@ def except_all(self, other: DataFrame) -> DataFrame: The two `DataFrame`s must have exactly the same schema. - Parameters - ---------- - other : DataFrame - DataFrame to calculate exception with. + Args: + other: DataFrame to calculate exception with. Returns: - ------- - DataFrame DataFrame after exception. """ return DataFrame(self.df.except_all(other.df)) @@ -474,10 +411,9 @@ def except_all(self, other: DataFrame) -> DataFrame: def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None: """Execute the `DataFrame` and write the results to a CSV file. - Parameters - ---------- - path : str - Path of the CSV file to write. + Args: + path: Path of the CSV file to write. + with_header: If true, output the CSV header row. """ self.df.write_csv(str(path), with_header) @@ -489,24 +425,18 @@ def write_parquet( ) -> None: """Execute the `DataFrame` and write the results to a Parquet file. - Parameters - ---------- - path : str - Path of the Parquet file to write. - compression : str, optional - Compression type to use, by default "uncompressed" - compression_level : int | None, optional - Compression level to use, by default None + Args: + path: Path of the Parquet file to write. + compression: Compression type to use. + compression_level: Compression level to use. """ self.df.write_parquet(str(path), compression, compression_level) def write_json(self, path: str | pathlib.Path) -> None: """Execute the `DataFrame` and write the results to a JSON file. - Parameters - ---------- - path : str - Path of the JSON file to write. + Args: + path: Path of the JSON file to write. """ self.df.write_json(str(path)) @@ -514,18 +444,24 @@ def to_arrow_table(self) -> pa.Table: """Execute the `DataFrame` and convert it into an Arrow Table. Returns: - ------- - pa.Table Arrow Table. """ return self.df.to_arrow_table() def execute_stream(self) -> RecordBatchStream: - """Executes this DataFrame and returns a stream over a single partition.""" + """Executes this DataFrame and returns a stream over a single partition. + + Returns: + Record Batch Stream over a single partition. + """ return RecordBatchStream(self.df.execute_stream()) def execute_stream_partitioned(self) -> list[RecordBatchStream]: - """Executes this DataFrame and returns a stream for each partition.""" + """Executes this DataFrame and returns a stream for each partition. + + Returns: + One record batch stream per partition. + """ streams = self.df.execute_stream_partitioned() return [RecordBatchStream(rbs) for rbs in streams] @@ -533,8 +469,6 @@ def to_pandas(self) -> pd.DataFrame: """Execute the `DataFrame` and convert it into a Pandas DataFrame. Returns: - ------- - pd.DataFrame Pandas DataFrame. """ return self.df.to_pandas() @@ -543,8 +477,6 @@ def to_pylist(self) -> list[dict[str, Any]]: """Execute the `DataFrame` and convert it into a list of dictionaries. Returns: - ------- - list[dict[str, Any]] List of dictionaries. """ return self.df.to_pylist() @@ -553,8 +485,6 @@ def to_pydict(self) -> dict[str, list[Any]]: """Execute the `DataFrame` and convert it into a dictionary of lists. Returns: - ------- - dict[str, list[Any]] Dictionary of lists. """ return self.df.to_pydict() @@ -563,8 +493,6 @@ def to_polars(self) -> pl.DataFrame: """Execute the `DataFrame` and convert it into a Polars DataFrame. Returns: - ------- - pl.DataFrame Polars DataFrame. """ return self.df.to_polars() @@ -576,8 +504,6 @@ def count(self) -> int: count, which may be slow for large or complicated DataFrames. Returns: - ------- - int Number of rows in the DataFrame. """ return self.df.count() @@ -588,6 +514,15 @@ def unnest_column(self, column: str, preserve_nulls: bool = True) -> DataFrame: return DataFrame(self.df.unnest_column(column, preserve_nulls=preserve_nulls)) def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFrame: - """Expand columns of arrays into a single row per array element.""" + """Expand columns of arrays into a single row per array element. + + Args: + columns: Column names to perform unnest operation on. + preserve_nulls: If False, rows with null entries will not be + returned. + + Returns: + A DataFrame with the columns expanded. + """ columns = [c for c in columns] return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls)) From 95a4688be9c923de655ade4477611f6cb45a701e Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 14 Jul 2024 10:08:18 -0400 Subject: [PATCH 40/55] Updating docstring formatting --- python/datafusion/context.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index da304d47..33c54b7b 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -309,7 +309,7 @@ def with_fair_spill_pool(self, size: int) -> RuntimeConfig: Examples usage: ```python - config = RuntimeConfig().with_fair_spill_pool(1024) + config = RuntimeConfig().with_fair_spill_pool(1024) ``` """ self.config_internal = self.config_internal.with_fair_spill_pool(size) @@ -330,7 +330,7 @@ def with_greedy_memory_pool(self, size: int) -> RuntimeConfig: Example usage: ```python - config = RuntimeConfig().with_greedy_memory_pool(1024) + config = RuntimeConfig().with_greedy_memory_pool(1024) ``` """ self.config_internal = self.config_internal.with_greedy_memory_pool(size) @@ -347,7 +347,7 @@ def with_temp_file_path(self, path: str | pathlib.Path) -> RuntimeConfig: Example usage: ```python - config = RuntimeConfig().with_temp_file_path("/tmp") + config = RuntimeConfig().with_temp_file_path("/tmp") ``` """ self.config_internal = self.config_internal.with_temp_file_path(str(path)) @@ -987,7 +987,7 @@ def read_avro( file_partition_cols: list[tuple[str, str]] | None = None, file_extension: str = ".avro", ) -> DataFrame: - """Create a `DataFrame` for reading Avro data source. + """Create a ``DataFrame`` for reading Avro data source. Args: path: Path to the Avro file. From 39d9c00234413dbc1368dc747386ecaeecf5aca6 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 14 Jul 2024 10:12:53 -0400 Subject: [PATCH 41/55] Updating docstring formatting --- python/datafusion/expr.py | 50 +++++++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 0739c6cf..1c4f707f 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -109,7 +109,7 @@ def display_name(self) -> str: return self.expr.display_name() def canonical_name(self) -> str: - """Returns a full and complete string representation of this expression.""" + """Returns a complete string representation of this expression.""" return self.expr.canonical_name() def variant_name(self) -> str: @@ -266,7 +266,12 @@ def alias(self, name: str) -> Expr: return Expr(self.expr.alias(name)) def sort(self, ascending: bool = True, nulls_first: bool = True) -> Expr: - """Creates a sort ``Expr`` from an existing ``Expr``.""" + """Creates a sort ``Expr`` from an existing ``Expr``. + + Args: + ascending: If true, sort in ascending order. + nulls_first: Return null values first. + """ return Expr(self.expr.sort(ascending=ascending, nulls_first=nulls_first)) def is_null(self) -> Expr: @@ -281,33 +286,42 @@ def rex_type(self) -> RexType: """Return the Rex Type of this expression. A Rex (Row Expression) specifies a single row of data.That specification - could include user defined functions or types. RexType identifies the row - as one of the possible valid ``RexType``(s). + could include user defined functions or types. RexType identifies the + row as one of the possible valid ``RexType``(s). """ return self.expr.rex_type() def types(self) -> DataTypeMap: - """Return the ``DataTypeMap`` which represents the PythonType, Arrow DataType, and SqlType Enum which this expression represents.""" + """Return the ``DataTypeMap``. + + Returns: + DataTypeMap which represents the PythonType, Arrow DataType, and + SqlType Enum which this expression represents. + """ return self.expr.types() def python_value(self) -> Any: - """Extracts the Expr value into a PyObject that can be shared with Python. + """Extracts the Expr value into a PyObject. This is only valid for literal expressions. + + Returns: + Python object representing literal value of the expression. """ return self.expr.python_value() def rex_call_operands(self) -> list[Expr]: """Return the operands of the expression based on it's variant type. - Row expressions, Rex(s), operate on the concept of operands. Different variants of Expressions, Expr(s), - store those operands in different datastructures. This function examines the Expr variant and returns + Row expressions, Rex(s), operate on the concept of operands. Different + variants of Expressions, Expr(s), store those operands in different + datastructures. This function examines the Expr variant and returns the operands to the calling logic. """ return [Expr(e) for e in self.expr.rex_call_operands()] def rex_call_operator(self) -> str: - """Extracts the operator associated with a row expression type ``Call``.""" + """Extracts the operator associated with a row expression type call.""" return self.expr.rex_call_operator() def column_name(self, plan: LogicalPlan) -> str: @@ -325,8 +339,12 @@ def __init__( Args: units: Should be one of `rows`, `range`, or `groups`. - start_bound: Sets the preceeding bound. Must be >= 0. If none, this will be set to unbounded. If unit type is `groups`, this parameter must be set. - end_bound: Sets the following bound. Must be >= 0. If none, this will be set to unbounded. If unit type is `groups`, this parameter must be set. + start_bound: Sets the preceeding bound. Must be >= 0. If none, this + will be set to unbounded. If unit type is `groups`, this + parameter must be set. + end_bound: Sets the following bound. Must be >= 0. If none, this + will be set to unbounded. If unit type is `groups`, this + parameter must be set. """ self.window_frame = expr_internal.WindowFrame(units, start_bound, end_bound) @@ -382,14 +400,20 @@ class CaseBuilder: ```python import datafusion.functions as f from datafusion import lit, col - df.select(f.case(col("column_a").when(lit(1), lit("One")).when(lit(2), lit("Two")).otherwise(lit("Unknown"))) + df.select( + f.case(col("column_a") + .when(lit(1), lit("One")) + .when(lit(2), lit("Two")) + .otherwise(lit("Unknown")) + ) ``` """ def __init__(self, case_builder: expr_internal.CaseBuilder) -> None: """Constructs a case builder. - This is not typically called by the end user directly. See ``datafusion.functions.case`` instead. + This is not typically called by the end user directly. See + ``datafusion.functions.case`` instead. """ self.case_builder = case_builder From 671d508d0090b0c1883ba5cfbc6452f2e78d0035 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 15 Jul 2024 07:24:59 -0400 Subject: [PATCH 42/55] Updating docstring formatting --- python/datafusion/record_batch.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py index ddb90178..dcfd5548 100644 --- a/python/datafusion/record_batch.py +++ b/python/datafusion/record_batch.py @@ -15,7 +15,10 @@ # specific language governing permissions and limitations # under the License. -"""This module provides the classes for handling record batches, which are typically the result of dataframe `execute_stream` operations.""" +"""This module provides the classes for handling record batches. + +These are typically the result of dataframe `execute_stream` operations. +""" from __future__ import annotations @@ -43,7 +46,10 @@ def to_pyarrow(self) -> pyarrow.RecordBatch: class RecordBatchStream: - """This class represents a stream of record batches, typically as the result of a DataFrame `execute_stream` operation.""" + """This class represents a stream of record batches. + + These are typically the result of a ``DataFrame::execute_stream`` operation. + """ def __init__(self, record_batch_stream: df_internal.RecordBatchStream) -> None: """This constructor is typically not called by the end user.""" From 49efdd04f27318a91f90d54a99ded7928b67e5c5 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 15 Jul 2024 07:29:26 -0400 Subject: [PATCH 43/55] Updating docstring formatting --- python/datafusion/substrait.py | 84 ++++++++++------------------------ 1 file changed, 25 insertions(+), 59 deletions(-) diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index f3e5f59b..a199dd73 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -17,8 +17,8 @@ """This module provides support for using substrait with datafusion. -For additional information about substrait, see https://substrait.io/ for more information -about substrait. +For additional information about substrait, see https://substrait.io/ for more +information about substrait. """ from __future__ import annotations @@ -40,8 +40,8 @@ class Plan: def __init__(self, plan: substrait_internal.Plan) -> None: """Create a substrait plan. - The user should not have to call this constructor directly. Rather, it should be created - via `Serde` or `Producer` classes in this module. + The user should not have to call this constructor directly. Rather, it + should be created via `Serde` or `Producer` classes in this module. """ self.plan_internal = plan @@ -49,8 +49,6 @@ def encode(self) -> bytes: """Encode the plan to bytes. Returns: - ------- - bytes Encoded plan. """ return self.plan_internal.encode() @@ -64,20 +62,16 @@ class plan(Plan): class Serde: - """Provides the serialization and deserialization required to convert to and from a Substrait plan.""" + """Provides the ``Substrait`` serialization and deserialization.""" @staticmethod def serialize(sql: str, ctx: SessionContext, path: str | pathlib.Path) -> None: """Serialize a SQL query to a Substrait plan and write it to a file. - Parameters - ---------- - sql : str - SQL query to serialize. - ctx : SessionContext - SessionContext to use. - path : str - Path to write the Substrait plan to. + Args: + sql:SQL query to serialize. + ctx: SessionContext to use. + path: Path to write the Substrait plan to. """ return substrait_internal.serde.serialize(sql, ctx.ctx, str(path)) @@ -85,16 +79,11 @@ def serialize(sql: str, ctx: SessionContext, path: str | pathlib.Path) -> None: def serialize_to_plan(sql: str, ctx: SessionContext) -> Plan: """Serialize a SQL query to a Substrait plan. - Parameters - ---------- - sql : str - SQL query to serialize. - ctx : SessionContext - SessionContext to use. + Args: + sql: SQL query to serialize. + ctx: SessionContext to use. Returns: - ------- - plan Substrait plan. """ return Plan(substrait_internal.serde.serialize_to_plan(sql, ctx.ctx)) @@ -103,16 +92,11 @@ def serialize_to_plan(sql: str, ctx: SessionContext) -> Plan: def serialize_bytes(sql: str, ctx: SessionContext) -> bytes: """Serialize a SQL query to a Substrait plan as bytes. - Parameters - ---------- - sql : str - SQL query to serialize. - ctx : SessionContext - SessionContext to use. + Args: + sql: SQL query to serialize. + ctx: SessionContext to use. Returns: - ------- - bytes Substrait plan as bytes. """ return substrait_internal.serde.serialize_bytes(sql, ctx.ctx) @@ -121,14 +105,10 @@ def serialize_bytes(sql: str, ctx: SessionContext) -> bytes: def deserialize(path: str | pathlib.Path) -> Plan: """Deserialize a Substrait plan from a file. - Parameters - ---------- - path : str - Path to read the Substrait plan from. + Args: + path: Path to read the Substrait plan from. Returns: - ------- - plan Substrait plan. """ return Plan(substrait_internal.serde.deserialize(str(path))) @@ -137,14 +117,10 @@ def deserialize(path: str | pathlib.Path) -> Plan: def deserialize_bytes(proto_bytes: bytes) -> Plan: """Deserialize a Substrait plan from bytes. - Parameters - ---------- - proto_bytes : bytes - Bytes to read the Substrait plan from. + Args: + proto_bytes: Bytes to read the Substrait plan from. Returns: - ------- - plan Substrait plan. """ return Plan(substrait_internal.serde.deserialize_bytes(proto_bytes)) @@ -164,16 +140,11 @@ class Producer: def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> Plan: """Convert a DataFusion LogicalPlan to a Substrait plan. - Parameters - ---------- - plan : LogicalPlan - LogicalPlan to convert. - ctx : SessionContext - SessionContext to use. + Args: + logical_plan: LogicalPlan to convert. + ctx: SessionContext to use. Returns: - ------- - plan Substrait plan. """ return Plan( @@ -195,16 +166,11 @@ class Consumer: def from_substrait_plan(ctx: SessionContext, plan: Plan) -> LogicalPlan: """Convert a Substrait plan to a DataFusion LogicalPlan. - Parameters - ---------- - ctx : SessionContext - SessionContext to use. - plan : plan - Substrait plan to convert. + Args: + ctx: SessionContext to use. + plan: Substrait plan to convert. Returns: - ------- - LogicalPlan LogicalPlan. """ return substrait_internal.consumer.from_substrait_plan( From 3c7a8110826f0c64f4e908830af1b9be695e2001 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 15 Jul 2024 07:58:36 -0400 Subject: [PATCH 44/55] Cleaning up docstring line lengths --- python/datafusion/functions.py | 153 ++++++++++++++++++++++++--------- python/datafusion/udf.py | 11 ++- 2 files changed, 121 insertions(+), 43 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index cd726f4c..ad77712e 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -34,7 +34,10 @@ def isnan(expr: Expr) -> Expr: def nullif(expr1: Expr, expr2: Expr) -> Expr: - """Returns NULL if expr1 equals expr2; otherwise it returns expr1. This can be used to perform the inverse operation of the COALESCE expression.""" + """Returns NULL if expr1 equals expr2; otherwise it returns expr1. + + This can be used to perform the inverse operation of the COALESCE expression. + """ return Expr(f.nullif(expr1.expr, expr2.expr)) @@ -86,19 +89,26 @@ def in_list(arg: Expr, values: list[Expr], negated: bool = False) -> Expr: def digest(value: Expr, method: Expr) -> Expr: """Computes the binary hash of an expression using the specified algorithm. - Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, blake2b, and blake3. + Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, + blake2b, and blake3. """ return Expr(f.digest(value.expr, method.expr)) def concat(*args: Expr) -> Expr: - """Concatenates the text representations of all the arguments. NULL arguments are ignored.""" + """Concatenates the text representations of all the arguments. + + NULL arguments are ignored. + """ args = [arg.expr for arg in args] return Expr(f.concat(*args)) def concat_ws(separator: str, *args: Expr) -> Expr: - """Concatenates the list `args` with the separator. `NULL` arugments are ignored. `separator` should not be `NULL`.""" + """Concatenates the list `args` with the separator. + + `NULL` arugments are ignored. `separator` should not be `NULL`. + """ args = [arg.expr for arg in args] return Expr(f.concat_ws(separator, *args)) @@ -124,7 +134,10 @@ def count_star() -> Expr: def case(expr: Expr) -> CaseBuilder: - """Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.""" + """Create a ``CaseBuilder`` to match cases for the expression ``expr``. + + See ``datafusion.expr.CaseBuilder`` for detailed usage of ``CaseBuilder``. + """ return CaseBuilder(f.case(expr.expr)) @@ -284,7 +297,10 @@ def factorial(arg: Expr) -> Expr: def find_in_set(string: Expr, string_list: Expr) -> Expr: - """Returns a value in the range of 1 to N if the string is in the string list `string_list` consisting of N substrings. + """Find a string in a list of strings. + + Returns a value in the range of 1 to N if the string is in the string list + `string_list` consisting of N substrings. The string list is a string composed of substrings separated by `,` characters. """ @@ -302,7 +318,11 @@ def gcd(x: Expr, y: Expr) -> Expr: def initcap(string: Expr) -> Expr: - """Converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase.""" + """Set the initial letter of each word to capital. + + Converts the first letter of each word in `string` to uppercase and the remaining + characters to lowercase. + """ return Expr(f.initcap(string.expr)) @@ -360,7 +380,12 @@ def lower(arg: Expr) -> Expr: def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: - """Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right).""" + """Add left padding to a string. + + Extends the string to length length by prepending the characters fill (a + space by default). If the string is already longer than length then it is + truncated (on the right). + """ characters = characters if characters is not None else Expr.literal(" ") return Expr(f.lpad(string.expr, count.expr, characters.expr)) @@ -390,7 +415,8 @@ def overlay( ) -> Expr: """Replace a substring with a new substring. - Replace the substring of string that starts at the `start`'th character and extends for `length` characters with new substring. + Replace the substring of string that starts at the `start`'th character and + extends for `length` characters with new substring. """ if length is None: return Expr(f.overlay(string.expr, substring.expr, start.expr)) @@ -429,14 +455,22 @@ def radians(arg: Expr) -> Expr: def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: - """Tests a string using a regular expression returning true if at least one match, false otherwise.""" + """Find if any regular expression (regex) matches exist. + + Tests a string using a regular expression returning true if at least one match, + false otherwise. + """ if flags is not None: flags = flags.expr return Expr(f.regexp_like(string.expr, regex.expr, flags)) def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: - """Returns an array with each element containing the leftmost-first match of the corresponding index in `regex` to string in `string`.""" + """Perform regular expression (regex) matching. + + Returns an array with each element containing the leftmost-first match of the + corresponding index in `regex` to string in `string`. + """ if flags is not None: flags = flags.expr return Expr(f.regexp_match(string.expr, regex.expr, flags)) @@ -539,7 +573,11 @@ def sinh(arg: Expr) -> Expr: def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: - """Splits a string based on a delimiter and picks out the desired field based on the index.""" + """Split a string and return one part. + + Splits a string based on a delimiter and picks out the desired field based + on the index. + """ return Expr(f.split_part(string.expr, delimiter.expr, index.expr)) @@ -887,22 +925,30 @@ def array_has(first_array: Expr, second_array: Expr) -> Expr: def array_has_all(first_array: Expr, second_array: Expr) -> Expr: - """Returns true if each element of the second array appears in the first array. Otherwise, it returns false.""" + """Determines if there is complete overlap ``second_array`` in ``first_array``. + + Returns true if each element of the second array appears in the first array. + Otherwise, it returns false. + """ return Expr(f.array_has_all(first_array.expr, second_array.expr)) def array_has_any(first_array: Expr, second_array: Expr) -> Expr: - """Returns true if at least one element of the second array appears in the first array. Otherwise, it returns false.""" + """Determine if there is an overlap between ``first_array`` and ``second_array``. + + Returns true if at least one element of the second array appears in the first + array. Otherwise, it returns false. + """ return Expr(f.array_has_any(first_array.expr, second_array.expr)) def array_position(array: Expr, element: Expr, index: int | None = 1) -> Expr: - """Searches for an element in the array and returns the position of the first occurrence.""" + """Return the position of the first occurrence of ``element`` in ``array``.""" return Expr(f.array_position(array.expr, element.expr, index)) def array_indexof(array: Expr, element: Expr, index: int | None = 1) -> Expr: - """Searches for an element in the array and returns the position of the first occurrence. + """Return the position of the first occurrence of ``element`` in ``array``. This is an alias for `array_position`. """ @@ -910,7 +956,7 @@ def array_indexof(array: Expr, element: Expr, index: int | None = 1) -> Expr: def list_position(array: Expr, element: Expr, index: int | None = 1) -> Expr: - """Searches for an element in the array and returns the position of the first occurrence. + """Return the position of the first occurrence of ``element`` in ``array``. This is an alias for `array_position`. """ @@ -918,7 +964,7 @@ def list_position(array: Expr, element: Expr, index: int | None = 1) -> Expr: def list_indexof(array: Expr, element: Expr, index: int | None = 1) -> Expr: - """Searches for an element in the array and returns the position of the first occurrence. + """Return the position of the first occurrence of ``element`` in ``array``. This is an alias for `array_position`. """ @@ -1035,12 +1081,12 @@ def array_repeat(element: Expr, count: Expr) -> Expr: def array_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: - """Replaces the first occurrence of the specified element with another specified element.""" + """Replaces the first occurrence of ``from_val`` with ``to_val``.""" return Expr(f.array_replace(array.expr, from_val.expr, to_val.expr)) def list_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: - """Replaces the first occurrence of the specified element with another specified element. + """Replaces the first occurrence of ``from_val`` with ``to_val``. This is an alias for `array_replace`. """ @@ -1048,12 +1094,19 @@ def list_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: - """Replaces the first `max` occurrences of the specified element with another specified element.""" + """Replace `n` occurrences of ``from_val`` with ``to_val``. + + Replaces the first `max` occurrences of the specified element with another + specified element. + """ return Expr(f.array_replace_n(array.expr, from_val.expr, to_val.expr, max.expr)) def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: - """Replaces the first `max` occurrences of the specified element with another specified element. + """Replace `n` occurrences of ``from_val`` with ``to_val``. + + Replaces the first `max` occurrences of the specified element with another + specified element. This is an alias for `array_replace_n`. """ @@ -1061,12 +1114,12 @@ def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr def array_replace_all(array: Expr, from_val: Expr, to_val: Expr) -> Expr: - """Replaces all occurrences of the specified element with another specified element.""" + """Replaces all occurrences of ``from_val`` with ``to_val``.""" return Expr(f.array_replace_all(array.expr, from_val.expr, to_val.expr)) def list_replace_all(array: Expr, from_val: Expr, to_val: Expr) -> Expr: - """Replaces all occurrences of the specified element with another specified element. + """Replaces all occurrences of ``from_val`` with ``to_val``. This is an alias for `array_replace_all`. """ @@ -1104,12 +1157,17 @@ def list_intersect(array1: Expr, array2: Expr) -> Expr: def array_union(array1: Expr, array2: Expr) -> Expr: - """Returns an array of the elements in the union of array1 and array2 without duplicates.""" + """Returns an array of the elements in the union of array1 and array2. + + Duplicate rows will not be returned. + """ return Expr(f.array_union(array1.expr, array2.expr)) def list_union(array1: Expr, array2: Expr) -> Expr: - """Returns an array of the elements in the union of array1 and array2 without duplicates. + """Returns an array of the elements in the union of array1 and array2. + + Duplicate rows will not be returned. This is an alias for `array_union`. """ @@ -1117,7 +1175,7 @@ def list_union(array1: Expr, array2: Expr) -> Expr: def array_except(array1: Expr, array2: Expr) -> Expr: - """Returns an array of the elements that appear in `array1` but not in the `array2`.""" + """Returns an array of the elements that appear in `array1` but not in `array2`.""" return Expr(f.array_except(array1.expr, array2.expr)) @@ -1130,15 +1188,19 @@ def list_except(array1: Expr, array2: Expr) -> Expr: def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: - """Returns an array with the specified size filled. If `size` is greater than the `array` length, the additional entries will be filled with the given `value`.""" + """Returns an array with the specified size filled. + + If `size` is greater than the `array` length, the additional entries will be filled + with the given `value`. + """ return Expr(f.array_resize(array.expr, size.expr, value.expr)) def list_resize(array: Expr, size: Expr, value: Expr) -> Expr: """Returns an array with the specified size filled. - If `size` is greater than the `array` length, the additional entries will be filled with the given `value`. - This is an alias for `array_resize`. + If `size` is greater than the `array` length, the additional entries will be + filled with the given `value`. This is an alias for `array_resize`. """ return array_resize(array, size, value) @@ -1160,20 +1222,20 @@ def approx_median(arg: Expr, distinct: bool = False) -> Expr: def approx_percentile_cont( - arg: Expr, + expr: Expr, percentile: Expr, num_centroids: int | None = None, distinct: bool = False, ) -> Expr: - """Returns the value that is approximately at a given percentile of a distribution of values.""" + """Returns the value that is approximately at a given percentile of ``expr``.""" if num_centroids is None: return Expr( - f.approx_percentile_cont(arg.expr, percentile.expr, distinct=distinct) + f.approx_percentile_cont(expr.expr, percentile.expr, distinct=distinct) ) return Expr( f.approx_percentile_cont( - arg.expr, percentile.expr, num_centroids, distinct=distinct + expr.expr, percentile.expr, num_centroids, distinct=distinct ) ) @@ -1181,7 +1243,11 @@ def approx_percentile_cont( def approx_percentile_cont_with_weight( arg: Expr, weight: Expr, percentile: Expr, distinct: bool = False ) -> Expr: - """Returns the value that is approximately at a given percentile of a distribution of values with associated weights.""" + """Returns the value of the approximate percentile. + + This function is similar to ``approx_percentile_cont`` except that it uses + the associated associated weights. + """ return Expr( f.approx_percentile_cont_with_weight( arg.expr, weight.expr, percentile.expr, distinct=distinct @@ -1232,7 +1298,10 @@ def covar_samp(y: Expr, x: Expr) -> Expr: def grouping(arg: Expr, distinct: bool = False) -> Expr: - """Returns 1 if the value of the argument in the returned row is a null value.""" + """Indicates if the expression is aggregated or not. + + Returns 1 if the value of the argument is aggregated, 0 if not. + """ return Expr(f.grouping([arg.expr], distinct=distinct)) @@ -1301,17 +1370,23 @@ def var_samp(arg: Expr) -> Expr: def regr_avgx(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """Computes the average of the independent variable `x` for non-null pairs of the inputs.""" + """Computes the average of the independent variable `x`. + + Only non-null pairs of the inputs are evaluated. + """ return Expr(f.regr_avgx[y.expr, x.expr], distinct) def regr_avgy(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """Computes the average of the dependent variable `y` for non-null pairs of the inputs.""" + """Computes the average of the dependent variable ``y``. + + Only non-null pairs of the inputs are evaluated. + """ return Expr(f.regr_avgy[y.expr, x.expr], distinct) def regr_count(y: Expr, x: Expr, distinct: bool = False) -> Expr: - """Counts the number of input rows in which both expressions are not null.""" + """Counts the number of rows in which both expressions are not null.""" return Expr(f.regr_count[y.expr, x.expr], distinct) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 688a6433..4bfbabe6 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""This module provides the user defined functions for evaluation of dataframes.""" +"""Provides the user defined functions for evaluation of dataframes.""" from __future__ import annotations @@ -90,7 +90,10 @@ def __init__( return_type: _R, volatility: Volatility | str, ) -> None: - """Instantiate a scalar user defined function (UDF).""" + """Instantiate a scalar user defined function (UDF). + + See helper method ``udf`` for argument details. + """ self.udf = df_internal.ScalarUDF( name, func, input_types, return_type, str(volatility) ) @@ -120,7 +123,7 @@ def udf( must be of the same length as the number of arguments. return_type: The data type of the return value from the python function. - volatility: See ~`Volatility` for allowed values. + volatility: See ``Volatility`` for allowed values. name: A descriptive name for the function. Returns: @@ -186,7 +189,7 @@ def __init__( ) -> None: """Instantiate a user defined aggregate function (UDAF). - See ~`Aggregate::udaf` for a convenience function and arugment + See ``Aggregate::udaf`` for a convenience function and arugment descriptions. """ self.udf = df_internal.AggregateUDF( From fbf3f46d10a1bef6921eda12590457b844fab465 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 15 Jul 2024 08:20:55 -0400 Subject: [PATCH 45/55] Add pre-commit check of docstring line length --- pyproject.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8f21dc48..d579230c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,12 +67,14 @@ features = ["substrait"] # Enable docstring linting using the google style guide [tool.ruff.lint] -select = ["E4", "E7", "E9", "F", "D"] -ignore = ["D417"] +select = ["E4", "E7", "E9", "F", "D", "W"] [tool.ruff.lint.pydocstyle] convention = "google" +[tool.ruff.lint.pycodestyle] +max-doc-length = 88 + # Disable docstring checking for these directories [tool.ruff.lint.per-file-ignores] "python/datafusion/tests/*" = ["D"] From d6c6598dc3a1ba618714967f6466416862503c55 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 16 Jul 2024 08:34:36 -0400 Subject: [PATCH 46/55] Do not emit doc entry for __init__ of some classes --- docs/source/conf.py | 19 +++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 20 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index 2946efe3..308069b6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -80,6 +80,25 @@ autosummary_generate = True + +def autodoc_skip_member(app, what, name, obj, skip, options): + exclude_functions = "__init__" + exclude_classes = ("Expr", "DataFrame") + + class_name = "" + if hasattr(obj, "__qualname__"): + if obj.__qualname__ is not None: + class_name = obj.__qualname__.split(".")[0] + + should_exclude = name in exclude_functions and class_name in exclude_classes + + return True if should_exclude else None + + +def setup(app): + app.connect("autodoc-skip-member", autodoc_skip_member) + + # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for diff --git a/pyproject.toml b/pyproject.toml index d579230c..7c5cd3a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,3 +81,4 @@ max-doc-length = 88 "examples/*" = ["D"] "dev/*" = ["D"] "benchmarks/*" = ["D", "F"] +"docs/*" = ["D"] From cccf30504ae401ce8069906029dfcdfefa501d4b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 16 Jul 2024 08:35:01 -0400 Subject: [PATCH 47/55] Correct errors on code blocks generating in sphinx --- python/datafusion/context.py | 52 ++++++++++++++-------------------- python/datafusion/dataframe.py | 11 ++++--- python/datafusion/expr.py | 26 ++++++++--------- 3 files changed, 38 insertions(+), 51 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 33c54b7b..a717db10 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -288,9 +288,8 @@ def with_fair_spill_pool(self, size: int) -> RuntimeConfig: This pool works best when you know beforehand the query has multiple spillable operators that will likely all need to spill. Sometimes it will cause spills even when there was sufficient memory (reserved for other operators) to avoid - doing so. + doing so:: - ```text ┌───────────────────────z──────────────────────z───────────────┐ │ z z │ │ z z │ @@ -299,18 +298,16 @@ def with_fair_spill_pool(self, size: int) -> RuntimeConfig: │ z z │ │ z z │ └───────────────────────z──────────────────────z───────────────┘ - ``` Args: size: Size of the memory pool in bytes. Returns: - A new `RuntimeConfig` object with the updated setting. + A new ``RuntimeConfig`` object with the updated setting. + + Examples usage:: - Examples usage: - ```python - config = RuntimeConfig().with_fair_spill_pool(1024) - ``` + config = RuntimeConfig().with_fair_spill_pool(1024) """ self.config_internal = self.config_internal.with_fair_spill_pool(size) return self @@ -328,10 +325,9 @@ def with_greedy_memory_pool(self, size: int) -> RuntimeConfig: Returns: A new `RuntimeConfig` object with the updated setting. - Example usage: - ```python - config = RuntimeConfig().with_greedy_memory_pool(1024) - ``` + Example usage:: + + config = RuntimeConfig().with_greedy_memory_pool(1024) """ self.config_internal = self.config_internal.with_greedy_memory_pool(size) return self @@ -345,10 +341,9 @@ def with_temp_file_path(self, path: str | pathlib.Path) -> RuntimeConfig: Returns: A new `RuntimeConfig` object with the updated setting. - Example usage: - ```python - config = RuntimeConfig().with_temp_file_path("/tmp") - ``` + Example usage:: + + config = RuntimeConfig().with_temp_file_path("/tmp") """ self.config_internal = self.config_internal.with_temp_file_path(str(path)) return self @@ -378,10 +373,9 @@ def with_allow_ddl(self, allow: bool = True) -> SQLOptions: Returns: A new `SQLOptions` object with the updated setting. - Example usage: - ```python + Example usage:: + options = SQLOptions().with_allow_ddl(True) - ``` """ self.options_internal = self.options_internal.with_allow_ddl(allow) return self @@ -397,10 +391,9 @@ def with_allow_dml(self, allow: bool = True) -> SQLOptions: Returns: A new `SQLOptions` object with the updated setting. - Example usage: - ```python + Example usage:: + options = SQLOptions().with_allow_dml(True) - ``` """ self.options_internal = self.options_internal.with_allow_dml(allow) return self @@ -414,10 +407,9 @@ def with_allow_statements(self, allow: bool = True) -> SQLOptions: Returns: A new `SQLOptions` object with the updated setting. - Example usage: - ```python + Example usage:: + options = SQLOptions().with_allow_statements(True) - ``` """ self.options_internal = self.options_internal.with_allow_statements(allow) return self @@ -446,14 +438,12 @@ def __init__( Example usage: The following example demostrates how to use the context to execute - a query against a CSV data source using the `DataFrame` API: + a query against a CSV data source using the ``DataFrame`` API:: - ```python - from datafusion import SessionContext + from datafusion import SessionContext - ctx = SessionContext() - df = ctx.read_csv("data.csv") - ``` + ctx = SessionContext() + df = ctx.read_csv("data.csv") """ config = config.config_internal if config is not None else None runtime = runtime.config_internal if config is not None else None diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 6b815dad..68e6298f 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -119,11 +119,10 @@ def select(self, *exprs: Expr | str) -> DataFrame: The following example will return 3 columns from the original dataframe. The first two columns will be the original column `a` and `b` since the string "a" is assumed to refer to column selection. Also a duplicate of - column `a` will be returned with the column name `alternate_a`. + column `a` will be returned with the column name `alternate_a`:: + + df = df.select("a", col("b"), col("a").alias("alternate_a")) - ```python - df = df.select("a", col("b"), col("a").alias("alternate_a")) - ``` """ exprs = [ arg.expr if isinstance(arg, Expr) else Expr.column(arg).expr @@ -243,11 +242,11 @@ def cache(self) -> DataFrame: def collect_partitioned(self) -> list[list[pa.RecordBatch]]: """Execute this DataFrame and collect all partitioned results. - This operation returns `pyarrow.RecordBatch`es maintaining the input + This operation returns ``RecordBatch`` maintaining the input partitioning. Returns: - List of list of `pyarrow.RecordBatch`es collected from the + List of list of ``RecordBatch`` collected from the DataFrame. """ return self.df.collect_partitioned() diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 1c4f707f..c04a525a 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -89,7 +89,7 @@ class Expr: """Expression object. Expressions are one of the core concepts in DataFusion. See - [the online help](https://datafusion.apache.org/python/user-guide/common-operations/expressions.html) + https://datafusion.apache.org/python/user-guide/common-operations/expressions.html for more information. """ @@ -364,7 +364,7 @@ def get_upper_bound(self): class WindowFrameBound: """Defines a single window frame bound. - ```WindowFrame`` typically requires a start and end bound. + ``WindowFrame`` typically requires a start and end bound. """ def __init__(self, frame_bound: expr_internal.WindowFrameBound) -> None: @@ -395,18 +395,16 @@ def is_unbounded(self) -> bool: class CaseBuilder: """Builder class for constructing case statements. - An example usage would be as follows: - - ```python - import datafusion.functions as f - from datafusion import lit, col - df.select( - f.case(col("column_a") - .when(lit(1), lit("One")) - .when(lit(2), lit("Two")) - .otherwise(lit("Unknown")) - ) - ``` + An example usage would be as follows:: + + import datafusion.functions as f + from datafusion import lit, col + df.select( + f.case(col("column_a") + .when(lit(1), lit("One")) + .when(lit(2), lit("Two")) + .otherwise(lit("Unknown")) + ) """ def __init__(self, case_builder: expr_internal.CaseBuilder) -> None: From 6579ac556446e308a2246c00833fcef237cf1dff Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 16 Jul 2024 08:58:17 -0400 Subject: [PATCH 48/55] Resolve conflict with --- examples/tpch/_tests.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/tpch/_tests.py b/examples/tpch/_tests.py index cc201a31..903b5354 100644 --- a/examples/tpch/_tests.py +++ b/examples/tpch/_tests.py @@ -78,10 +78,7 @@ def check_q17(df): ("q08_market_share", "q8"), ("q09_product_type_profit_measure", "q9"), ("q10_returned_item_reporting", "q10"), - pytest.param( - "q11_important_stock_identification", - "q11", - ), + ("q11_important_stock_identification", "q11"), ("q12_ship_mode_order_priority", "q12"), ("q13_customer_distribution", "q13"), ("q14_promotion_effect", "q14"), @@ -99,8 +96,9 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str): module = import_module(query_code) df = module.df - # Treat q17 as a special case. The answer file does not match the spec. Running at - # scale factor 1, we have manually verified this result does match the expected value. + # Treat q17 as a special case. The answer file does not match the spec. + # Running at scale factor 1, we have manually verified this result does + # match the expected value. if answer_file == "q17": return check_q17(df) From 62197bca453b6849b10098a4f49340bd00dfeaf0 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 16 Jul 2024 13:23:24 -0400 Subject: [PATCH 49/55] Add license info to py.typed --- python/datafusion/py.typed | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/datafusion/py.typed b/python/datafusion/py.typed index e69de29b..d216be4d 100644 --- a/python/datafusion/py.typed +++ b/python/datafusion/py.typed @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. \ No newline at end of file From 2821183d04efd56828ee93777554cd54be654f9a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 16 Jul 2024 13:30:26 -0400 Subject: [PATCH 50/55] Clean up some docstring too long errors in CI --- benchmarks/db-benchmark/join-datafusion.py | 3 ++- pyproject.toml | 2 +- python/datafusion/input/base.py | 10 ++++++++-- python/datafusion/input/location.py | 7 +++++-- python/datafusion/object_store.py | 2 +- 5 files changed, 17 insertions(+), 7 deletions(-) diff --git a/benchmarks/db-benchmark/join-datafusion.py b/benchmarks/db-benchmark/join-datafusion.py index 4d59c7dc..03f6bd1e 100755 --- a/benchmarks/db-benchmark/join-datafusion.py +++ b/benchmarks/db-benchmark/join-datafusion.py @@ -74,7 +74,8 @@ def ans_shape(batches): ctx = df.SessionContext() print(ctx) -# TODO we should be applying projections to these table reads to crete relations of different sizes +# TODO we should be applying projections to these table reads to crete relations +# of different sizes x_data = pacsv.read_csv( src_jn_x, convert_options=pacsv.ConvertOptions(auto_dict_encode=True) diff --git a/pyproject.toml b/pyproject.toml index 7c5cd3a9..a18ef0e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ max-doc-length = 88 # Disable docstring checking for these directories [tool.ruff.lint.per-file-ignores] "python/datafusion/tests/*" = ["D"] -"examples/*" = ["D"] +"examples/*" = ["D", "W505"] "dev/*" = ["D"] "benchmarks/*" = ["D", "F"] "docs/*" = ["D"] diff --git a/python/datafusion/input/base.py b/python/datafusion/input/base.py index b91e0a1e..4eba1978 100644 --- a/python/datafusion/input/base.py +++ b/python/datafusion/input/base.py @@ -15,7 +15,10 @@ # specific language governing permissions and limitations # under the License. -"""This module provides ``BaseInputSource`` which a user can extend to provide a custom input source.""" +"""This module provides ``BaseInputSource``. + +A user can extend this to provide a custom input source. +""" from abc import ABC, abstractmethod from typing import Any @@ -24,7 +27,10 @@ class BaseInputSource(ABC): - """If a consuming library would like to provider their own InputSource this is the class they should extend to write their own. + """Base Input Source class. + + If a consuming library would like to provider their own InputSource this is + the class they should extend to write their own. Once completed the Plugin InputSource can be registered with the SessionContext to ensure that it will be used in order diff --git a/python/datafusion/input/location.py b/python/datafusion/input/location.py index 7454829d..566a63da 100644 --- a/python/datafusion/input/location.py +++ b/python/datafusion/input/location.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""This module provides ``LocationInputPlugin`` which is the default input source for DataFusion.""" +"""The default input source for DataFusion.""" import os import glob @@ -26,7 +26,10 @@ class LocationInputPlugin(BaseInputSource): - """Input Plugin for everything, which can be read in from a file (on disk, remote etc.).""" + """Input Plugin for everything. + + This can be read in from a file (on disk, remote etc.). + """ def is_correct_input(self, input_item: Any, table_name: str, **kwargs): """Returns `True` if the input is valid.""" diff --git a/python/datafusion/object_store.py b/python/datafusion/object_store.py index 06db9a25..a9bb83d2 100644 --- a/python/datafusion/object_store.py +++ b/python/datafusion/object_store.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module contains functionality for operating with different types of object stores.""" +"""Object store functionality.""" from ._internal import object_store From c1df7db698e80acf12c6e0de5ffad506aa54f5ee Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 16 Jul 2024 14:09:56 -0400 Subject: [PATCH 51/55] Correct ruff complain in unit tests --- python/datafusion/tests/test_context.py | 10 +++++----- python/datafusion/tests/test_dataframe.py | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/datafusion/tests/test_context.py b/python/datafusion/tests/test_context.py index fb60360d..8373659b 100644 --- a/python/datafusion/tests/test_context.py +++ b/python/datafusion/tests/test_context.py @@ -153,7 +153,7 @@ def test_from_arrow_table(ctx): assert df assert len(tables) == 1 - assert type(df) == DataFrame + assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 @@ -215,7 +215,7 @@ def test_from_pylist(ctx): assert df assert len(tables) == 1 - assert type(df) == DataFrame + assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 @@ -229,7 +229,7 @@ def test_from_pydict(ctx): assert df assert len(tables) == 1 - assert type(df) == DataFrame + assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 @@ -245,7 +245,7 @@ def test_from_pandas(ctx): assert df assert len(tables) == 1 - assert type(df) == DataFrame + assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 @@ -261,7 +261,7 @@ def test_from_polars(ctx): assert df assert len(tables) == 1 - assert type(df) == DataFrame + assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index f5db9fdb..25875da7 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -634,7 +634,7 @@ def test_to_pandas(df): # Convert datafusion dataframe to pandas dataframe pandas_df = df.to_pandas() - assert type(pandas_df) == pd.DataFrame + assert isinstance(pandas_df, pd.DataFrame) assert pandas_df.shape == (3, 3) assert set(pandas_df.columns) == {"a", "b", "c"} @@ -645,7 +645,7 @@ def test_empty_to_pandas(df): # Convert empty datafusion dataframe to pandas dataframe pandas_df = df.limit(0).to_pandas() - assert type(pandas_df) == pd.DataFrame + assert isinstance(pandas_df, pd.DataFrame) assert pandas_df.shape == (0, 3) assert set(pandas_df.columns) == {"a", "b", "c"} @@ -656,7 +656,7 @@ def test_to_polars(df): # Convert datafusion dataframe to polars dataframe polars_df = df.to_polars() - assert type(polars_df) == pl.DataFrame + assert isinstance(polars_df, pl.DataFrame) assert polars_df.shape == (3, 3) assert set(polars_df.columns) == {"a", "b", "c"} @@ -667,7 +667,7 @@ def test_empty_to_polars(df): # Convert empty datafusion dataframe to polars dataframe polars_df = df.limit(0).to_polars() - assert type(polars_df) == pl.DataFrame + assert isinstance(polars_df, pl.DataFrame) assert polars_df.shape == (0, 3) assert set(polars_df.columns) == {"a", "b", "c"} @@ -675,7 +675,7 @@ def test_empty_to_polars(df): def test_to_arrow_table(df): # Convert datafusion dataframe to pyarrow Table pyarrow_table = df.to_arrow_table() - assert type(pyarrow_table) == pa.Table + assert isinstance(pyarrow_table, pa.Table) assert pyarrow_table.shape == (3, 3) assert set(pyarrow_table.column_names) == {"a", "b", "c"} @@ -715,7 +715,7 @@ def test_execute_stream_partitioned(df): def test_empty_to_arrow_table(df): # Convert empty datafusion dataframe to pyarrow Table pyarrow_table = df.limit(0).to_arrow_table() - assert type(pyarrow_table) == pa.Table + assert isinstance(pyarrow_table, pa.Table) assert pyarrow_table.shape == (0, 3) assert set(pyarrow_table.column_names) == {"a", "b", "c"} From 461e7b5c26a0421128cece282e4d6d4b48342529 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 16 Jul 2024 17:58:24 -0400 Subject: [PATCH 52/55] Temporarily install google test to get clippy to pass --- .github/workflows/test.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 4f47dc98..c9a365bb 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -55,6 +55,10 @@ jobs: version: '3.20.2' repo-token: ${{ secrets.GITHUB_TOKEN }} + # To remove once https://github.com/MaterializeInc/rust-protobuf-native/issues/20 is resolved + - name: Install gtest + uses: MarkusJx/googletest-installer@v1.1 + - name: Setup Python uses: actions/setup-python@v5 with: From 4af541e1abb259f86dc307574eb15843a1861d73 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 16 Jul 2024 18:24:29 -0400 Subject: [PATCH 53/55] Adding gmock to build step due to upstream error --- .github/workflows/build.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 350be46d..a37abe53 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -89,6 +89,10 @@ jobs: name: python-wheel-license path: . + # To remove once https://github.com/MaterializeInc/rust-protobuf-native/issues/20 is resolved + - name: Install gtest + uses: MarkusJx/googletest-installer@v1.1 + - name: Install Protoc uses: arduino/setup-protoc@v1 with: From 5588f282edeb0b0427ccfca9b3f48b43a2c2a6ba Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 16 Jul 2024 18:09:40 -0400 Subject: [PATCH 54/55] Add type_extensions to conda meta file --- conda/recipes/meta.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/conda/recipes/meta.yaml b/conda/recipes/meta.yaml index 72ac7f50..b0784253 100644 --- a/conda/recipes/meta.yaml +++ b/conda/recipes/meta.yaml @@ -51,6 +51,7 @@ requirements: run: - python - pyarrow >=11.0.0 + - typing_extensions test: imports: From 39f01fbfe74baae78fa8008619ac52f2d3c03606 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 17 Jul 2024 07:55:13 -0400 Subject: [PATCH 55/55] Small comment suggestions from PR --- benchmarks/db-benchmark/join-datafusion.py | 2 +- examples/substrait.py | 2 -- python/datafusion/__init__.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/benchmarks/db-benchmark/join-datafusion.py b/benchmarks/db-benchmark/join-datafusion.py index 03f6bd1e..811ad870 100755 --- a/benchmarks/db-benchmark/join-datafusion.py +++ b/benchmarks/db-benchmark/join-datafusion.py @@ -74,7 +74,7 @@ def ans_shape(batches): ctx = df.SessionContext() print(ctx) -# TODO we should be applying projections to these table reads to crete relations +# TODO we should be applying projections to these table reads to create relations # of different sizes x_data = pacsv.read_csv( diff --git a/examples/substrait.py b/examples/substrait.py index 66f8a30d..fd4d0f9c 100644 --- a/examples/substrait.py +++ b/examples/substrait.py @@ -18,8 +18,6 @@ from datafusion import SessionContext from datafusion import substrait as ss -# TODO add user changing interface note to PR that datafusion.substrait.substrait is simplified to datafusion.substrait - # Create a DataFusion context ctx = SessionContext() diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 6fd1a887..59bc8e30 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -18,7 +18,7 @@ """DataFusion python package. This is a Python library that binds to Apache Arrow in-memory query engine DataFusion. -See https://datafusion.apache.org/python/index.html for more information. +See https://datafusion.apache.org/python for more information. """ try: