From ee37e854fed41dbc2e50c858564d3be62acbe75e Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 11 Jul 2024 14:57:57 -0700 Subject: [PATCH 1/8] fix lance-io --- Cargo.toml | 42 +- rust/lance-core/src/utils/testing.rs | 24 +- rust/lance-io/src/object_store.rs | 28 +- rust/lance-io/src/object_store/gcs_wrapper.rs | 509 ------------------ rust/lance-io/src/object_store/tracing.rs | 92 ++-- rust/lance-io/src/object_writer.rs | 431 +++++++++++++-- rust/lance-io/src/scheduler.rs | 3 +- rust/lance-io/src/testing.rs | 15 +- rust/lance-io/src/utils.rs | 9 +- 9 files changed, 480 insertions(+), 673 deletions(-) delete mode 100644 rust/lance-io/src/object_store/gcs_wrapper.rs diff --git a/Cargo.toml b/Cargo.toml index 691a6b13a1..a2b54297b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,17 +59,17 @@ lance-test-macros = { version = "=0.14.2", path = "./rust/lance-test-macros" } lance-testing = { version = "=0.14.2", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow -arrow = { version = "51.0.0", optional = false, features = ["prettyprint"] } -arrow-arith = "51.0" -arrow-array = "51.0" -arrow-buffer = "51.0" -arrow-cast = "51.0" -arrow-data = "51.0" -arrow-ipc = { version = "51.0", features = ["zstd"] } -arrow-ord = "51.0" -arrow-row = "51.0" -arrow-schema = "51.0" -arrow-select = "51.0" +arrow = { version = "52.0", optional = false, features = ["prettyprint"] } +arrow-arith = "52.0" +arrow-array = "52.0" +arrow-buffer = "52.0" +arrow-cast = "52.0" +arrow-data = "52.0" +arrow-ipc = { version = "52.0", features = ["zstd"] } +arrow-ord = "52.0" +arrow-row = "52.0" +arrow-schema = "52.0" +arrow-select = "52.0" async-recursion = "1.0" async-trait = "0.1" aws-config = "0.57" @@ -93,17 +93,17 @@ criterion = { version = "0.5", features = [ "html_reports", ] } crossbeam-queue = "0.3" -datafusion = { version = "37.1", default-features = false, features = [ +datafusion = { version = "39.0", default-features = false, features = [ "array_expressions", "regex_expressions", ] } -datafusion-common = "37.1" -datafusion-functions = { version = "37.1", features = ["regex_expressions"] } -datafusion-sql = "37.1" -datafusion-expr = "37.1" -datafusion-execution = "37.1" -datafusion-optimizer = "37.1" -datafusion-physical-expr = { version = "37.1", features = [ +datafusion-common = "39.0" +datafusion-functions = { version = "39.0", features = ["regex_expressions"] } +datafusion-sql = "39.0" +datafusion-expr = "39.0" +datafusion-execution = "39.0" +datafusion-optimizer = "39.0" +datafusion-physical-expr = { version = "39.0", features = [ "regex_expressions", ] } deepsize = "0.2.0" @@ -119,8 +119,8 @@ mock_instant = { version = "0.3.1", features = ["sync"] } moka = "0.11" num-traits = "0.2" num_cpus = "1.0" -object_store = { version = "0.9.0" } -parquet = "51.0" +object_store = { version = "0.10.1" } +parquet = "52.0" pin-project = "1.0" path_abs = "0.5" pprof = { version = "0.13", features = ["flamegraph", "criterion"] } diff --git a/rust/lance-core/src/utils/testing.rs b/rust/lance-core/src/utils/testing.rs index 3d76538e0e..87cbb4d33d 100644 --- a/rust/lance-core/src/utils/testing.rs +++ b/rust/lance-core/src/utils/testing.rs @@ -11,15 +11,14 @@ use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use object_store::path::Path; use object_store::{ - Error as OSError, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, - PutOptions, PutResult, Result as OSResult, + Error as OSError, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, + PutMultipartOpts, PutOptions, PutPayload, PutResult, Result as OSResult, }; use std::collections::HashMap; use std::fmt::Debug; use std::future; use std::ops::Range; use std::sync::{Arc, Mutex, MutexGuard}; -use tokio::io::AsyncWrite; // A policy function takes in the name of the operation (e.g. "put") and the location // that is being accessed / modified and returns an optional error. @@ -125,32 +124,23 @@ impl std::fmt::Display for ProxyObjectStore { #[async_trait] impl ObjectStore for ProxyObjectStore { - async fn put(&self, location: &Path, bytes: Bytes) -> OSResult { - self.before_method("put", location)?; - self.target.put(location, bytes).await - } - async fn put_opts( &self, location: &Path, - bytes: Bytes, + bytes: PutPayload, opts: PutOptions, ) -> OSResult { self.before_method("put", location)?; self.target.put_opts(location, bytes, opts).await } - async fn put_multipart( + async fn put_multipart_opts( &self, location: &Path, - ) -> OSResult<(MultipartId, Box)> { + opts: PutMultipartOpts, + ) -> OSResult> { self.before_method("put_multipart", location)?; - self.target.put_multipart(location).await - } - - async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> OSResult<()> { - self.before_method("abort_multipart", location)?; - self.target.abort_multipart(location, multipart_id).await + self.target.put_multipart_opts(location, opts).await } async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult { diff --git a/rust/lance-io/src/object_store.rs b/rust/lance-io/src/object_store.rs index b1a6237e9f..0948d32fbf 100644 --- a/rust/lance-io/src/object_store.rs +++ b/rust/lance-io/src/object_store.rs @@ -31,9 +31,7 @@ use tokio::{io::AsyncWriteExt, sync::RwLock}; use url::Url; use super::local::LocalObjectReader; -mod gcs_wrapper; mod tracing; -use self::gcs_wrapper::PatchedGoogleCloudStorage; use self::tracing::ObjectStoreTracingExt; use crate::{object_reader::CloudObjectReader, object_writer::ObjectWriter, traits::Reader}; use lance_core::{Error, Result}; @@ -85,6 +83,7 @@ pub struct ObjectStore { pub inner: Arc, scheme: String, block_size: usize, + pub use_constant_size_upload_parts: bool, } impl DeepSizeOf for ObjectStore { @@ -396,6 +395,7 @@ impl ObjectStore { inner: Arc::new(LocalFileSystem::new()).traced(), scheme: String::from(scheme), block_size: 4 * 1024, // 4KB block size + use_constant_size_upload_parts: false, }, Path::from_absolute_path(expanded_path.as_path())?, )) @@ -415,6 +415,7 @@ impl ObjectStore { inner: Arc::new(LocalFileSystem::new()).traced(), scheme: String::from("file"), block_size: 4 * 1024, // 4KB block size + use_constant_size_upload_parts: false, } } @@ -424,6 +425,7 @@ impl ObjectStore { inner: Arc::new(InMemory::new()).traced(), scheme: String::from("memory"), block_size: 64 * 1024, + use_constant_size_upload_parts: false, } } @@ -489,11 +491,10 @@ impl ObjectStore { /// Create a new file. pub async fn create(&self, path: &Path) -> Result { - ObjectWriter::new(self.inner.as_ref(), path).await + ObjectWriter::new(self, path).await } /// A helper function to create a file and write content to it. - /// pub async fn put(&self, path: &Path, content: &[u8]) -> Result<()> { let mut writer = self.create(path).await?; writer.write_all(content).await?; @@ -714,6 +715,12 @@ async fn configure_store(url: &str, options: ObjectStoreParams) -> Result Result { @@ -746,15 +754,13 @@ async fn configure_store(url: &str, options: ObjectStoreParams) -> Result { @@ -766,6 +772,7 @@ async fn configure_store(url: &str, options: ObjectStoreParams) -> Result Result { + unknown_scheme => { let err = lance_core::Error::from(object_store::Error::NotSupported { - source: format!("Unsupported URI scheme: {}", unknow_scheme).into(), + source: format!("Unsupported URI scheme: {}", unknown_scheme).into(), }); Err(err) } @@ -796,6 +804,7 @@ impl ObjectStore { location: Url, block_size: Option, wrapper: Option>, + use_constant_size_upload_parts: bool, ) -> Self { let scheme = location.scheme(); let block_size = block_size.unwrap_or_else(|| infer_block_size(scheme)); @@ -809,6 +818,7 @@ impl ObjectStore { inner: store, scheme: scheme.into(), block_size, + use_constant_size_upload_parts, } } } diff --git a/rust/lance-io/src/object_store/gcs_wrapper.rs b/rust/lance-io/src/object_store/gcs_wrapper.rs deleted file mode 100644 index 2fb0a5d76a..0000000000 --- a/rust/lance-io/src/object_store/gcs_wrapper.rs +++ /dev/null @@ -1,509 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -//! Wrappers around object_store that apply tracing - -use std::io; -use std::ops::Range; -use std::pin::Pin; -use std::sync::{Arc, OnceLock}; -use std::task::Poll; - -use async_trait::async_trait; -use bytes::Bytes; -use futures::future::BoxFuture; -use futures::stream::{BoxStream, FuturesUnordered}; -use futures::{FutureExt, StreamExt}; -use object_store::gcp::GoogleCloudStorage; -use object_store::multipart::{MultiPartStore, PartId}; -use object_store::path::Path; -use object_store::{ - Error as OSError, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, - PutOptions, PutResult, Result as OSResult, -}; -use rand::Rng; -use tokio::io::AsyncWrite; - -fn max_upload_parallelism() -> usize { - static MAX_UPLOAD_PARALLELISM: OnceLock = OnceLock::new(); - *MAX_UPLOAD_PARALLELISM.get_or_init(|| { - std::env::var("LANCE_UPLOAD_CONCURRENCY") - .ok() - .and_then(|s| s.parse::().ok()) - .unwrap_or(10) - }) -} - -fn max_conn_reset_retries() -> u16 { - static MAX_CONN_RESET_RETRIES: OnceLock = OnceLock::new(); - *MAX_CONN_RESET_RETRIES.get_or_init(|| { - std::env::var("LANCE_CONN_RESET_RETRIES") - .ok() - .and_then(|s| s.parse::().ok()) - .unwrap_or(20) - }) -} - -/// Wrapper around GoogleCloudStorage with a larger maximum upload size. -/// -/// This will be obsolete once object_store 0.10.0 is released. -#[derive(Debug)] -pub struct PatchedGoogleCloudStorage(pub Arc); - -impl std::fmt::Display for PatchedGoogleCloudStorage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "PatchedGoogleCloudStorage({})", self.0) - } -} - -#[async_trait] -impl ObjectStore for PatchedGoogleCloudStorage { - async fn put(&self, location: &Path, bytes: Bytes) -> OSResult { - self.0.put(location, bytes).await - } - - async fn put_opts( - &self, - location: &Path, - bytes: Bytes, - opts: PutOptions, - ) -> OSResult { - self.0.put_opts(location, bytes, opts).await - } - - async fn put_multipart( - &self, - location: &Path, - ) -> OSResult<(MultipartId, Box)> { - // We don't return a real multipart id here. This will be addressed - // in object_store 0.10.0. - Upload::new(self.0.clone(), location.clone()) - .map(|upload| (MultipartId::default(), Box::new(upload) as _)) - } - - async fn abort_multipart(&self, _location: &Path, _multipart_id: &MultipartId) -> OSResult<()> { - // TODO: Once we fix the API above, we can support this. - return Err(OSError::NotSupported { - source: "abort_multipart is not supported for Google Cloud Storage".into(), - }); - } - - async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult { - self.0.get_opts(location, options).await - } - - async fn get_range(&self, location: &Path, range: Range) -> OSResult { - self.0.get_range(location, range).await - } - - async fn get_ranges(&self, location: &Path, ranges: &[Range]) -> OSResult> { - self.0.get_ranges(location, ranges).await - } - - async fn head(&self, location: &Path) -> OSResult { - self.0.head(location).await - } - - async fn delete(&self, location: &Path) -> OSResult<()> { - self.0.delete(location).await - } - - fn delete_stream<'a>( - &'a self, - locations: BoxStream<'a, OSResult>, - ) -> BoxStream<'a, OSResult> { - self.0.delete_stream(locations) - } - - fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, OSResult> { - self.0.list(prefix) - } - - async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult { - self.0.list_with_delimiter(prefix).await - } - - async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> { - self.0.copy(from, to).await - } - - async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> { - self.0.rename(from, to).await - } - - async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { - self.0.copy_if_not_exists(from, to).await - } -} - -enum UploadState { - /// The writer has been opened but no data has been written yet. Will be in - /// this state until the buffer is full or the writer is shut down. - Started, - /// The writer is in the process of creating a multipart upload. - CreatingUpload(BoxFuture<'static, OSResult>), - /// The writer is in the process of uploading parts. - InProgress { - multipart_id: Arc, - part_idx: u16, - futures: FuturesUnordered< - BoxFuture<'static, std::result::Result<(u16, PartId), UploadPutError>>, - >, - part_ids: Vec>, - }, - /// The writer is in the process of uploading data in a single PUT request. - /// This happens when shutdown is called before the buffer is full. - PuttingSingle(BoxFuture<'static, OSResult<()>>), - /// The writer is in the process of completing the multipart upload. - Completing(BoxFuture<'static, OSResult<()>>), - /// The writer has been shut down and all data has been written. - Done, -} - -/// Start at 5MB. -const INITIAL_UPLOAD_SIZE: usize = 1024 * 1024 * 5; - -struct Upload { - store: Arc, - path: Arc, - buffer: Vec, - state: UploadState, - connection_resets: u16, -} - -impl Upload { - fn new(store: Arc, path: Path) -> OSResult { - Ok(Self { - store, - path: Arc::new(path), - buffer: Vec::with_capacity(INITIAL_UPLOAD_SIZE), - state: UploadState::Started, - connection_resets: 0, - }) - } - - /// Returns the contents of `buffer` as a `Bytes` object and resets `buffer`. - /// The new capacity of `buffer` is determined by the current part index. - fn next_part_buffer(buffer: &mut Vec, part_idx: u16) -> Bytes { - // Increase the upload size every 100 parts. This gives maximum part size of 2.5TB. - let new_capacity = ((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_SIZE; - let new_buffer = Vec::with_capacity(new_capacity); - let part = std::mem::replace(buffer, new_buffer); - Bytes::from(part) - } - - fn put_part( - path: Arc, - store: Arc, - buffer: Bytes, - part_idx: u16, - multipart_id: Arc, - sleep: Option, - ) -> BoxFuture<'static, std::result::Result<(u16, PartId), UploadPutError>> { - Box::pin(async move { - if let Some(sleep) = sleep { - tokio::time::sleep(sleep).await; - } - let part_id = store - .put_part( - path.as_ref(), - multipart_id.as_ref(), - part_idx as usize, - buffer.clone(), - ) - .await - .map_err(|source| UploadPutError { - part_idx, - buffer, - source, - })?; - Ok((part_idx, part_id)) - }) - } - - fn poll_tasks( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Result<(), io::Error> { - let mut_self = &mut *self; - loop { - match &mut mut_self.state { - UploadState::Started | UploadState::Done => break, - UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) { - Poll::Ready(Ok(multipart_id)) => { - let futures = FuturesUnordered::new(); - let multipart_id = Arc::new(multipart_id); - - let data = Self::next_part_buffer(&mut mut_self.buffer, 0); - futures.push(Self::put_part( - mut_self.path.clone(), - mut_self.store.clone(), - data, - 0, - multipart_id.clone(), - None, - )); - - mut_self.state = UploadState::InProgress { - multipart_id, - part_idx: 1, // We just used 0 - futures, - part_ids: Vec::new(), - }; - } - Poll::Ready(Err(e)) => { - return Err(std::io::Error::new(std::io::ErrorKind::Other, e)) - } - Poll::Pending => break, - }, - UploadState::InProgress { - futures, - part_ids, - multipart_id, - .. - } => { - while let Poll::Ready(Some(res)) = futures.poll_next_unpin(cx) { - match res { - Ok((part_idx, part_id)) => { - let total_parts = part_ids.len(); - part_ids.resize(total_parts.max(part_idx as usize + 1), None); - part_ids[part_idx as usize] = Some(part_id); - } - Err(UploadPutError { - source: OSError::Generic { source, .. }, - part_idx, - buffer, - }) if source - .to_string() - .to_lowercase() - .contains("connection reset by peer") => - { - if mut_self.connection_resets < max_conn_reset_retries() { - // Retry, but only up to max_conn_reset_retries of them. - mut_self.connection_resets += 1; - - // Resubmit with random jitter - let sleep_time_ms = rand::thread_rng().gen_range(2_000..8_000); - let sleep_time = - std::time::Duration::from_millis(sleep_time_ms); - - futures.push(Self::put_part( - mut_self.path.clone(), - mut_self.store.clone(), - buffer, - part_idx, - multipart_id.clone(), - Some(sleep_time), - )); - } else { - return Err(io::Error::new( - io::ErrorKind::ConnectionReset, - Box::new(ConnectionResetError { - message: format!( - "Hit max retries ({}) for connection reset", - max_conn_reset_retries() - ), - source, - }), - )); - } - } - Err(err) => return Err(err.source.into()), - } - } - break; - } - UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => { - match fut.poll_unpin(cx) { - Poll::Ready(Ok(())) => mut_self.state = UploadState::Done, - Poll::Ready(Err(e)) => { - return Err(std::io::Error::new(std::io::ErrorKind::Other, e)) - } - Poll::Pending => break, - } - } - } - } - Ok(()) - } -} - -#[derive(Debug)] -struct ConnectionResetError { - message: String, - source: Box, -} - -impl std::error::Error for ConnectionResetError {} - -impl std::fmt::Display for ConnectionResetError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}: {}", self.message, self.source) - } -} - -impl AsyncWrite for Upload { - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - self.as_mut().poll_tasks(cx)?; - - // Fill buffer up to remaining capacity. - let remaining_capacity = self.buffer.capacity() - self.buffer.len(); - let bytes_to_write = std::cmp::min(remaining_capacity, buf.len()); - self.buffer.extend_from_slice(&buf[..bytes_to_write]); - - // Rust needs a little help to borrow self mutably and immutably at the same time - // through a Pin. - let mut_self = &mut *self; - - // Instantiate next request, if available. - if mut_self.buffer.capacity() == mut_self.buffer.len() { - match &mut mut_self.state { - UploadState::Started => { - let store = self.store.clone(); - let path = self.path.clone(); - let fut = Box::pin(async move { store.create_multipart(path.as_ref()).await }); - self.state = UploadState::CreatingUpload(fut); - } - UploadState::InProgress { - multipart_id, - part_idx, - futures, - .. - } => { - // TODO: Make max concurrency configurable. - if futures.len() < max_upload_parallelism() { - let data = Self::next_part_buffer(&mut mut_self.buffer, *part_idx); - futures.push(Self::put_part( - mut_self.path.clone(), - mut_self.store.clone(), - data, - *part_idx, - multipart_id.clone(), - None, - )); - *part_idx += 1; - } - } - _ => {} - } - } - - self.poll_tasks(cx)?; - - match bytes_to_write { - 0 => Poll::Pending, - _ => Poll::Ready(Ok(bytes_to_write)), - } - } - - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.as_mut().poll_tasks(cx)?; - - match &self.state { - UploadState::Started | UploadState::Done => Poll::Ready(Ok(())), - UploadState::CreatingUpload(_) - | UploadState::Completing(_) - | UploadState::PuttingSingle(_) => Poll::Pending, - UploadState::InProgress { futures, .. } => { - if futures.is_empty() { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } - } - } - - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - loop { - self.as_mut().poll_tasks(cx)?; - - // Rust needs a little help to borrow self mutably and immutably at the same time - // through a Pin. - let mut_self = &mut *self; - match &mut mut_self.state { - UploadState::Done => return Poll::Ready(Ok(())), - UploadState::CreatingUpload(_) - | UploadState::PuttingSingle(_) - | UploadState::Completing(_) => return Poll::Pending, - UploadState::Started => { - // If we didn't start a multipart upload, we can just do a single put. - let part = Bytes::from(std::mem::take(&mut mut_self.buffer)); - let path = mut_self.path.clone(); - let store = mut_self.store.clone(); - let fut = Box::pin(async move { - store.put(&path, part).await?; - Ok(()) - }); - self.state = UploadState::PuttingSingle(fut); - } - UploadState::InProgress { - futures, - part_ids, - multipart_id, - part_idx, - } => { - // Flush final batch - if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() { - // We can just use `take` since we don't need the buffer anymore. - let data = Bytes::from(std::mem::take(&mut mut_self.buffer)); - futures.push(Self::put_part( - mut_self.path.clone(), - mut_self.store.clone(), - data, - *part_idx, - multipart_id.clone(), - None, - )); - // We need to go back to beginning of loop to poll the - // new feature and get the waker registered on the ctx. - continue; - } - - // We handle the transition from in progress to completing here. - if futures.is_empty() { - let part_ids = std::mem::take(part_ids) - .into_iter() - .map(|maybe_id| { - maybe_id.ok_or_else(|| { - io::Error::new(io::ErrorKind::Other, "missing part id") - }) - }) - .collect::>>()?; - let path = mut_self.path.clone(); - let store = mut_self.store.clone(); - let multipart_id = multipart_id.clone(); - let fut = Box::pin(async move { - store - .complete_multipart(&path, &multipart_id, part_ids) - .await?; - Ok(()) - }); - self.state = UploadState::Completing(fut); - } else { - return Poll::Pending; - } - } - } - } - } -} - -/// Returned error from trying to upload a part. -/// Has the part_idx and buffer so we can pass -/// them to the retry logic. -struct UploadPutError { - part_idx: u16, - buffer: Bytes, - source: OSError, -} diff --git a/rust/lance-io/src/object_store/tracing.rs b/rust/lance-io/src/object_store/tracing.rs index 0fc0036af4..2de8241c2b 100644 --- a/rust/lance-io/src/object_store/tracing.rs +++ b/rust/lance-io/src/object_store/tracing.rs @@ -10,52 +10,36 @@ use bytes::Bytes; use futures::stream::BoxStream; use object_store::path::Path; use object_store::{ - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, PutOptions, PutResult, - Result as OSResult, + GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, PutMultipartOpts, PutOptions, + PutPayload, PutResult, Result as OSResult, UploadPart, }; -use pin_project::pin_project; -use tokio::io::AsyncWrite; use tracing::{debug_span, instrument, Span}; -#[pin_project] -pub struct TracedAsyncWrite { +#[derive(Debug)] +pub struct TracedMultipartUpload { write_span: Span, - finish_span: Option, - #[pin] - target: Box, + target: Box, } -impl AsyncWrite for TracedAsyncWrite { - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - let this = self.project(); - let _guard = this.write_span.enter(); - this.target.poll_write(cx, buf) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.project(); - let _guard = this.write_span.enter(); - this.target.poll_flush(cx) - } - - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.project(); - // TODO: Replace with get_or_insert_with when - let _guard = this - .finish_span - .get_or_insert_with(|| debug_span!("put_multipart_finish")) - .enter(); - this.target.poll_shutdown(cx) +#[async_trait::async_trait] +impl MultipartUpload for TracedMultipartUpload { + fn put_part(&mut self, data: PutPayload) -> UploadPart { + let write_span = self.write_span.clone(); + let fut = self.target.put_part(data); + Box::pin(async move { + let _guard = write_span.enter(); + fut.await + }) + } + + #[instrument(level = "debug")] + async fn complete(&mut self) -> OSResult { + self.target.complete().await + } + + #[instrument(level = "debug")] + async fn abort(&mut self) -> OSResult<()> { + self.target.abort().await } } @@ -73,7 +57,7 @@ impl std::fmt::Display for TracedObjectStore { #[async_trait::async_trait] impl object_store::ObjectStore for TracedObjectStore { #[instrument(level = "debug", skip(self, bytes))] - async fn put(&self, location: &Path, bytes: Bytes) -> OSResult { + async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult { self.target.put(location, bytes).await } @@ -81,30 +65,22 @@ impl object_store::ObjectStore for TracedObjectStore { async fn put_opts( &self, location: &Path, - bytes: Bytes, + bytes: PutPayload, opts: PutOptions, ) -> OSResult { self.target.put_opts(location, bytes, opts).await } - async fn put_multipart( + async fn put_multipart_opts( &self, location: &Path, - ) -> OSResult<(MultipartId, Box)> { - let (multipart_id, async_write) = self.target.put_multipart(location).await?; - Ok(( - multipart_id, - Box::new(TracedAsyncWrite { - write_span: debug_span!("put_multipart"), - finish_span: None, - target: async_write, - }) as Box, - )) - } - - #[instrument(level = "debug", skip(self))] - async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> OSResult<()> { - self.target.abort_multipart(location, multipart_id).await + opts: PutMultipartOpts, + ) -> OSResult> { + let upload = self.target.put_multipart_opts(location, opts).await?; + Ok(Box::new(TracedMultipartUpload { + target: upload, + write_span: debug_span!("put_multipart_opts"), + })) } #[instrument(level = "debug", skip(self, options))] diff --git a/rust/lance-io/src/object_writer.rs b/rust/lance-io/src/object_writer.rs index ad5b675312..d242de0fc7 100644 --- a/rust/lance-io/src/object_writer.rs +++ b/rust/lance-io/src/object_writer.rs @@ -1,54 +1,270 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::io; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::sync::{Arc, OnceLock}; +use std::task::Poll; +use crate::object_store::ObjectStore as LanceObjectStore; use async_trait::async_trait; -use object_store::{path::Path, MultipartId, ObjectStore}; -use pin_project::pin_project; -use snafu::{location, Location}; +use bytes::Bytes; +use futures::future::BoxFuture; +use futures::FutureExt; +use object_store::MultipartUpload; +use object_store::{path::Path, Error as OSError, ObjectStore, Result as OSResult}; +use rand::Rng; use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::task::JoinSet; use lance_core::{Error, Result}; use crate::traits::Writer; +use snafu::{location, Location}; + +/// Start at 5MB. +const INITIAL_UPLOAD_SIZE: usize = 1024 * 1024 * 5; + +fn max_upload_parallelism() -> usize { + static MAX_UPLOAD_PARALLELISM: OnceLock = OnceLock::new(); + *MAX_UPLOAD_PARALLELISM.get_or_init(|| { + std::env::var("LANCE_UPLOAD_CONCURRENCY") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(10) + }) +} + +fn max_conn_reset_retries() -> u16 { + static MAX_CONN_RESET_RETRIES: OnceLock = OnceLock::new(); + *MAX_CONN_RESET_RETRIES.get_or_init(|| { + std::env::var("LANCE_CONN_RESET_RETRIES") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(20) + }) +} -/// AsyncWrite with the capability to tell the position the data is written. +/// Writer to an object in an object store. +/// +/// If the object is small enough, the writer will upload the object in a single +/// PUT request. If the object is larger, the writer will create a multipart +/// upload and upload parts in parallel. /// -#[pin_project] +/// This implements the `AsyncWrite` trait. pub struct ObjectWriter { - // TODO: wrap writer with a BufWriter. - #[pin] - writer: Box, + state: UploadState, + path: Arc, + cursor: usize, + connection_resets: u16, + buffer: Vec, + // TODO: use constant size to support R2 + use_constant_size_upload_parts: bool, +} - // TODO: pub(crate) - pub multipart_id: MultipartId, - path: Path, +enum UploadState { + /// The writer has been opened but no data has been written yet. Will be in + /// this state until the buffer is full or the writer is shut down. + Started(Arc), + /// The writer is in the process of creating a multipart upload. + CreatingUpload(BoxFuture<'static, OSResult>>), + /// The writer is in the process of uploading parts. + InProgress { + part_idx: u16, + upload: Box, + futures: JoinSet>, + }, + /// The writer is in the process of uploading data in a single PUT request. + /// This happens when shutdown is called before the buffer is full. + PuttingSingle(BoxFuture<'static, OSResult<()>>), + /// The writer is in the process of completing the multipart upload. + Completing(BoxFuture<'static, OSResult<()>>), + /// The writer has been shut down and all data has been written. + Done, +} - cursor: usize, +/// Methods for state transitions. +impl UploadState { + fn started_to_completing(&mut self, path: Arc, buffer: Vec) { + // To get owned self, we temporarily swap with Done. + let this = std::mem::replace(self, Self::Done); + *self = match this { + Self::Started(store) => { + let fut = async move { + store.put(&path, buffer.into()).await?; + Ok(()) + }; + Self::PuttingSingle(Box::pin(fut)) + } + _ => unreachable!(), + } + } + + fn in_progress_to_completing(&mut self) { + // To get owned self, we temporarily swap with Done. + let this = std::mem::replace(self, Self::Done); + *self = match this { + Self::InProgress { + mut upload, + futures, + .. + } => { + debug_assert!(futures.is_empty()); + let fut = async move { + upload.complete().await?; + Ok(()) + }; + Self::Completing(Box::pin(fut)) + } + _ => unreachable!(), + }; + } } impl ObjectWriter { - pub async fn new(object_store: &dyn ObjectStore, path: &Path) -> Result { - let (multipart_id, writer) = object_store.put_multipart(path).await.map_err(|e| { - Error::io( - format!("failed to create object writer for {}: {}", path, e), - // and wrap it in here. - location!(), - ) - })?; - + pub async fn new(object_store: &LanceObjectStore, path: &Path) -> Result { Ok(Self { - writer, - multipart_id, + state: UploadState::Started(object_store.inner.clone()), cursor: 0, - path: path.clone(), + path: Arc::new(path.clone()), + connection_resets: 0, + buffer: Vec::with_capacity(INITIAL_UPLOAD_SIZE), + use_constant_size_upload_parts: object_store.use_constant_size_upload_parts, + }) + } + + /// Returns the contents of `buffer` as a `Bytes` object and resets `buffer`. + /// The new capacity of `buffer` is determined by the current part index. + fn next_part_buffer(buffer: &mut Vec, part_idx: u16, constant_upload_size: bool) -> Bytes { + let new_capacity = if constant_upload_size { + // The store does not support variable part sizes, so use the initial size. + INITIAL_UPLOAD_SIZE + } else { + // Increase the upload size every 100 parts. This gives maximum part size of 2.5TB. + ((part_idx / 100) as usize + 1) * INITIAL_UPLOAD_SIZE + }; + let new_buffer = Vec::with_capacity(new_capacity); + let part = std::mem::replace(buffer, new_buffer); + Bytes::from(part) + } + + fn put_part( + upload: &mut dyn MultipartUpload, + buffer: Bytes, + part_idx: u16, + sleep: Option, + ) -> BoxFuture<'static, std::result::Result<(), UploadPutError>> { + let fut = upload.put_part(buffer.clone().into()); + Box::pin(async move { + if let Some(sleep) = sleep { + tokio::time::sleep(sleep).await; + } + fut.await.map_err(|source| UploadPutError { + part_idx, + buffer, + source, + })?; + Ok(()) }) } + fn poll_tasks( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::result::Result<(), io::Error> { + let mut_self = &mut *self; + loop { + match &mut mut_self.state { + UploadState::Started(_) | UploadState::Done => break, + UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) { + Poll::Ready(Ok(mut upload)) => { + let mut futures = JoinSet::new(); + + let data = Self::next_part_buffer( + &mut mut_self.buffer, + 0, + mut_self.use_constant_size_upload_parts, + ); + futures.spawn(Self::put_part(upload.as_mut(), data, 0, None)); + + mut_self.state = UploadState::InProgress { + part_idx: 1, // We just used 0 + futures, + upload, + }; + } + Poll::Ready(Err(e)) => { + return Err(std::io::Error::new(std::io::ErrorKind::Other, e)) + } + Poll::Pending => break, + }, + UploadState::InProgress { + upload, futures, .. + } => { + while let Poll::Ready(Some(res)) = futures.poll_join_next(cx) { + match res { + Ok(Ok(())) => {} + Err(err) => { + return Err(std::io::Error::new(std::io::ErrorKind::Other, err)) + } + Ok(Err(UploadPutError { + source: OSError::Generic { source, .. }, + part_idx, + buffer, + })) if source + .to_string() + .to_lowercase() + .contains("connection reset by peer") => + { + if mut_self.connection_resets < max_conn_reset_retries() { + // Retry, but only up to max_conn_reset_retries of them. + mut_self.connection_resets += 1; + + // Resubmit with random jitter + let sleep_time_ms = rand::thread_rng().gen_range(2_000..8_000); + let sleep_time = + std::time::Duration::from_millis(sleep_time_ms); + + futures.spawn(Self::put_part( + upload.as_mut(), + buffer, + part_idx, + Some(sleep_time), + )); + } else { + return Err(io::Error::new( + io::ErrorKind::ConnectionReset, + Box::new(ConnectionResetError { + message: format!( + "Hit max retries ({}) for connection reset", + max_conn_reset_retries() + ), + source, + }), + )); + } + } + Ok(Err(err)) => return Err(err.source.into()), + } + } + break; + } + UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => { + match fut.poll_unpin(cx) { + Poll::Ready(Ok(())) => mut_self.state = UploadState::Done, + Poll::Ready(Err(e)) => { + return Err(std::io::Error::new(std::io::ErrorKind::Other, e)) + } + Poll::Pending => break, + } + } + } + } + Ok(()) + } + pub async fn shutdown(&mut self) -> Result<()> { - self.writer.as_mut().shutdown().await.map_err(|e| { + AsyncWriteExt::shutdown(self).await.map_err(|e| { Error::io( format!("failed to shutdown object writer for {}: {}", self.path, e), // and wrap it in here. @@ -58,44 +274,169 @@ impl ObjectWriter { } } -#[async_trait] -impl Writer for ObjectWriter { - async fn tell(&mut self) -> Result { - Ok(self.cursor) +/// Returned error from trying to upload a part. +/// Has the part_idx and buffer so we can pass +/// them to the retry logic. +struct UploadPutError { + part_idx: u16, + buffer: Bytes, + source: OSError, +} + +#[derive(Debug)] +struct ConnectionResetError { + message: String, + source: Box, +} + +impl std::error::Error for ConnectionResetError {} + +impl std::fmt::Display for ConnectionResetError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}: {}", self.message, self.source) } } + impl AsyncWrite for ObjectWriter { fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, buf: &[u8], - ) -> Poll> { - let mut this = self.project(); - this.writer.as_mut().poll_write(cx, buf).map_ok(|n| { - *this.cursor += n; - n - }) + ) -> std::task::Poll> { + self.as_mut().poll_tasks(cx)?; + + // Fill buffer up to remaining capacity. + let remaining_capacity = self.buffer.capacity() - self.buffer.len(); + let bytes_to_write = std::cmp::min(remaining_capacity, buf.len()); + self.buffer.extend_from_slice(&buf[..bytes_to_write]); + + // Rust needs a little help to borrow self mutably and immutably at the same time + // through a Pin. + let mut_self = &mut *self; + + // Instantiate next request, if available. + if mut_self.buffer.capacity() == mut_self.buffer.len() { + match &mut mut_self.state { + UploadState::Started(store) => { + let path = mut_self.path.clone(); + let store = store.clone(); + let fut = Box::pin(async move { store.put_multipart(path.as_ref()).await }); + self.state = UploadState::CreatingUpload(fut); + } + UploadState::InProgress { + upload, + part_idx, + futures, + .. + } => { + // TODO: Make max concurrency configurable from storage options. + if futures.len() < max_upload_parallelism() { + let data = Self::next_part_buffer( + &mut mut_self.buffer, + *part_idx, + mut_self.use_constant_size_upload_parts, + ); + futures.spawn(Self::put_part(upload.as_mut(), data, *part_idx, None)); + *part_idx += 1; + } + } + _ => {} + } + } + + self.poll_tasks(cx)?; + + match bytes_to_write { + 0 => Poll::Pending, + _ => Poll::Ready(Ok(bytes_to_write)), + } } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().writer.as_mut().poll_flush(cx) + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.as_mut().poll_tasks(cx)?; + + match &self.state { + UploadState::Started(_) | UploadState::Done => Poll::Ready(Ok(())), + UploadState::CreatingUpload(_) + | UploadState::Completing(_) + | UploadState::PuttingSingle(_) => Poll::Pending, + UploadState::InProgress { futures, .. } => { + if futures.is_empty() { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + } + } + + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + loop { + self.as_mut().poll_tasks(cx)?; + + // Rust needs a little help to borrow self mutably and immutably at the same time + // through a Pin. + let mut_self = &mut *self; + match &mut mut_self.state { + UploadState::Done => return Poll::Ready(Ok(())), + UploadState::CreatingUpload(_) + | UploadState::PuttingSingle(_) + | UploadState::Completing(_) => return Poll::Pending, + UploadState::Started(_) => { + // If we didn't start a multipart upload, we can just do a single put. + let part = std::mem::take(&mut mut_self.buffer); + let path = mut_self.path.clone(); + self.state.started_to_completing(path, part); + } + UploadState::InProgress { + upload, + futures, + part_idx, + } => { + // Flush final batch + if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() { + // We can just use `take` since we don't need the buffer anymore. + let data = Bytes::from(std::mem::take(&mut mut_self.buffer)); + futures.spawn(Self::put_part(upload.as_mut(), data, *part_idx, None)); + // We need to go back to beginning of loop to poll the + // new feature and get the waker registered on the ctx. + continue; + } + + // We handle the transition from in progress to completing here. + if futures.is_empty() { + self.state.in_progress_to_completing(); + } else { + return Poll::Pending; + } + } + } + } } +} - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().writer.as_mut().poll_shutdown(cx) +#[async_trait] +impl Writer for ObjectWriter { + async fn tell(&mut self) -> Result { + Ok(self.cursor) } } #[cfg(test)] mod tests { - - use object_store::memory::InMemory; + use tokio::io::AsyncWriteExt; use super::*; #[tokio::test] async fn test_write() { - let store = InMemory::new(); + let store = LanceObjectStore::memory(); let mut object_writer = ObjectWriter::new(&store, &Path::from("/foo")) .await diff --git a/rust/lance-io/src/scheduler.rs b/rust/lance-io/src/scheduler.rs index ad26f354b9..5def8fd42d 100644 --- a/rust/lance-io/src/scheduler.rs +++ b/rust/lance-io/src/scheduler.rs @@ -350,7 +350,7 @@ mod tests { let some_path = Path::parse("foo").unwrap(); let base_store = Arc::new(InMemory::new()); base_store - .put(&some_path, Bytes::from(vec![0; 1000])) + .put(&some_path, vec![0; 1000].into()) .await .unwrap(); @@ -374,6 +374,7 @@ mod tests { Url::parse("mem://").unwrap(), None, None, + false, )); let scan_scheduler = ScanScheduler::new(obj_store, 1); diff --git a/rust/lance-io/src/testing.rs b/rust/lance-io/src/testing.rs index a0f6bd2e0c..226964868f 100644 --- a/rust/lance-io/src/testing.rs +++ b/rust/lance-io/src/testing.rs @@ -3,27 +3,26 @@ use std::fmt::{self, Display, Formatter}; use async_trait::async_trait; -use bytes::Bytes; use futures::stream::BoxStream; use mockall::mock; use object_store::{ - path::Path, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, - ObjectStore as OSObjectStore, PutOptions, PutResult, Result as OSResult, + path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore as OSObjectStore, PutMultipartOpts, PutOptions, PutPayload, PutResult, + Result as OSResult, }; use std::future::Future; -use tokio::io::AsyncWrite; mock! { pub ObjectStore {} #[async_trait] impl OSObjectStore for ObjectStore { - async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> OSResult; - async fn put_multipart( + async fn put_opts(&self, location: &Path, bytes: PutPayload, opts: PutOptions) -> OSResult; + async fn put_multipart_opts( &self, location: &Path, - ) -> OSResult<(MultipartId, Box)>; - async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> OSResult<()>; + opts: PutMultipartOpts, + ) -> OSResult>; fn get_opts<'life0, 'life1, 'async_trait>( &'life0 self, location: &'life1 Path, diff --git a/rust/lance-io/src/utils.rs b/rust/lance-io/src/utils.rs index 4e9309b933..2c4cb0900a 100644 --- a/rust/lance-io/src/utils.rs +++ b/rust/lance-io/src/utils.rs @@ -178,13 +178,12 @@ pub fn read_struct_from_buf< #[cfg(test)] mod tests { - use std::sync::Arc; - use bytes::Bytes; - use object_store::{memory::InMemory, path::Path}; + use object_store::path::Path; use crate::{ object_reader::CloudObjectReader, + object_store::ObjectStore, object_writer::ObjectWriter, traits::{ProtoStruct, WriteExt, Writer}, utils::read_struct, @@ -215,7 +214,7 @@ mod tests { #[tokio::test] async fn test_write_proto_structs() { - let store = InMemory::new(); + let store = ObjectStore::memory(); let path = Path::from("/foo"); let mut object_writer = ObjectWriter::new(&store, &path).await.unwrap(); @@ -227,7 +226,7 @@ mod tests { assert_eq!(pos, 0); object_writer.shutdown().await.unwrap(); - let object_reader = CloudObjectReader::new(Arc::new(store), path, 1024, None).unwrap(); + let object_reader = CloudObjectReader::new(store.inner, path, 1024, None).unwrap(); let actual: BytesWrapper = read_struct(&object_reader, pos).await.unwrap(); assert_eq!(some_message, actual); } From 395dfc0716796879c5529ca2ff2709b6757bb10f Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 12 Jul 2024 15:12:55 -0700 Subject: [PATCH 2/8] fix all errors and tests --- rust/lance-datafusion/Cargo.toml | 2 +- rust/lance-datafusion/src/exec.rs | 2 +- rust/lance-datafusion/src/expr.rs | 2 +- rust/lance-datagen/src/generator.rs | 15 ++-- rust/lance-encoding-datafusion/Cargo.toml | 2 + rust/lance-encoding-datafusion/src/zone.rs | 1 + rust/lance-file/src/v2/writer.rs | 4 - rust/lance-file/src/writer.rs | 5 -- rust/lance-file/src/writer/statistics.rs | 13 ++- rust/lance-index/src/scalar/btree.rs | 14 +++ rust/lance-index/src/scalar/expression.rs | 6 +- rust/lance-io/src/object_store.rs | 9 +- rust/lance-io/src/object_writer.rs | 16 ++++ rust/lance-table/src/io/commit.rs | 32 ++++--- .../src/io/commit/external_manifest.rs | 6 +- rust/lance/src/datafusion/logical_expr.rs | 34 ++----- rust/lance/src/dataset.rs | 4 +- rust/lance/src/dataset/builder.rs | 1 + rust/lance/src/dataset/cleanup.rs | 88 ++----------------- rust/lance/src/dataset/fragment/write.rs | 4 +- rust/lance/src/dataset/progress.rs | 4 +- rust/lance/src/dataset/scanner.rs | 10 +-- rust/lance/src/dataset/write.rs | 15 +--- rust/lance/src/io/exec/knn.rs | 14 +-- rust/lance/src/io/exec/optimizer.rs | 17 ++-- rust/lance/src/io/exec/planner.rs | 76 +++++++--------- rust/lance/src/io/exec/projection.rs | 4 +- rust/lance/src/io/exec/pushdown_scan.rs | 3 +- rust/lance/src/io/exec/scalar_index.rs | 8 +- rust/lance/src/io/exec/scan.rs | 2 +- rust/lance/src/io/exec/take.rs | 4 +- rust/lance/src/io/exec/testing.rs | 2 +- rust/lance/src/io/exec/utils.rs | 4 +- rust/lance/src/utils/test.rs | 20 ++--- 34 files changed, 176 insertions(+), 267 deletions(-) diff --git a/rust/lance-datafusion/Cargo.toml b/rust/lance-datafusion/Cargo.toml index 11be3a96c2..9bf6475c04 100644 --- a/rust/lance-datafusion/Cargo.toml +++ b/rust/lance-datafusion/Cargo.toml @@ -19,7 +19,7 @@ datafusion.workspace = true datafusion-common.workspace = true datafusion-functions.workspace = true datafusion-physical-expr.workspace = true -datafusion-substrait = { version = "37.1", optional = true } +datafusion-substrait = { version = "39.0", optional = true } futures.workspace = true lance-arrow.workspace = true lance-core = { workspace = true, features = ["datafusion"] } diff --git a/rust/lance-datafusion/src/exec.rs b/rust/lance-datafusion/src/exec.rs index dbed261761..6baacbba1e 100644 --- a/rust/lance-datafusion/src/exec.rs +++ b/rust/lance-datafusion/src/exec.rs @@ -105,7 +105,7 @@ impl ExecutionPlan for OneShotExec { self.schema.clone() } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/rust/lance-datafusion/src/expr.rs b/rust/lance-datafusion/src/expr.rs index dd2c646adb..61de7b2e81 100644 --- a/rust/lance-datafusion/src/expr.rs +++ b/rust/lance-datafusion/src/expr.rs @@ -561,7 +561,7 @@ pub async fn parse_substrait(expr: &[u8], input_schema: Arc) -> Result { - if table == "dummy" { + if table.as_ref() == "dummy" { Ok(Transformed::yes(Expr::Column(Column { relation: None, name: column.name, diff --git a/rust/lance-datagen/src/generator.rs b/rust/lance-datagen/src/generator.rs index 76119202ea..fcb3ba62b2 100644 --- a/rust/lance-datagen/src/generator.rs +++ b/rust/lance-datagen/src/generator.rs @@ -1171,9 +1171,7 @@ const MS_PER_DAY: i64 = 86400000; pub mod array { - use arrow::datatypes::{ - Int16Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, - }; + use arrow::datatypes::{Int16Type, Int64Type, Int8Type}; use arrow_array::types::{ Decimal128Type, Decimal256Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, DurationSecondType, Float16Type, Float32Type, Float64Type, @@ -1384,7 +1382,7 @@ pub mod array { _ => panic!(), }; - let data_type = DataType::Time32(resolution.clone()); + let data_type = DataType::Time32(*resolution); let size = ByteCount::from(data_type.primitive_width().unwrap() as u64); let dist = Uniform::new(start, end); let sample_fn = move |rng: &mut _| dist.sample(rng); @@ -1412,7 +1410,7 @@ pub mod array { _ => panic!(), }; - let data_type = DataType::Time64(resolution.clone()); + let data_type = DataType::Time64(*resolution); let size = ByteCount::from(data_type.primitive_width().unwrap() as u64); let dist = Uniform::new(start, end); let sample_fn = move |rng: &mut _| dist.sample(rng); @@ -1664,8 +1662,11 @@ pub mod array { TimeUnit::Nanosecond => rand::(), }, DataType::Interval(unit) => match unit { - IntervalUnit::DayTime => rand::(), - IntervalUnit::MonthDayNano => rand::(), + // TODO: fix these. In Arrow they changed to have specialized + // Native types, which don't support Distribution. + // IntervalUnit::DayTime => rand::(), + // IntervalUnit::MonthDayNano => rand::(), + IntervalUnit::DayTime | IntervalUnit::MonthDayNano => todo!(), IntervalUnit::YearMonth => rand::(), }, DataType::Date32 => rand_date32(), diff --git a/rust/lance-encoding-datafusion/Cargo.toml b/rust/lance-encoding-datafusion/Cargo.toml index d8e890206f..6610312472 100644 --- a/rust/lance-encoding-datafusion/Cargo.toml +++ b/rust/lance-encoding-datafusion/Cargo.toml @@ -14,6 +14,7 @@ rust-version.workspace = true [dependencies] lance-core = { workspace = true, features = ["datafusion"] } lance-datafusion = { workspace = true, features = ["substrait"] } + lance-encoding.workspace = true lance-file.workspace = true lance-io.workspace = true @@ -23,6 +24,7 @@ arrow-schema.workspace = true bytes.workspace = true datafusion-common.workspace = true datafusion-expr.workspace = true +datafusion-functions.workspace = true datafusion-optimizer.workspace = true datafusion-physical-expr.workspace = true futures.workspace = true diff --git a/rust/lance-encoding-datafusion/src/zone.rs b/rust/lance-encoding-datafusion/src/zone.rs index 8961b2a7ce..14ffb1b92a 100644 --- a/rust/lance-encoding-datafusion/src/zone.rs +++ b/rust/lance-encoding-datafusion/src/zone.rs @@ -15,6 +15,7 @@ use datafusion_expr::{ simplify::SimplifyContext, Accumulator, Expr, }; +use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_physical_expr::expressions::{MaxAccumulator, MinAccumulator}; use futures::{future::BoxFuture, FutureExt}; diff --git a/rust/lance-file/src/v2/writer.rs b/rust/lance-file/src/v2/writer.rs index 3a585a497f..6f32eecd16 100644 --- a/rust/lance-file/src/v2/writer.rs +++ b/rust/lance-file/src/v2/writer.rs @@ -471,10 +471,6 @@ impl FileWriter { Ok(self.rows_written) } - pub fn multipart_id(&self) -> &str { - &self.writer.multipart_id - } - pub async fn tell(&mut self) -> Result { Ok(self.writer.tell().await? as u64) } diff --git a/rust/lance-file/src/writer.rs b/rust/lance-file/src/writer.rs index b5c877be1f..84b807145e 100644 --- a/rust/lance-file/src/writer.rs +++ b/rust/lance-file/src/writer.rs @@ -246,11 +246,6 @@ impl FileWriter { self.object_writer.tell().await } - /// Returns the in-flight multipart ID. - pub fn multipart_id(&self) -> &str { - &self.object_writer.multipart_id - } - /// Return the id of the next batch to be written. pub fn next_batch_id(&self) -> i32 { self.batch_id diff --git a/rust/lance-file/src/writer/statistics.rs b/rust/lance-file/src/writer/statistics.rs index e8c674bb3d..39425a407a 100644 --- a/rust/lance-file/src/writer/statistics.rs +++ b/rust/lance-file/src/writer/statistics.rs @@ -1564,7 +1564,7 @@ mod tests { }, TestCase { source_arrays: vec![Arc::new(FixedSizeBinaryArray::from(vec![ - min_binary_value.clone().as_ref(), + min_binary_value.as_slice(), ]))], stats: StatisticsRow { null_count: 0, @@ -1579,12 +1579,9 @@ mod tests { }, }, TestCase { - source_arrays: vec![Arc::new(FixedSizeBinaryArray::from(vec![vec![ - 0xFFu8; - BINARY_PREFIX_LENGTH - + 7 - ] - .as_ref()]))], + source_arrays: vec![Arc::new(FixedSizeBinaryArray::from(vec![ + &[0xFFu8; BINARY_PREFIX_LENGTH + 7], + ]))], stats: StatisticsRow { null_count: 0, min_value: ScalarValue::FixedSizeBinary( @@ -2033,7 +2030,7 @@ mod tests { let timeunits = [TimeUnit::Second, TimeUnit::Millisecond, TimeUnit::Microsecond, TimeUnit::Nanosecond]; let timezone = timezones[timezone_index].clone(); - let timeunit = timeunits[timeunit_index].clone(); + let timeunit = timeunits[timeunit_index]; let value = match timeunit { TimeUnit::Second => ScalarValue::TimestampSecond(value, timezone), TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(value, timezone), diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index edfb8ae575..72ffa69aee 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -160,6 +160,20 @@ impl Ord for OrderableScalarValue { } } (Float64(_), _) => panic!("Attempt to compare f64 with non-f64"), + (Float16(v1), Float16(v2)) => match (v1, v2) { + (Some(f1), Some(f2)) => f1.total_cmp(f2), + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + (None, None) => Ordering::Equal, + }, + (Float16(v1), Null) => { + if v1.is_none() { + Ordering::Equal + } else { + Ordering::Greater + } + } + (Float16(_), _) => panic!("Attempt to compare f16 with non-f16"), (Int8(v1), Int8(v2)) => v1.cmp(v2), (Int8(v1), Null) => { if v1.is_none() { diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index d7d0164dc5..2393a38c5c 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -739,15 +739,15 @@ mod tests { todo!() } - fn udfs_names(&self) -> Vec { + fn udf_names(&self) -> Vec { todo!() } - fn udafs_names(&self) -> Vec { + fn udaf_names(&self) -> Vec { todo!() } - fn udwfs_names(&self) -> Vec { + fn udwf_names(&self) -> Vec { todo!() } } diff --git a/rust/lance-io/src/object_store.rs b/rust/lance-io/src/object_store.rs index 0948d32fbf..33a9406567 100644 --- a/rust/lance-io/src/object_store.rs +++ b/rust/lance-io/src/object_store.rs @@ -27,7 +27,8 @@ use object_store::{parse_url_opts, ClientOptions, DynObjectStore, StaticCredenti use object_store::{path::Path, ObjectMeta, ObjectStore as OSObjectStore}; use shellexpand::tilde; use snafu::{location, Location}; -use tokio::{io::AsyncWriteExt, sync::RwLock}; +use tokio::io::AsyncWriteExt; +use tokio::sync::RwLock; use url::Url; use super::local::LocalObjectReader; @@ -307,6 +308,11 @@ pub struct ObjectStoreParams { pub aws_credentials: Option, pub object_store_wrapper: Option>, pub storage_options: Option>, + /// Use constant size upload parts for multipart uploads. Only necessary + /// for Cloudflare R2, which doesn't support variable size parts. When this + /// is false, max upload size is 2.5TB. When this is true, the max size is + /// 50GB. + pub use_constant_size_upload_parts: bool, } impl Default for ObjectStoreParams { @@ -318,6 +324,7 @@ impl Default for ObjectStoreParams { aws_credentials: None, object_store_wrapper: None, storage_options: None, + use_constant_size_upload_parts: false, } } } diff --git a/rust/lance-io/src/object_writer.rs b/rust/lance-io/src/object_writer.rs index d242de0fc7..ce2e4331b7 100644 --- a/rust/lance-io/src/object_writer.rs +++ b/rust/lance-io/src/object_writer.rs @@ -274,6 +274,21 @@ impl ObjectWriter { } } +impl Drop for ObjectWriter { + fn drop(&mut self) { + // If there is a multipart upload started but not finished, we should abort it. + if matches!(self.state, UploadState::InProgress { .. }) { + // Take ownership of the state. + let state = std::mem::replace(&mut self.state, UploadState::Done); + if let UploadState::InProgress { mut upload, .. } = state { + tokio::task::spawn(async move { + let _ = upload.abort().await; + }); + } + } + } +} + /// Returned error from trying to upload a part. /// Has the part_idx and buffer so we can pass /// them to the retry logic. @@ -309,6 +324,7 @@ impl AsyncWrite for ObjectWriter { let remaining_capacity = self.buffer.capacity() - self.buffer.len(); let bytes_to_write = std::cmp::min(remaining_capacity, buf.len()); self.buffer.extend_from_slice(&buf[..bytes_to_write]); + self.cursor += bytes_to_write; // Rust needs a little help to borrow self mutably and immutably at the same time // through a Pin. diff --git a/rust/lance-table/src/io/commit.rs b/rust/lance-table/src/io/commit.rs index cd1a5a76c2..5c3a2536ae 100644 --- a/rust/lance-table/src/io/commit.rs +++ b/rust/lance-table/src/io/commit.rs @@ -63,7 +63,7 @@ const MANIFEST_EXTENSION: &str = "manifest"; /// Function that writes the manifest to the object store. pub type ManifestWriter = for<'a> fn( - object_store: &'a dyn OSObjectStore, + object_store: &'a ObjectStore, manifest: &'a mut Manifest, indices: Option>, path: &'a Path, @@ -311,7 +311,7 @@ pub trait CommitHandler: Debug + Send + Sync { manifest: &mut Manifest, indices: Option>, base_path: &Path, - object_store: &dyn OSObjectStore, + object_store: &ObjectStore, manifest_writer: ManifestWriter, ) -> std::result::Result<(), CommitError>; } @@ -525,7 +525,7 @@ impl CommitHandler for UnsafeCommitHandler { manifest: &mut Manifest, indices: Option>, base_path: &Path, - object_store: &dyn OSObjectStore, + object_store: &ObjectStore, manifest_writer: ManifestWriter, ) -> std::result::Result<(), CommitError> { // Log a one-time warning @@ -538,12 +538,12 @@ impl CommitHandler for UnsafeCommitHandler { } let version_path = self - .resolve_version(base_path, manifest.version, object_store) + .resolve_version(base_path, manifest.version, &object_store.inner) .await?; // Write the manifest naively manifest_writer(object_store, manifest, indices, &version_path).await?; - write_latest_manifest(&version_path, base_path, object_store).await?; + write_latest_manifest(&version_path, base_path, &object_store.inner).await?; Ok(()) } @@ -588,18 +588,18 @@ impl CommitHandler for T { manifest: &mut Manifest, indices: Option>, base_path: &Path, - object_store: &dyn OSObjectStore, + object_store: &ObjectStore, manifest_writer: ManifestWriter, ) -> std::result::Result<(), CommitError> { let path = self - .resolve_version(base_path, manifest.version, object_store) + .resolve_version(base_path, manifest.version, &object_store.inner) .await?; // NOTE: once we have the lease we cannot use ? to return errors, since // we must release the lease before returning. let lease = self.lock(manifest.version).await?; // Head the location and make sure it's not already committed - match object_store.head(&path).await { + match object_store.inner.head(&path).await { Ok(_) => { // The path already exists, so it's already committed // Release the lock @@ -618,7 +618,7 @@ impl CommitHandler for T { } let res = manifest_writer(object_store, manifest, indices, &path).await; - write_latest_manifest(&path, base_path, object_store).await?; + write_latest_manifest(&path, base_path, &object_store.inner).await?; // Release the lock lease.release(res.is_ok()).await?; @@ -634,7 +634,7 @@ impl CommitHandler for Arc { manifest: &mut Manifest, indices: Option>, base_path: &Path, - object_store: &dyn OSObjectStore, + object_store: &ObjectStore, manifest_writer: ManifestWriter, ) -> std::result::Result<(), CommitError> { self.as_ref() @@ -655,14 +655,14 @@ impl CommitHandler for RenameCommitHandler { manifest: &mut Manifest, indices: Option>, base_path: &Path, - object_store: &dyn OSObjectStore, + object_store: &ObjectStore, manifest_writer: ManifestWriter, ) -> std::result::Result<(), CommitError> { // Create a temporary object, then use `rename_if_not_exists` to commit. // If failed, clean up the temporary object. let path = self - .resolve_version(base_path, manifest.version, object_store) + .resolve_version(base_path, manifest.version, &object_store.inner) .await?; // Add .tmp_ prefix to the path @@ -680,7 +680,11 @@ impl CommitHandler for RenameCommitHandler { // Write the manifest to the temporary path manifest_writer(object_store, manifest, indices, &tmp_path).await?; - let res = match object_store.rename_if_not_exists(&tmp_path, &path).await { + let res = match object_store + .inner + .rename_if_not_exists(&tmp_path, &path) + .await + { Ok(_) => Ok(()), Err(ObjectStoreError::AlreadyExists { .. }) => { // Another transaction has already been committed @@ -695,7 +699,7 @@ impl CommitHandler for RenameCommitHandler { } }; - write_latest_manifest(&path, base_path, object_store).await?; + write_latest_manifest(&path, base_path, &object_store.inner).await?; res } diff --git a/rust/lance-table/src/io/commit/external_manifest.rs b/rust/lance-table/src/io/commit/external_manifest.rs index a7620da7b7..7285518e95 100644 --- a/rust/lance-table/src/io/commit/external_manifest.rs +++ b/rust/lance-table/src/io/commit/external_manifest.rs @@ -233,7 +233,7 @@ impl CommitHandler for ExternalManifestCommitHandler { manifest: &mut Manifest, indices: Option>, base_path: &Path, - object_store: &dyn object_store::ObjectStore, + object_store: &ObjectStore, manifest_writer: ManifestWriter, ) -> std::result::Result<(), CommitError> { // path we get here is the path to the manifest we want to write @@ -253,7 +253,7 @@ impl CommitHandler for ExternalManifestCommitHandler { .map_err(|_| CommitError::CommitConflict {})?; // step 4: copy the manifest to the final location - object_store.copy( + object_store.inner.copy( &staging_path, &path, ).await.map_err(|e| CommitError::OtherError( @@ -264,7 +264,7 @@ impl CommitHandler for ExternalManifestCommitHandler { ))?; // update the _latest.manifest pointer - write_latest_manifest(&path, base_path, object_store).await?; + write_latest_manifest(&path, base_path, &object_store.inner).await?; // step 5: flip the external store to point to the final location self.external_manifest_store diff --git a/rust/lance/src/datafusion/logical_expr.rs b/rust/lance/src/datafusion/logical_expr.rs index 757c0dc626..ef1ad3668d 100644 --- a/rust/lance/src/datafusion/logical_expr.rs +++ b/rust/lance/src/datafusion/logical_expr.rs @@ -5,11 +5,8 @@ use arrow_schema::DataType; -use datafusion::logical_expr::ScalarFunctionDefinition; use datafusion::logical_expr::ScalarUDFImpl; -use datafusion::logical_expr::{ - expr::ScalarFunction, BinaryExpr, GetFieldAccess, GetIndexedField, Operator, -}; +use datafusion::logical_expr::{expr::ScalarFunction, BinaryExpr, Operator}; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use datafusion_functions::core::getfield::GetFieldFunc; @@ -68,18 +65,6 @@ pub fn resolve_column_type(expr: &Expr, schema: &Schema) -> Option { return None; } } - Expr::GetIndexedField(GetIndexedField { expr, field }) => { - if let GetFieldAccess::NamedStructField { - name: ScalarValue::Utf8(Some(name)), - } = field - { - field_path.push(name); - } else { - // We don't support other kinds of access right now. - return None; - } - current_expr = expr.as_ref(); - } _ => return None, } } @@ -144,10 +129,7 @@ pub fn resolve_expr(expr: &Expr, schema: &Schema) -> Result { } } Expr::InList(in_list) => { - if matches!( - in_list.expr.as_ref(), - Expr::Column(_) | Expr::GetIndexedField(_) - ) { + if matches!(in_list.expr.as_ref(), Expr::Column(_)) { if let Some(resolved_type) = resolve_column_type(in_list.expr.as_ref(), schema) { let resolved_values = in_list .list @@ -200,11 +182,8 @@ pub fn coerce_filter_type_to_boolean(expr: Expr) -> Result { match &expr { // TODO: consider making this dispatch more generic, i.e. fun.output_type -> coerce // instead of hardcoding coerce method for each function - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(udf), - .. - }) => { - if udf.name() == "regexp_match" { + Expr::ScalarFunction(ScalarFunction { func, .. }) => { + if func.name() == "regexp_match" { Ok(Expr::IsNotNull(Box::new(expr))) } else { Ok(expr) @@ -222,6 +201,7 @@ pub mod tests { use arrow_schema::{Field, Schema as ArrowSchema}; use datafusion::logical_expr::ScalarUDF; + use datafusion_functions::core::expr_ext::FieldAccessor; // As part of the DF 37 release there are now two different ways to // represent a nested field access in `Expr`. The old way is to use @@ -242,9 +222,7 @@ pub mod tests { impl ExprExt for Expr { fn field_newstyle(&self, name: &str) -> Expr { Self::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( - GetFieldFunc::default(), - ))), + func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())), args: vec![ self.clone(), Self::Literal(ScalarValue::Utf8(Some(name.to_string()))), diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 138c74ff5b..34b7b363d0 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -1299,7 +1299,7 @@ pub(crate) async fn write_manifest_file( manifest, indices, base_path, - &object_store.inner, + object_store, write_manifest_file_to_path, ) .await?; @@ -1308,7 +1308,7 @@ pub(crate) async fn write_manifest_file( } fn write_manifest_file_to_path<'a>( - object_store: &'a dyn object_store::ObjectStore, + object_store: &'a ObjectStore, manifest: &'a mut Manifest, indices: Option>, path: &'a Path, diff --git a/rust/lance/src/dataset/builder.rs b/rust/lance/src/dataset/builder.rs index 7c3073c078..68292fc782 100644 --- a/rust/lance/src/dataset/builder.rs +++ b/rust/lance/src/dataset/builder.rs @@ -206,6 +206,7 @@ impl DatasetBuilder { store.1.clone(), self.options.block_size, self.options.object_store_wrapper, + self.options.use_constant_size_upload_parts, ), Path::from(store.1.path()), commit_handler, diff --git a/rust/lance/src/dataset/cleanup.rs b/rust/lance/src/dataset/cleanup.rs index 4568197183..fb30f1e5fa 100644 --- a/rust/lance/src/dataset/cleanup.rs +++ b/rust/lance/src/dataset/cleanup.rs @@ -35,8 +35,7 @@ use chrono::{DateTime, TimeDelta, Utc}; use futures::{stream, StreamExt, TryStreamExt}; -use lance_core::{Error, Result}; -use lance_io::object_store::ObjectStore; +use lance_core::Result; use lance_table::{ format::{Index, Manifest}, io::{ @@ -394,61 +393,20 @@ pub async fn cleanup_old_versions( cleanup.run().await } -/// Force cleanup of specific partial writes. -/// -/// These files can be cleaned up easily with [cleanup_old_versions()] after 7 days, -/// but if you know specific partial writes have been made, you can call this -/// function to clean them up immediately. -/// -/// To find partial writes, you can use the -/// [crate::dataset::progress::WriteFragmentProgress] trait to track which files -/// have been started but never finished. -pub async fn cleanup_partial_writes( - store: &ObjectStore, - base_path: &Path, - objects: impl IntoIterator, -) -> Result<()> { - futures::stream::iter(objects) - .map(Ok) - .try_for_each_concurrent(num_cpus::get() * 2, |(path, multipart_id)| async move { - let path: Path = base_path - .child("data") - .parts() - .chain(path.parts()) - .collect(); - match store.inner.abort_multipart(&path, multipart_id).await { - Ok(_) => Ok(()), - // We don't care if it's not there. - // TODO: once this issue is addressed, we should just use the error - // variant. https://github.com/apache/arrow-rs/issues/4749 - // Err(object_store::Error::NotFound { .. }) => { - Err(e) - if e.to_string().contains("No such file or directory") - || e.to_string().contains("cannot find the file") => - { - log::warn!("Partial write not found: {} {}", path, multipart_id); - Ok(()) - } - Err(e) => Err(Error::from(e)), - } - }) - .await?; - Ok(()) -} - #[cfg(test)] mod tests { use std::{collections::HashMap, sync::Arc}; - use arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; - use arrow_array::{RecordBatchIterator, RecordBatchReader}; - use lance_core::utils::testing::{MockClock, ProxyObjectStore, ProxyObjectStorePolicy}; + use arrow_array::RecordBatchReader; + use lance_core::{ + utils::testing::{MockClock, ProxyObjectStore, ProxyObjectStorePolicy}, + Error, + }; use lance_index::{DatasetIndexExt, IndexType}; - use lance_io::object_store::{ObjectStoreParams, WrappingObjectStore}; + use lance_io::object_store::{ObjectStore, ObjectStoreParams, WrappingObjectStore}; use lance_linalg::distance::MetricType; use lance_testing::datagen::{some_batch, BatchGenerator, IncrementingInt32}; use snafu::{location, Location}; - use tokio::io::AsyncWriteExt; use crate::{ dataset::{builder::DatasetBuilder, ReadParams, WriteMode, WriteParams}, @@ -1098,36 +1056,4 @@ mod tests { assert_eq!(after_count.num_data_files, 1); assert_eq!(after_count.num_manifest_files, 2); } - - #[tokio::test] - async fn test_cleanup_partial_writes() { - let test_dir = tempdir().unwrap(); - let test_uri = test_dir.path().to_str().unwrap(); - - let schema = ArrowSchema::new(vec![Field::new("a", DataType::Int32, false)]); - let reader = RecordBatchIterator::new(vec![], Arc::new(schema)); - let dataset = Dataset::write(reader, test_uri, Default::default()) - .await - .unwrap(); - let store = dataset.object_store(); - - // Create a partial write - let path1 = dataset.base.child("data").child("test"); - let (multipart_id, mut writer) = store.inner.put_multipart(&path1).await.unwrap(); - writer.write_all(b"test").await.unwrap(); - - // paths are relative to the store data path - let path1 = Path::from("test"); - // Add a non-existant path and id - let path2 = Path::from("test2"); - let non_existent_multipart_id = "non-existant-id".to_string(); - let objects = vec![ - (&path1, &multipart_id), - (&path2, &non_existent_multipart_id), - ]; - - cleanup_partial_writes(dataset.object_store(), &dataset.base, objects) - .await - .unwrap(); - } } diff --git a/rust/lance/src/dataset/fragment/write.rs b/rust/lance/src/dataset/fragment/write.rs index 0219ff9db3..cd97b646d4 100644 --- a/rust/lance/src/dataset/fragment/write.rs +++ b/rust/lance/src/dataset/fragment/write.rs @@ -95,7 +95,7 @@ impl<'a> FragmentCreateBuilder<'a> { FileWriterOptions::default(), )?; - progress.begin(&fragment, writer.multipart_id()).await?; + progress.begin(&fragment).await?; let break_limit = (128 * 1024).min(params.max_rows_per_file); @@ -166,7 +166,7 @@ impl<'a> FragmentCreateBuilder<'a> { ) .await?; - progress.begin(&fragment, writer.multipart_id()).await?; + progress.begin(&fragment).await?; let mut buffered_reader = chunk_stream(stream, params.max_rows_per_group); while let Some(batched_chunk) = buffered_reader.next().await { diff --git a/rust/lance/src/dataset/progress.rs b/rust/lance/src/dataset/progress.rs index a77608d28a..1e489d0b76 100644 --- a/rust/lance/src/dataset/progress.rs +++ b/rust/lance/src/dataset/progress.rs @@ -20,7 +20,7 @@ use crate::Result; #[async_trait] pub trait WriteFragmentProgress: std::fmt::Debug + Sync + Send { /// Indicate the beginning of writing a [Fragment], with the in-flight multipart ID. - async fn begin(&self, fragment: &Fragment, multipart_id: &str) -> Result<()>; + async fn begin(&self, fragment: &Fragment) -> Result<()>; /// Complete writing a [Fragment]. async fn complete(&self, fragment: &Fragment) -> Result<()>; @@ -39,7 +39,7 @@ impl NoopFragmentWriteProgress { #[async_trait] impl WriteFragmentProgress for NoopFragmentWriteProgress { #[inline] - async fn begin(&self, _fragment: &Fragment, _multipart_id: &str) -> Result<()> { + async fn begin(&self, _fragment: &Fragment) -> Result<()> { Ok(()) } diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 142b0ee6c8..0decaec57d 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -4250,7 +4250,7 @@ mod test { .project(&["_rowid", "_distance"]) }, "Projection: fields=[_rowid, _distance] - SortExec: TopK(fetch=32), expr=[_distance@0 ASC NULLS LAST] + SortExec: TopK(fetch=32), expr=[_distance@0 ASC NULLS LAST]... ANNSubIndex: name=idx, k=32, deltas=1 ANNIvfPartition: uuid=..., nprobes=1, deltas=1", ) @@ -4266,7 +4266,7 @@ mod test { .project(&["_rowid", "_distance"]) }, "Projection: fields=[_rowid, _distance] - SortExec: TopK(fetch=33), expr=[_distance@0 ASC NULLS LAST] + SortExec: TopK(fetch=33), expr=[_distance@0 ASC NULLS LAST]... ANNSubIndex: name=idx, k=33, deltas=1 ANNIvfPartition: uuid=..., nprobes=1, deltas=1", ) @@ -4283,17 +4283,17 @@ mod test { }, "Projection: fields=[_rowid, _distance] FilterExec: _distance@2 IS NOT NULL - SortExec: TopK(fetch=34), expr=[_distance@2 ASC NULLS LAST] + SortExec: TopK(fetch=34), expr=[_distance@2 ASC NULLS LAST]... KNNVectorDistance: metric=l2 RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2 UnionExec Projection: fields=[_distance, _rowid, vec] FilterExec: _distance@2 IS NOT NULL - SortExec: TopK(fetch=34), expr=[_distance@2 ASC NULLS LAST] + SortExec: TopK(fetch=34), expr=[_distance@2 ASC NULLS LAST]... KNNVectorDistance: metric=l2 LanceScan: uri=..., projection=[vec], row_id=true, row_addr=false, ordered=false Take: columns=\"_distance, _rowid, vec\" - SortExec: TopK(fetch=34), expr=[_distance@0 ASC NULLS LAST] + SortExec: TopK(fetch=34), expr=[_distance@0 ASC NULLS LAST]... ANNSubIndex: name=idx, k=34, deltas=1 ANNIvfPartition: uuid=..., nprobes=1, deltas=1", ) diff --git a/rust/lance/src/dataset/write.rs b/rust/lance/src/dataset/write.rs index 2caa25c725..9152c72e49 100644 --- a/rust/lance/src/dataset/write.rs +++ b/rust/lance/src/dataset/write.rs @@ -235,10 +235,7 @@ pub async fn write_fragments_internal( if writer.is_none() { let (new_writer, new_fragment) = writer_generator.new_writer().await?; - // rustc has a hard time analyzing the lifetime of the &str returned - // by multipart_id(), so we convert it to an owned value here. - let multipart_id = new_writer.multipart_id().to_string(); - params.progress.begin(&new_fragment, &multipart_id).await?; + params.progress.begin(&new_fragment).await?; writer = Some(new_writer); fragments.push(new_fragment); } @@ -274,10 +271,6 @@ pub async fn write_fragments_internal( #[async_trait::async_trait] pub trait GenericWriter: Send { - /// Get a unique id associated with the fragment being written - /// - /// This is used for progress reporting - fn multipart_id(&self) -> &str; /// Write the given batches to the file async fn write(&mut self, batches: &[RecordBatch]) -> Result<()>; /// Get the current position in the file @@ -291,9 +284,6 @@ pub trait GenericWriter: Send { #[async_trait::async_trait] impl GenericWriter for (FileWriter, String) { - fn multipart_id(&self) -> &str { - self.0.multipart_id() - } async fn write(&mut self, batches: &[RecordBatch]) -> Result<()> { self.0.write(batches).await } @@ -315,9 +305,6 @@ struct V2WriterAdapter { #[async_trait::async_trait] impl GenericWriter for V2WriterAdapter { - fn multipart_id(&self) -> &str { - self.writer.multipart_id() - } async fn write(&mut self, batches: &[RecordBatch]) -> Result<()> { for batch in batches { self.writer.write_batch(batch).await?; diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index d17c1af0f4..09d151d485 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -148,8 +148,8 @@ impl ExecutionPlan for KNNVectorDistanceExec { self.output_schema.clone() } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn with_new_children( @@ -404,7 +404,7 @@ impl ExecutionPlan for ANNIvfPartitionExec { &self.properties } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -549,11 +549,11 @@ impl ExecutionPlan for ANNIvfSubIndexExec { KNN_INDEX_SCHEMA.clone() } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { match &self.prefilter_source { - PreFilterSource::None => vec![self.input.clone()], - PreFilterSource::FilteredRowIds(src) => vec![self.input.clone(), src.clone()], - PreFilterSource::ScalarIndexQuery(src) => vec![self.input.clone(), src.clone()], + PreFilterSource::None => vec![&self.input], + PreFilterSource::FilteredRowIds(src) => vec![&self.input, &src], + PreFilterSource::ScalarIndexQuery(src) => vec![&self.input, &src], } } diff --git a/rust/lance/src/io/exec/optimizer.rs b/rust/lance/src/io/exec/optimizer.rs index fb350b1a99..3b72c064bd 100644 --- a/rust/lance/src/io/exec/optimizer.rs +++ b/rust/lance/src/io/exec/optimizer.rs @@ -5,6 +5,7 @@ use std::sync::Arc; +use super::TakeExec; use datafusion::{ common::tree_node::{Transformed, TreeNode}, config::ConfigOptions, @@ -14,8 +15,6 @@ use datafusion::{ }; use datafusion_physical_expr::expressions::Column; -use super::TakeExec; - /// Rule that eliminates [TakeExec] nodes that are immediately followed by another [TakeExec]. pub struct CoalesceTake; @@ -26,12 +25,12 @@ impl PhysicalOptimizerRule for CoalesceTake { _config: &ConfigOptions, ) -> DFResult> { Ok(plan - .transform_down(&|plan| { + .transform_down(|plan| { if let Some(take) = plan.as_any().downcast_ref::() { - let child = &take.children()[0]; + let child = take.children()[0]; if let Some(exec_child) = child.as_any().downcast_ref::() { - let upstream_plan = exec_child.children(); - return Ok(Transformed::yes(plan.with_new_children(upstream_plan)?)); + let inner_child = exec_child.children()[0].clone(); + return Ok(Transformed::yes(plan.with_new_children(vec![inner_child])?)); } } Ok(Transformed::no(plan)) @@ -59,14 +58,14 @@ impl PhysicalOptimizerRule for SimplifyProjection { _config: &ConfigOptions, ) -> DFResult> { Ok(plan - .transform_down(&|plan| { + .transform_down(|plan| { if let Some(proj) = plan.as_any().downcast_ref::() { - let children = &proj.children(); + let children = proj.children(); if children.len() != 1 { return Ok(Transformed::no(plan)); } - let input = &children[0]; + let input = children[0]; // TODO: we could try to coalesce consecutive projections, something for later // For now, we just keep things simple and only remove NoOp projections diff --git a/rust/lance/src/io/exec/planner.rs b/rust/lance/src/io/exec/planner.rs index 9446d56afb..84c22f2e74 100644 --- a/rust/lance/src/io/exec/planner.rs +++ b/rust/lance/src/io/exec/planner.rs @@ -20,15 +20,15 @@ use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::logical_expr::expr::ScalarFunction; use datafusion::logical_expr::{ - AggregateUDF, ColumnarValue, GetFieldAccess, GetIndexedField, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, WindowUDF, + AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowUDF, }; use datafusion::optimizer::simplify_expressions::SimplifyContext; use datafusion::physical_optimizer::optimizer::PhysicalOptimizer; use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}; use datafusion::sql::sqlparser::ast::{ Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo, Expr as SQLExpr, - Function, FunctionArg, FunctionArgExpr, Ident, TimezoneInfo, UnaryOperator, Value, + Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, TimezoneInfo, UnaryOperator, + Value, }; use datafusion::{ common::Column, @@ -229,15 +229,15 @@ impl ContextProvider for LanceContextProvider { &self.options } - fn udfs_names(&self) -> Vec { + fn udf_names(&self) -> Vec { self.state.scalar_functions().keys().cloned().collect() } - fn udafs_names(&self) -> Vec { + fn udaf_names(&self) -> Vec { self.state.aggregate_functions().keys().cloned().collect() } - fn udwfs_names(&self) -> Vec { + fn udwf_names(&self) -> Vec { self.state.window_functions().keys().cloned().collect() } } @@ -259,9 +259,7 @@ impl Planner { column, Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone()))), ], - func_def: datafusion::logical_expr::ScalarFunctionDefinition::UDF(Arc::new( - ScalarUDF::new_from_impl(GetFieldFunc::default()), - )), + func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())), }); } column @@ -354,20 +352,13 @@ impl Planner { Ok(match value { Value::Number(v, _) => self.number(v.as_str())?, Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))), - Value::DollarQuotedString(_) => todo!(), - Value::EscapedStringLiteral(_) => todo!(), - Value::NationalStringLiteral(_) => todo!(), Value::HexStringLiteral(hsl) => { Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl))) } Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))), Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v))), Value::Null => Expr::Literal(ScalarValue::Null), - Value::Placeholder(_) => todo!(), - Value::UnQuotedString(_) => todo!(), - Value::SingleQuotedByteStringLiteral(_) => todo!(), - Value::DoubleQuotedByteStringLiteral(_) => todo!(), - Value::RawStringLiteral(_) => todo!(), + _ => todo!(), }) } @@ -388,15 +379,23 @@ impl Planner { // this is a function that comes from duckdb. Datafusion does not consider is_valid to be a function // but rather an AST node (Expr::IsNotNull) and so we need to handle this case specially. fn legacy_parse_function(&self, func: &Function) -> Result { - if func.args.len() != 1 { - return Err(Error::io( - format!("is_valid only support 1 args, got {}", func.args.len()), + match &func.args { + FunctionArguments::List(args) => { + if func.name.0.len() != 1 { + return Err(Error::io( + format!("Function name must have 1 part, got: {:?}", func.name.0), + location!(), + )); + } + Ok(Expr::IsNotNull(Box::new( + self.parse_function_args(&args.args[0])?, + ))) + } + _ => Err(Error::io( + format!("Unsupported function args: {:?}", &func.args), location!(), - )); + )), } - Ok(Expr::IsNotNull(Box::new( - self.parse_function_args(&func.args[0])?, - ))) } fn parse_function(&self, function: SQLExpr) -> Result { @@ -622,7 +621,7 @@ impl Planner { *negated, Box::new(self.parse_sql_expr(expr)?), Box::new(self.parse_sql_expr(pattern)?), - *escape_char, + escape_char.as_ref().and_then(|c| c.chars().next()), true, ))), SQLExpr::Like { @@ -634,7 +633,7 @@ impl Planner { *negated, Box::new(self.parse_sql_expr(expr)?), Box::new(self.parse_sql_expr(pattern)?), - *escape_char, + escape_char.as_ref().and_then(|c| c.chars().next()), false, ))), SQLExpr::Cast { @@ -723,7 +722,7 @@ impl Planner { datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context); let expr = simplifier.simplify(expr.clone())?; - let expr = simplifier.coerce(expr, df_schema.clone())?; + let expr = simplifier.coerce(expr, &df_schema)?; Ok(expr) } @@ -794,7 +793,7 @@ struct ColumnCapturingVisitor { columns: BTreeSet, } -impl TreeNodeVisitor for ColumnCapturingVisitor { +impl TreeNodeVisitor<'_> for ColumnCapturingVisitor { type Node = Expr; fn f_down(&mut self, node: &Self::Node) -> DFResult { @@ -819,12 +818,6 @@ impl TreeNodeVisitor for ColumnCapturingVisitor { self.current_path.clear(); } } - Expr::GetIndexedField(GetIndexedField { - expr: _, - field: GetFieldAccess::NamedStructField { name }, - }) => { - self.current_path.push_front(name.to_string()); - } _ => { self.current_path.clear(); } @@ -847,7 +840,8 @@ mod tests { TimestampNanosecondArray, TimestampSecondArray, }; use arrow_schema::{DataType, Fields, Schema}; - use datafusion::logical_expr::{lit, Cast, ScalarFunctionDefinition}; + use datafusion::logical_expr::{lit, Cast}; + use datafusion_functions::core::expr_ext::FieldAccessor; #[test] fn test_parse_filter_simple() { @@ -968,9 +962,7 @@ mod tests { assert_column_eq(&planner, "`s0`", &expected); let expected = Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( - GetFieldFunc::default(), - ))), + func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())), args: vec![ Expr::Column(Column { relation: None, @@ -984,14 +976,10 @@ mod tests { assert_column_eq(&planner, "st.`s1`", &expected); let expected = Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( - GetFieldFunc::default(), - ))), + func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())), args: vec![ Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( - GetFieldFunc::default(), - ))), + func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())), args: vec![ Expr::Column(Column { relation: None, diff --git a/rust/lance/src/io/exec/projection.rs b/rust/lance/src/io/exec/projection.rs index d0a8f29bab..a666e03eae 100644 --- a/rust/lance/src/io/exec/projection.rs +++ b/rust/lance/src/io/exec/projection.rs @@ -166,8 +166,8 @@ impl ExecutionPlan for ProjectionExec { arrow_schema.into() } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn with_new_children( diff --git a/rust/lance/src/io/exec/pushdown_scan.rs b/rust/lance/src/io/exec/pushdown_scan.rs index 09377d1ab2..f354f5dbb9 100644 --- a/rust/lance/src/io/exec/pushdown_scan.rs +++ b/rust/lance/src/io/exec/pushdown_scan.rs @@ -24,6 +24,7 @@ use datafusion::{ }, prelude::Expr, }; +use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_physical_expr::EquivalenceProperties; use futures::{FutureExt, Stream, StreamExt, TryStreamExt}; use lance_arrow::{RecordBatchExt, SchemaExt}; @@ -154,7 +155,7 @@ impl ExecutionPlan for LancePushdownScanExec { self.output_schema.clone() } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/rust/lance/src/io/exec/scalar_index.rs b/rust/lance/src/io/exec/scalar_index.rs index c78ce382c0..78e852dccd 100644 --- a/rust/lance/src/io/exec/scalar_index.rs +++ b/rust/lance/src/io/exec/scalar_index.rs @@ -116,7 +116,7 @@ impl ExecutionPlan for ScalarIndexExec { SCALAR_INDEX_SCHEMA.clone() } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -276,8 +276,8 @@ impl ExecutionPlan for MapIndexExec { INDEX_LOOKUP_SCHEMA.clone() } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn with_new_children( @@ -545,7 +545,7 @@ impl ExecutionPlan for MaterializeIndexExec { MATERIALIZE_INDEX_SCHEMA.clone() } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/rust/lance/src/io/exec/scan.rs b/rust/lance/src/io/exec/scan.rs index 7638241818..67b636ecd7 100644 --- a/rust/lance/src/io/exec/scan.rs +++ b/rust/lance/src/io/exec/scan.rs @@ -293,7 +293,7 @@ impl ExecutionPlan for LanceScanExec { } /// Scan is the leaf node, so returns an empty vector. - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/rust/lance/src/io/exec/take.rs b/rust/lance/src/io/exec/take.rs index 5fe5629441..fea67ddce4 100644 --- a/rust/lance/src/io/exec/take.rs +++ b/rust/lance/src/io/exec/take.rs @@ -257,8 +257,8 @@ impl ExecutionPlan for TakeExec { ArrowSchema::from(&self.output_schema).into() } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } /// This preserves the output schema. diff --git a/rust/lance/src/io/exec/testing.rs b/rust/lance/src/io/exec/testing.rs index 9ab2794590..25529cce30 100644 --- a/rust/lance/src/io/exec/testing.rs +++ b/rust/lance/src/io/exec/testing.rs @@ -55,7 +55,7 @@ impl ExecutionPlan for TestingExec { self.batches[0].schema() } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/rust/lance/src/io/exec/utils.rs b/rust/lance/src/io/exec/utils.rs index 4889279e7e..171cc99f52 100644 --- a/rust/lance/src/io/exec/utils.rs +++ b/rust/lance/src/io/exec/utils.rs @@ -131,8 +131,8 @@ impl ExecutionPlan for ReplayExec { self.input.schema() } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn with_new_children( diff --git a/rust/lance/src/utils/test.rs b/rust/lance/src/utils/test.rs index 1fa32968a9..6bcb78bede 100644 --- a/rust/lance/src/utils/test.rs +++ b/rust/lance/src/utils/test.rs @@ -15,12 +15,11 @@ use lance_io::object_store::WrappingObjectStore; use lance_table::format::Fragment; use object_store::path::Path; use object_store::{ - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutOptions, PutResult, - Result as OSResult, + GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, PutMultipartOpts, + PutOptions, PutPayload, PutResult, Result as OSResult, }; use rand::prelude::SliceRandom; use rand::{Rng, SeedableRng}; -use tokio::io::AsyncWrite; use crate::dataset::fragment::write::FragmentCreateBuilder; use crate::dataset::transaction::Operation; @@ -304,28 +303,25 @@ impl IoTrackingStore { #[async_trait::async_trait] impl ObjectStore for IoTrackingStore { - async fn put(&self, location: &Path, bytes: Bytes) -> OSResult { + async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult { self.target.put(location, bytes).await } async fn put_opts( &self, location: &Path, - bytes: Bytes, + bytes: PutPayload, opts: PutOptions, ) -> OSResult { self.target.put_opts(location, bytes, opts).await } - async fn put_multipart( + async fn put_multipart_opts( &self, location: &Path, - ) -> OSResult<(MultipartId, Box)> { - self.target.put_multipart(location).await - } - - async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> OSResult<()> { - self.target.abort_multipart(location, multipart_id).await + opts: PutMultipartOpts, + ) -> OSResult> { + self.target.put_multipart_opts(location, opts).await } async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult { From 5b5926c0cd297f2923c2132e3e74fce80df1fd14 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 12 Jul 2024 19:52:29 -0700 Subject: [PATCH 3/8] upgrade python --- python/Cargo.toml | 15 +++--- python/python/lance/cleanup.py | 17 ------- python/python/lance/progress.py | 76 ++-------------------------- python/python/tests/helper.py | 5 +- python/python/tests/test_fragment.py | 6 --- python/src/dataset.rs | 56 ++++++++++---------- python/src/fragment.rs | 41 +++------------ python/src/lib.rs | 3 +- python/src/utils.rs | 12 ++--- rust/lance/src/utils/tfrecord.rs | 5 +- 10 files changed, 55 insertions(+), 181 deletions(-) diff --git a/python/Cargo.toml b/python/Cargo.toml index c001008624..81050c0b0c 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -12,11 +12,11 @@ name = "lance" crate-type = ["cdylib"] [dependencies] -arrow = { version = "51.0.0", features = ["pyarrow"] } -arrow-array = "51.0" -arrow-data = "51.0" -arrow-schema = "51.0" -object_store = "0.9.0" +arrow = { version = "52.0", features = ["pyarrow"] } +arrow-array = "52.0" +arrow-data = "52.0" +arrow-schema = "52.0" +object_store = "0.10.1" async-trait = "0.1" chrono = "0.4.31" env_logger = "0.10" @@ -42,7 +42,7 @@ lance-table = { path = "../rust/lance-table" } lazy_static = "1" log = "0.4" prost = "0.12.2" -pyo3 = { version = "0.20", features = ["extension-module", "abi3-py39"] } +pyo3 = { version = "0.21", features = ["extension-module", "abi3-py39", "gil-refs"] } tokio = { version = "1.23", features = ["rt-multi-thread"] } uuid = "1.3.0" serde_json = "1" @@ -55,9 +55,6 @@ tracing-subscriber = "0.3.17" tracing = "0.1.37" url = "2.5.0" -# Prevent dynamic linking of lzma, which comes from datafusion -lzma-sys = { version = "*", features = ["static"] } - [features] datagen = ["lance-datagen"] fp16kernels = ["lance/fp16kernels"] diff --git a/python/python/lance/cleanup.py b/python/python/lance/cleanup.py index 4dff6edf27..addbf06e4b 100644 --- a/python/python/lance/cleanup.py +++ b/python/python/lance/cleanup.py @@ -1,20 +1,3 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The Lance Authors -from typing import List, Tuple - -from .lance import _cleanup_partial_writes - - -def cleanup_partial_writes(objects: List[Tuple[str, str]]): - """Cleans up partial writes from a list of objects. - - These writes can be discovered using the - :class:`lance.progress.FragmentWriteProgress` class. - - Parameters - ---------- - objects : List[Tuple[str, str]] - A list of tuples of (fragment_id, multipart_id) to clean up. - """ - _cleanup_partial_writes(objects) diff --git a/python/python/lance/progress.py b/python/python/lance/progress.py index fb0e85c21d..0c42863864 100644 --- a/python/python/lance/progress.py +++ b/python/python/lance/progress.py @@ -10,8 +10,6 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Dict, Optional -from .lance import _cleanup_partial_writes - if TYPE_CHECKING: # We don't import directly because of circular import from .fragment import FragmentMetadata @@ -25,19 +23,15 @@ class FragmentWriteProgress(ABC): This tracking class is experimental and may change in the future. """ - def _do_begin( - self, fragment_json: str, multipart_id: Optional[str] = None, **kwargs - ): + def _do_begin(self, fragment_json: str, **kwargs): """Called when a new fragment is created""" from .fragment import FragmentMetadata fragment = FragmentMetadata.from_json(fragment_json) - return self.begin(fragment, multipart_id, **kwargs) + return self.begin(fragment, **kwargs) @abstractmethod - def begin( - self, fragment: "FragmentMetadata", multipart_id: Optional[str] = None, **kwargs - ) -> None: + def begin(self, fragment: "FragmentMetadata", **kwargs) -> None: """Called when a new fragment is about to be written. Parameters @@ -45,9 +39,6 @@ def begin( fragment : FragmentMetadata The fragment that is open to write to. The fragment id might not yet be assigned at this point. - multipart_id : str, optional - The multipart id that will be uploaded to cloud storage. This may be - used later to abort incomplete uploads if this fragment write fails. kwargs: dict, optional Extra keyword arguments to pass to the implementation. @@ -84,9 +75,7 @@ class NoopFragmentWriteProgress(FragmentWriteProgress): This is the default implementation. """ - def begin( - self, fragment: "FragmentMetadata", multipart_id: Optional[str] = None, **kargs - ): + def begin(self, fragment: "FragmentMetadata", **kargs): pass def complete(self, fragment: "FragmentMetadata", **kwargs): @@ -135,17 +124,13 @@ def _in_progress_path(self, fragment: "FragmentMetadata") -> str: def _fragment_file(self, fragment: "FragmentMetadata") -> str: return os.path.join(self._base_path, f"fragment_{fragment.id}.json") - def begin( - self, fragment: "FragmentMetadata", multipart_id: Optional[str] = None, **kwargs - ): + def begin(self, fragment: "FragmentMetadata", **kwargs): """Called when a new fragment is created. Parameters ---------- fragment : FragmentMetadata The fragment that is open to write to. - multipart_id : str, optional - The multipart id to upload this fragment to cloud storage. """ self._fs.create_dir(self._base_path, recursive=True) @@ -153,7 +138,6 @@ def begin( with self._fs.open_output_stream(self._in_progress_path(fragment)) as out: progress_data = { "fragment_id": fragment.id, - "multipart_id": multipart_id if multipart_id else "", "metadata": self._metadata, } out.write(json.dumps(progress_data).encode("utf-8")) @@ -164,53 +148,3 @@ def begin( def complete(self, fragment: "FragmentMetadata", **kwargs): """Called when a fragment is completed""" self._fs.delete_file(self._in_progress_path(fragment)) - - def cleanup_partial_writes(self, dataset_uri: str) -> int: - """ - Finds all in-progress files and cleans up any partially written data - files. This is useful for cleaning up after a failed write. - - Parameters - ---------- - dataset_uri : str - The URI of the table to clean up. - - Returns - ------- - int - The number of partial writes cleaned up. - """ - from pyarrow.fs import FileSelector - - from .fragment import FragmentMetadata - - marker_paths = [] - objects = [] - selector = FileSelector(self._base_path) - for info in self._fs.get_file_info(selector): - path = info.path - if path.endswith(self.PROGRESS_EXT): - marker_paths.append(path) - with self._fs.open_input_stream(path) as f: - progress_data = json.loads(f.read().decode("utf-8")) - - json_path = path.rstrip(self.PROGRESS_EXT) + ".json" - with self._fs.open_input_stream(json_path) as f: - fragment_metadata = FragmentMetadata.from_json( - f.read().decode("utf-8") - ) - objects.append( - ( - fragment_metadata.data_files()[0].path(), - progress_data["multipart_id"], - ) - ) - - _cleanup_partial_writes(dataset_uri, objects) - - for path in marker_paths: - self._fs.delete_file(path) - json_path = path.rstrip(self.PROGRESS_EXT) + ".json" - self._fs.delete_file(json_path) - - return len(objects) diff --git a/python/python/tests/helper.py b/python/python/tests/helper.py index 47d6962f59..5a9357305a 100644 --- a/python/python/tests/helper.py +++ b/python/python/tests/helper.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The Lance Authors -from typing import Optional from lance.fragment import FragmentMetadata from lance.progress import FragmentWriteProgress @@ -13,9 +12,7 @@ def __init__(self): self.begin_called = 0 self.complete_called = 0 - def begin( - self, fragment: FragmentMetadata, multipart_id: Optional[str] = None, **kwargs - ): + def begin(self, fragment: FragmentMetadata, **kwargs): self.begin_called += 1 def complete(self, fragment: FragmentMetadata): diff --git a/python/python/tests/test_fragment.py b/python/python/tests/test_fragment.py index d92e387ab9..64fb717cd3 100644 --- a/python/python/tests/test_fragment.py +++ b/python/python/tests/test_fragment.py @@ -217,7 +217,6 @@ def test_dataset_progress(tmp_path: Path): with open(progress_uri / "fragment_1.in_progress") as f: progress_data = json.load(f) assert progress_data["fragment_id"] == 1 - assert isinstance(progress_data["multipart_id"], str) # progress contains custom metadata assert progress_data["metadata"]["test_key"] == "test_value" @@ -226,11 +225,6 @@ def test_dataset_progress(tmp_path: Path): metadata = json.load(f) assert metadata["id"] == 1 - progress.cleanup_partial_writes(str(dataset_uri)) - - assert not (progress_uri / "fragment_1.json").exists() - assert not (progress_uri / "fragment_1.in_progress").exists() - def test_fragment_meta(): # Intentionally leaving off column_indices / version fields to make sure diff --git a/python/src/dataset.rs b/python/src/dataset.rs index c94f459280..7ee64b0691 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -59,7 +59,7 @@ use pyo3::types::{PyBytes, PyList, PySet, PyString}; use pyo3::{ exceptions::{PyIOError, PyKeyError, PyValueError}, pyclass, - types::{IntoPyDict, PyBool, PyDict, PyInt, PyLong}, + types::{IntoPyDict, PyDict}, PyObject, PyResult, }; use snafu::{location, Location}; @@ -178,7 +178,7 @@ impl MergeInsertBuilder { Ok(slf) } - pub fn execute(&mut self, new_data: &PyAny) -> PyResult { + pub fn execute(&mut self, new_data: &Bound) -> PyResult { let py = new_data.py(); let new_data: Box = if new_data.is_instance_of::() { @@ -188,7 +188,7 @@ impl MergeInsertBuilder { .map_err(|err| PyValueError::new_err(err.to_string()))?, ) } else { - Box::new(ArrowArrayStreamReader::from_pyarrow(new_data)?) + Box::new(ArrowArrayStreamReader::from_pyarrow_bound(new_data)?) }; let job = self @@ -424,7 +424,7 @@ impl Dataset { prefilter: Option, limit: Option, offset: Option, - nearest: Option<&PyDict>, + nearest: Option<&Bound>, batch_size: Option, batch_readahead: Option, fragment_readahead: Option, @@ -522,7 +522,7 @@ impl Dataset { let qval = nearest .get_item("q")? .ok_or_else(|| PyKeyError::new_err("Need q for nearest"))?; - let data = ArrayData::from_pyarrow(qval)?; + let data = ArrayData::from_pyarrow_bound(&qval)?; let q = Float32Array::from(data); let k: usize = if let Some(k) = nearest.get_item("k")? { @@ -530,7 +530,7 @@ impl Dataset { // Use limit if k is not specified, default to 10. limit.unwrap_or(10) as usize } else { - PyAny::downcast::(k)?.extract()? + k.extract()? } } else { 10 @@ -540,7 +540,7 @@ impl Dataset { if nprobes.is_none() { DEFAULT_NPROBS } else { - PyAny::downcast::(nprobes)?.extract()? + nprobes.extract()? } } else { DEFAULT_NPROBS @@ -567,14 +567,14 @@ impl Dataset { if rf.is_none() { None } else { - PyAny::downcast::(rf)?.extract()? + rf.extract()? } } else { None }; let use_index: bool = if let Some(idx) = nearest.get_item("use_index")? { - PyAny::downcast::(idx)?.extract()? + idx.extract()? } else { true }; @@ -583,7 +583,7 @@ impl Dataset { if ef.is_none() { None } else { - PyAny::downcast::(ef)?.extract()? + ef.extract()? } } else { None @@ -930,7 +930,7 @@ impl Dataset { name: Option, replace: Option, storage_options: Option>, - kwargs: Option<&PyDict>, + kwargs: Option<&Bound>, ) -> PyResult<()> { let index_type = index_type.to_uppercase(); let idx_type = match index_type.as_str() { @@ -1145,7 +1145,7 @@ impl Dataset { } #[pyfunction(name = "_write_dataset")] -pub fn write_dataset(reader: &PyAny, uri: String, options: &PyDict) -> PyResult { +pub fn write_dataset(reader: &Bound, uri: String, options: &PyDict) -> PyResult { let params = get_write_params(options)?; let py = options.py(); let ds = if reader.is_instance_of::() { @@ -1157,7 +1157,7 @@ pub fn write_dataset(reader: &PyAny, uri: String, options: &PyDict) -> PyResult< RT.block_on(Some(py), LanceDataset::write(batches, &uri, params))? .map_err(|err| PyIOError::new_err(err.to_string()))? } else { - let batches = ArrowArrayStreamReader::from_pyarrow(reader)?; + let batches = ArrowArrayStreamReader::from_pyarrow_bound(reader)?; RT.block_on(Some(py), LanceDataset::write(batches, &uri, params))? .map_err(|err| PyIOError::new_err(err.to_string()))? }; @@ -1250,7 +1250,7 @@ fn prepare_vector_index_params( index_type: &str, column_type: &DataType, storage_options: Option>, - kwargs: Option<&PyDict>, + kwargs: Option<&Bound>, ) -> PyResult> { let mut m_type = MetricType::L2; let mut ivf_params = IvfBuildParams::default(); @@ -1267,7 +1267,7 @@ fn prepare_vector_index_params( // Parse sample rate if let Some(sample_rate) = kwargs.get_item("sample_rate")? { - let sample_rate = PyAny::downcast::(sample_rate)?.extract()?; + let sample_rate: usize = sample_rate.extract()?; ivf_params.sample_rate = sample_rate; pq_params.sample_rate = sample_rate; sq_params.sample_rate = sample_rate; @@ -1275,15 +1275,15 @@ fn prepare_vector_index_params( // Parse IVF params if let Some(n) = kwargs.get_item("num_partitions")? { - ivf_params.num_partitions = PyAny::downcast::(n)?.extract()? + ivf_params.num_partitions = n.extract()? }; if let Some(n) = kwargs.get_item("shuffle_partition_concurrency")? { - ivf_params.shuffle_partition_concurrency = PyAny::downcast::(n)?.extract()? + ivf_params.shuffle_partition_concurrency = n.extract()? }; if let Some(c) = kwargs.get_item("ivf_centroids")? { - let batch = RecordBatch::from_pyarrow(c)?; + let batch = RecordBatch::from_pyarrow_bound(&c)?; if "_ivf_centroids" != batch.schema().field(0).name() { return Err(PyValueError::new_err( "Expected '_ivf_centroids' as the first column name.", @@ -1327,7 +1327,7 @@ fn prepare_vector_index_params( e )) })?; - let list = PyAny::downcast::(l)? + let list = l.downcast::()? .iter() .map(|f| f.to_string()) .collect(); @@ -1343,28 +1343,28 @@ fn prepare_vector_index_params( // Parse HNSW params if let Some(max_level) = kwargs.get_item("max_level")? { - hnsw_params.max_level = PyAny::downcast::(max_level)?.extract()?; + hnsw_params.max_level = max_level.extract()?; } if let Some(m) = kwargs.get_item("m")? { - hnsw_params.m = PyAny::downcast::(m)?.extract()?; + hnsw_params.m = m.extract()?; } if let Some(ef_c) = kwargs.get_item("ef_construction")? { - hnsw_params.ef_construction = PyAny::downcast::(ef_c)?.extract()?; + hnsw_params.ef_construction = ef_c.extract()?; } // Parse PQ params if let Some(n) = kwargs.get_item("num_bits")? { - pq_params.num_bits = PyAny::downcast::(n)?.extract()? + pq_params.num_bits = n.extract()? }; if let Some(n) = kwargs.get_item("num_sub_vectors")? { - pq_params.num_sub_vectors = PyAny::downcast::(n)?.extract()? + pq_params.num_sub_vectors = n.extract()? }; if let Some(c) = kwargs.get_item("pq_codebook")? { - let batch = RecordBatch::from_pyarrow(c)?; + let batch = RecordBatch::from_pyarrow_bound(&c)?; if "_pq_codebook" != batch.schema().field(0).name() { return Err(PyValueError::new_err( "Expected '_pq_codebook' as the first column name.", @@ -1415,14 +1415,12 @@ impl PyWriteProgress { #[async_trait] impl WriteFragmentProgress for PyWriteProgress { - async fn begin(&self, fragment: &Fragment, multipart_id: &str) -> lance::Result<()> { + async fn begin(&self, fragment: &Fragment) -> lance::Result<()> { let json_str = serde_json::to_string(fragment)?; Python::with_gil(|py| -> PyResult<()> { - let kwargs = PyDict::new(py); - kwargs.set_item("multipart_id", multipart_id)?; self.py_obj - .call_method(py, "_do_begin", (json_str,), Some(kwargs))?; + .call_method(py, "_do_begin", (json_str,), None)?; Ok(()) }) .map_err(|e| { diff --git a/python/src/fragment.rs b/python/src/fragment.rs index 8bf400b21b..4f5c55e5f1 100644 --- a/python/src/fragment.rs +++ b/python/src/fragment.rs @@ -22,10 +22,8 @@ use arrow_schema::Schema as ArrowSchema; use futures::TryFutureExt; use lance::dataset::fragment::FileFragment as LanceFragment; use lance::datatypes::Schema; -use lance_io::object_store::ObjectStore; use lance_table::format::{DataFile as LanceDataFile, Fragment as LanceFragmentMetadata}; use lance_table::io::deletion::deletion_file_path; -use object_store::path::Path; use pyo3::prelude::*; use pyo3::{exceptions::*, pyclass::CompareOp, types::PyDict}; @@ -99,7 +97,7 @@ impl FileFragment { fn create( dataset_uri: &str, fragment_id: Option, - reader: &PyAny, + reader: &Bound, kwargs: Option<&PyDict>, ) -> PyResult { let params = if let Some(kw_params) = kwargs { @@ -430,40 +428,11 @@ impl FragmentMetadata { } } -#[pyfunction(name = "_cleanup_partial_writes")] -pub fn cleanup_partial_writes(base_uri: &str, files: Vec<(String, String)>) -> PyResult<()> { - let (store, base_path) = RT - .runtime - .block_on(ObjectStore::from_uri(base_uri)) - .map_err(|err| PyIOError::new_err(format!("Failed to create object store: {}", err)))?; - - let files: Vec<(Path, String)> = files - .into_iter() - .map(|(path, multipart_id)| (Path::from(path.as_str()), multipart_id)) - .collect(); - - #[allow(clippy::map_identity)] - async fn inner( - store: ObjectStore, - base_path: Path, - files: Vec<(Path, String)>, - ) -> Result<(), ::lance::Error> { - let files_iter = files - .iter() - .map(|(path, multipart_id)| (path, multipart_id)); - lance::dataset::cleanup::cleanup_partial_writes(&store, &base_path, files_iter).await - } - - RT.runtime - .block_on(inner(store, base_path, files)) - .map_err(|err| PyIOError::new_err(format!("Failed to cleanup files: {}", err))) -} - #[pyfunction(name = "_write_fragments")] #[pyo3(signature = (dataset_uri, reader, **kwargs))] pub fn write_fragments( dataset_uri: &str, - reader: &PyAny, + reader: &Bound, kwargs: Option<&PyDict>, ) -> PyResult> { let batches = convert_reader(reader)?; @@ -485,7 +454,7 @@ pub fn write_fragments( .collect() } -fn convert_reader(reader: &PyAny) -> PyResult> { +fn convert_reader(reader: &Bound) -> PyResult> { if reader.is_instance_of::() { let scanner: Scanner = reader.extract()?; let reader = RT.block_on( @@ -496,6 +465,8 @@ fn convert_reader(reader: &PyAny) -> PyResult PyResult<()> { m.add_wrapped(wrap_pyfunction!(json_to_schema))?; m.add_wrapped(wrap_pyfunction!(infer_tfrecord_schema))?; m.add_wrapped(wrap_pyfunction!(read_tfrecord))?; - m.add_wrapped(wrap_pyfunction!(cleanup_partial_writes))?; m.add_wrapped(wrap_pyfunction!(trace_to_chrome))?; m.add_wrapped(wrap_pyfunction!(manifest_needs_migration))?; // Debug functions diff --git a/python/src/utils.rs b/python/src/utils.rs index d8b98402e8..2e18890bfd 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -96,8 +96,8 @@ impl KMeans { } /// Train the model - fn fit(&mut self, _py: Python, arr: &PyAny) -> PyResult<()> { - let data = ArrayData::from_pyarrow(arr)?; + fn fit(&mut self, _py: Python, arr: &Bound) -> PyResult<()> { + let data = ArrayData::from_pyarrow_bound(arr)?; if !matches!(data.data_type(), DataType::FixedSizeList(_, _)) { return Err(PyValueError::new_err("Must be a FixedSizeList")); } @@ -113,11 +113,11 @@ impl KMeans { Ok(()) } - fn predict(&self, py: Python, array: &PyAny) -> PyResult { + fn predict(&self, py: Python, array: &Bound) -> PyResult { let Some(kmeans) = self.trained_kmeans.as_ref() else { return Err(PyRuntimeError::new_err("KMeans must fit (train) first")); }; - let data = ArrayData::from_pyarrow(array)?; + let data = ArrayData::from_pyarrow_bound(array)?; if !matches!(data.data_type(), DataType::FixedSizeList(_, _)) { return Err(PyValueError::new_err("Must be a FixedSizeList")); } @@ -178,7 +178,7 @@ impl Hnsw { distance_type="l2", ))] fn build( - vectors_array: &PyIterator, + vectors_array: &Bound, max_level: u16, m: usize, ef_construction: usize, @@ -191,7 +191,7 @@ impl Hnsw { let mut data: Vec> = Vec::new(); for vectors in vectors_array { - let vectors = ArrayData::from_pyarrow(vectors?)?; + let vectors = ArrayData::from_pyarrow_bound(&vectors?)?; if !matches!(vectors.data_type(), DataType::FixedSizeList(_, _)) { return Err(PyValueError::new_err("Must be a FixedSizeList")); } diff --git a/rust/lance/src/utils/tfrecord.rs b/rust/lance/src/utils/tfrecord.rs index 878e2da9fc..a076d72813 100644 --- a/rust/lance/src/utils/tfrecord.rs +++ b/rust/lance/src/utils/tfrecord.rs @@ -9,6 +9,7 @@ use arrow::buffer::OffsetBuffer; use arrow_array::builder::PrimitiveBuilder; use arrow_array::{ArrayRef, FixedSizeListArray, ListArray}; +use arrow_buffer::ArrowNativeType; use arrow_buffer::ScalarBuffer; use datafusion::error::DataFusionError; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; @@ -764,13 +765,13 @@ fn append_primitive_from_slice( // TensorProto to tell us the original endianness, so it's possible there // could be a mismatch here. let (prefix, middle, suffix) = unsafe { slice.align_to::() }; - for val in prefix.chunks_exact(T::get_byte_width()) { + for val in prefix.chunks_exact(T::Native::get_byte_width()) { builder.append_value(parse_val(val)); } builder.append_slice(middle); - for val in suffix.chunks_exact(T::get_byte_width()) { + for val in suffix.chunks_exact(T::Native::get_byte_width()) { builder.append_value(parse_val(val)); } } From 66dec81ea67917198b454bb4c1c26678d55e023e Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 12 Jul 2024 20:12:12 -0700 Subject: [PATCH 4/8] fix rebase errors --- rust/lance-core/src/utils/mask.rs | 4 ++-- rust/lance-index/src/scalar.rs | 7 ++++--- rust/lance-index/src/scalar/expression.rs | 10 +++++----- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index 764e949e10..e2cb8af901 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -678,8 +678,8 @@ impl<'a> Extend<&'a u64> for RowIdTreeMap { } // Extending with RowIdTreeMap is basically a cumulative set union -impl Extend for RowIdTreeMap { - fn extend>(&mut self, iter: T) { +impl Extend for RowIdTreeMap { + fn extend>(&mut self, iter: T) { for other in iter { for (fragment, set) in other.inner { match self.inner.get_mut(&fragment) { diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 61b5ee5165..30ff7130c0 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -10,11 +10,12 @@ use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow_array::{ListArray, RecordBatch}; use arrow_schema::{Field, Schema}; use async_trait::async_trait; +use datafusion::functions_array::array_has; use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_common::{scalar::ScalarValue, Column}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{Expr, ScalarFunctionDefinition}; +use datafusion_expr::Expr; use deepsize::DeepSizeOf; use lance_core::utils::mask::RowIdTreeMap; use lance_core::Result; @@ -257,7 +258,7 @@ impl AnyQuery for LabelListQuery { .unwrap(); let labels_arr = Arc::new(labels_list); Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::Name("array_contains_all".into()), + func: Arc::new(array_has::ArrayHasAll::new().into()), args: vec![ Expr::Column(Column::new_unqualified(col)), Expr::Literal(ScalarValue::List(labels_arr)), @@ -277,7 +278,7 @@ impl AnyQuery for LabelListQuery { .unwrap(); let labels_arr = Arc::new(labels_list); Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::Name("array_contains_any".into()), + func: Arc::new(array_has::ArrayHasAny::new().into()), args: vec![ Expr::Column(Column::new_unqualified(col)), Expr::Literal(ScalarValue::List(labels_arr)), diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index 2393a38c5c..cf2b1c662c 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -10,7 +10,7 @@ use async_trait::async_trait; use datafusion_common::ScalarValue; use datafusion_expr::{ expr::{InList, ScalarFunction}, - Between, BinaryExpr, Expr, Operator, ScalarFunctionDefinition, + Between, BinaryExpr, Expr, Operator, ScalarUDF, }; use futures::join; @@ -75,7 +75,7 @@ pub trait ScalarQueryParser: std::fmt::Debug + Send + Sync { &self, column: &str, data_type: &DataType, - func: &ScalarFunctionDefinition, + func: &ScalarUDF, args: &[Expr], ) -> Option; } @@ -146,7 +146,7 @@ impl ScalarQueryParser for SargableQueryParser { &self, _: &str, _: &DataType, - _: &ScalarFunctionDefinition, + _: &ScalarUDF, _: &[Expr], ) -> Option { None @@ -181,7 +181,7 @@ impl ScalarQueryParser for LabelListQueryParser { &self, column: &str, data_type: &DataType, - func: &ScalarFunctionDefinition, + func: &ScalarUDF, args: &[Expr], ) -> Option { if args.len() != 2 { @@ -628,7 +628,7 @@ fn visit_scalar_fn( return None; } let (col, data_type, query_parser) = maybe_indexed_column(&scalar_fn.args[0], index_info)?; - query_parser.visit_scalar_function(col, data_type, &scalar_fn.func_def, &scalar_fn.args) + query_parser.visit_scalar_function(col, data_type, &scalar_fn.func, &scalar_fn.args) } fn visit_node(expr: &Expr, index_info: &dyn IndexInformationProvider) -> Option { From 8d7b61dd87076007163c4b87cf898a212d6caa99 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 12 Jul 2024 20:20:48 -0700 Subject: [PATCH 5/8] version policy --- Cargo.toml | 30 +++++++++++++++--------------- python/Cargo.toml | 2 +- python/src/utils.rs | 4 ++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a2b54297b6..c2bf634f44 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ exclude = ["python"] resolver = "2" [workspace.package] -version = "0.14.2" +version = "0.15.0" edition = "2021" authors = ["Lance Devs "] license = "Apache-2.0" @@ -43,20 +43,20 @@ categories = [ rust-version = "1.78" [workspace.dependencies] -lance = { version = "=0.14.2", path = "./rust/lance" } -lance-arrow = { version = "=0.14.2", path = "./rust/lance-arrow" } -lance-core = { version = "=0.14.2", path = "./rust/lance-core" } -lance-datafusion = { version = "=0.14.2", path = "./rust/lance-datafusion" } -lance-datagen = { version = "=0.14.2", path = "./rust/lance-datagen" } -lance-encoding = { version = "=0.14.2", path = "./rust/lance-encoding" } -lance-encoding-datafusion = { version = "=0.14.2", path = "./rust/lance-encoding-datafusion" } -lance-file = { version = "=0.14.2", path = "./rust/lance-file" } -lance-index = { version = "=0.14.2", path = "./rust/lance-index" } -lance-io = { version = "=0.14.2", path = "./rust/lance-io" } -lance-linalg = { version = "=0.14.2", path = "./rust/lance-linalg" } -lance-table = { version = "=0.14.2", path = "./rust/lance-table" } -lance-test-macros = { version = "=0.14.2", path = "./rust/lance-test-macros" } -lance-testing = { version = "=0.14.2", path = "./rust/lance-testing" } +lance = { version = "=0.15.0", path = "./rust/lance" } +lance-arrow = { version = "=0.15.0", path = "./rust/lance-arrow" } +lance-core = { version = "=0.15.0", path = "./rust/lance-core" } +lance-datafusion = { version = "=0.15.0", path = "./rust/lance-datafusion" } +lance-datagen = { version = "=0.15.0", path = "./rust/lance-datagen" } +lance-encoding = { version = "=0.15.0", path = "./rust/lance-encoding" } +lance-encoding-datafusion = { version = "=0.15.0", path = "./rust/lance-encoding-datafusion" } +lance-file = { version = "=0.15.0", path = "./rust/lance-file" } +lance-index = { version = "=0.15.0", path = "./rust/lance-index" } +lance-io = { version = "=0.15.0", path = "./rust/lance-io" } +lance-linalg = { version = "=0.15.0", path = "./rust/lance-linalg" } +lance-table = { version = "=0.15.0", path = "./rust/lance-table" } +lance-test-macros = { version = "=0.15.0", path = "./rust/lance-test-macros" } +lance-testing = { version = "=0.15.0", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow arrow = { version = "52.0", optional = false, features = ["prettyprint"] } diff --git a/python/Cargo.toml b/python/Cargo.toml index 81050c0b0c..1bbf34c39c 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylance" -version = "0.14.2" +version = "0.15.0" edition = "2021" authors = ["Lance Devs "] rust-version = "1.65" diff --git a/python/src/utils.rs b/python/src/utils.rs index 2e18890bfd..9b8420e781 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -63,10 +63,10 @@ impl KMeans { k: usize, metric_type: &str, max_iters: u32, - centroids_arr: Option<&PyAny>, + centroids_arr: Option<&Bound>, ) -> PyResult { let trained_kmeans = if let Some(arr) = centroids_arr { - let data = ArrayData::from_pyarrow(arr)?; + let data = ArrayData::from_pyarrow_bound(arr)?; if !matches!(data.data_type(), DataType::FixedSizeList(_, _)) { return Err(PyValueError::new_err("Must be a FixedSizeList")); } From 1770d83b74a467dd68c378bc83f4df52e1d7aeac Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 12 Jul 2024 20:59:05 -0700 Subject: [PATCH 6/8] debug ci job --- ci/check_versions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ci/check_versions.py b/ci/check_versions.py index d194ca9b98..a16246d3b8 100644 --- a/ci/check_versions.py +++ b/ci/check_versions.py @@ -37,10 +37,12 @@ def parse_version(version: str) -> tuple[int, int, int]: if __name__ == "__main__": new_version = parse_version(get_versions()) + print(f"New version: {new_version}") repo = Github().get_repo(os.environ["GITHUB_REPOSITORY"]) latest_release = repo.get_latest_release() last_version = parse_version(latest_release.tag_name[1:]) + print(f"Last version: {last_version}") # Check for a breaking-change label in the PRs between the last release and the current commit. commits = repo.compare(latest_release.tag_name, os.environ["GITHUB_SHA"]).commits From 0f1dded0dfba8a5562860e5aea565301162eeabc Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 15 Jul 2024 10:40:24 -0700 Subject: [PATCH 7/8] upgrade to 40 --- Cargo.toml | 38 ++++++------- python/Cargo.toml | 8 +-- python/python/lance/cleanup.py | 3 -- rust/lance-datafusion/Cargo.toml | 2 +- rust/lance-datafusion/src/exec.rs | 4 ++ rust/lance-datagen/src/generator.rs | 71 +++++++++++++++++++++---- rust/lance-index/src/scalar/btree.rs | 48 +++++++++-------- rust/lance/src/dataset/scanner.rs | 12 +++-- rust/lance/src/io/exec/knn.rs | 12 +++++ rust/lance/src/io/exec/planner.rs | 1 + rust/lance/src/io/exec/projection.rs | 4 ++ rust/lance/src/io/exec/pushdown_scan.rs | 14 ++--- rust/lance/src/io/exec/scalar_index.rs | 12 +++++ rust/lance/src/io/exec/scan.rs | 4 ++ rust/lance/src/io/exec/take.rs | 4 ++ rust/lance/src/io/exec/testing.rs | 4 ++ rust/lance/src/io/exec/utils.rs | 4 ++ 17 files changed, 174 insertions(+), 71 deletions(-) delete mode 100644 python/python/lance/cleanup.py diff --git a/Cargo.toml b/Cargo.toml index c2bf634f44..b9cbce6748 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,17 +59,17 @@ lance-test-macros = { version = "=0.15.0", path = "./rust/lance-test-macros" } lance-testing = { version = "=0.15.0", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow -arrow = { version = "52.0", optional = false, features = ["prettyprint"] } -arrow-arith = "52.0" -arrow-array = "52.0" -arrow-buffer = "52.0" -arrow-cast = "52.0" -arrow-data = "52.0" -arrow-ipc = { version = "52.0", features = ["zstd"] } -arrow-ord = "52.0" -arrow-row = "52.0" -arrow-schema = "52.0" -arrow-select = "52.0" +arrow = { version = "52.1", optional = false, features = ["prettyprint"] } +arrow-arith = "52.1" +arrow-array = "52.1" +arrow-buffer = "52.1" +arrow-cast = "52.1" +arrow-data = "52.1" +arrow-ipc = { version = "52.1", features = ["zstd"] } +arrow-ord = "52.1" +arrow-row = "52.1" +arrow-schema = "52.1" +arrow-select = "52.1" async-recursion = "1.0" async-trait = "0.1" aws-config = "0.57" @@ -93,17 +93,17 @@ criterion = { version = "0.5", features = [ "html_reports", ] } crossbeam-queue = "0.3" -datafusion = { version = "39.0", default-features = false, features = [ +datafusion = { version = "40.0", default-features = false, features = [ "array_expressions", "regex_expressions", ] } -datafusion-common = "39.0" -datafusion-functions = { version = "39.0", features = ["regex_expressions"] } -datafusion-sql = "39.0" -datafusion-expr = "39.0" -datafusion-execution = "39.0" -datafusion-optimizer = "39.0" -datafusion-physical-expr = { version = "39.0", features = [ +datafusion-common = "40.0" +datafusion-functions = { version = "40.0", features = ["regex_expressions"] } +datafusion-sql = "40.0" +datafusion-expr = "40.0" +datafusion-execution = "40.0" +datafusion-optimizer = "40.0" +datafusion-physical-expr = { version = "40.0", features = [ "regex_expressions", ] } deepsize = "0.2.0" diff --git a/python/Cargo.toml b/python/Cargo.toml index 1bbf34c39c..6ec415a978 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -12,10 +12,10 @@ name = "lance" crate-type = ["cdylib"] [dependencies] -arrow = { version = "52.0", features = ["pyarrow"] } -arrow-array = "52.0" -arrow-data = "52.0" -arrow-schema = "52.0" +arrow = { version = "52.1", features = ["pyarrow"] } +arrow-array = "52.1" +arrow-data = "52.1" +arrow-schema = "52.1" object_store = "0.10.1" async-trait = "0.1" chrono = "0.4.31" diff --git a/python/python/lance/cleanup.py b/python/python/lance/cleanup.py deleted file mode 100644 index addbf06e4b..0000000000 --- a/python/python/lance/cleanup.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright The Lance Authors - diff --git a/rust/lance-datafusion/Cargo.toml b/rust/lance-datafusion/Cargo.toml index 9bf6475c04..059698810f 100644 --- a/rust/lance-datafusion/Cargo.toml +++ b/rust/lance-datafusion/Cargo.toml @@ -19,7 +19,7 @@ datafusion.workspace = true datafusion-common.workspace = true datafusion-functions.workspace = true datafusion-physical-expr.workspace = true -datafusion-substrait = { version = "39.0", optional = true } +datafusion-substrait = { version = "40.0", optional = true } futures.workspace = true lance-arrow.workspace = true lance-core = { workspace = true, features = ["datafusion"] } diff --git a/rust/lance-datafusion/src/exec.rs b/rust/lance-datafusion/src/exec.rs index 6baacbba1e..bc22ac6937 100644 --- a/rust/lance-datafusion/src/exec.rs +++ b/rust/lance-datafusion/src/exec.rs @@ -97,6 +97,10 @@ impl DisplayAs for OneShotExec { } impl ExecutionPlan for OneShotExec { + fn name(&self) -> &str { + "OneShotExec" + } + fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/rust/lance-datagen/src/generator.rs b/rust/lance-datagen/src/generator.rs index fcb3ba62b2..e9dc81d32c 100644 --- a/rust/lance-datagen/src/generator.rs +++ b/rust/lance-datagen/src/generator.rs @@ -6,7 +6,7 @@ use std::{iter, marker::PhantomData, sync::Arc}; use arrow::{ array::{ArrayData, AsArray}, buffer::{BooleanBuffer, Buffer, OffsetBuffer, ScalarBuffer}, - datatypes::{ArrowPrimitiveType, Int32Type}, + datatypes::{ArrowPrimitiveType, Int32Type, IntervalDayTime, IntervalMonthDayNano}, }; use arrow_array::{ make_array, @@ -14,7 +14,7 @@ use arrow_array::{ Array, FixedSizeBinaryArray, FixedSizeListArray, ListArray, PrimitiveArray, RecordBatch, RecordBatchOptions, RecordBatchReader, StringArray, StructArray, }; -use arrow_schema::{ArrowError, DataType, Field, Fields, Schema, SchemaRef}; +use arrow_schema::{ArrowError, DataType, Field, Fields, IntervalUnit, Schema, SchemaRef}; use futures::{stream::BoxStream, StreamExt}; use rand::{distributions::Uniform, Rng, RngCore, SeedableRng}; @@ -596,6 +596,58 @@ impl ArrayGenerator for RandomFixedSizeBinaryGenerator { } } +pub struct RandomIntervalGenerator { + unit: IntervalUnit, + data_type: DataType, +} + +impl RandomIntervalGenerator { + pub fn new(unit: IntervalUnit) -> Self { + Self { + unit, + data_type: DataType::Interval(unit), + } + } +} + +impl ArrayGenerator for RandomIntervalGenerator { + fn generate( + &mut self, + length: RowCount, + rng: &mut rand_xoshiro::Xoshiro256PlusPlus, + ) -> Result, ArrowError> { + match self.unit { + IntervalUnit::YearMonth => { + let months = (0..length.0).map(|_| rng.gen::()).collect::>(); + Ok(Arc::new(arrow_array::IntervalYearMonthArray::from(months))) + } + IntervalUnit::MonthDayNano => { + let day_time_array = (0..length.0) + .map(|_| IntervalMonthDayNano::new(rng.gen(), rng.gen(), rng.gen())) + .collect::>(); + Ok(Arc::new(arrow_array::IntervalMonthDayNanoArray::from( + day_time_array, + ))) + } + IntervalUnit::DayTime => { + let day_time_array = (0..length.0) + .map(|_| IntervalDayTime::new(rng.gen(), rng.gen())) + .collect::>(); + Ok(Arc::new(arrow_array::IntervalDayTimeArray::from( + day_time_array, + ))) + } + } + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn element_size_bytes(&self) -> Option { + Some(ByteCount::from(12)) + } +} pub struct RandomBinaryGenerator { bytes_per_element: ByteCount, scale_to_utf8: bool, @@ -1175,7 +1227,7 @@ pub mod array { use arrow_array::types::{ Decimal128Type, Decimal256Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, DurationSecondType, Float16Type, Float32Type, Float64Type, - IntervalYearMonthType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use arrow_array::{ ArrowNativeTypeOp, Date32Array, Date64Array, Time32MillisecondArray, Time32SecondArray, @@ -1459,6 +1511,10 @@ pub mod array { Box::new(RandomFixedSizeBinaryGenerator::new(size)) } + pub fn rand_interval(unit: IntervalUnit) -> Box { + Box::new(RandomIntervalGenerator::new(unit)) + } + /// Create a generator of randomly sampled date32 values /// /// Instead of sampling the entire range, all values will be drawn from the last year as this @@ -1661,14 +1717,7 @@ pub mod array { TimeUnit::Microsecond => rand::(), TimeUnit::Nanosecond => rand::(), }, - DataType::Interval(unit) => match unit { - // TODO: fix these. In Arrow they changed to have specialized - // Native types, which don't support Distribution. - // IntervalUnit::DayTime => rand::(), - // IntervalUnit::MonthDayNano => rand::(), - IntervalUnit::DayTime | IntervalUnit::MonthDayNano => todo!(), - IntervalUnit::YearMonth => rand::(), - }, + DataType::Interval(unit) => rand_interval(*unit), DataType::Date32 => rand_date32(), DataType::Date64 => rand_date64(), DataType::Time32(resolution) => rand_time32(resolution), diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 72ffa69aee..c2d5e4dd91 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -246,33 +246,33 @@ impl Ord for OrderableScalarValue { } } (UInt64(_), _) => panic!("Attempt to compare Int16 with non-UInt64"), - (Utf8(v1), Utf8(v2)) => v1.cmp(v2), - (Utf8(v1), Null) => { - if v1.is_none() { - Ordering::Equal - } else { - Ordering::Greater - } + (Utf8(v1) | Utf8View(v1) | LargeUtf8(v1), Utf8(v2) | Utf8View(v2) | LargeUtf8(v2)) => { + v1.cmp(v2) } - (Utf8(_), _) => panic!("Attempt to compare Utf8 with non-Utf8"), - (LargeUtf8(v1), LargeUtf8(v2)) => v1.cmp(v2), - (LargeUtf8(v1), Null) => { + (Utf8(v1) | Utf8View(v1) | LargeUtf8(v1), Null) => { if v1.is_none() { Ordering::Equal } else { Ordering::Greater } } - (LargeUtf8(_), _) => panic!("Attempt to compare LargeUtf8 with non-LargeUtf8"), - (Binary(v1), Binary(v2)) => v1.cmp(v2), - (Binary(v1), Null) => { + (Utf8(_) | Utf8View(_) | LargeUtf8(_), _) => { + panic!("Attempt to compare Utf8 with non-Utf8") + } + ( + Binary(v1) | LargeBinary(v1) | BinaryView(v1), + Binary(v2) | LargeBinary(v2) | BinaryView(v2), + ) => v1.cmp(v2), + (Binary(v1) | LargeBinary(v1) | BinaryView(v1), Null) => { if v1.is_none() { Ordering::Equal } else { Ordering::Greater } } - (Binary(_), _) => panic!("Attempt to compare Binary with non-Binary"), + (Binary(_) | LargeBinary(_) | BinaryView(_), _) => { + panic!("Attempt to compare Binary with non-Binary") + } (FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.cmp(v2), (FixedSizeBinary(_, v1), Null) => { if v1.is_none() { @@ -284,15 +284,6 @@ impl Ord for OrderableScalarValue { (FixedSizeBinary(_, _), _) => { panic!("Attempt to compare FixedSizeBinary with non-FixedSizeBinary") } - (LargeBinary(v1), LargeBinary(v2)) => v1.cmp(v2), - (LargeBinary(v1), Null) => { - if v1.is_none() { - Ordering::Equal - } else { - Ordering::Greater - } - } - (LargeBinary(_), _) => panic!("Attempt to compare LargeBinary with non-LargeBinary"), (FixedSizeList(left), FixedSizeList(right)) => { if left.eq(right) { todo!() @@ -324,6 +315,17 @@ impl Ord for OrderableScalarValue { panic!("Attempt to compare List with non-List") } (LargeList(_), _) => todo!(), + (Map(_), Map(_)) => todo!(), + (Map(left), Null) => { + if left.is_null(0) { + Ordering::Equal + } else { + Ordering::Greater + } + } + (Map(_), _) => { + panic!("Attempt to compare Map with non-Map") + } (Date32(v1), Date32(v2)) => v1.cmp(v2), (Date32(v1), Null) => { if v1.is_none() { diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 0decaec57d..dfde581c09 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -11,7 +11,8 @@ use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema, SchemaR use arrow_select::concat::concat_batches; use async_recursion::async_recursion; use datafusion::common::DFSchema; -use datafusion::logical_expr::{AggregateFunction, Expr}; +use datafusion::functions_aggregate::count::count_udaf; +use datafusion::logical_expr::{lit, Expr}; use datafusion::physical_expr::PhysicalSortExpr; use datafusion::physical_plan::expressions; use datafusion::physical_plan::projection::ProjectionExec as DFProjectionExec; @@ -19,10 +20,11 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{ aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, display::DisplayableExecutionPlan, - expressions::{create_aggregate_expr, Literal}, + expressions::Literal, filter::FilterExec, limit::GlobalLimitExec, repartition::RepartitionExec, + udaf::create_aggregate_expr, union::UnionExec, ExecutionPlan, SendableRecordBatchStream, }; @@ -738,13 +740,15 @@ impl Scanner { // Datafusion interprets COUNT(*) as COUNT(1) let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1)))); let count_expr = create_aggregate_expr( - &AggregateFunction::Count, - false, + &count_udaf(), &[one], + &[lit(1)], + &[], &[], &plan.schema(), "", false, + false, )?; let plan_schema = plan.schema().clone(); let count_plan = Arc::new(AggregateExec::try_new( diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 09d151d485..69bd4e43dc 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -139,6 +139,10 @@ impl KNNVectorDistanceExec { } impl ExecutionPlan for KNNVectorDistanceExec { + fn name(&self) -> &str { + "KNNVectorDistanceExec" + } + fn as_any(&self) -> &dyn Any { self } @@ -385,6 +389,10 @@ impl DisplayAs for ANNIvfPartitionExec { } impl ExecutionPlan for ANNIvfPartitionExec { + fn name(&self) -> &str { + "ANNIVFPartitionExec" + } + fn as_any(&self) -> &dyn Any { self } @@ -541,6 +549,10 @@ impl DisplayAs for ANNIvfSubIndexExec { } impl ExecutionPlan for ANNIvfSubIndexExec { + fn name(&self) -> &str { + "ANNSubIndexExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/rust/lance/src/io/exec/planner.rs b/rust/lance/src/io/exec/planner.rs index 84c22f2e74..0e20018183 100644 --- a/rust/lance/src/io/exec/planner.rs +++ b/rust/lance/src/io/exec/planner.rs @@ -410,6 +410,7 @@ impl Planner { ParserOptions { parse_float_as_decimal: false, enable_ident_normalization: false, + support_varchar_with_length: false, }, ); let mut planner_context = PlannerContext::default(); diff --git a/rust/lance/src/io/exec/projection.rs b/rust/lance/src/io/exec/projection.rs index a666e03eae..177d07ca4b 100644 --- a/rust/lance/src/io/exec/projection.rs +++ b/rust/lance/src/io/exec/projection.rs @@ -157,6 +157,10 @@ impl ProjectionExec { } impl ExecutionPlan for ProjectionExec { + fn name(&self) -> &str { + "ProjectionExec" + } + fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/rust/lance/src/io/exec/pushdown_scan.rs b/rust/lance/src/io/exec/pushdown_scan.rs index f354f5dbb9..989ebbb263 100644 --- a/rust/lance/src/io/exec/pushdown_scan.rs +++ b/rust/lance/src/io/exec/pushdown_scan.rs @@ -103,10 +103,9 @@ impl LancePushdownScanExec { ) -> Result { // This should be infallible. let columns: Vec<_> = predicate - .to_columns() - .unwrap() + .column_refs() .into_iter() - .map(|col| col.name) + .map(|col| col.name.as_str()) .collect(); let dataset_schema = dataset.schema(); let predicate_projection = Arc::new(dataset_schema.project(&columns) @@ -147,6 +146,10 @@ impl LancePushdownScanExec { } impl ExecutionPlan for LancePushdownScanExec { + fn name(&self) -> &str { + "LancePushdownScanExec" + } + fn as_any(&self) -> &dyn Any { self } @@ -362,10 +365,9 @@ impl FragmentScanner { // 1. Load needed filter columns, which might be a subset of all filter // columns if statistics obviated the need for some columns. let columns: Vec<_> = predicate - .to_columns() - .unwrap() + .column_refs() .into_iter() - .map(|col| col.name) + .map(|col| col.name.as_str()) .collect(); let predicate_projection = Arc::new(self.fragment.dataset().schema().project(&columns).unwrap()); diff --git a/rust/lance/src/io/exec/scalar_index.rs b/rust/lance/src/io/exec/scalar_index.rs index 78e852dccd..10e657f5b5 100644 --- a/rust/lance/src/io/exec/scalar_index.rs +++ b/rust/lance/src/io/exec/scalar_index.rs @@ -108,6 +108,10 @@ impl ScalarIndexExec { } impl ExecutionPlan for ScalarIndexExec { + fn name(&self) -> &str { + "ScalarIndexExec" + } + fn as_any(&self) -> &dyn std::any::Any { self } @@ -268,6 +272,10 @@ impl MapIndexExec { } impl ExecutionPlan for MapIndexExec { + fn name(&self) -> &str { + "MapIndexExec" + } + fn as_any(&self) -> &dyn std::any::Any { self } @@ -537,6 +545,10 @@ async fn retain_fragments( } impl ExecutionPlan for MaterializeIndexExec { + fn name(&self) -> &str { + "MaterializeIndexExec" + } + fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/rust/lance/src/io/exec/scan.rs b/rust/lance/src/io/exec/scan.rs index 67b636ecd7..fe5586ae90 100644 --- a/rust/lance/src/io/exec/scan.rs +++ b/rust/lance/src/io/exec/scan.rs @@ -284,6 +284,10 @@ impl LanceScanExec { } impl ExecutionPlan for LanceScanExec { + fn name(&self) -> &str { + "LanceScanExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/rust/lance/src/io/exec/take.rs b/rust/lance/src/io/exec/take.rs index fea67ddce4..e72988b6e4 100644 --- a/rust/lance/src/io/exec/take.rs +++ b/rust/lance/src/io/exec/take.rs @@ -249,6 +249,10 @@ impl TakeExec { } impl ExecutionPlan for TakeExec { + fn name(&self) -> &str { + "TakeExec" + } + fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/rust/lance/src/io/exec/testing.rs b/rust/lance/src/io/exec/testing.rs index 25529cce30..6135ef1d14 100644 --- a/rust/lance/src/io/exec/testing.rs +++ b/rust/lance/src/io/exec/testing.rs @@ -47,6 +47,10 @@ impl DisplayAs for TestingExec { } impl ExecutionPlan for TestingExec { + fn name(&self) -> &str { + "TestingExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/rust/lance/src/io/exec/utils.rs b/rust/lance/src/io/exec/utils.rs index 171cc99f52..c1be2ab54c 100644 --- a/rust/lance/src/io/exec/utils.rs +++ b/rust/lance/src/io/exec/utils.rs @@ -123,6 +123,10 @@ impl RecordBatchStream for ShareableRecordBatchStreamAdapter { } impl ExecutionPlan for ReplayExec { + fn name(&self) -> &str { + "ReplayExec" + } + fn as_any(&self) -> &dyn std::any::Any { self } From 84967b2d7435d398de042a8bd2fa10b21ceae2e7 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 15 Jul 2024 12:46:10 -0700 Subject: [PATCH 8/8] fix planner bug --- Cargo.toml | 1 + rust/lance/src/io/exec/planner.rs | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index b9cbce6748..9aab5d3950 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -96,6 +96,7 @@ crossbeam-queue = "0.3" datafusion = { version = "40.0", default-features = false, features = [ "array_expressions", "regex_expressions", + "unicode_expressions", ] } datafusion-common = "40.0" datafusion-functions = { version = "40.0", features = ["regex_expressions"] } diff --git a/rust/lance/src/io/exec/planner.rs b/rust/lance/src/io/exec/planner.rs index 0e20018183..777d46e42d 100644 --- a/rust/lance/src/io/exec/planner.rs +++ b/rust/lance/src/io/exec/planner.rs @@ -18,6 +18,7 @@ use datafusion::error::Result as DFResult; use datafusion::execution::config::SessionConfig; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::expr::ScalarFunction; use datafusion::logical_expr::{ AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowUDF, @@ -405,7 +406,7 @@ impl Planner { } } let context_provider = LanceContextProvider::default(); - let sql_to_rel = SqlToRel::new_with_options( + let mut sql_to_rel = SqlToRel::new_with_options( &context_provider, ParserOptions { parse_float_as_decimal: false, @@ -413,6 +414,12 @@ impl Planner { support_varchar_with_length: false, }, ); + // These planners are not automatically propagated. + // See: https://github.com/apache/datafusion/issues/11477 + for planner in context_provider.state.expr_planners() { + sql_to_rel = sql_to_rel.with_user_defined_planner(planner.clone()); + } + let mut planner_context = PlannerContext::default(); let schema = DFSchema::try_from(self.schema.as_ref().clone())?; Ok(sql_to_rel.sql_to_expr(function, &schema, &mut planner_context)?)