-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CSV options to the CSV parser #28491
Changes from all commits
f744e2c
fb5a57d
3230205
f6a67db
d01200b
b271a9e
7add1c7
e8c88be
9e73b51
84cabeb
79f7748
0ae95da
6324257
0fd42ca
8b54aff
b9a4a71
9982834
32844ce
cd48738
43ce434
df47586
ce9a672
2c03349
c445b02
ff8f5d4
87a3bcb
9a1954f
72caf7d
ecea4e0
cf298b7
8cd05a4
d1fb6ae
124cfcf
a9ee16b
a629ef0
f11a551
da274bc
b373221
ce51b3d
f8d76a1
b857737
e8609c4
59f00be
1cdaf60
d1c4036
903074a
bdfccee
355d596
0426b4c
8b7d519
8a7bcf7
9252651
06157dc
e5a1c0e
9c9dc72
146680a
4cfd721
6f10047
e4986e8
1f57507
0441c28
c2b3a37
d8538f9
16df89d
7d7f6dd
cef6a41
cec32dc
05067a7
bdbd413
bf525b4
d32a94f
69240b0
bfe4d47
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,12 +5,13 @@ | |
import csv | ||
import json | ||
import logging | ||
from distutils.util import strtobool | ||
from typing import Any, Dict, Iterable, Mapping, Optional | ||
from functools import partial | ||
from io import IOBase | ||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set | ||
|
||
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat, QuotingBehavior | ||
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig | ||
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError | ||
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, RecordParseError | ||
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode | ||
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser | ||
from airbyte_cdk.sources.file_based.remote_file import RemoteFile | ||
|
@@ -34,30 +35,25 @@ async def infer_schema( | |
stream_reader: AbstractFileBasedStreamReader, | ||
logger: logging.Logger, | ||
) -> Dict[str, Any]: | ||
config_format = config.format.get(config.file_type) if config.format else None | ||
if config_format: | ||
if not isinstance(config_format, CsvFormat): | ||
raise ValueError(f"Invalid format config: {config_format}") | ||
dialect_name = config.name + DIALECT_NAME | ||
csv.register_dialect( | ||
dialect_name, | ||
delimiter=config_format.delimiter, | ||
quotechar=config_format.quote_char, | ||
escapechar=config_format.escape_char, | ||
doublequote=config_format.double_quote, | ||
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL), | ||
) | ||
with stream_reader.open_file(file, self.file_read_mode, logger) as fp: | ||
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual | ||
# sources will likely require one. Rather than modify the interface now we can wait until the real use case | ||
reader = csv.DictReader(fp, dialect=dialect_name) # type: ignore | ||
schema = {field.strip(): {"type": "string"} for field in next(reader)} | ||
csv.unregister_dialect(dialect_name) | ||
return schema | ||
else: | ||
with stream_reader.open_file(file, self.file_read_mode, logger) as fp: | ||
reader = csv.DictReader(fp) # type: ignore | ||
return {field.strip(): {"type": "string"} for field in next(reader)} | ||
config_format = config.format.get(config.file_type) if config.format else CsvFormat() | ||
if not isinstance(config_format, CsvFormat): | ||
raise ValueError(f"Invalid format config: {config_format}") | ||
dialect_name = config.name + DIALECT_NAME | ||
csv.register_dialect( | ||
dialect_name, | ||
delimiter=config_format.delimiter, | ||
quotechar=config_format.quote_char, | ||
escapechar=config_format.escape_char, | ||
doublequote=config_format.double_quote, | ||
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL), | ||
) | ||
with stream_reader.open_file(file, self.file_read_mode, logger) as fp: | ||
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual | ||
# sources will likely require one. Rather than modify the interface now we can wait until the real use case | ||
headers = self._get_headers(fp, config_format, dialect_name) | ||
schema = {field.strip(): {"type": "string"} for field in headers} | ||
csv.unregister_dialect(dialect_name) | ||
return schema | ||
|
||
def parse_records( | ||
self, | ||
|
@@ -67,38 +63,36 @@ def parse_records( | |
logger: logging.Logger, | ||
) -> Iterable[Dict[str, Any]]: | ||
schema: Mapping[str, Any] = config.input_schema # type: ignore | ||
config_format = config.format.get(config.file_type) if config.format else None | ||
if config_format: | ||
if not isinstance(config_format, CsvFormat): | ||
raise ValueError(f"Invalid format config: {config_format}") | ||
# Formats are configured individually per-stream so a unique dialect should be registered for each stream. | ||
# Wwe don't unregister the dialect because we are lazily parsing each csv file to generate records | ||
dialect_name = config.name + DIALECT_NAME | ||
csv.register_dialect( | ||
dialect_name, | ||
delimiter=config_format.delimiter, | ||
quotechar=config_format.quote_char, | ||
escapechar=config_format.escape_char, | ||
doublequote=config_format.double_quote, | ||
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL), | ||
) | ||
with stream_reader.open_file(file, self.file_read_mode, logger) as fp: | ||
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual | ||
# sources will likely require one. Rather than modify the interface now we can wait until the real use case | ||
reader = csv.DictReader(fp, dialect=dialect_name) # type: ignore | ||
yield from self._read_and_cast_types(reader, schema, logger) | ||
else: | ||
with stream_reader.open_file(file, self.file_read_mode, logger) as fp: | ||
reader = csv.DictReader(fp) # type: ignore | ||
yield from self._read_and_cast_types(reader, schema, logger) | ||
config_format = config.format.get(config.file_type) if config.format else CsvFormat() | ||
if not isinstance(config_format, CsvFormat): | ||
raise ValueError(f"Invalid format config: {config_format}") | ||
# Formats are configured individually per-stream so a unique dialect should be registered for each stream. | ||
# We don't unregister the dialect because we are lazily parsing each csv file to generate records | ||
# This will potentially be a problem if we ever process multiple streams concurrently | ||
dialect_name = config.name + DIALECT_NAME | ||
csv.register_dialect( | ||
dialect_name, | ||
delimiter=config_format.delimiter, | ||
quotechar=config_format.quote_char, | ||
escapechar=config_format.escape_char, | ||
doublequote=config_format.double_quote, | ||
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL), | ||
) | ||
with stream_reader.open_file(file, self.file_read_mode, logger) as fp: | ||
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual | ||
# sources will likely require one. Rather than modify the interface now we can wait until the real use case | ||
self._skip_rows_before_header(fp, config_format.skip_rows_before_header) | ||
field_names = self._auto_generate_headers(fp, config_format) if config_format.autogenerate_column_names else None | ||
reader = csv.DictReader(fp, dialect=dialect_name, fieldnames=field_names) # type: ignore | ||
yield from self._read_and_cast_types(reader, schema, config_format, logger) | ||
|
||
@property | ||
def file_read_mode(self) -> FileReadMode: | ||
return FileReadMode.READ | ||
|
||
@staticmethod | ||
def _read_and_cast_types( | ||
reader: csv.DictReader, schema: Optional[Mapping[str, Any]], logger: logging.Logger # type: ignore | ||
reader: csv.DictReader, schema: Optional[Mapping[str, Any]], config_format: CsvFormat, logger: logging.Logger # type: ignore | ||
) -> Iterable[Dict[str, Any]]: | ||
""" | ||
If the user provided a schema, attempt to cast the record values to the associated type. | ||
|
@@ -107,16 +101,65 @@ def _read_and_cast_types( | |
cast it to a string. Downstream, the user's validation policy will determine whether the | ||
record should be emitted. | ||
""" | ||
if not schema: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed the branching to ensure we always go through the same row skipping and validation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the skip row and validation is done after the if/else. The if/else only controls how the fields are casted (either using cast_type, or return as-is) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extracted to a separate function to make this clearer |
||
yield from reader | ||
cast_fn = CsvParser._get_cast_function(schema, config_format, logger) | ||
for i, row in enumerate(reader): | ||
if i < config_format.skip_rows_after_header: | ||
continue | ||
# The row was not properly parsed if any of the values are None | ||
if any(val is None for val in row.values()): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have a test for this? It seems odd and I feel it would be worth documenting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is implicitly tested by |
||
raise RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD) | ||
else: | ||
yield CsvParser._to_nullable(cast_fn(row), config_format.null_values) | ||
|
||
else: | ||
@staticmethod | ||
def _get_cast_function( | ||
schema: Optional[Mapping[str, Any]], config_format: CsvFormat, logger: logging.Logger | ||
) -> Callable[[Mapping[str, str]], Mapping[str, str]]: | ||
# Only cast values if the schema is provided | ||
if schema: | ||
property_types = {col: prop["type"] for col, prop in schema["properties"].items()} | ||
for row in reader: | ||
yield cast_types(row, property_types, logger) | ||
return partial(_cast_types, property_types=property_types, config_format=config_format, logger=logger) | ||
else: | ||
# If no schema is provided, yield the rows as they are | ||
return _no_cast | ||
|
||
@staticmethod | ||
def _to_nullable(row: Mapping[str, str], null_values: Set[str]) -> Dict[str, Optional[str]]: | ||
nullable = row | {k: None if v in null_values else v for k, v in row.items()} | ||
return nullable | ||
|
||
@staticmethod | ||
def _skip_rows_before_header(fp: IOBase, rows_to_skip: int) -> None: | ||
""" | ||
Skip rows before the header. This has to be done on the file object itself, not the reader | ||
""" | ||
for _ in range(rows_to_skip): | ||
fp.readline() | ||
|
||
def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) -> List[str]: | ||
# Note that this method assumes the dialect has already been registered if we're parsing the headers | ||
if config_format.autogenerate_column_names: | ||
return self._auto_generate_headers(fp, config_format) | ||
else: | ||
# If we're not autogenerating column names, we need to skip the rows before the header | ||
self._skip_rows_before_header(fp, config_format.skip_rows_before_header) | ||
# Then read the header | ||
reader = csv.DictReader(fp, dialect=dialect_name) # type: ignore | ||
return next(reader) # type: ignore | ||
|
||
def _auto_generate_headers(self, fp: IOBase, config_format: CsvFormat) -> List[str]: | ||
""" | ||
Generates field names as [f0, f1, ...] in the same way as pyarrow's csv reader with autogenerate_column_names=True. | ||
See https://arrow.apache.org/docs/python/generated/pyarrow.csv.ReadOptions.html | ||
""" | ||
next_line = next(fp).strip() | ||
number_of_columns = len(next_line.split(config_format.delimiter)) # type: ignore | ||
# Reset the file pointer to the beginning of the file so that the first row is not skipped | ||
fp.seek(0) | ||
return [f"f{i}" for i in range(number_of_columns)] | ||
|
||
def cast_types(row: Dict[str, str], property_types: Dict[str, Any], logger: logging.Logger) -> Dict[str, Any]: | ||
|
||
def _cast_types(row: Dict[str, str], property_types: Dict[str, Any], config_format: CsvFormat, logger: logging.Logger) -> Dict[str, Any]: | ||
""" | ||
Casts the values in the input 'row' dictionary according to the types defined in the JSON schema. | ||
|
@@ -142,7 +185,7 @@ def cast_types(row: Dict[str, str], property_types: Dict[str, Any], logger: logg | |
|
||
elif python_type == bool: | ||
try: | ||
cast_value = strtobool(value) | ||
cast_value = _value_to_bool(value, config_format.true_values, config_format.false_values) | ||
except ValueError: | ||
warnings.append(_format_warning(key, value, prop_type)) | ||
|
||
|
@@ -178,5 +221,17 @@ def cast_types(row: Dict[str, str], property_types: Dict[str, Any], logger: logg | |
return result | ||
|
||
|
||
def _value_to_bool(value: str, true_values: Set[str], false_values: Set[str]) -> bool: | ||
if value in true_values: | ||
return True | ||
if value in false_values: | ||
return False | ||
raise ValueError(f"Value {value} is not a valid boolean value") | ||
|
||
|
||
def _format_warning(key: str, value: str, expected_type: Optional[Any]) -> str: | ||
return f"{key}: value={value},expected_type={expected_type}" | ||
|
||
|
||
def _no_cast(row: Mapping[str, str]) -> Mapping[str, str]: | ||
return row |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
import pytest as pytest | ||
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"skip_rows_before_header, autogenerate_column_names, expected_error", | ||
[ | ||
pytest.param(1, True, ValueError, id="test_skip_rows_before_header_and_autogenerate_column_names"), | ||
pytest.param(1, False, None, id="test_skip_rows_before_header_and_no_autogenerate_column_names"), | ||
pytest.param(0, True, None, id="test_no_skip_rows_before_header_and_autogenerate_column_names"), | ||
pytest.param(0, False, None, id="test_no_skip_rows_before_header_and_no_autogenerate_column_names"), | ||
] | ||
) | ||
def test_csv_format(skip_rows_before_header, autogenerate_column_names, expected_error): | ||
if expected_error: | ||
with pytest.raises(expected_error): | ||
CsvFormat(skip_rows_before_header=skip_rows_before_header, autogenerate_column_names=autogenerate_column_names) | ||
else: | ||
CsvFormat(skip_rows_before_header=skip_rows_before_header, autogenerate_column_names=autogenerate_column_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip the rows before the header