Skip to content

Commit

Permalink
allow to read non-standard CSV (#326) (#407)
Browse files Browse the repository at this point in the history
* refactor Reader::from_reader

split into build_csv_reader, from_csv_reader
add escape, quote, terminator arg to build_csv_reader

* add escape,quote,terminator field to ReaderBuilder

schema inference support for non-standard CSV

  add fn infer_file_schema_with_csv_options
  add fn infer_reader_schema_with_csv_options

ReaderBuilder support for non-standard CSV

add escape, quote, terminator field
add fn with_escape, with_quote, with_terminator
change ReaderBuilder::build for non-standard CSV

* minimize API change

* add tests

add #[test] fn test_non_std_quote
add #[test] fn test_non_std_escape
add #[test] fn test_non_std_terminator

* apply cargo fmt

Co-authored-by: kazuhiko kikuchi <kazuk.dll@kazuk.jp>
  • Loading branch information
alamb and kazuk authored Jun 5, 2021
1 parent db581f3 commit c928d57
Showing 1 changed file with 225 additions and 10 deletions.
235 changes: 225 additions & 10 deletions arrow/src/csv/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,38 @@ pub fn infer_file_schema<R: Read + Seek>(
delimiter: u8,
max_read_records: Option<usize>,
has_header: bool,
) -> Result<(Schema, usize)> {
infer_file_schema_with_csv_options(
reader,
delimiter,
max_read_records,
has_header,
None,
None,
None,
)
}

fn infer_file_schema_with_csv_options<R: Read + Seek>(
reader: &mut R,
delimiter: u8,
max_read_records: Option<usize>,
has_header: bool,
escape: Option<u8>,
quote: Option<u8>,
terminator: Option<u8>,
) -> Result<(Schema, usize)> {
let saved_offset = reader.seek(SeekFrom::Current(0))?;

let (schema, records_count) =
infer_reader_schema(reader, delimiter, max_read_records, has_header)?;
let (schema, records_count) = infer_reader_schema_with_csv_options(
reader,
delimiter,
max_read_records,
has_header,
escape,
quote,
terminator,
)?;

// return the reader seek back to the start
reader.seek(SeekFrom::Start(saved_offset))?;
Expand All @@ -131,9 +158,34 @@ pub fn infer_reader_schema<R: Read>(
max_read_records: Option<usize>,
has_header: bool,
) -> Result<(Schema, usize)> {
let mut csv_reader = csv_crate::ReaderBuilder::new()
.delimiter(delimiter)
.from_reader(reader);
infer_reader_schema_with_csv_options(
reader,
delimiter,
max_read_records,
has_header,
None,
None,
None,
)
}

fn infer_reader_schema_with_csv_options<R: Read>(
reader: &mut R,
delimiter: u8,
max_read_records: Option<usize>,
has_header: bool,
escape: Option<u8>,
quote: Option<u8>,
terminator: Option<u8>,
) -> Result<(Schema, usize)> {
let mut csv_reader = Reader::build_csv_reader(
reader,
has_header,
Some(delimiter),
escape,
quote,
terminator,
);

// get or create header names
// when has_header is false, creates default column names with column_ prefix
Expand Down Expand Up @@ -324,15 +376,45 @@ impl<R: Read> Reader<R> {
bounds: Bounds,
projection: Option<Vec<usize>>,
) -> Self {
let csv_reader =
Self::build_csv_reader(reader, has_header, delimiter, None, None, None);
Self::from_csv_reader(
csv_reader, schema, has_header, batch_size, bounds, projection,
)
}

fn build_csv_reader(
reader: R,
has_header: bool,
delimiter: Option<u8>,
escape: Option<u8>,
quote: Option<u8>,
terminator: Option<u8>,
) -> csv_crate::Reader<R> {
let mut reader_builder = csv_crate::ReaderBuilder::new();
reader_builder.has_headers(has_header);

if let Some(c) = delimiter {
reader_builder.delimiter(c);
}
reader_builder.escape(escape);
if let Some(c) = quote {
reader_builder.quote(c);
}
if let Some(t) = terminator {
reader_builder.terminator(csv_crate::Terminator::Any(t));
}
reader_builder.from_reader(reader)
}

let mut csv_reader = reader_builder.from_reader(reader);

fn from_csv_reader(
mut csv_reader: csv_crate::Reader<R>,
schema: SchemaRef,
has_header: bool,
batch_size: usize,
bounds: Bounds,
projection: Option<Vec<usize>>,
) -> Self {
let (start, end) = match bounds {
None => (0, usize::MAX),
Some((start, end)) => (start, end),
Expand Down Expand Up @@ -731,6 +813,12 @@ pub struct ReaderBuilder {
has_header: bool,
/// An optional column delimiter. Defaults to `b','`
delimiter: Option<u8>,
/// An optional escape charactor. Defaults None
escape: Option<u8>,
/// An optional quote charactor. Defaults b'\"'
quote: Option<u8>,
/// An optional record terminator. Defaults CRLF
terminator: Option<u8>,
/// Optional maximum number of records to read during schema inference
///
/// If a number is not provided, all the records are read.
Expand All @@ -751,6 +839,9 @@ impl Default for ReaderBuilder {
schema: None,
has_header: false,
delimiter: None,
escape: None,
quote: None,
terminator: None,
max_records: None,
batch_size: 1024,
bounds: None,
Expand Down Expand Up @@ -805,6 +896,21 @@ impl ReaderBuilder {
self
}

pub fn with_escape(mut self, escape: u8) -> Self {
self.escape = Some(escape);
self
}

pub fn with_quote(mut self, quote: u8) -> Self {
self.quote = Some(quote);
self
}

pub fn with_terminator(mut self, terminator: u8) -> Self {
self.terminator = Some(terminator);
self
}

/// Set the CSV reader to infer the schema of the file
pub fn infer_schema(mut self, max_records: Option<usize>) -> Self {
// remove any schema that is set
Expand Down Expand Up @@ -832,21 +938,31 @@ impl ReaderBuilder {
let schema = match self.schema {
Some(schema) => schema,
None => {
let (inferred_schema, _) = infer_file_schema(
let (inferred_schema, _) = infer_file_schema_with_csv_options(
&mut reader,
delimiter,
self.max_records,
self.has_header,
self.escape,
self.quote,
self.terminator,
)?;

Arc::new(inferred_schema)
}
};
Ok(Reader::from_reader(
let csv_reader = Reader::build_csv_reader(
reader,
schema,
self.has_header,
self.delimiter,
self.escape,
self.quote,
self.terminator,
);
Ok(Reader::from_csv_reader(
csv_reader,
schema,
self.has_header,
self.batch_size,
None,
self.projection.clone(),
Expand Down Expand Up @@ -1383,4 +1499,103 @@ mod tests {
assert_eq!(None, parse_item::<Float64Type>("dd"));
assert_eq!(None, parse_item::<Float64Type>("12.34.56"));
}

#[test]
fn test_non_std_quote() {
let schema = Schema::new(vec![
Field::new("text1", DataType::Utf8, false),
Field::new("text2", DataType::Utf8, false),
]);
let builder = ReaderBuilder::new()
.with_schema(Arc::new(schema))
.has_header(false)
.with_quote(b'~'); // default is ", change to ~

let mut csv_text = Vec::new();
let mut csv_writer = std::io::Cursor::new(&mut csv_text);
for index in 0..10 {
let text1 = format!("id{:}", index);
let text2 = format!("value{:}", index);
csv_writer
.write_fmt(format_args!("~{}~,~{}~\r\n", text1, text2))
.unwrap();
}
let mut csv_reader = std::io::Cursor::new(&csv_text);
let mut reader = builder.build(&mut csv_reader).unwrap();
let batch = reader.next().unwrap().unwrap();
let col0 = batch.column(0);
assert_eq!(col0.len(), 10);
let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(col0_arr.value(0), "id0");
let col1 = batch.column(1);
assert_eq!(col1.len(), 10);
let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(col1_arr.value(5), "value5");
}

#[test]
fn test_non_std_escape() {
let schema = Schema::new(vec![
Field::new("text1", DataType::Utf8, false),
Field::new("text2", DataType::Utf8, false),
]);
let builder = ReaderBuilder::new()
.with_schema(Arc::new(schema))
.has_header(false)
.with_escape(b'\\'); // default is None, change to \

let mut csv_text = Vec::new();
let mut csv_writer = std::io::Cursor::new(&mut csv_text);
for index in 0..10 {
let text1 = format!("id{:}", index);
let text2 = format!("value\\\"{:}", index);
csv_writer
.write_fmt(format_args!("\"{}\",\"{}\"\r\n", text1, text2))
.unwrap();
}
let mut csv_reader = std::io::Cursor::new(&csv_text);
let mut reader = builder.build(&mut csv_reader).unwrap();
let batch = reader.next().unwrap().unwrap();
let col0 = batch.column(0);
assert_eq!(col0.len(), 10);
let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(col0_arr.value(0), "id0");
let col1 = batch.column(1);
assert_eq!(col1.len(), 10);
let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(col1_arr.value(5), "value\"5");
}

#[test]
fn test_non_std_terminator() {
let schema = Schema::new(vec![
Field::new("text1", DataType::Utf8, false),
Field::new("text2", DataType::Utf8, false),
]);
let builder = ReaderBuilder::new()
.with_schema(Arc::new(schema))
.has_header(false)
.with_terminator(b'\n'); // default is CRLF, change to LF

let mut csv_text = Vec::new();
let mut csv_writer = std::io::Cursor::new(&mut csv_text);
for index in 0..10 {
let text1 = format!("id{:}", index);
let text2 = format!("value{:}", index);
csv_writer
.write_fmt(format_args!("\"{}\",\"{}\"\n", text1, text2))
.unwrap();
}
let mut csv_reader = std::io::Cursor::new(&csv_text);
let mut reader = builder.build(&mut csv_reader).unwrap();
let batch = reader.next().unwrap().unwrap();
let col0 = batch.column(0);
assert_eq!(col0.len(), 10);
let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(col0_arr.value(0), "id0");
let col1 = batch.column(1);
assert_eq!(col1.len(), 10);
let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(col1_arr.value(5), "value5");
}
}

0 comments on commit c928d57

Please sign in to comment.