From 16f5a8c742c77a7a02c872b8525ab2c5a120a91c Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Mon, 11 Nov 2024 15:47:59 -0800 Subject: [PATCH] [FEAT] Add initial Spark Connect support (#3261) Implements basic Spark Connect functionality in Daft with the following: - Add `daft-connect` crate to handle Spark Connect protocol - Implement configuration management via Spark Connect API - Add Python bindings with `connect_start()` function - Add integration tests for config operations Currently supports: - Basic session management - Config operations (Set, Get, GetWithDefault, GetOption, GetAll, Unset) - Error handling and status reporting Notable changes: - New dependency on `spark-connect` protocol - Added `tracing` for debugging and monitoring - Integration with existing Daft infrastructure Some operations like execute_plan, analyze_plan, and artifact handling are currently unimplemented and will return appropriate error messages. --- Cargo.lock | 203 +++++++++++++-- Cargo.toml | 39 ++- daft/daft/__init__.pyi | 4 + requirements-dev.txt | 5 + src/daft-connect/Cargo.toml | 23 ++ src/daft-connect/src/config.rs | 222 ++++++++++++++++ src/daft-connect/src/err.rs | 15 ++ src/daft-connect/src/lib.rs | 243 ++++++++++++++++++ src/daft-connect/src/main.rs | 42 +++ src/daft-connect/src/session.rs | 41 +++ src/daft-connect/src/util.rs | 92 +++++++ src/daft-local-execution/src/lib.rs | 2 +- src/daft-table/src/lib.rs | 10 + src/lib.rs | 1 + .../spark-connect-generation/Cargo.lock | 2 +- tests/connect/__init__.py | 0 tests/connect/test_config_simple.py | 124 +++++++++ 17 files changed, 1033 insertions(+), 35 deletions(-) create mode 100644 src/daft-connect/Cargo.toml create mode 100644 src/daft-connect/src/config.rs create mode 100644 src/daft-connect/src/err.rs create mode 100644 src/daft-connect/src/lib.rs create mode 100644 src/daft-connect/src/main.rs create mode 100644 src/daft-connect/src/session.rs create mode 100644 src/daft-connect/src/util.rs create mode 100644 tests/connect/__init__.py create mode 100644 tests/connect/test_config_simple.py diff --git a/Cargo.lock b/Cargo.lock index 983385050d..9108356529 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1845,6 +1845,7 @@ dependencies = [ "common-tracing", "common-version", "daft-compression", + "daft-connect", "daft-core", "daft-csv", "daft-dsl", @@ -1886,6 +1887,23 @@ dependencies = [ "url", ] +[[package]] +name = "daft-connect" +version = "0.3.0-dev0" +dependencies = [ + "dashmap", + "eyre", + "futures", + "pyo3", + "spark-connect", + "tokio", + "tonic", + "tracing", + "tracing-subscriber", + "tracing-tracy", + "uuid 1.10.0", +] + [[package]] name = "daft-core" version = "0.3.0-dev0" @@ -2507,6 +2525,20 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core 0.9.10", +] + [[package]] name = "deflate64" version = "0.1.9" @@ -2751,6 +2783,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eyre" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" +dependencies = [ + "indenter", + "once_cell", +] + [[package]] name = "fallible-streaming-iterator" version = "0.1.9" @@ -2855,9 +2897,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -2870,9 +2912,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -2880,15 +2922,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -2897,9 +2939,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-lite" @@ -2918,9 +2960,9 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", @@ -2929,15 +2971,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-timer" @@ -2947,9 +2989,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -2963,6 +3005,19 @@ dependencies = [ "slab", ] +[[package]] +name = "generator" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb949699c3e4df3a183b1d2142cb24277057055ed23c68ed58894f76c517223" +dependencies = [ + "cfg-if", + "libc", + "log", + "rustversion", + "windows 0.58.0", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -3505,7 +3560,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows-core", + "windows-core 0.52.0", ] [[package]] @@ -3561,6 +3616,12 @@ dependencies = [ "quick-error 2.0.1", ] +[[package]] +name = "indenter" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" + [[package]] name = "indexmap" version = "1.9.3" @@ -3917,6 +3978,19 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "lz4" version = "1.26.0" @@ -5485,6 +5559,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -6022,7 +6102,7 @@ dependencies = [ "ntapi", "once_cell", "rayon", - "windows", + "windows 0.52.0", ] [[package]] @@ -6281,9 +6361,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.39.2" +version = "1.41.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", "bytes", @@ -6478,6 +6558,38 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "tracing-tracy" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc775fdaf33c3dfd19dc354729e65e87914bc67dcdc390ca1210807b8bee5902" +dependencies = [ + "tracing-core", + "tracing-subscriber", + "tracy-client", +] + +[[package]] +name = "tracy-client" +version = "0.17.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "746b078c6a09ebfd5594609049e07116735c304671eaab06ce749854d23435bc" +dependencies = [ + "loom", + "once_cell", + "tracy-client-sys", +] + +[[package]] +name = "tracy-client-sys" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3637e734239e12ab152cd269302500bd063f37624ee210cd04b4936ed671f3b1" +dependencies = [ + "cc", + "windows-targets 0.52.6", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -6907,7 +7019,17 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ - "windows-core", + "windows-core 0.52.0", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd04d41d93c4992d421894c18c8b43496aa748dd4c081bac0dc93eb0489272b6" +dependencies = [ + "windows-core 0.58.0", "windows-targets 0.52.6", ] @@ -6920,6 +7042,41 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba6d44ec8c2591c134257ce647b7ea6b20335bf6379a27dac5f1641fcf59f99" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-result", + "windows-strings", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-implement" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "windows-interface" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "windows-registry" version = "0.2.0" diff --git a/Cargo.toml b/Cargo.toml index e843958f4d..2830d2a0a8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,10 +9,11 @@ common-system-info = {path = "src/common/system-info", default-features = false} common-tracing = {path = "src/common/tracing", default-features = false} common-version = {path = "src/common/version", default-features = false} daft-compression = {path = "src/daft-compression", default-features = false} +daft-connect = {path = "src/daft-connect", optional = true} daft-core = {path = "src/daft-core", default-features = false} daft-csv = {path = "src/daft-csv", default-features = false} daft-dsl = {path = "src/daft-dsl", default-features = false} -daft-functions = {path = "src/daft-functions", default-features = false} +daft-functions = {path = "src/daft-functions"} daft-functions-json = {path = "src/daft-functions-json", default-features = false} daft-hash = {path = "src/daft-hash", default-features = false} daft-image = {path = "src/daft-image", default-features = false} @@ -45,6 +46,11 @@ python = [ "common-display/python", "common-resource-request/python", "common-system-info/python", + "common-daft-config/python", + "common-display/python", + "common-resource-request/python", + "common-system-info/python", + "daft-connect/python", "daft-core/python", "daft-csv/python", "daft-dsl/python", @@ -63,14 +69,10 @@ python = [ "daft-scheduler/python", "daft-sql/python", "daft-stats/python", - "daft-table/python", - "daft-functions/python", - "daft-functions-json/python", + "daft-stats/python", "daft-writers/python", - "common-daft-config/python", - "common-system-info/python", - "common-display/python", - "common-resource-request/python", + "daft-table/python", + "dep:daft-connect", "dep:pyo3", "dep:pyo3-log" ] @@ -157,6 +159,7 @@ members = [ "src/daft-table", "src/daft-writers", "src/hyperloglog", + "src/daft-connect", "src/parquet2", # "src/spark-connect-script", "src/generated/spark-connect" @@ -164,6 +167,7 @@ members = [ [workspace.dependencies] ahash = "0.8.11" +anyhow = "1.0.89" approx = "0.5.1" async-compat = "0.2.3" async-compression = {version = "0.4.12", features = [ @@ -177,8 +181,22 @@ bytes = "1.6.0" chrono = "0.4.38" chrono-tz = "0.8.4" comfy-table = "7.1.1" +common-daft-config = {path = "src/common/daft-config"} +common-display = {path = "src/common/display"} common-error = {path = "src/common/error", default-features = false} +common-file-formats = {path = "src/common/file-formats"} +daft-connect = {path = "src/daft-connect", default-features = false} +daft-core = {path = "src/daft-core"} +daft-dsl = {path = "src/daft-dsl"} daft-hash = {path = "src/daft-hash"} +daft-local-execution = {path = "src/daft-local-execution"} +daft-local-plan = {path = "src/daft-local-plan"} +daft-logical-plan = {path = "src/daft-logical-plan"} +daft-micropartition = {path = "src/daft-micropartition"} +daft-physical-plan = {path = "src/daft-physical-plan"} +daft-schema = {path = "src/daft-schema"} +daft-sql = {path = "src/daft-sql"} +daft-table = {path = "src/daft-table"} derivative = "2.2.0" derive_builder = "0.20.2" divan = "0.1.14" @@ -207,6 +225,7 @@ serde_json = "1.0.116" sha1 = "0.11.0-pre.4" sketches-ddsketch = {version = "0.2.2", features = ["use_serde"]} snafu = {version = "0.7.4", features = ["futures"]} +spark-connect = {path = "src/generated/spark-connect", default-features = false} sqlparser = "0.51.0" sysinfo = "0.30.12" tango-bench = "0.6.0" @@ -236,7 +255,7 @@ path = "src/arrow2" version = "1.3.3" [workspace.dependencies.derive_more] -features = ["display"] +features = ["display", "from", "constructor"] version = "1.0.0" [workspace.dependencies.lazy_static] @@ -324,7 +343,7 @@ uninlined_format_args = "allow" unnecessary_wraps = "allow" unnested_or_patterns = "allow" unreadable_literal = "allow" -# todo: remove? +# todo: remove this at some point unsafe_derive_deserialize = "allow" unused_async = "allow" # used_underscore_items = "allow" # REMOVE diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 8ec06319d3..38582ece15 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1246,6 +1246,10 @@ def sql_expr(sql: str) -> PyExpr: ... def list_sql_functions() -> list[SQLFunctionStub]: ... def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ... def to_struct(inputs: list[PyExpr]) -> PyExpr: ... +def connect_start(addr: str) -> ConnectionHandle: ... + +class ConnectionHandle: + def shutdown(self) -> None: ... # expr numeric ops def abs(expr: PyExpr) -> PyExpr: ... diff --git a/requirements-dev.txt b/requirements-dev.txt index 7ab648ef4c..bc39a80b98 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -91,3 +91,8 @@ sphinx-reredirects>=0.1.1 sphinx-copybutton>=0.5.2 sphinx-autosummary-accessors==2023.4.0; python_version >= "3.9" sphinx-tabs==3.4.5 + +# Daft connect testing +pyspark==3.5.3 +grpcio==1.67.0 +grpcio-status==1.67.0 diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml new file mode 100644 index 0000000000..c11972fa4c --- /dev/null +++ b/src/daft-connect/Cargo.toml @@ -0,0 +1,23 @@ +[dependencies] +dashmap = "6.1.0" +eyre = "0.6.12" +futures = "0.3.31" +pyo3 = {workspace = true, optional = true} +tokio = {version = "1.40.0", features = ["full"]} +tonic = "0.12.3" +tracing-subscriber = {version = "0.3.18", features = ["env-filter"]} +tracing-tracy = "0.11.3" +uuid = {version = "1.10.0", features = ["v4"]} +spark-connect.workspace = true +tracing.workspace = true + +[features] +python = ["dep:pyo3"] + +[lints] +workspace = true + +[package] +edition = {workspace = true} +name = "daft-connect" +version = {workspace = true} diff --git a/src/daft-connect/src/config.rs b/src/daft-connect/src/config.rs new file mode 100644 index 0000000000..b29215e668 --- /dev/null +++ b/src/daft-connect/src/config.rs @@ -0,0 +1,222 @@ +use std::collections::BTreeMap; + +use spark_connect::{ + config_request::{Get, GetAll, GetOption, GetWithDefault, IsModifiable, Set, Unset}, + ConfigResponse, KeyValue, +}; +use tonic::Status; + +use crate::Session; + +impl Session { + fn config_response(&self) -> ConfigResponse { + ConfigResponse { + session_id: self.client_side_session_id().to_string(), + server_side_session_id: self.server_side_session_id().to_string(), + pairs: vec![], + warnings: vec![], + } + } + + #[tracing::instrument(skip(self), level = "trace")] + pub fn set(&mut self, operation: Set) -> Result { + let mut response = self.config_response(); + + let span = + tracing::info_span!("set", session_id = %self.client_side_session_id(), ?operation); + let _enter = span.enter(); + + for KeyValue { key, value } in operation.pairs { + let Some(value) = value else { + let msg = format!("Missing value for key {key}. If you want to unset a value use the Unset operation"); + response.warnings.push(msg); + continue; + }; + + let previous = self.config_values_mut().insert(key.clone(), value.clone()); + if previous.is_some() { + tracing::info!("Updated existing configuration value"); + } else { + tracing::info!("Set new configuration value"); + } + } + + Ok(response) + } + + #[tracing::instrument(skip(self), level = "trace")] + pub fn get(&self, operation: Get) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("get", session_id = %self.client_side_session_id()); + let _enter = span.enter(); + + for key in operation.keys { + let value = self.config_values().get(&key).cloned(); + response.pairs.push(KeyValue { key, value }); + } + + Ok(response) + } + + #[tracing::instrument(skip(self), level = "trace")] + pub fn get_with_default(&self, operation: GetWithDefault) -> Result { + let mut response = self.config_response(); + + let span = + tracing::info_span!("get_with_default", session_id = %self.client_side_session_id()); + let _enter = span.enter(); + + for KeyValue { + key, + value: default_value, + } in operation.pairs + { + let value = self.config_values().get(&key).cloned().or(default_value); + response.pairs.push(KeyValue { key, value }); + } + + Ok(response) + } + + /// Needs to be fixed so it has different behavior than [`Session::get`]. Not entirely + /// sure how it should work yet. + #[tracing::instrument(skip(self), level = "trace")] + pub fn get_option(&self, operation: GetOption) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("get_option", session_id = %self.client_side_session_id()); + let _enter = span.enter(); + + for key in operation.keys { + let value = self.config_values().get(&key).cloned(); + response.pairs.push(KeyValue { key, value }); + } + + Ok(response) + } + + #[tracing::instrument(skip(self), level = "trace")] + pub fn get_all(&self, operation: GetAll) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("get_all", session_id = %self.client_side_session_id()); + let _enter = span.enter(); + + let Some(prefix) = operation.prefix else { + for (key, value) in self.config_values() { + response.pairs.push(KeyValue { + key: key.clone(), + value: Some(value.clone()), + }); + } + return Ok(response); + }; + + for (k, v) in prefix_search(self.config_values(), &prefix) { + response.pairs.push(KeyValue { + key: k.clone(), + value: Some(v.clone()), + }); + } + + Ok(response) + } + + #[tracing::instrument(skip(self), level = "trace")] + pub fn unset(&mut self, operation: Unset) -> Result { + let mut response = self.config_response(); + + let span = tracing::info_span!("unset", session_id = %self.client_side_session_id()); + let _enter = span.enter(); + + for key in operation.keys { + if self.config_values_mut().remove(&key).is_none() { + let msg = format!("Key {key} not found"); + response.warnings.push(msg); + } else { + tracing::info!("Unset configuration value"); + } + } + + Ok(response) + } + + #[tracing::instrument(skip(self), level = "trace")] + pub fn is_modifiable(&self, _operation: IsModifiable) -> Result { + let response = self.config_response(); + + let span = + tracing::info_span!("is_modifiable", session_id = %self.client_side_session_id()); + let _enter = span.enter(); + + tracing::warn!(session_id = %self.client_side_session_id(), "is_modifiable operation not yet implemented"); + // todo: need to implement this + Ok(response) + } +} + +fn prefix_search<'a, V>( + map: &'a BTreeMap, + prefix: &'a str, +) -> impl Iterator { + let start = map.range(prefix.to_string()..); + start.take_while(move |(k, _)| k.starts_with(prefix)) +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use super::*; + + #[test] + fn test_prefix_search() { + let mut map = BTreeMap::new(); + map.insert("apple".to_string(), 1); + map.insert("application".to_string(), 2); + map.insert("banana".to_string(), 3); + map.insert("app".to_string(), 4); + map.insert("apricot".to_string(), 5); + + // Test with prefix "app" + let result: Vec<_> = prefix_search(&map, "app").collect(); + assert_eq!( + result, + vec![ + (&"app".to_string(), &4), + (&"apple".to_string(), &1), + (&"application".to_string(), &2), + ] + ); + + // Test with prefix "b" + let result: Vec<_> = prefix_search(&map, "b").collect(); + assert_eq!(result, vec![(&"banana".to_string(), &3),]); + + // Test with prefix that doesn't match any keys + let result: Vec<_> = prefix_search(&map, "z").collect(); + assert_eq!(result, vec![]); + + // Test with empty prefix (should return all items) + let result: Vec<_> = prefix_search(&map, "").collect(); + assert_eq!( + result, + vec![ + (&"app".to_string(), &4), + (&"apple".to_string(), &1), + (&"application".to_string(), &2), + (&"apricot".to_string(), &5), + (&"banana".to_string(), &3), + ] + ); + + // Test with prefix that matches a complete key + let result: Vec<_> = prefix_search(&map, "apple").collect(); + assert_eq!(result, vec![(&"apple".to_string(), &1),]); + + // Test with case sensitivity + let result: Vec<_> = prefix_search(&map, "App").collect(); + assert_eq!(result, vec![]); + } +} diff --git a/src/daft-connect/src/err.rs b/src/daft-connect/src/err.rs new file mode 100644 index 0000000000..d210ef8458 --- /dev/null +++ b/src/daft-connect/src/err.rs @@ -0,0 +1,15 @@ +#[macro_export] +macro_rules! invalid_argument { + ($arg: tt) => {{ + let msg = format!($arg); + Err(::tonic::Status::invalid_argument(msg)) + }}; +} + +#[macro_export] +macro_rules! unimplemented_err { + ($arg: tt) => {{ + let msg = format!($arg); + Err(::tonic::Status::unimplemented(msg)) + }}; +} diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs new file mode 100644 index 0000000000..12c43e6901 --- /dev/null +++ b/src/daft-connect/src/lib.rs @@ -0,0 +1,243 @@ +#![feature(iterator_try_collect)] +#![feature(let_chains)] +#![feature(try_trait_v2)] +#![feature(coroutines)] +#![feature(iter_from_coroutine)] +#![feature(stmt_expr_attributes)] +#![feature(try_trait_v2_residual)] +#![warn(unused)] + +use dashmap::DashMap; +use eyre::Context; +#[cfg(feature = "python")] +use pyo3::types::PyModuleMethods; +use spark_connect::{ + spark_connect_service_server::{SparkConnectService, SparkConnectServiceServer}, + AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, + ArtifactStatusesRequest, ArtifactStatusesResponse, ConfigRequest, ConfigResponse, + ExecutePlanRequest, ExecutePlanResponse, FetchErrorDetailsRequest, FetchErrorDetailsResponse, + InterruptRequest, InterruptResponse, ReattachExecuteRequest, ReleaseExecuteRequest, + ReleaseExecuteResponse, ReleaseSessionRequest, ReleaseSessionResponse, +}; +use tonic::{transport::Server, Request, Response, Status}; +use tracing::{info, warn}; +use uuid::Uuid; + +use crate::session::Session; + +mod config; +mod err; +mod session; +pub mod util; + +#[cfg_attr(feature = "python", pyo3::pyclass)] +pub struct ConnectionHandle { + shutdown_signal: Option>, +} + +#[cfg_attr(feature = "python", pyo3::pymethods)] +impl ConnectionHandle { + pub fn shutdown(&mut self) { + let Some(shutdown_signal) = self.shutdown_signal.take() else { + return; + }; + shutdown_signal.send(()).unwrap(); + } +} + +pub fn start(addr: &str) -> eyre::Result { + info!("Daft-Connect server listening on {addr}"); + let addr = util::parse_spark_connect_address(addr)?; + + let service = DaftSparkConnectService::default(); + + info!("Daft-Connect server listening on {addr}"); + + let (shutdown_signal, shutdown_receiver) = tokio::sync::oneshot::channel(); + + let handle = ConnectionHandle { + shutdown_signal: Some(shutdown_signal), + }; + + std::thread::spawn(move || { + let runtime = tokio::runtime::Runtime::new().unwrap(); + let result = runtime + .block_on(async { + tokio::select! { + result = Server::builder() + .add_service(SparkConnectServiceServer::new(service)) + .serve(addr) => { + result + } + _ = shutdown_receiver => { + info!("Received shutdown signal"); + Ok(()) + } + } + }) + .wrap_err_with(|| format!("Failed to start server on {addr}")); + + if let Err(e) = result { + eprintln!("Daft-Connect server error: {e:?}"); + } + + eyre::Result::<_>::Ok(()) + }); + + Ok(handle) +} + +#[derive(Default)] +pub struct DaftSparkConnectService { + client_to_session: DashMap, // To track session data +} + +impl DaftSparkConnectService { + fn get_session( + &self, + session_id: &str, + ) -> Result, Status> { + let Ok(uuid) = Uuid::parse_str(session_id) else { + return Err(Status::invalid_argument( + "Invalid session_id format, must be a UUID", + )); + }; + + let res = self + .client_to_session + .entry(uuid) + .or_insert_with(|| Session::new(session_id.to_string())); + + Ok(res) + } +} + +#[tonic::async_trait] +impl SparkConnectService for DaftSparkConnectService { + type ExecutePlanStream = std::pin::Pin< + Box< + dyn futures::Stream> + Send + Sync + 'static, + >, + >; + type ReattachExecuteStream = std::pin::Pin< + Box< + dyn futures::Stream> + Send + Sync + 'static, + >, + >; + + #[tracing::instrument(skip_all)] + async fn execute_plan( + &self, + _request: Request, + ) -> Result, Status> { + unimplemented_err!("Unsupported plan type") + } + + #[tracing::instrument(skip_all)] + async fn config( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + let mut session = self.get_session(&request.session_id)?; + + let Some(operation) = request.operation.and_then(|op| op.op_type) else { + return Err(Status::invalid_argument("Missing operation")); + }; + + use spark_connect::config_request::operation::OpType; + + let response = match operation { + OpType::Set(op) => session.set(op), + OpType::Get(op) => session.get(op), + OpType::GetWithDefault(op) => session.get_with_default(op), + OpType::GetOption(op) => session.get_option(op), + OpType::GetAll(op) => session.get_all(op), + OpType::Unset(op) => session.unset(op), + OpType::IsModifiable(op) => session.is_modifiable(op), + }?; + + Ok(Response::new(response)) + } + + #[tracing::instrument(skip_all)] + async fn add_artifacts( + &self, + _request: Request>, + ) -> Result, Status> { + unimplemented_err!("add_artifacts operation is not yet implemented") + } + + #[tracing::instrument(skip_all)] + async fn analyze_plan( + &self, + _request: Request, + ) -> Result, Status> { + unimplemented_err!("Analyze plan operation is not yet implemented") + } + + #[tracing::instrument(skip_all)] + async fn artifact_status( + &self, + _request: Request, + ) -> Result, Status> { + unimplemented_err!("artifact_status operation is not yet implemented") + } + + #[tracing::instrument(skip_all)] + async fn interrupt( + &self, + _request: Request, + ) -> Result, Status> { + println!("got interrupt"); + unimplemented_err!("interrupt operation is not yet implemented") + } + + #[tracing::instrument(skip_all)] + async fn reattach_execute( + &self, + _request: Request, + ) -> Result, Status> { + unimplemented_err!("reattach_execute operation is not yet implemented") + } + + #[tracing::instrument(skip_all)] + async fn release_execute( + &self, + _request: Request, + ) -> Result, Status> { + unimplemented_err!("release_execute operation is not yet implemented") + } + + #[tracing::instrument(skip_all)] + async fn release_session( + &self, + _request: Request, + ) -> Result, Status> { + println!("got release session"); + unimplemented_err!("release_session operation is not yet implemented") + } + + #[tracing::instrument(skip_all)] + async fn fetch_error_details( + &self, + _request: Request, + ) -> Result, Status> { + println!("got fetch error details"); + unimplemented_err!("fetch_error_details operation is not yet implemented") + } +} +#[cfg(feature = "python")] +#[pyo3::pyfunction] +#[pyo3(name = "connect_start")] +pub fn py_connect_start(addr: &str) -> pyo3::PyResult { + start(addr).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:?}"))) +} + +#[cfg(feature = "python")] +pub fn register_modules(parent: &pyo3::Bound) -> pyo3::PyResult<()> { + parent.add_function(pyo3::wrap_pyfunction_bound!(py_connect_start, parent)?)?; + parent.add_class::()?; + Ok(()) +} diff --git a/src/daft-connect/src/main.rs b/src/daft-connect/src/main.rs new file mode 100644 index 0000000000..249938896c --- /dev/null +++ b/src/daft-connect/src/main.rs @@ -0,0 +1,42 @@ +use daft_connect::DaftSparkConnectService; +use spark_connect::spark_connect_service_server::SparkConnectServiceServer; +use tonic::transport::Server; +use tracing::info; +use tracing_subscriber::{layer::SubscriberExt, Registry}; +use tracing_tracy::TracyLayer; + +fn setup_tracing() { + tracing::subscriber::set_global_default( + Registry::default().with(TracyLayer::default()).with( + tracing_subscriber::fmt::layer() + .with_target(false) + .with_thread_ids(false) + .with_file(true) + .with_line_number(true), + ), + ) + .expect("setup tracing subscribers"); +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + setup_tracing(); + + let addr = "[::1]:50051".parse()?; + let service = DaftSparkConnectService::default(); + + info!("Daft-Connect server listening on {}", addr); + + tokio::select! { + result = Server::builder() + .add_service(SparkConnectServiceServer::new(service)) + .serve(addr) => { + result?; + } + _ = tokio::signal::ctrl_c() => { + info!("\nReceived Ctrl-C, gracefully shutting down server"); + } + } + + Ok(()) +} diff --git a/src/daft-connect/src/session.rs b/src/daft-connect/src/session.rs new file mode 100644 index 0000000000..24f7fabe80 --- /dev/null +++ b/src/daft-connect/src/session.rs @@ -0,0 +1,41 @@ +use std::collections::BTreeMap; + +use uuid::Uuid; + +pub struct Session { + /// so order is preserved, and so we can efficiently do a prefix search + /// + /// Also, + config_values: BTreeMap, + + id: String, + server_side_session_id: String, +} + +impl Session { + pub fn config_values(&self) -> &BTreeMap { + &self.config_values + } + + pub fn config_values_mut(&mut self) -> &mut BTreeMap { + &mut self.config_values + } + + pub fn new(id: String) -> Self { + let server_side_session_id = Uuid::new_v4(); + let server_side_session_id = server_side_session_id.to_string(); + Self { + config_values: Default::default(), + id, + server_side_session_id, + } + } + + pub fn client_side_session_id(&self) -> &str { + &self.id + } + + pub fn server_side_session_id(&self) -> &str { + &self.server_side_session_id + } +} diff --git a/src/daft-connect/src/util.rs b/src/daft-connect/src/util.rs new file mode 100644 index 0000000000..cbec2211b2 --- /dev/null +++ b/src/daft-connect/src/util.rs @@ -0,0 +1,92 @@ +use std::net::ToSocketAddrs; + +pub fn parse_spark_connect_address(addr: &str) -> eyre::Result { + // Check if address starts with "sc://" + if !addr.starts_with("sc://") { + return Err(eyre::eyre!("Address must start with 'sc://'")); + } + + // Remove the "sc://" prefix + let addr = addr.trim_start_matches("sc://"); + + // Resolve the hostname using tokio's DNS resolver + let addrs = addr.to_socket_addrs()?; + + // Take the first resolved address + addrs + .into_iter() + .next() + .ok_or_else(|| eyre::eyre!("No addresses found for hostname")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_spark_connect_address_valid() { + let addr = "sc://localhost:10009"; + let result = parse_spark_connect_address(addr); + assert!(result.is_ok()); + } + + #[test] + fn test_parse_spark_connect_address_missing_prefix() { + let addr = "localhost:10009"; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("must start with 'sc://'")); + } + + #[test] + fn test_parse_spark_connect_address_invalid_port() { + let addr = "sc://localhost:invalid"; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + } + + #[test] + fn test_parse_spark_connect_address_missing_port() { + let addr = "sc://localhost"; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + } + + #[test] + fn test_parse_spark_connect_address_empty() { + let addr = ""; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("must start with 'sc://'")); + } + + #[test] + fn test_parse_spark_connect_address_only_prefix() { + let addr = "sc://"; + let result = parse_spark_connect_address(addr); + assert!(result.is_err()); + + let err = result.unwrap_err().to_string(); + assert_eq!(err, "invalid socket address"); + } + + #[test] + fn test_parse_spark_connect_address_ipv4() { + let addr = "sc://127.0.0.1:10009"; + let result = parse_spark_connect_address(addr); + assert!(result.is_ok()); + } + + #[test] + fn test_parse_spark_connect_address_ipv6() { + let addr = "sc://[::1]:10009"; + let result = parse_spark_connect_address(addr); + assert!(result.is_ok()); + } +} diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index 553ad18b40..719da409c4 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -11,7 +11,7 @@ mod sources; use common_error::{DaftError, DaftResult}; use lazy_static::lazy_static; -pub use run::NativeExecutor; +pub use run::{run_local, NativeExecutor}; use snafu::{futures::TryFutureExt, Snafu}; lazy_static! { diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 0a450ace70..c68546f318 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -99,6 +99,12 @@ impl Table { Ok(Self::new_unchecked(schema, columns?, num_rows)) } + pub fn get_inner_arrow_arrays( + &self, + ) -> impl Iterator> + '_ { + self.columns.iter().map(|s| s.to_arrow()) + } + /// Create a new [`Table`] and validate against `num_rows` /// /// Note that this function is slow. You might instead be looking for [`Table::new_unchecked`] if you've already performed your own validation logic. @@ -194,6 +200,10 @@ impl Table { self.num_rows } + pub fn num_rows(&self) -> usize { + self.num_rows + } + pub fn is_empty(&self) -> bool { self.len() == 0 } diff --git a/src/lib.rs b/src/lib.rs index 804291ff4a..12c35539c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -118,6 +118,7 @@ pub mod pylib { daft_sql::register_modules(m)?; daft_functions::register_modules(m)?; daft_functions_json::register_modules(m)?; + daft_connect::register_modules(m)?; m.add_wrapped(wrap_pyfunction!(version))?; m.add_wrapped(wrap_pyfunction!(build_type))?; diff --git a/src/scripts/spark-connect-generation/Cargo.lock b/src/scripts/spark-connect-generation/Cargo.lock index 39c219603a..759b9bb64f 100644 --- a/src/scripts/spark-connect-generation/Cargo.lock +++ b/src/scripts/spark-connect-generation/Cargo.lock @@ -270,7 +270,7 @@ dependencies = [ ] [[package]] -name = "spark-connect-script" +name = "spark-connect-generation-script" version = "0.0.1" dependencies = [ "tonic-build", diff --git a/tests/connect/__init__.py b/tests/connect/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/connect/test_config_simple.py b/tests/connect/test_config_simple.py new file mode 100644 index 0000000000..de65c7c0f2 --- /dev/null +++ b/tests/connect/test_config_simple.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import time + +import pytest +from pyspark.sql import SparkSession + + +@pytest.fixture +def spark_session(): + """Fixture to create and clean up a Spark session.""" + from daft.daft import connect_start + + # Start Daft Connect server + server = connect_start("sc://localhost:50051") + + # Initialize Spark Connect session + session = SparkSession.builder.appName("DaftConfigTest").remote("sc://localhost:50051").getOrCreate() + + yield session + + # Cleanup + server.shutdown() + session.stop() + time.sleep(2) # Allow time for session cleanup + + +def test_set_operation(spark_session): + """Test the Set operation with various data types and edge cases.""" + configs = { + "spark.test.string": "test_value", + # todo: I check if non-strings are supported. + # "spark.test.boolean": True, + # "spark.test.integer": 42, + # "spark.test.float": 3.14, + "spark.test.empty": "", # Test empty string + "spark.test.special": "!@#$%^&*()", # Test special characters + } + + # Set all configurations + for key, value in configs.items(): + spark_session.conf.set(key, value) + + # Verify all configurations + for key, value in configs.items(): + assert str(spark_session.conf.get(key)) == str(value) + + +def test_get_operations(spark_session): + """Test various Get operations including Get, GetWithDefault, and GetOption.""" + # Setup test data + test_data = {"spark.test.existing": "value", "spark.prefix.one": "1", "spark.prefix.two": "2"} + for key, value in test_data.items(): + spark_session.conf.set(key, value) + + # Test basic Get + assert spark_session.conf.get("spark.test.existing") == "value" + + # Test GetWithDefault + assert spark_session.conf.get("spark.test.nonexistent", "default") == "default" + assert spark_session.conf.get("spark.test.existing", "default") == "value" + + # Test GetOption (if implemented) + # Note: This might need to be adjusted based on actual GetOption implementation + assert spark_session.conf.get("spark.test.nonexistent") is None + + # Test GetAll with prefix + prefix_configs = {key: spark_session.conf.get(key) for key in ["spark.prefix.one", "spark.prefix.two"]} + assert prefix_configs == {"spark.prefix.one": "1", "spark.prefix.two": "2"} + + +def test_unset_operation(spark_session): + """Test the Unset operation with various scenarios.""" + # Setup test data + spark_session.conf.set("spark.test.temp", "value") + + # Test basic unset + spark_session.conf.unset("spark.test.temp") + assert spark_session.conf.get("spark.test.temp") is None + + # Test unset non-existent key (should not raise error) + spark_session.conf.unset("spark.test.nonexistent") + + # Test unset and then set again + key = "spark.test.resettable" + spark_session.conf.set(key, "first") + spark_session.conf.unset(key) + spark_session.conf.set(key, "second") + assert spark_session.conf.get(key) == "second" + + +def test_edge_cases(spark_session): + """Test various edge cases and potential error conditions.""" + # Test very long key and value + long_key = "spark.test." + "x" * 1000 + long_value = "y" * 1000 + spark_session.conf.set(long_key, long_value) + assert spark_session.conf.get(long_key) == long_value + + # Test unicode characters + unicode_key = "spark.test.unicode" + unicode_value = "测试值" + spark_session.conf.set(unicode_key, unicode_value) + assert spark_session.conf.get(unicode_key) == unicode_value + + # Test setting same key multiple times rapidly + key = "spark.test.rapid" + for i in range(100): + spark_session.conf.set(key, f"value{i}") + assert spark_session.conf.get(key) == "value99" + + # Test concurrent modifications (if supported) + # Note: This might need to be adjusted based on concurrency support + from concurrent.futures import ThreadPoolExecutor + + key = "spark.test.concurrent" + + def modify_conf(i: int): + spark_session.conf.set(key, f"value{i}") + + with ThreadPoolExecutor(max_workers=4) as executor: + list(executor.map(modify_conf, range(100))) + + assert spark_session.conf.get(key) is not None # Value should be set to something