diff --git a/Cargo.toml b/Cargo.toml index 0df68fc2cd3..b491f71f320 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,9 +26,13 @@ hash_hasher = "^2.0.3" # For SIMD utf8 validation simdutf8 = "0.1.3" +# for csv io csv = { version = "^1.1", optional = true } + +# for csv async io +csv-async = { version = "^1.1", optional = true } + regex = { version = "^1.3", optional = true } -lazy_static = { version = "^1.4", optional = true } streaming-iterator = { version = "0.1", optional = true } serde = { version = "^1.0", features = ["rc"], optional = true } @@ -78,6 +82,9 @@ criterion = "0.3" flate2 = "1" doc-comment = "0.3" crossbeam-channel = "0.5.1" +# used to test async readers +tokio = { version = "1", features = ["macros", "rt", "fs"] } +tokio-util = { version = "0.6", features = ["compat"] } # used to run formal property testing proptest = { version = "1", default_features = false, features = ["std"] } @@ -89,6 +96,7 @@ rustdoc-args = ["--cfg", "docsrs"] default = [] full = [ "io_csv", + "io_csv_async", "io_json", "io_ipc", "io_flight", @@ -106,7 +114,9 @@ full = [ ] merge_sort = ["itertools"] io_csv = ["io_csv_read", "io_csv_write"] +io_csv_async = ["io_csv_read_async"] io_csv_read = ["csv", "lexical-core"] +io_csv_read_async = ["csv-async", "lexical-core", "futures"] io_csv_write = ["csv", "streaming-iterator", "lexical-core"] io_json = ["serde", "serde_json", "indexmap"] io_ipc = ["arrow-format"] @@ -146,6 +156,8 @@ skip_feature_sets = [ ["io_csv"], ["io_csv_read"], ["io_csv_write"], + ["io_csv_async"], + ["io_csv_read_async"], ["io_avro"], ["io_json"], ["io_flight"], diff --git a/examples/csv_read_async.rs b/examples/csv_read_async.rs new file mode 100644 index 00000000000..fa0e9481fc5 --- /dev/null +++ b/examples/csv_read_async.rs @@ -0,0 +1,36 @@ +use std::sync::Arc; + +use futures::io::Cursor; +use tokio::fs::File; +use tokio_util::compat::*; + +use arrow2::array::*; +use arrow2::error::Result; +use arrow2::io::csv::read_async::*; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<()> { + use std::env; + let args: Vec = env::args().collect(); + + let file_path = &args[1]; + + let file = File::open(file_path).await?.compat(); + + let mut reader = AsyncReaderBuilder::new().create_reader(file); + + let schema = Arc::new(infer_schema(&mut reader, None, true, &infer).await?); + + let mut rows = vec![ByteRecord::default(); 100]; + let rows_read = read_rows(&mut reader, 0, &mut rows).await?; + + let batch = deserialize_batch( + &rows[..rows_read], + schema.fields(), + None, + 0, + deserialize_column, + )?; + println!("{}", batch.column(0)); + Ok(()) +} diff --git a/guide/src/io/csv_reader.md b/guide/src/io/csv_reader.md index b38bc9cb19f..6b9d3161209 100644 --- a/guide/src/io/csv_reader.md +++ b/guide/src/io/csv_reader.md @@ -30,6 +30,19 @@ thereby maximizing IO throughput. The example below shows how to do just that: {{#include ../../../examples/csv_read_parallel.rs}} ``` +## Async + +This crate also supports reading from a CSV asyncronously through the `csv-async` crate. +The example below demonstrates this: + +```rust +{{#include ../../../examples/csv_read_async.rs}} +``` + +Note that the deserialization _should_ be performed on a separate thread to not +block (see also [here](https://ryhl.io/blog/async-what-is-blocking/)), which this +example does not show. + ## Customization In the code above, `parser` and `infer` allow for customization: they declare diff --git a/src/io/csv/mod.rs b/src/io/csv/mod.rs index f06e6ffa85e..00b14185051 100644 --- a/src/io/csv/mod.rs +++ b/src/io/csv/mod.rs @@ -1,12 +1,16 @@ #![deny(missing_docs)] -//! Transfer data between the Arrow memory format and CSV (comma-separated values). +//! Convert data between the Arrow and CSV (comma-separated values). use crate::error::ArrowError; -pub use csv::Error as CSVError; +#[cfg(any(feature = "io_csv_read_async", feature = "io_csv_read"))] +mod read_utils; +#[cfg(any(feature = "io_csv_read_async", feature = "io_csv_read"))] +mod utils; -impl From for ArrowError { - fn from(error: CSVError) -> Self { +#[cfg(any(feature = "io_csv_read", feature = "io_csv_write"))] +impl From for ArrowError { + fn from(error: csv::Error) -> Self { ArrowError::External("".to_string(), Box::new(error)) } } @@ -23,3 +27,7 @@ pub mod read; #[cfg(feature = "io_csv_write")] #[cfg_attr(docsrs, doc(cfg(feature = "io_csv_write")))] pub mod write; + +#[cfg(feature = "io_csv_read_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_csv_read_async")))] +pub mod read_async; diff --git a/src/io/csv/read/deserialize.rs b/src/io/csv/read/deserialize.rs index 7a6bcaf2bb9..37a8a45f813 100644 --- a/src/io/csv/read/deserialize.rs +++ b/src/io/csv/read/deserialize.rs @@ -1,83 +1,22 @@ use std::sync::Arc; -use chrono::Datelike; use csv::ByteRecord; use crate::{ - array::*, - datatypes::*, - error::{ArrowError, Result}, + array::Array, + datatypes::{DataType, Field}, + error::Result, record_batch::RecordBatch, - temporal_conversions, - types::{NativeType, NaturalDataType}, }; -use super::infer_schema::RFC3339; - -fn deserialize_primitive( - rows: &[ByteRecord], - column: usize, - datatype: DataType, - op: F, -) -> Arc -where - T: NativeType + NaturalDataType + lexical_core::FromLexical, - F: Fn(&[u8]) -> Option, -{ - let iter = rows.iter().map(|row| match row.get(column) { - Some(bytes) => { - if bytes.is_empty() { - return None; - } - op(bytes) - } - None => None, - }); - Arc::new(PrimitiveArray::::from_trusted_len_iter(iter).to(datatype)) -} - -fn deserialize_boolean(rows: &[ByteRecord], column: usize, op: F) -> Arc -where - F: Fn(&[u8]) -> Option, -{ - let iter = rows.iter().map(|row| match row.get(column) { - Some(bytes) => { - if bytes.is_empty() { - return None; - } - op(bytes) - } - None => None, - }); - Arc::new(BooleanArray::from_trusted_len_iter(iter)) -} - -fn deserialize_utf8(rows: &[ByteRecord], column: usize) -> Arc { - let iter = rows.iter().map(|row| match row.get(column) { - Some(bytes) => simdutf8::basic::from_utf8(bytes).ok(), - None => None, - }); - Arc::new(Utf8Array::::from_trusted_len_iter(iter)) -} - -fn deserialize_binary(rows: &[ByteRecord], column: usize) -> Arc { - let iter = rows.iter().map(|row| row.get(column)); - Arc::new(BinaryArray::::from_trusted_len_iter(iter)) -} +use super::super::read_utils::{ + deserialize_batch as deserialize_batch_gen, deserialize_column as deserialize_column_gen, + ByteRecordGeneric, +}; -#[inline] -fn deserialize_datetime(string: &str, tz: &T) -> Option { - let mut parsed = chrono::format::Parsed::new(); - let fmt = chrono::format::StrftimeItems::new(RFC3339); - if chrono::format::parse(&mut parsed, string, fmt).is_ok() { - parsed - .to_datetime() - .map(|x| x.naive_utc()) - .map(|x| tz.from_utc_datetime(&x)) - .map(|x| x.timestamp_nanos()) - .ok() - } else { - None +impl ByteRecordGeneric for ByteRecord { + fn get(&self, index: usize) -> Option<&[u8]> { + self.get(index) } } @@ -86,114 +25,9 @@ pub fn deserialize_column( rows: &[ByteRecord], column: usize, datatype: DataType, - _line_number: usize, + line_number: usize, ) -> Result> { - use DataType::*; - Ok(match datatype { - Boolean => deserialize_boolean(rows, column, |bytes| { - if bytes.eq_ignore_ascii_case(b"false") { - Some(false) - } else if bytes.eq_ignore_ascii_case(b"true") { - Some(true) - } else { - None - } - }), - Int8 => deserialize_primitive(rows, column, datatype, |bytes| { - lexical_core::parse::(bytes).ok() - }), - Int16 => deserialize_primitive(rows, column, datatype, |bytes| { - lexical_core::parse::(bytes).ok() - }), - Int32 => deserialize_primitive(rows, column, datatype, |bytes| { - lexical_core::parse::(bytes).ok() - }), - Int64 => deserialize_primitive(rows, column, datatype, |bytes| { - lexical_core::parse::(bytes).ok() - }), - UInt8 => deserialize_primitive(rows, column, datatype, |bytes| { - lexical_core::parse::(bytes).ok() - }), - UInt16 => deserialize_primitive(rows, column, datatype, |bytes| { - lexical_core::parse::(bytes).ok() - }), - UInt32 => deserialize_primitive(rows, column, datatype, |bytes| { - lexical_core::parse::(bytes).ok() - }), - UInt64 => deserialize_primitive(rows, column, datatype, |bytes| { - lexical_core::parse::(bytes).ok() - }), - Float32 => deserialize_primitive(rows, column, datatype, |bytes| { - lexical_core::parse::(bytes).ok() - }), - Float64 => deserialize_primitive(rows, column, datatype, |bytes| { - lexical_core::parse::(bytes).ok() - }), - Date32 => deserialize_primitive(rows, column, datatype, |bytes| { - simdutf8::basic::from_utf8(bytes) - .ok() - .and_then(|x| x.parse::().ok()) - .map(|x| x.num_days_from_ce() - temporal_conversions::EPOCH_DAYS_FROM_CE) - }), - Date64 => deserialize_primitive(rows, column, datatype, |bytes| { - simdutf8::basic::from_utf8(bytes) - .ok() - .and_then(|x| x.parse::().ok()) - .map(|x| x.timestamp_millis()) - }), - Timestamp(TimeUnit::Nanosecond, None) => { - deserialize_primitive(rows, column, datatype, |bytes| { - simdutf8::basic::from_utf8(bytes) - .ok() - .and_then(|x| x.parse::().ok()) - .map(|x| x.timestamp_nanos()) - }) - } - Timestamp(TimeUnit::Microsecond, None) => { - deserialize_primitive(rows, column, datatype, |bytes| { - simdutf8::basic::from_utf8(bytes) - .ok() - .and_then(|x| x.parse::().ok()) - .map(|x| x.timestamp_nanos() / 1000) - }) - } - Timestamp(time_unit, None) => deserialize_primitive(rows, column, datatype, |bytes| { - simdutf8::basic::from_utf8(bytes) - .ok() - .and_then(|x| x.parse::().ok()) - .map(|x| x.timestamp_nanos()) - .map(|x| match time_unit { - TimeUnit::Second => x / 1_000_000_000, - TimeUnit::Millisecond => x / 1_000_000, - TimeUnit::Microsecond => x / 1_000, - TimeUnit::Nanosecond => x, - }) - }), - Timestamp(time_unit, Some(ref tz)) => { - let tz = temporal_conversions::parse_offset(tz)?; - deserialize_primitive(rows, column, datatype, |bytes| { - simdutf8::basic::from_utf8(bytes) - .ok() - .and_then(|x| deserialize_datetime(x, &tz)) - .map(|x| match time_unit { - TimeUnit::Second => x / 1_000_000_000, - TimeUnit::Millisecond => x / 1_000_000, - TimeUnit::Microsecond => x / 1_000, - TimeUnit::Nanosecond => x, - }) - }) - } - Utf8 => deserialize_utf8::(rows, column), - LargeUtf8 => deserialize_utf8::(rows, column), - Binary => deserialize_binary::(rows, column), - LargeBinary => deserialize_binary::(rows, column), - other => { - return Err(ArrowError::NotYetImplemented(format!( - "Deserializing type \"{:?}\" is not implemented", - other - ))) - } - }) + deserialize_column_gen(rows, column, datatype, line_number) } /// Deserializes rows [`ByteRecord`] into a [`RecordBatch`]. @@ -209,27 +43,5 @@ pub fn deserialize_batch( where F: Fn(&[ByteRecord], usize, DataType, usize) -> Result>, { - let projection: Vec = match projection { - Some(v) => v.to_vec(), - None => fields.iter().enumerate().map(|(i, _)| i).collect(), - }; - let projected_fields: Vec = projection.iter().map(|i| fields[*i].clone()).collect(); - - let schema = Arc::new(Schema::new(projected_fields)); - - if rows.is_empty() { - return Ok(RecordBatch::new_empty(schema)); - } - - let columns = projection - .iter() - .map(|column| { - let column = *column; - let field = &fields[column]; - let data_type = field.data_type(); - deserialize_column(rows, column, data_type.clone(), line_number) - }) - .collect::>>()?; - - RecordBatch::try_new(schema, columns) + deserialize_batch_gen(rows, fields, projection, line_number, deserialize_column) } diff --git a/src/io/csv/read/infer_schema.rs b/src/io/csv/read/infer_schema.rs index 1b8fd060d0b..b56fb8fc598 100644 --- a/src/io/csv/read/infer_schema.rs +++ b/src/io/csv/read/infer_schema.rs @@ -3,87 +3,11 @@ use std::{ io::{Read, Seek}, }; -use super::{ByteRecord, Reader}; - -use crate::datatypes::{DataType, TimeUnit}; -use crate::datatypes::{Field, Schema}; +use crate::datatypes::{DataType, Schema}; use crate::error::Result; -pub(super) const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z"; - -fn is_boolean(bytes: &[u8]) -> bool { - bytes.eq_ignore_ascii_case(b"true") | bytes.eq_ignore_ascii_case(b"false") -} - -fn is_float(bytes: &[u8]) -> bool { - lexical_core::parse::(bytes).is_ok() -} - -fn is_integer(bytes: &[u8]) -> bool { - lexical_core::parse::(bytes).is_ok() -} - -fn is_date(string: &str) -> bool { - string.parse::().is_ok() -} - -fn is_time(string: &str) -> bool { - string.parse::().is_ok() -} - -fn is_naive_datetime(string: &str) -> bool { - string.parse::().is_ok() -} - -fn is_datetime(string: &str) -> Option { - let mut parsed = chrono::format::Parsed::new(); - let fmt = chrono::format::StrftimeItems::new(RFC3339); - if chrono::format::parse(&mut parsed, string, fmt).is_ok() { - parsed.offset.map(|x| { - let hours = x / 60 / 60; - let minutes = x / 60 - hours * 60; - format!("{:03}:{:02}", hours, minutes) - }) - } else { - None - } -} - -/// Infers [`DataType`] from `bytes` -/// # Implementation -/// * case insensitive "true" or "false" are mapped to [`DataType::Boolean`] -/// * parsable to integer is mapped to [`DataType::Int64`] -/// * parsable to float is mapped to [`DataType::Float64`] -/// * parsable to date is mapped to [`DataType::Date32`] -/// * parsable to time is mapped to [`DataType::Time32(TimeUnit::Millisecond)`] -/// * parsable to naive datetime is mapped to [`DataType::Timestamp(TimeUnit::Millisecond, None)`] -/// * parsable to time-aware datetime is mapped to [`DataType::Timestamp`] of milliseconds and parsed offset. -/// * other utf8 is mapped to [`DataType::Utf8`] -/// * invalid utf8 is mapped to [`DataType::Binary`] -pub fn infer(bytes: &[u8]) -> DataType { - if is_boolean(bytes) { - DataType::Boolean - } else if is_integer(bytes) { - DataType::Int64 - } else if is_float(bytes) { - DataType::Float64 - } else if let Ok(string) = simdutf8::basic::from_utf8(bytes) { - if is_date(string) { - DataType::Date32 - } else if is_time(string) { - DataType::Time32(TimeUnit::Millisecond) - } else if is_naive_datetime(string) { - DataType::Timestamp(TimeUnit::Millisecond, None) - } else if let Some(offset) = is_datetime(string) { - DataType::Timestamp(TimeUnit::Millisecond, Some(offset)) - } else { - DataType::Utf8 - } - } else { - // invalid utf8 - DataType::Binary - } -} +use super::super::utils::merge_schema; +use super::{ByteRecord, Reader}; /// Infers a [`Schema`] of a CSV file by reading through the first n records up to `max_rows`. /// Seeks back to the begining of the file _after_ the header @@ -128,31 +52,7 @@ pub fn infer_schema DataType>( } } - // build schema from inference results - let fields = headers - .iter() - .zip(column_types.into_iter()) - .map(|(field_name, mut possibilities)| { - // determine data type based on possible types - // if there are incompatible types, use DataType::Utf8 - let data_type = match possibilities.len() { - 1 => possibilities.drain().next().unwrap(), - 2 => { - if possibilities.contains(&DataType::Int64) - && possibilities.contains(&DataType::Float64) - { - // we have an integer and double, fall down to double - DataType::Float64 - } else { - // default to Utf8 for conflicting datatypes (e.g bool and int) - DataType::Utf8 - } - } - _ => DataType::Utf8, - }; - Field::new(field_name, data_type, true) - }) - .collect(); + let fields = merge_schema(&headers, &mut column_types); // return the reader seek back to the start reader.seek(position)?; diff --git a/src/io/csv/read/mod.rs b/src/io/csv/read/mod.rs index 48881040abc..b47de105c01 100644 --- a/src/io/csv/read/mod.rs +++ b/src/io/csv/read/mod.rs @@ -7,6 +7,7 @@ pub use csv::{ByteRecord, Reader, ReaderBuilder}; mod infer_schema; +pub use super::utils::infer; pub use deserialize::{deserialize_batch, deserialize_column}; -pub use infer_schema::{infer, infer_schema}; +pub use infer_schema::infer_schema; pub use reader::*; diff --git a/src/io/csv/read/reader.rs b/src/io/csv/read/reader.rs index e2b95a8e405..d6b4bf1cabb 100644 --- a/src/io/csv/read/reader.rs +++ b/src/io/csv/read/reader.rs @@ -2,25 +2,9 @@ use std::io::Read; use super::{ByteRecord, Reader}; -use crate::{ - datatypes::*, - error::{ArrowError, Result}, -}; +use crate::error::{ArrowError, Result}; -/// Returns a new [`Schema`] whereby the fields are selected based on `projection`. -pub fn projected_schema(schema: &Schema, projection: Option<&[usize]>) -> Schema { - match &projection { - Some(projection) => { - let fields = schema.fields(); - let projected_fields: Vec = - projection.iter().map(|i| fields[*i].clone()).collect(); - Schema::new_from(projected_fields, schema.metadata().clone()) - } - None => schema.clone(), - } -} - -/// Reads `len` rows from `reader` into `row`, skiping `skip`. +/// Reads `len` rows from `reader` into `row`, skiping the first `skip`. /// This operation has minimal CPU work and is thus the fastest way to read through a CSV /// without deserializing the contents to Arrow. pub fn read_rows( diff --git a/src/io/csv/read_async/deserialize.rs b/src/io/csv/read_async/deserialize.rs new file mode 100644 index 00000000000..41074b99844 --- /dev/null +++ b/src/io/csv/read_async/deserialize.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use csv_async::ByteRecord; + +use crate::{ + array::Array, + datatypes::{DataType, Field}, + error::Result, + record_batch::RecordBatch, +}; + +use super::super::read_utils::{ + deserialize_batch as deserialize_batch_gen, deserialize_column as deserialize_column_gen, + ByteRecordGeneric, +}; + +impl ByteRecordGeneric for ByteRecord { + fn get(&self, index: usize) -> Option<&[u8]> { + self.get(index) + } +} + +/// Deserializes `column` of `rows` into an [`Array`] of [`DataType`] `datatype`. +pub fn deserialize_column( + rows: &[ByteRecord], + column: usize, + datatype: DataType, + line_number: usize, +) -> Result> { + deserialize_column_gen(rows, column, datatype, line_number) +} + +/// Deserializes rows [`ByteRecord`] into a [`RecordBatch`]. +/// Note that this is a convenience function: column deserialization +/// is trivially parallelizable (e.g. rayon). +pub fn deserialize_batch( + rows: &[ByteRecord], + fields: &[Field], + projection: Option<&[usize]>, + line_number: usize, + deserialize_column: F, +) -> Result +where + F: Fn(&[ByteRecord], usize, DataType, usize) -> Result>, +{ + deserialize_batch_gen(rows, fields, projection, line_number, deserialize_column) +} diff --git a/src/io/csv/read_async/infer_schema.rs b/src/io/csv/read_async/infer_schema.rs new file mode 100644 index 00000000000..7fae717873e --- /dev/null +++ b/src/io/csv/read_async/infer_schema.rs @@ -0,0 +1,69 @@ +use std::collections::HashSet; + +use super::{AsyncReader, ByteRecord}; + +use crate::datatypes::{DataType, Schema}; +use crate::error::Result; +use crate::io::csv::utils::merge_schema; + +use futures::{AsyncRead, AsyncSeek}; + +/// Infers a [`Schema`] of a CSV file by reading through the first n records up to `max_rows`. +/// Seeks back to the begining of the file _after_ the header. +pub async fn infer_schema( + reader: &mut AsyncReader, + max_rows: Option, + has_header: bool, + infer: &F, +) -> Result +where + R: AsyncRead + AsyncSeek + Unpin + Send + Sync, + F: Fn(&[u8]) -> DataType, +{ + // get or create header names + // when has_header is false, creates default column names with column_ prefix + let headers: Vec = if has_header { + reader + .headers() + .await? + .iter() + .map(|s| s.to_string()) + .collect() + } else { + let first_record_count = &reader.headers().await?.len(); + (0..*first_record_count) + .map(|i| format!("column_{}", i + 1)) + .collect() + }; + + // save the csv reader position after reading headers + let position = reader.position().clone(); + + let header_length = headers.len(); + // keep track of inferred field types + let mut column_types: Vec> = vec![HashSet::new(); header_length]; + + let mut records_count = 0; + + let mut record = ByteRecord::new(); + let max_records = max_rows.unwrap_or(usize::MAX); + while records_count < max_records { + if !reader.read_byte_record(&mut record).await? { + break; + } + records_count += 1; + + for (i, column) in column_types.iter_mut().enumerate() { + if let Some(string) = record.get(i) { + column.insert(infer(string)); + } + } + } + + let fields = merge_schema(&headers, &mut column_types); + + // return the reader seek back to the start + reader.seek(position).await?; + + Ok(Schema::new(fields)) +} diff --git a/src/io/csv/read_async/mod.rs b/src/io/csv/read_async/mod.rs new file mode 100644 index 00000000000..155416b7460 --- /dev/null +++ b/src/io/csv/read_async/mod.rs @@ -0,0 +1,21 @@ +//! Asynchronous reading of CSV + +// Re-export for usage by consumers. +pub use csv_async::{AsyncReader, AsyncReaderBuilder, ByteRecord}; + +mod deserialize; +mod infer_schema; +mod reader; + +pub use super::utils::infer; +pub use deserialize::{deserialize_batch, deserialize_column}; +pub use infer_schema::infer_schema; +pub use reader::*; + +pub use csv_async::Error as CSVError; + +impl From for crate::error::ArrowError { + fn from(error: CSVError) -> Self { + crate::error::ArrowError::External("".to_string(), Box::new(error)) + } +} diff --git a/src/io/csv/read_async/reader.rs b/src/io/csv/read_async/reader.rs new file mode 100644 index 00000000000..0f145c4986f --- /dev/null +++ b/src/io/csv/read_async/reader.rs @@ -0,0 +1,38 @@ +use futures::AsyncRead; + +use super::{AsyncReader, ByteRecord}; + +use crate::error::{ArrowError, Result}; + +/// Asynchronosly read `len` rows from `reader` into `row`, skiping the first `skip`. +/// This operation has minimal CPU work and is thus the fastest way to read through a CSV +/// without deserializing the contents to Arrow. +pub async fn read_rows( + reader: &mut AsyncReader, + skip: usize, + rows: &mut [ByteRecord], +) -> Result +where + R: AsyncRead + Unpin + Send + Sync, +{ + // skip first `start` rows. + let mut row = ByteRecord::new(); + for _ in 0..skip { + let res = reader.read_byte_record(&mut row).await; + if !res.unwrap_or(false) { + break; + } + } + + let mut row_number = 0; + for row in rows.iter_mut() { + let has_more = reader.read_byte_record(row).await.map_err(|e| { + ArrowError::External(format!(" at line {}", skip + row_number), Box::new(e)) + })?; + if !has_more { + break; + } + row_number += 1; + } + Ok(row_number) +} diff --git a/src/io/csv/read_utils.rs b/src/io/csv/read_utils.rs new file mode 100644 index 00000000000..df6b4e07bbd --- /dev/null +++ b/src/io/csv/read_utils.rs @@ -0,0 +1,245 @@ +use std::sync::Arc; + +use chrono::Datelike; + +// Ideally this trait should not be needed and both `csv` and `csv_async` crates would share +// the same `ByteRecord` struct. Unfortunately, they do not and thus we must use generics +// over this trait and materialize the generics for each struct. +pub(crate) trait ByteRecordGeneric { + fn get(&self, index: usize) -> Option<&[u8]>; +} + +use crate::{ + array::*, + datatypes::*, + error::{ArrowError, Result}, + record_batch::RecordBatch, + temporal_conversions, + types::{NativeType, NaturalDataType}, +}; + +use super::utils::RFC3339; + +fn deserialize_primitive( + rows: &[B], + column: usize, + datatype: DataType, + op: F, +) -> Arc +where + T: NativeType + NaturalDataType + lexical_core::FromLexical, + F: Fn(&[u8]) -> Option, +{ + let iter = rows.iter().map(|row| match row.get(column) { + Some(bytes) => { + if bytes.is_empty() { + return None; + } + op(bytes) + } + None => None, + }); + Arc::new(PrimitiveArray::::from_trusted_len_iter(iter).to(datatype)) +} + +fn deserialize_boolean(rows: &[B], column: usize, op: F) -> Arc +where + B: ByteRecordGeneric, + F: Fn(&[u8]) -> Option, +{ + let iter = rows.iter().map(|row| match row.get(column) { + Some(bytes) => { + if bytes.is_empty() { + return None; + } + op(bytes) + } + None => None, + }); + Arc::new(BooleanArray::from_trusted_len_iter(iter)) +} + +fn deserialize_utf8(rows: &[B], column: usize) -> Arc { + let iter = rows.iter().map(|row| match row.get(column) { + Some(bytes) => simdutf8::basic::from_utf8(bytes).ok(), + None => None, + }); + Arc::new(Utf8Array::::from_trusted_len_iter(iter)) +} + +fn deserialize_binary( + rows: &[B], + column: usize, +) -> Arc { + let iter = rows.iter().map(|row| row.get(column)); + Arc::new(BinaryArray::::from_trusted_len_iter(iter)) +} + +#[inline] +fn deserialize_datetime(string: &str, tz: &T) -> Option { + let mut parsed = chrono::format::Parsed::new(); + let fmt = chrono::format::StrftimeItems::new(RFC3339); + if chrono::format::parse(&mut parsed, string, fmt).is_ok() { + parsed + .to_datetime() + .map(|x| x.naive_utc()) + .map(|x| tz.from_utc_datetime(&x)) + .map(|x| x.timestamp_nanos()) + .ok() + } else { + None + } +} + +/// Deserializes `column` of `rows` into an [`Array`] of [`DataType`] `datatype`. +pub(crate) fn deserialize_column( + rows: &[B], + column: usize, + datatype: DataType, + _line_number: usize, +) -> Result> { + use DataType::*; + Ok(match datatype { + Boolean => deserialize_boolean(rows, column, |bytes| { + if bytes.eq_ignore_ascii_case(b"false") { + Some(false) + } else if bytes.eq_ignore_ascii_case(b"true") { + Some(true) + } else { + None + } + }), + Int8 => deserialize_primitive(rows, column, datatype, |bytes| { + lexical_core::parse::(bytes).ok() + }), + Int16 => deserialize_primitive(rows, column, datatype, |bytes| { + lexical_core::parse::(bytes).ok() + }), + Int32 => deserialize_primitive(rows, column, datatype, |bytes| { + lexical_core::parse::(bytes).ok() + }), + Int64 => deserialize_primitive(rows, column, datatype, |bytes| { + lexical_core::parse::(bytes).ok() + }), + UInt8 => deserialize_primitive(rows, column, datatype, |bytes| { + lexical_core::parse::(bytes).ok() + }), + UInt16 => deserialize_primitive(rows, column, datatype, |bytes| { + lexical_core::parse::(bytes).ok() + }), + UInt32 => deserialize_primitive(rows, column, datatype, |bytes| { + lexical_core::parse::(bytes).ok() + }), + UInt64 => deserialize_primitive(rows, column, datatype, |bytes| { + lexical_core::parse::(bytes).ok() + }), + Float32 => deserialize_primitive(rows, column, datatype, |bytes| { + lexical_core::parse::(bytes).ok() + }), + Float64 => deserialize_primitive(rows, column, datatype, |bytes| { + lexical_core::parse::(bytes).ok() + }), + Date32 => deserialize_primitive(rows, column, datatype, |bytes| { + simdutf8::basic::from_utf8(bytes) + .ok() + .and_then(|x| x.parse::().ok()) + .map(|x| x.num_days_from_ce() - temporal_conversions::EPOCH_DAYS_FROM_CE) + }), + Date64 => deserialize_primitive(rows, column, datatype, |bytes| { + simdutf8::basic::from_utf8(bytes) + .ok() + .and_then(|x| x.parse::().ok()) + .map(|x| x.timestamp_millis()) + }), + Timestamp(TimeUnit::Nanosecond, None) => { + deserialize_primitive(rows, column, datatype, |bytes| { + simdutf8::basic::from_utf8(bytes) + .ok() + .and_then(|x| x.parse::().ok()) + .map(|x| x.timestamp_nanos()) + }) + } + Timestamp(TimeUnit::Microsecond, None) => { + deserialize_primitive(rows, column, datatype, |bytes| { + simdutf8::basic::from_utf8(bytes) + .ok() + .and_then(|x| x.parse::().ok()) + .map(|x| x.timestamp_nanos() / 1000) + }) + } + Timestamp(time_unit, None) => deserialize_primitive(rows, column, datatype, |bytes| { + simdutf8::basic::from_utf8(bytes) + .ok() + .and_then(|x| x.parse::().ok()) + .map(|x| x.timestamp_nanos()) + .map(|x| match time_unit { + TimeUnit::Second => x / 1_000_000_000, + TimeUnit::Millisecond => x / 1_000_000, + TimeUnit::Microsecond => x / 1_000, + TimeUnit::Nanosecond => x, + }) + }), + Timestamp(time_unit, Some(ref tz)) => { + let tz = temporal_conversions::parse_offset(tz)?; + deserialize_primitive(rows, column, datatype, |bytes| { + simdutf8::basic::from_utf8(bytes) + .ok() + .and_then(|x| deserialize_datetime(x, &tz)) + .map(|x| match time_unit { + TimeUnit::Second => x / 1_000_000_000, + TimeUnit::Millisecond => x / 1_000_000, + TimeUnit::Microsecond => x / 1_000, + TimeUnit::Nanosecond => x, + }) + }) + } + Utf8 => deserialize_utf8::(rows, column), + LargeUtf8 => deserialize_utf8::(rows, column), + Binary => deserialize_binary::(rows, column), + LargeBinary => deserialize_binary::(rows, column), + other => { + return Err(ArrowError::NotYetImplemented(format!( + "Deserializing type \"{:?}\" is not implemented", + other + ))) + } + }) +} + +/// Deserializes rows [`ByteRecord`] into a [`RecordBatch`]. +/// Note that this is a convenience function: column deserialization +/// is trivially parallelizable (e.g. rayon). +pub(crate) fn deserialize_batch( + rows: &[B], + fields: &[Field], + projection: Option<&[usize]>, + line_number: usize, + deserialize_column: F, +) -> Result +where + F: Fn(&[B], usize, DataType, usize) -> Result>, +{ + let projection: Vec = match projection { + Some(v) => v.to_vec(), + None => fields.iter().enumerate().map(|(i, _)| i).collect(), + }; + let projected_fields: Vec = projection.iter().map(|i| fields[*i].clone()).collect(); + + let schema = Arc::new(Schema::new(projected_fields)); + + if rows.is_empty() { + return Ok(RecordBatch::new_empty(schema)); + } + + let columns = projection + .iter() + .map(|column| { + let column = *column; + let field = &fields[column]; + let data_type = field.data_type(); + deserialize_column(rows, column, data_type.clone(), line_number) + }) + .collect::>>()?; + + RecordBatch::try_new(schema, columns) +} diff --git a/src/io/csv/utils.rs b/src/io/csv/utils.rs new file mode 100644 index 00000000000..266801d940a --- /dev/null +++ b/src/io/csv/utils.rs @@ -0,0 +1,111 @@ +use std::collections::HashSet; + +use crate::datatypes::{DataType, Field, TimeUnit}; + +pub(super) const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z"; + +fn is_boolean(bytes: &[u8]) -> bool { + bytes.eq_ignore_ascii_case(b"true") | bytes.eq_ignore_ascii_case(b"false") +} + +fn is_float(bytes: &[u8]) -> bool { + lexical_core::parse::(bytes).is_ok() +} + +fn is_integer(bytes: &[u8]) -> bool { + lexical_core::parse::(bytes).is_ok() +} + +fn is_date(string: &str) -> bool { + string.parse::().is_ok() +} + +fn is_time(string: &str) -> bool { + string.parse::().is_ok() +} + +fn is_naive_datetime(string: &str) -> bool { + string.parse::().is_ok() +} + +fn is_datetime(string: &str) -> Option { + let mut parsed = chrono::format::Parsed::new(); + let fmt = chrono::format::StrftimeItems::new(RFC3339); + if chrono::format::parse(&mut parsed, string, fmt).is_ok() { + parsed.offset.map(|x| { + let hours = x / 60 / 60; + let minutes = x / 60 - hours * 60; + format!("{:03}:{:02}", hours, minutes) + }) + } else { + None + } +} + +/// Infers [`DataType`] from `bytes` +/// # Implementation +/// * case insensitive "true" or "false" are mapped to [`DataType::Boolean`] +/// * parsable to integer is mapped to [`DataType::Int64`] +/// * parsable to float is mapped to [`DataType::Float64`] +/// * parsable to date is mapped to [`DataType::Date32`] +/// * parsable to time is mapped to [`DataType::Time32(TimeUnit::Millisecond)`] +/// * parsable to naive datetime is mapped to [`DataType::Timestamp(TimeUnit::Millisecond, None)`] +/// * parsable to time-aware datetime is mapped to [`DataType::Timestamp`] of milliseconds and parsed offset. +/// * other utf8 is mapped to [`DataType::Utf8`] +/// * invalid utf8 is mapped to [`DataType::Binary`] +pub fn infer(bytes: &[u8]) -> DataType { + if is_boolean(bytes) { + DataType::Boolean + } else if is_integer(bytes) { + DataType::Int64 + } else if is_float(bytes) { + DataType::Float64 + } else if let Ok(string) = simdutf8::basic::from_utf8(bytes) { + if is_date(string) { + DataType::Date32 + } else if is_time(string) { + DataType::Time32(TimeUnit::Millisecond) + } else if is_naive_datetime(string) { + DataType::Timestamp(TimeUnit::Millisecond, None) + } else if let Some(offset) = is_datetime(string) { + DataType::Timestamp(TimeUnit::Millisecond, Some(offset)) + } else { + DataType::Utf8 + } + } else { + // invalid utf8 + DataType::Binary + } +} + +fn merge_fields(field_name: &str, possibilities: &mut HashSet) -> Field { + // determine data type based on possible types + // if there are incompatible types, use DataType::Utf8 + let data_type = match possibilities.len() { + 1 => possibilities.drain().next().unwrap(), + 2 => { + if possibilities.contains(&DataType::Int64) + && possibilities.contains(&DataType::Float64) + { + // we have an integer and double, fall down to double + DataType::Float64 + } else { + // default to Utf8 for conflicting datatypes (e.g bool and int) + DataType::Utf8 + } + } + _ => DataType::Utf8, + }; + Field::new(field_name, data_type, true) +} + +pub(crate) fn merge_schema( + headers: &[String], + column_types: &mut [HashSet], +) -> Vec { + headers + .iter() + .zip(column_types.iter_mut()) + .map(|(field_name, possibilities)| merge_fields(field_name, possibilities)) + .collect() +} diff --git a/src/io/mod.rs b/src/io/mod.rs index c86b6655fd2..a8d68d66951 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -1,7 +1,18 @@ //! Contains modules to interface with other formats such as [`csv`], //! [`parquet`], [`json`], [`ipc`], [`mod@print`] and [`avro`]. -#[cfg(any(feature = "io_csv_read", feature = "io_csv_write"))] -#[cfg_attr(docsrs, doc(cfg(feature = "io_csv")))] +#[cfg(any( + feature = "io_csv_read", + feature = "io_csv_read_async", + feature = "io_csv_write", +))] +#[cfg_attr( + docsrs, + doc(cfg(any( + feature = "io_csv_read", + feature = "io_csv_read_async", + feature = "io_csv_write", + ))) +)] pub mod csv; #[cfg(feature = "io_json")] diff --git a/tests/it/io/csv/mod.rs b/tests/it/io/csv/mod.rs index f7cdac66078..b0efd3ed704 100644 --- a/tests/it/io/csv/mod.rs +++ b/tests/it/io/csv/mod.rs @@ -1,2 +1,6 @@ +#[cfg(feature = "io_csv_read")] mod read; +#[cfg(feature = "io_csv_read_async")] +mod read_async; +#[cfg(feature = "io_csv_write")] mod write; diff --git a/tests/it/io/csv/read_async.rs b/tests/it/io/csv/read_async.rs new file mode 100644 index 00000000000..6d22004ad7c --- /dev/null +++ b/tests/it/io/csv/read_async.rs @@ -0,0 +1,62 @@ +use futures::io::Cursor; +use std::sync::Arc; + +use arrow2::array::*; +use arrow2::error::Result; +use arrow2::io::csv::read_async::*; + +#[tokio::test] +async fn read() -> Result<()> { + let data = r#"city,lat,lng +"Elgin, Scotland, the UK",57.653484,-3.335724 +"Stoke-on-Trent, Staffordshire, the UK",53.002666,-2.179404 +"Solihull, Birmingham, UK",52.412811,-1.778197 +"Cardiff, Cardiff county, UK",51.481583,-3.179090 +"Eastbourne, East Sussex, UK",50.768036,0.290472 +"Oxford, Oxfordshire, UK",51.752022,-1.257677 +"London, UK",51.509865,-0.118092 +"Swindon, Swindon, UK",51.568535,-1.772232 +"Gravesend, Kent, UK",51.441883,0.370759 +"Northampton, Northamptonshire, UK",52.240479,-0.902656 +"Rugby, Warwickshire, UK",52.370876,-1.265032 +"Sutton Coldfield, West Midlands, UK",52.570385,-1.824042 +"Harlow, Essex, UK",51.772938,0.102310 +"Aberdeen, Aberdeen City, UK",57.149651,-2.099075"#; + let mut reader = AsyncReaderBuilder::new().create_reader(Cursor::new(data.as_bytes())); + + let schema = Arc::new(infer_schema(&mut reader, None, true, &infer).await?); + + let mut rows = vec![ByteRecord::default(); 100]; + let rows_read = read_rows(&mut reader, 0, &mut rows).await?; + + let batch = deserialize_batch( + &rows[..rows_read], + schema.fields(), + None, + 0, + deserialize_column, + )?; + + let batch_schema = batch.schema(); + + assert_eq!(&schema, batch_schema); + assert_eq!(14, batch.num_rows()); + assert_eq!(3, batch.num_columns()); + + let lat = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert!((57.653484 - lat.value(0)).abs() < f64::EPSILON); + + let city = batch + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + + assert_eq!("Elgin, Scotland, the UK", city.value(0)); + assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13)); + Ok(()) +} diff --git a/tests/it/io/mod.rs b/tests/it/io/mod.rs index 5009e99015f..75296ecf766 100644 --- a/tests/it/io/mod.rs +++ b/tests/it/io/mod.rs @@ -13,5 +13,9 @@ mod parquet; #[cfg(feature = "io_avro")] mod avro; -#[cfg(feature = "io_csv")] +#[cfg(any( + feature = "io_csv_read", + feature = "io_csv_write", + feature = "io_csv_read_async" +))] mod csv;