diff --git a/common/datavalues/src/types/deserializations/array.rs b/common/datavalues/src/types/deserializations/array.rs index c9ef65a7e5e4..bc45c21e67f2 100644 --- a/common/datavalues/src/types/deserializations/array.rs +++ b/common/datavalues/src/types/deserializations/array.rs @@ -85,8 +85,37 @@ impl TypeDeserializer for ArrayDeserializer { } let mut values = Vec::with_capacity(idx); for _ in 0..idx { - let value = self.inner.pop_data_value().unwrap(); - values.push(value); + values.push(self.inner.pop_data_value()?); + } + values.reverse(); + self.builder.append_value(ArrayValue::new(values)); + Ok(()) + } + + fn de_text_csv( + &mut self, + reader: &mut CheckpointReader, + format: &FormatSettings, + ) -> Result<()> { + reader.must_ignore_byte(b'[')?; + let mut idx = 0; + loop { + let _ = reader.ignore_white_spaces()?; + if let Ok(res) = reader.ignore_byte(b']') { + if res { + break; + } + } + if idx != 0 { + let _ = reader.must_ignore_byte(b',')?; + } + let _ = reader.ignore_white_spaces()?; + self.inner.de_text_csv(reader, format)?; + idx += 1; + } + let mut values = Vec::with_capacity(idx); + for _ in 0..idx { + values.push(self.inner.pop_data_value()?); } values.reverse(); self.builder.append_value(ArrayValue::new(values)); diff --git a/common/datavalues/src/types/deserializations/boolean.rs b/common/datavalues/src/types/deserializations/boolean.rs index 3bac9f8171b5..7c86f99b9623 100644 --- a/common/datavalues/src/types/deserializations/boolean.rs +++ b/common/datavalues/src/types/deserializations/boolean.rs @@ -49,6 +49,14 @@ impl TypeDeserializer for BooleanDeserializer { Ok(()) } + fn de_json(&mut self, value: &serde_json::Value, _format: &FormatSettings) -> Result<()> { + match value { + serde_json::Value::Bool(v) => self.builder.append_value(*v), + _ => return Err(ErrorCode::BadBytes("Incorrect boolean value")), + } + Ok(()) + } + fn de_whole_text(&mut self, reader: &[u8], _format: &FormatSettings) -> Result<()> { if reader.eq_ignore_ascii_case(b"true") { self.builder.append_value(true); @@ -77,14 +85,6 @@ impl TypeDeserializer for BooleanDeserializer { Ok(()) } - fn de_json(&mut self, value: &serde_json::Value, _format: &FormatSettings) -> Result<()> { - match value { - serde_json::Value::Bool(v) => self.builder.append_value(*v), - _ => return Err(ErrorCode::BadBytes("Incorrect boolean value")), - } - Ok(()) - } - fn append_data_value(&mut self, value: DataValue, _format: &FormatSettings) -> Result<()> { self.builder.append_value(value.as_bool()?); Ok(()) diff --git a/common/datavalues/src/types/deserializations/number.rs b/common/datavalues/src/types/deserializations/number.rs index e6cb80095cfd..95e91b966a76 100644 --- a/common/datavalues/src/types/deserializations/number.rs +++ b/common/datavalues/src/types/deserializations/number.rs @@ -71,6 +71,10 @@ where } } + fn de_null(&mut self, _format: &FormatSettings) -> bool { + false + } + fn de_whole_text(&mut self, reader: &[u8], _format: &FormatSettings) -> Result<()> { let mut reader = BufferReader::new(reader); let v: T = if !T::FLOATING { @@ -98,8 +102,25 @@ where Ok(()) } - fn de_null(&mut self, _format: &FormatSettings) -> bool { - false + fn de_text_csv( + &mut self, + reader: &mut CheckpointReader, + _settings: &FormatSettings, + ) -> Result<()> { + let maybe_quote = reader.ignore(|f| f == b'\'' || f == b'"')?; + + let v: T = if !T::FLOATING { + reader.read_int_text() + } else { + reader.read_float_text() + }?; + + if maybe_quote { + reader.must_ignore(|f| f == b'\'' || f == b'"')?; + } + + self.builder.append_value(v); + Ok(()) } fn append_data_value(&mut self, value: DataValue, _format: &FormatSettings) -> Result<()> { diff --git a/common/datavalues/src/types/deserializations/string.rs b/common/datavalues/src/types/deserializations/string.rs index a3e4f9c13df0..7df6807b17bd 100644 --- a/common/datavalues/src/types/deserializations/string.rs +++ b/common/datavalues/src/types/deserializations/string.rs @@ -79,33 +79,168 @@ impl TypeDeserializer for StringDeserializer { } } - fn de_text_quoted( + fn de_whole_text(&mut self, reader: &[u8], _format: &FormatSettings) -> Result<()> { + self.builder.append_value(reader); + Ok(()) + } + + fn de_text( &mut self, reader: &mut CheckpointReader, _format: &FormatSettings, ) -> Result<()> { self.buffer.clear(); - reader.read_quoted_text(&mut self.buffer, b'\'')?; + reader.read_escaped_string_text(&mut self.buffer)?; self.builder.append_value(self.buffer.as_slice()); Ok(()) } - fn de_whole_text(&mut self, reader: &[u8], _format: &FormatSettings) -> Result<()> { - self.builder.append_value(reader); - Ok(()) - } - - fn de_text( + fn de_text_quoted( &mut self, reader: &mut CheckpointReader, _format: &FormatSettings, ) -> Result<()> { self.buffer.clear(); - reader.read_escaped_string_text(&mut self.buffer)?; + reader.read_quoted_text(&mut self.buffer, b'\'')?; self.builder.append_value(self.buffer.as_slice()); Ok(()) } + fn de_text_csv( + &mut self, + reader: &mut CheckpointReader, + settings: &FormatSettings, + ) -> Result<()> { + let mut read_buffer = reader.fill_buf()?; + + if read_buffer.is_empty() { + return Err(ErrorCode::BadBytes("Read string after eof.")); + } + + let maybe_quote = read_buffer[0]; + if maybe_quote == b'\'' || maybe_quote == b'"' { + let mut index = 1; + let mut bytes = 0; + + loop { + let begin = index; + while index < read_buffer.len() { + if read_buffer[index] == maybe_quote { + self.builder + .values_mut() + .extend_from_slice(&read_buffer[begin..index]); + self.builder.add_offset(bytes + index - begin); + reader.consume(index + 1); + return Ok(()); + } + + index += 1; + } + + bytes += index - begin; + self.builder + .values_mut() + .extend_from_slice(&read_buffer[begin..]); + reader.consume(index - begin); + + index = 0; + read_buffer = reader.fill_buf()?; + + if read_buffer.is_empty() { + break; + } + } + + Err(ErrorCode::BadBytes(format!( + "Not found '{}' before eof in parse string.", + maybe_quote as char + ))) + } else { + // Unquoted case. Look for field_delimiter or record_delimiter. + let mut field_delimiter = b','; + + if !settings.field_delimiter.is_empty() { + field_delimiter = settings.field_delimiter[0]; + } + + if settings.record_delimiter.is_empty() + || settings.record_delimiter[0] == b'\r' + || settings.record_delimiter[0] == b'\n' + { + let mut index = 0; + let mut bytes = 0; + + 'outer1: loop { + while index < read_buffer.len() { + if read_buffer[index] == field_delimiter + || read_buffer[index] == b'\r' + || read_buffer[index] == b'\n' + { + break 'outer1; + } + index += 1; + } + + bytes += index; + self.builder + .values_mut() + .extend_from_slice(&read_buffer[..index]); + reader.consume(index); + + index = 0; + read_buffer = reader.fill_buf()?; + + if read_buffer.is_empty() { + break 'outer1; + } + } + + self.builder + .values_mut() + .extend_from_slice(&read_buffer[..index]); + self.builder.add_offset(bytes + index); + reader.consume(index); + } else { + let record_delimiter = settings.record_delimiter[0]; + + let mut index = 0; + let mut bytes = 0; + + 'outer2: loop { + while index < read_buffer.len() { + if read_buffer[index] == field_delimiter + || read_buffer[index] == record_delimiter + { + break 'outer2; + } + index += 1; + } + + bytes += index; + self.builder + .values_mut() + .extend_from_slice(&read_buffer[..index]); + reader.consume(index); + + index = 0; + read_buffer = reader.fill_buf()?; + + if read_buffer.is_empty() { + break 'outer2; + } + } + + self.builder + .values_mut() + .extend_from_slice(&read_buffer[..index]); + self.builder.add_offset(bytes + index); + reader.consume(index); + } + + Ok(()) + } + } + fn append_data_value(&mut self, value: DataValue, _format: &FormatSettings) -> Result<()> { self.builder.append_data_value(value) } diff --git a/common/datavalues/src/types/deserializations/variant.rs b/common/datavalues/src/types/deserializations/variant.rs index af13c92ee3a1..4ad4a37651cf 100644 --- a/common/datavalues/src/types/deserializations/variant.rs +++ b/common/datavalues/src/types/deserializations/variant.rs @@ -75,6 +75,12 @@ impl TypeDeserializer for VariantDeserializer { Ok(()) } + fn de_whole_text(&mut self, reader: &[u8], _format: &FormatSettings) -> Result<()> { + let val = serde_json::from_slice(reader)?; + self.builder.append_value(val); + Ok(()) + } + fn de_text( &mut self, reader: &mut CheckpointReader, @@ -87,19 +93,27 @@ impl TypeDeserializer for VariantDeserializer { Ok(()) } - fn de_whole_text(&mut self, reader: &[u8], _format: &FormatSettings) -> Result<()> { - let val = serde_json::from_slice(reader)?; + fn de_text_quoted( + &mut self, + reader: &mut CheckpointReader, + _format: &FormatSettings, + ) -> Result<()> { + self.buffer.clear(); + reader.read_quoted_text(&mut self.buffer, b'\'')?; + + let val = serde_json::from_slice(self.buffer.as_slice())?; + self.builder.append_value(val); Ok(()) } - fn de_text_quoted( + fn de_text_csv( &mut self, reader: &mut CheckpointReader, _format: &FormatSettings, ) -> Result<()> { self.buffer.clear(); - reader.read_quoted_text(&mut self.buffer, b'\'')?; + reader.read_quoted_text(&mut self.buffer, b'"')?; let val = serde_json::from_slice(self.buffer.as_slice())?; self.builder.append_value(val); diff --git a/common/exception/src/exception_code.rs b/common/exception/src/exception_code.rs index 87e85d24f958..a9b612e5f19c 100644 --- a/common/exception/src/exception_code.rs +++ b/common/exception/src/exception_code.rs @@ -139,6 +139,8 @@ build_exceptions! { // Network error codes. NetworkRequestError(1073), + UnknownFormat(1074), + // Tenant error codes. TenantIsEmpty(1101), IndexOutOfBounds(1102), diff --git a/common/io/src/buffer/buffer_read_ext.rs b/common/io/src/buffer/buffer_read_ext.rs index ff6df9336683..53ea92c19c01 100644 --- a/common/io/src/buffer/buffer_read_ext.rs +++ b/common/io/src/buffer/buffer_read_ext.rs @@ -24,6 +24,7 @@ pub trait BufferReadExt: BufferRead { fn ignore_bytes(&mut self, bs: &[u8]) -> Result; fn ignore_insensitive_bytes(&mut self, bs: &[u8]) -> Result; fn ignore_white_spaces(&mut self) -> Result; + fn ignore_white_spaces_and_byte(&mut self, b: u8) -> Result; fn until(&mut self, delim: u8, buf: &mut Vec) -> Result; fn keep_read(&mut self, buf: &mut Vec, f: impl Fn(u8) -> bool) -> Result; @@ -55,6 +56,11 @@ pub trait BufferReadExt: BufferRead { Ok(()) } + fn eof(&mut self) -> Result { + let buffer = self.fill_buf()?; + Ok(buffer.is_empty()) + } + fn must_eof(&mut self) -> Result<()> { let buffer = self.fill_buf()?; if !buffer.is_empty() { @@ -78,6 +84,16 @@ pub trait BufferReadExt: BufferRead { fn must_ignore_byte(&mut self, b: u8) -> Result<()> { if !self.ignore_byte(b)? { + return Err(std::io::Error::new( + ErrorKind::InvalidData, + format!("Expected to have char {}.", b as char), + )); + } + Ok(()) + } + + fn must_ignore_white_spaces_and_byte(&mut self, b: u8) -> Result<()> { + if !self.ignore_white_spaces_and_byte(b)? { return Err(std::io::Error::new( ErrorKind::InvalidData, format!("Expected to have char {}", b as char), @@ -172,6 +188,17 @@ where R: BufferRead Ok(cnt > 0) } + fn ignore_white_spaces_and_byte(&mut self, b: u8) -> Result { + self.ignores(|c: u8| c == b' ')?; + + if self.ignore_byte(b)? { + self.ignores(|c: u8| c == b' ')?; + return Ok(true); + } + + Ok(false) + } + fn until(&mut self, delim: u8, buf: &mut Vec) -> Result { self.read_until(delim, buf) } diff --git a/common/io/src/buffer/buffer_read_number_ext.rs b/common/io/src/buffer/buffer_read_number_ext.rs index 0145bed13ce3..45c9c116c1df 100644 --- a/common/io/src/buffer/buffer_read_number_ext.rs +++ b/common/io/src/buffer/buffer_read_number_ext.rs @@ -66,9 +66,17 @@ where R: BufferRead let _ = self.ignores(|f| (b'0'..=b'9').contains(&f))?; } - FromLexical::from_lexical(buf.as_slice()).map_err_to_code(ErrorCode::BadBytes, || { - format!("Cannot parse value:{:?} to number type", buf) - }) + match buf.is_empty() { + true => Ok(T::default()), + false => match FromLexical::from_lexical(buf.as_slice()) { + Ok(value) => Ok(value), + Err(cause) => Err(ErrorCode::BadBytes(format!( + "Cannot parse value:{:?} to number type, cause: {:?}", + String::from_utf8(buf), + cause + ))), + }, + } } fn read_float_text(&mut self) -> Result { diff --git a/query/src/formats/format.rs b/query/src/formats/format.rs new file mode 100644 index 000000000000..9a7507a7555e --- /dev/null +++ b/query/src/formats/format.rs @@ -0,0 +1,36 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; + +use common_datablocks::DataBlock; +use common_exception::Result; + +pub trait InputState: Send { + fn as_any(&mut self) -> &mut dyn Any; +} + +pub trait InputFormat: Send { + fn support_parallel(&self) -> bool { + false + } + + fn create_state(&self) -> Box; + + fn deserialize_data(&self, state: &mut Box) -> Result; + + fn read_buf(&self, buf: &[u8], state: &mut Box) -> Result; + + fn skip_header(&self, buf: &[u8], state: &mut Box) -> Result; +} diff --git a/query/src/formats/format_csv.rs b/query/src/formats/format_csv.rs new file mode 100644 index 000000000000..891f8ba465ea --- /dev/null +++ b/query/src/formats/format_csv.rs @@ -0,0 +1,285 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::io::Cursor; + +use common_datablocks::DataBlock; +use common_datavalues::DataSchemaRef; +use common_datavalues::DataType; +use common_datavalues::TypeDeserializer; +use common_exception::ErrorCode; +use common_exception::Result; +use common_io::prelude::BufferReadExt; +use common_io::prelude::BufferReader; +use common_io::prelude::CheckpointReader; +use common_io::prelude::FormatSettings; + +use crate::formats::FormatFactory; +use crate::formats::InputFormat; +use crate::formats::InputState; + +pub struct CsvInputState { + pub quotes: u8, + pub memory: Vec, + pub accepted_rows: usize, + pub accepted_bytes: usize, + pub need_more_data: bool, + pub ignore_if_first: Option, +} + +impl InputState for CsvInputState { + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +pub struct CsvInputFormat { + schema: DataSchemaRef, + field_delimiter: u8, + need_skip_header: bool, + row_delimiter: Option, + min_accepted_rows: usize, + min_accepted_bytes: usize, + settings: FormatSettings, +} + +impl CsvInputFormat { + pub fn register(factory: &mut FormatFactory) { + factory.register_input( + "csv", + Box::new( + |name: &str, schema: DataSchemaRef, settings: FormatSettings| { + CsvInputFormat::try_create(name, schema, settings, 8192, 10 * 1024 * 1024) + }, + ), + ) + } + + pub fn try_create( + _name: &str, + schema: DataSchemaRef, + settings: FormatSettings, + min_accepted_rows: usize, + min_accepted_bytes: usize, + ) -> Result> { + let field_delimiter = match settings.field_delimiter.len() { + n if n >= 1 => settings.field_delimiter[0], + _ => b',', + }; + + let mut row_delimiter = None; + + if !settings.record_delimiter.is_empty() + && settings.record_delimiter[0] != b'\n' + && settings.record_delimiter[0] != b'\r' + { + row_delimiter = Some(settings.record_delimiter[0]); + } + + let need_skip_header = settings.skip_header; + + Ok(Box::new(CsvInputFormat { + schema, + settings, + row_delimiter, + field_delimiter, + need_skip_header, + min_accepted_rows, + min_accepted_bytes, + })) + } + + fn find_quotes(buf: &[u8], pos: usize, state: &mut CsvInputState) -> usize { + for (index, byte) in buf.iter().enumerate().skip(pos) { + if *byte == b'"' || *byte == b'\'' { + state.quotes = 0; + return index + 1; + } + } + + buf.len() + } + + fn find_delimiter(&self, buf: &[u8], pos: usize, state: &mut CsvInputState) -> usize { + for index in pos..buf.len() { + if buf[index] == b'"' || buf[index] == b'\'' { + state.quotes = buf[index]; + return index + 1; + } + + if let Some(b) = &self.row_delimiter { + if buf[index] == *b { + return self.accept_row::<0>(buf, pos, state, index); + } + } else if buf[index] == b'\r' { + return self.accept_row::(buf, pos, state, index); + } else if buf[index] == b'\n' { + return self.accept_row::(buf, pos, state, index); + } + } + + buf.len() + } + + #[inline(always)] + fn accept_row( + &self, + buf: &[u8], + pos: usize, + state: &mut CsvInputState, + index: usize, + ) -> usize { + state.accepted_rows += 1; + state.accepted_bytes += index - pos; + + if state.accepted_rows >= self.min_accepted_rows + || (state.accepted_bytes + index) >= self.min_accepted_bytes + { + state.need_more_data = false; + } + + if C != 0 { + if buf.len() <= index + 1 { + state.ignore_if_first = Some(C); + } else if buf[index + 1] == C { + return index + 2; + } + } + + index + 1 + } +} + +impl InputFormat for CsvInputFormat { + fn create_state(&self) -> Box { + Box::new(CsvInputState { + quotes: 0, + memory: vec![], + accepted_rows: 0, + accepted_bytes: 0, + need_more_data: false, + ignore_if_first: None, + }) + } + + fn deserialize_data(&self, state: &mut Box) -> Result { + let mut deserializers = Vec::with_capacity(self.schema.num_fields()); + for field in self.schema.fields() { + let data_type = field.data_type(); + deserializers.push(data_type.create_deserializer(self.min_accepted_rows)); + } + + let mut state = std::mem::replace(state, self.create_state()); + let state = state.as_any().downcast_mut::().unwrap(); + let cursor = Cursor::new(&state.memory); + let reader = BufferReader::new(cursor); + let mut checkpoint_reader = CheckpointReader::new(reader); + + for row_index in 0..self.min_accepted_rows { + if checkpoint_reader.eof()? { + break; + } + + for column_index in 0..deserializers.len() { + if checkpoint_reader.ignore_white_spaces_and_byte(self.field_delimiter)? { + deserializers[column_index].de_default(&self.settings); + } else { + deserializers[column_index] + .de_text_csv(&mut checkpoint_reader, &self.settings)?; + + if column_index + 1 != deserializers.len() { + checkpoint_reader + .must_ignore_white_spaces_and_byte(self.field_delimiter)?; + } + } + } + + checkpoint_reader.ignore_white_spaces_and_byte(self.field_delimiter)?; + + if let Some(delimiter) = &self.row_delimiter { + if !checkpoint_reader.ignore_white_spaces_and_byte(*delimiter)? + && !checkpoint_reader.eof()? + { + return Err(ErrorCode::BadBytes(format!( + "Parse csv error at line {}", + row_index + ))); + } + } else { + if (!checkpoint_reader.ignore_white_spaces_and_byte(b'\n')? + & !checkpoint_reader.ignore_white_spaces_and_byte(b'\r')?) + && !checkpoint_reader.eof()? + { + return Err(ErrorCode::BadBytes(format!( + "Parse csv error at line {}", + row_index + ))); + } + + // \r\n + checkpoint_reader.ignore_white_spaces_and_byte(b'\n')?; + } + } + + let mut columns = Vec::with_capacity(deserializers.len()); + for deserializer in &mut deserializers { + columns.push(deserializer.finish_to_column()); + } + + Ok(DataBlock::create(self.schema.clone(), columns)) + } + + fn read_buf(&self, buf: &[u8], state: &mut Box) -> Result { + let mut index = 0; + let state = state.as_any().downcast_mut::().unwrap(); + + if let Some(first) = state.ignore_if_first.take() { + if buf[0] == first { + index += 1; + } + } + + state.need_more_data = true; + while index < buf.len() && state.need_more_data { + index = match state.quotes != 0 { + true => Self::find_quotes(buf, index, state), + false => self.find_delimiter(buf, index, state), + } + } + + state.memory.extend_from_slice(&buf[0..index]); + Ok(index) + } + + fn skip_header(&self, buf: &[u8], state: &mut Box) -> Result { + if self.need_skip_header { + let mut index = 0; + let state = state.as_any().downcast_mut::().unwrap(); + + while index < buf.len() { + index = match state.quotes != 0 { + true => Self::find_quotes(buf, index, state), + false => self.find_delimiter(buf, index, state), + }; + + if state.accepted_rows == 1 { + return Ok(index); + } + } + } + + Ok(0) + } +} diff --git a/query/src/formats/format_factory.rs b/query/src/formats/format_factory.rs new file mode 100644 index 000000000000..e13fc8ac1fa8 --- /dev/null +++ b/query/src/formats/format_factory.rs @@ -0,0 +1,76 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use common_datavalues::DataSchemaRef; +use common_exception::ErrorCode; +use common_exception::Result; +use common_io::prelude::FormatSettings; +use once_cell::sync::Lazy; + +use crate::formats::format::InputFormat; +use crate::formats::format_csv::CsvInputFormat; + +pub type InputFormatFactoryCreator = + Box Result> + Send + Sync>; + +pub struct FormatFactory { + case_insensitive_desc: HashMap, +} + +static FORMAT_FACTORY: Lazy> = Lazy::new(|| { + let mut format_factory = FormatFactory::create(); + + CsvInputFormat::register(&mut format_factory); + + Arc::new(format_factory) +}); + +impl FormatFactory { + pub(in crate::formats::format_factory) fn create() -> FormatFactory { + FormatFactory { + case_insensitive_desc: Default::default(), + } + } + + pub fn instance() -> &'static FormatFactory { + FORMAT_FACTORY.as_ref() + } + + pub fn register_input(&mut self, name: &str, creator: InputFormatFactoryCreator) { + let case_insensitive_desc = &mut self.case_insensitive_desc; + case_insensitive_desc.insert(name.to_lowercase(), creator); + } + + pub fn get_input( + &self, + name: impl AsRef, + schema: DataSchemaRef, + settings: FormatSettings, + ) -> Result> { + let origin_name = name.as_ref(); + let lowercase_name = origin_name.to_lowercase(); + + let creator = self + .case_insensitive_desc + .get(&lowercase_name) + .ok_or_else(|| { + ErrorCode::UnknownFormat(format!("Unsupported formats: {}", origin_name)) + })?; + + creator(origin_name, schema, settings) + } +} diff --git a/query/src/formats/mod.rs b/query/src/formats/mod.rs new file mode 100644 index 000000000000..9bcafc2008c1 --- /dev/null +++ b/query/src/formats/mod.rs @@ -0,0 +1,21 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod format; +pub mod format_csv; +mod format_factory; + +pub use format::InputFormat; +pub use format::InputState; +pub use format_factory::FormatFactory; diff --git a/query/src/interpreters/interpreter_insert.rs b/query/src/interpreters/interpreter_insert.rs index 87f572e68f4e..490b483f9ec8 100644 --- a/query/src/interpreters/interpreter_insert.rs +++ b/query/src/interpreters/interpreter_insert.rs @@ -284,8 +284,8 @@ impl Interpreter for InsertInterpreter { } fn create_new_pipeline(&self) -> Result { - let new_pipeline = NewPipeline::create(); - Ok(new_pipeline) + let insert_pipeline = NewPipeline::create(); + Ok(insert_pipeline) } fn set_source_pipe_builder(&self, builder: Option) -> Result<()> { diff --git a/query/src/lib.rs b/query/src/lib.rs index c48ee7aca3e2..d76db2afdd84 100644 --- a/query/src/lib.rs +++ b/query/src/lib.rs @@ -25,6 +25,7 @@ pub mod clusters; pub mod common; pub mod configs; pub mod databases; +pub mod formats; pub mod interpreters; pub mod metrics; pub mod optimizers; diff --git a/query/src/servers/http/v1/load.rs b/query/src/servers/http/v1/load.rs index d74d452c2870..1bffb27888f4 100644 --- a/query/src/servers/http/v1/load.rs +++ b/query/src/servers/http/v1/load.rs @@ -12,12 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::future::Future; use std::sync::Arc; use async_compat::CompatExt; use async_stream::stream; use common_base::base::ProgressValues; +use common_base::base::TrySpawn; +use common_datavalues::DataSchemaRef; use common_exception::ErrorCode; +use common_exception::Result; use common_exception::ToErrorCode; use common_io::prelude::parse_escape_string; use common_io::prelude::FormatSettings; @@ -45,6 +49,8 @@ use crate::interpreters::InterpreterFactory; use crate::pipelines::new::processors::port::OutputPort; use crate::pipelines::new::processors::StreamSourceV2; use crate::pipelines::new::SourcePipeBuilder; +use crate::servers::http::v1::multipart_format::MultipartFormat; +use crate::servers::http::v1::multipart_format::MultipartWorker; use crate::sessions::QueryContext; use crate::sessions::SessionType; use crate::sql::PlanParser; @@ -57,6 +63,74 @@ pub struct LoadResponse { pub error: Option, } +fn get_input_format(node: &PlanNode) -> Result<&str> { + match node { + PlanNode::Insert(insert) => match &insert.source { + InsertInputSource::StreamingWithFormat(format) => Ok(format), + _ => Err(ErrorCode::UnknownFormat("Not found format name in plan")), + }, + _ => Err(ErrorCode::UnknownFormat("Not found format name in plan")), + } +} + +#[allow(clippy::manual_async_fn)] +fn execute_query( + context: Arc, + node: PlanNode, + source_builder: SourcePipeBuilder, +) -> impl Future> { + async move { + let interpreter = InterpreterFactory::get(context, node)?; + + if let Err(cause) = interpreter.start().await { + tracing::error!("interpreter.start error: {:?}", cause); + } + + // TODO(Winter): very hack code. need remove it. + interpreter.set_source_pipe_builder(Option::from(source_builder))?; + + let mut data_stream = interpreter.execute(None).await?; + + while let Some(_block) = data_stream.next().await {} + + // Write Finish to query log table. + if let Err(cause) = interpreter.finish().await { + tracing::error!("interpreter.finish error: {:?}", cause); + } + + Ok(()) + } +} + +async fn new_processor_format( + ctx: &Arc, + node: &PlanNode, + multipart: Multipart, +) -> Result> { + let format = get_input_format(node)?; + let format_settings = ctx.get_format_settings()?; + + let (mut worker, builder) = + format_source_pipe_builder(format, node.schema(), multipart, &format_settings)?; + + let handler = ctx.spawn(execute_query(ctx.clone(), node.clone(), builder)); + + worker.work().await; + + match handler.await { + Ok(Ok(_)) => Ok(()), + Ok(Err(cause)) => Err(cause), + Err(_) => Err(ErrorCode::TokioError("Maybe panic.")), + }?; + + Ok(Json(LoadResponse { + error: None, + state: "SUCCESS".to_string(), + id: uuid::Uuid::new_v4().to_string(), + stats: ctx.get_scan_progress_value(), + })) +} + #[poem::handler] pub async fn streaming_load( ctx: &HttpQueryContext, @@ -114,15 +188,16 @@ pub async fn streaming_load( PlanNode::Insert(insert) => match &insert.source { InsertInputSource::StreamingWithFormat(format) => { if format.to_lowercase().as_str() == "csv" { - csv_source_pipe_builder( - context.clone(), - &plan, - &format_settings, - multipart, - max_block_size, - ) - .await - } else if format.to_lowercase().as_str() == "parquet" { + return match new_processor_format(&context, &plan, multipart).await { + Ok(res) => Ok(res), + Err(cause) => { + println!("catch error {:?}", cause); + Err(InternalServerError(cause)) + } + }; + } + + if format.to_lowercase().as_str() == "parquet" { parquet_source_pipe_builder(context.clone(), &plan, multipart).await } else if format.to_lowercase().as_str() == "ndjson" || format.to_lowercase().as_str() == "jsoneachrow" @@ -328,30 +403,27 @@ fn build_csv_stream( Ok(Box::pin(stream)) } -async fn csv_source_pipe_builder( - ctx: Arc, - plan: &PlanNode, +fn format_source_pipe_builder( + format: &str, + schema: DataSchemaRef, + multipart: Multipart, format_settings: &FormatSettings, - mut multipart: Multipart, - block_size: usize, -) -> PoemResult { - let mut builder = CsvSourceBuilder::create(plan.schema(), format_settings.clone()); - builder.block_size(block_size); +) -> Result<(MultipartWorker, SourcePipeBuilder)> { + let ports = vec![OutputPort::create()]; let mut source_pipe_builder = SourcePipeBuilder::create(); - while let Ok(Some(field)) = multipart.next_field().await { - let bytes = field - .bytes() - .await - .map_err_to_code(ErrorCode::BadBytes, || "Read part to field bytes error") - .unwrap(); - let cursor = Cursor::new(bytes); - let csv_source = builder.build(cursor).unwrap(); - let output_port = OutputPort::create(); - let source = - StreamSourceV2::create(ctx.clone(), Box::new(csv_source), output_port.clone()).unwrap(); - source_pipe_builder.add_source(output_port, source); + let (worker, sources) = MultipartFormat::input_sources( + format, + multipart, + schema, + format_settings.clone(), + ports.clone(), + )?; + + for (index, source) in sources.into_iter().enumerate() { + source_pipe_builder.add_source(ports[index].clone(), source); } - Ok(source_pipe_builder) + + Ok((worker, source_pipe_builder)) } async fn parquet_source_pipe_builder( diff --git a/query/src/servers/http/v1/mod.rs b/query/src/servers/http/v1/mod.rs index 2ee7b1e77c6d..b807b72012f9 100644 --- a/query/src/servers/http/v1/mod.rs +++ b/query/src/servers/http/v1/mod.rs @@ -15,6 +15,7 @@ mod http_query_handlers; pub mod json_block; mod load; +mod multipart_format; mod query; mod stage; diff --git a/query/src/servers/http/v1/multipart_format.rs b/query/src/servers/http/v1/multipart_format.rs new file mode 100644 index 000000000000..4746786b6b76 --- /dev/null +++ b/query/src/servers/http/v1/multipart_format.rs @@ -0,0 +1,298 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::mem::replace; +use std::sync::Arc; + +use common_base::base::tokio::io::AsyncReadExt; +use common_base::base::tokio::sync::mpsc::Receiver; +use common_base::base::tokio::sync::mpsc::Sender; +use common_datablocks::DataBlock; +use common_datavalues::DataSchemaRef; +use common_exception::ErrorCode; +use common_exception::Result; +use common_io::prelude::FormatSettings; +use poem::web::Multipart; + +use crate::formats::FormatFactory; +use crate::formats::InputFormat; +use crate::formats::InputState; +use crate::pipelines::new::processors::port::OutputPort; +use crate::pipelines::new::processors::processor::Event; +use crate::pipelines::new::processors::processor::ProcessorPtr; +use crate::pipelines::new::processors::Processor; + +pub struct MultipartFormat; + +pub struct MultipartWorker { + multipart: Multipart, + tx: Option>>>, +} + +impl MultipartWorker { + pub async fn work(&mut self) { + if let Some(tx) = self.tx.take() { + 'outer: loop { + match self.multipart.next_field().await { + Err(cause) => { + if let Err(cause) = tx + .send(Err(ErrorCode::BadBytes(format!( + "Parse multipart error, cause {:?}", + cause + )))) + .await + { + common_tracing::tracing::warn!( + "Multipart channel disconnect. {}", + cause + ); + + break 'outer; + } + } + Ok(None) => { + break 'outer; + } + Ok(Some(field)) => { + if let Err(cause) = tx.send(Ok(vec![])).await { + common_tracing::tracing::warn!( + "Multipart channel disconnect. {}", + cause + ); + + break 'outer; + } + + let mut async_reader = field.into_async_read(); + + 'read: loop { + // 1048576 from clickhouse DBMS_DEFAULT_BUFFER_SIZE + let mut buf = vec![0; 1048576]; + let read_res = async_reader.read(&mut buf[..]).await; + + match read_res { + Ok(0) => { + break 'read; + } + Ok(sz) => { + if sz != buf.len() { + buf = buf[..sz].to_vec(); + } + + if let Err(cause) = tx.send(Ok(buf)).await { + common_tracing::tracing::warn!( + "Multipart channel disconnect. {}", + cause + ); + + break 'outer; + } + } + Err(cause) => { + if let Err(cause) = tx + .send(Err(ErrorCode::BadBytes(format!( + "Read part to field bytes error, cause {:?}", + cause + )))) + .await + { + common_tracing::tracing::warn!( + "Multipart channel disconnect. {}", + cause + ); + break 'outer; + } + + break 'outer; + } + } + } + } + } + } + } + } +} + +impl MultipartFormat { + pub fn input_sources( + name: &str, + multipart: Multipart, + schema: DataSchemaRef, + settings: FormatSettings, + ports: Vec>, + ) -> Result<(MultipartWorker, Vec)> { + let input_format = FormatFactory::instance().get_input(name, schema, settings)?; + + if ports.len() != 1 || input_format.support_parallel() { + return Err(ErrorCode::UnImplement( + "Unimplemented parallel input format.", + )); + } + + let (tx, rx) = common_base::base::tokio::sync::mpsc::channel(2); + + Ok(( + MultipartWorker { + multipart, + tx: Some(tx), + }, + vec![SequentialInputFormatSource::create( + ports[0].clone(), + input_format, + rx, + )?], + )) + } +} + +enum State { + NeedReceiveData, + ReceivedData(Vec), + NeedDeserialize, +} + +pub struct SequentialInputFormatSource { + state: State, + finished: bool, + skipped_header: bool, + output: Arc, + data_block: Vec, + input_state: Box, + input_format: Box, + data_receiver: Receiver>>, +} + +impl SequentialInputFormatSource { + pub fn create( + output: Arc, + input_format: Box, + data_receiver: Receiver>>, + ) -> Result { + let input_state = input_format.create_state(); + Ok(ProcessorPtr::create(Box::new( + SequentialInputFormatSource { + output, + input_state, + input_format, + data_receiver, + finished: false, + state: State::NeedReceiveData, + data_block: vec![], + skipped_header: false, + }, + ))) + } +} + +#[async_trait::async_trait] +impl Processor for SequentialInputFormatSource { + fn name(&self) -> &'static str { + "SequentialInputFormatSource" + } + + fn event(&mut self) -> Result { + if self.output.is_finished() { + return Ok(Event::Finished); + } + + if !self.output.can_push() { + return Ok(Event::NeedConsume); + } + + if let Some(data_block) = self.data_block.pop() { + self.output.push_data(Ok(data_block)); + return Ok(Event::NeedConsume); + } + + if self.finished && !matches!(&self.state, State::NeedDeserialize) { + self.output.finish(); + return Ok(Event::Finished); + } + + match &self.state { + State::NeedReceiveData => Ok(Event::Async), + State::ReceivedData(_data) => Ok(Event::Sync), + State::NeedDeserialize => Ok(Event::Sync), + } + } + + fn process(&mut self) -> Result<()> { + match replace(&mut self.state, State::NeedReceiveData) { + State::ReceivedData(data) => { + let mut data_slice: &[u8] = &data; + + if !self.skipped_header { + let len = data_slice.len(); + let skip_size = self + .input_format + .skip_header(data_slice, &mut self.input_state)?; + + data_slice = &data_slice[skip_size..]; + + if skip_size < len { + self.skipped_header = true; + self.input_state = self.input_format.create_state(); + } + } + + while !data_slice.is_empty() { + let len = data_slice.len(); + let read_size = self + .input_format + .read_buf(data_slice, &mut self.input_state)?; + + data_slice = &data_slice[read_size..]; + + if read_size < len { + self.data_block + .push(self.input_format.deserialize_data(&mut self.input_state)?); + } + } + } + State::NeedDeserialize => { + self.data_block + .push(self.input_format.deserialize_data(&mut self.input_state)?); + } + _ => { + return Err(ErrorCode::LogicalError( + "State failure in Multipart format.", + )); + } + } + + Ok(()) + } + + async fn async_process(&mut self) -> Result<()> { + if let State::NeedReceiveData = replace(&mut self.state, State::NeedReceiveData) { + if let Some(receive_res) = self.data_receiver.recv().await { + let receive_bytes = receive_res?; + + if !receive_bytes.is_empty() { + self.state = State::ReceivedData(receive_bytes); + } else { + self.skipped_header = false; + self.state = State::NeedDeserialize; + } + + return Ok(()); + } + } + + self.finished = true; + self.state = State::NeedDeserialize; + Ok(()) + } +} diff --git a/query/tests/it/formats/format_csv.rs b/query/tests/it/formats/format_csv.rs new file mode 100644 index 000000000000..cc160893b188 --- /dev/null +++ b/query/tests/it/formats/format_csv.rs @@ -0,0 +1,188 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use common_datablocks::assert_blocks_eq; +use common_datavalues::type_primitive::UInt32Type; +use common_datavalues::DataField; +use common_datavalues::DataSchema; +use common_datavalues::DataTypeImpl; +use common_datavalues::StringType; +use common_exception::Result; +use common_io::prelude::FormatSettings; +use databend_query::formats::format_csv::CsvInputFormat; +use databend_query::formats::format_csv::CsvInputState; + +#[test] +fn test_accepted_multi_lines() -> Result<()> { + assert_complete_line("")?; + assert_complete_line("first,second\n")?; + assert_complete_line("first,second\r")?; + assert_complete_line("first,second\r\n")?; + assert_complete_line("first,second\n\r")?; + assert_complete_line("first,\"\n\"second\n")?; + assert_complete_line("first,\"\r\"second\n")?; + + assert_broken_line("first", 5)?; + assert_broken_line("first,", 6)?; + assert_broken_line("first,s", 7)?; + assert_broken_line("first,s\"\n", 9)?; + assert_broken_line("first,s\"\r", 9)?; + assert_broken_line("first,second\ns", 13)?; + + let csv_input_format = CsvInputFormat::try_create( + "csv", + Arc::new(DataSchema::empty()), + FormatSettings::default(), + 2, + 10 * 1024 * 1024, + )?; + + let mut csv_input_state = csv_input_format.create_state(); + + let bytes = "first,second\nfirst,".as_bytes(); + assert_eq!( + bytes.len(), + csv_input_format.read_buf(bytes, &mut csv_input_state)? + ); + assert_eq!( + bytes, + &csv_input_state + .as_any() + .downcast_mut::() + .unwrap() + .memory + ); + + let bytes = "second\nfirst,".as_bytes(); + assert_eq!(7, csv_input_format.read_buf(bytes, &mut csv_input_state)?); + assert_eq!( + "first,second\nfirst,second\n".as_bytes(), + csv_input_state + .as_any() + .downcast_mut::() + .unwrap() + .memory + ); + Ok(()) +} + +#[test] +fn test_deserialize_multi_lines() -> Result<()> { + let csv_input_format = CsvInputFormat::try_create( + "csv", + Arc::new(DataSchema::new(vec![ + DataField::new("a", DataTypeImpl::UInt32(UInt32Type::default())), + DataField::new("b", DataTypeImpl::String(StringType::default())), + ])), + FormatSettings::default(), + 1, + 10 * 1024 * 1024, + )?; + + let mut csv_input_state = csv_input_format.create_state(); + + csv_input_format.read_buf("1,\"second\"\n".as_bytes(), &mut csv_input_state)?; + assert_blocks_eq( + vec![ + "+---+--------+", + "| a | b |", + "+---+--------+", + "| 1 | second |", + "+---+--------+", + ], + &[csv_input_format.deserialize_data(&mut csv_input_state)?], + ); + + let csv_input_format = CsvInputFormat::try_create( + "csv", + Arc::new(DataSchema::new(vec![ + DataField::new("a", DataTypeImpl::UInt32(UInt32Type::default())), + DataField::new("b", DataTypeImpl::String(StringType::default())), + ])), + FormatSettings::default(), + 2, + 10 * 1024 * 1024, + )?; + + let mut csv_input_state = csv_input_format.create_state(); + + csv_input_format.read_buf("1,\"second\"\n".as_bytes(), &mut csv_input_state)?; + assert_blocks_eq( + vec![ + "+---+--------+", + "| a | b |", + "+---+--------+", + "| 1 | second |", + "+---+--------+", + ], + &[csv_input_format.deserialize_data(&mut csv_input_state)?], + ); + Ok(()) +} + +fn assert_complete_line(content: &str) -> Result<()> { + let csv_input_format = CsvInputFormat::try_create( + "csv", + Arc::new(DataSchema::empty()), + FormatSettings::default(), + 1, + 10 * 1024 * 1024, + )?; + + let mut csv_input_state = csv_input_format.create_state(); + + let bytes = content.as_bytes(); + assert_eq!( + bytes.len(), + csv_input_format.read_buf(bytes, &mut csv_input_state)? + ); + assert_eq!( + bytes, + &csv_input_state + .as_any() + .downcast_mut::() + .unwrap() + .memory + ); + Ok(()) +} + +fn assert_broken_line(content: &str, assert_size: usize) -> Result<()> { + let csv_input_format = CsvInputFormat::try_create( + "csv", + Arc::new(DataSchema::empty()), + FormatSettings::default(), + 1, + 10 * 1024 * 1024, + )?; + + let mut csv_input_state = csv_input_format.create_state(); + + let bytes = content.as_bytes(); + assert_eq!( + assert_size, + csv_input_format.read_buf(bytes, &mut csv_input_state)? + ); + assert_eq!( + &bytes[0..assert_size], + &csv_input_state + .as_any() + .downcast_mut::() + .unwrap() + .memory + ); + Ok(()) +} diff --git a/query/tests/it/formats/mod.rs b/query/tests/it/formats/mod.rs new file mode 100644 index 000000000000..0e948f332fc7 --- /dev/null +++ b/query/tests/it/formats/mod.rs @@ -0,0 +1,15 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod format_csv; diff --git a/query/tests/it/main.rs b/query/tests/it/main.rs index c1b3ec9555b3..41670b9fa564 100644 --- a/query/tests/it/main.rs +++ b/query/tests/it/main.rs @@ -16,6 +16,7 @@ mod catalogs; mod clusters; mod common; mod configs; +mod formats; mod functions; mod interpreters; mod metrics;