Skip to content

Commit

Permalink
#1467 source S3: validate CSV read options and convert options
Browse files Browse the repository at this point in the history
  • Loading branch information
davydov-d committed Feb 8, 2023
1 parent aace7ae commit e729abc
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 55 deletions.
2 changes: 1 addition & 1 deletion airbyte-integrations/connectors/source-s3/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ COPY source_s3 ./source_s3
ENV AIRBYTE_ENTRYPOINT "python /airbyte/integration_code/main.py"
ENTRYPOINT ["python", "/airbyte/integration_code/main.py"]

LABEL io.airbyte.version=0.1.30
LABEL io.airbyte.version=0.1.31
LABEL io.airbyte.name=airbyte/source-s3
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import pyarrow
import pyarrow as pa
import six # type: ignore[import]
from airbyte_cdk.models import FailureType
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from pyarrow import csv as pa_csv
from source_s3.exceptions import S3Exception
from source_s3.source_files_abstract.file_info import FileInfo
Expand Down Expand Up @@ -51,22 +53,57 @@ def format(self) -> CsvFormat:
self.format_model = CsvFormat.parse_obj(self._format)
return self.format_model

def _validate_field_len(self, config: Mapping[str, Any], field_name: str):
if len(config.get("format", {}).get(field_name)) != 1:
raise ValueError(f"{field_name} should contain 1 character only")
@staticmethod
def _validate_field(
format_: Mapping[str, Any], field_name: str, allow_empty: bool = False, disallow_values: Optional[Tuple[Any, ...]] = None
) -> Optional[str]:
disallow_values = disallow_values or ()
field_value = format_.get(field_name)
if not field_value and allow_empty:
return
if len(format_.get(field_name)) != 1:
return f"{field_name} should contain 1 character only"
if field_value in disallow_values:
return f"{field_name} can not be {field_value}"

@staticmethod
def _validate_read_options(format_: Mapping[str, Any]) -> Optional[str]:
options = format_.get("advanced_options", "{}")
try:
options = json.loads(options)
pa.csv.ReadOptions(**options)
except json.decoder.JSONDecodeError:
return "Malformed advanced read options!"
except TypeError as e:
return f"One or more read options are invalid: {str(e)}"

@staticmethod
def _validate_convert_options(format_: Mapping[str, Any]) -> Optional[str]:
options = format_.get("additional_reader_options", "{}")
try:
options = json.loads(options)
pa.csv.ConvertOptions(**options)
except json.decoder.JSONDecodeError:
return "Malformed advanced read options!"
except TypeError as e:
return f"One or more read options are invalid: {str(e)}"

def _validate_config(self, config: Mapping[str, Any]):
if config.get("format", {}).get("filetype") == "csv":
self._validate_field_len(config, "delimiter")
if config.get("format", {}).get("delimiter") in ("\r", "\n"):
raise ValueError("Delimiter cannot be \r or \n")

self._validate_field_len(config, "quote_char")

if config.get("format", {}).get("escape_char"):
self._validate_field_len(config, "escape_char")

codecs.lookup(config.get("format", {}).get("encoding"))
format_ = config.get("format", {})
for error_message in (
self._validate_field(format_, "delimiter", disallow_values=("\r", "\n")),
self._validate_field(format_, "quote_char"),
self._validate_field(format_, "escape_char", allow_empty=True),
self._validate_read_options(format_),
self._validate_convert_options(format_),
):
if error_message:
raise AirbyteTracedException(error_message, error_message, failure_type=FailureType.config_error)

try:
codecs.lookup(format_.get("encoding"))
except LookupError:
raise AirbyteTracedException(error_message, error_message, failure_type=FailureType.config_error)

def _read_options(self) -> Mapping[str, str]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import ConnectorSpecification
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from source_s3 import SourceS3
from source_s3.source_files_abstract.spec import SourceFilesAbstractSpec

Expand Down Expand Up @@ -39,23 +40,27 @@ def test_check_connection_exception(config):


@pytest.mark.parametrize(
"delimiter, quote_char, escape_char, encoding, error_type",
"delimiter, quote_char, escape_char, encoding, read_options, convert_options",
[
("string", "'", None, "utf8", ValueError),
("\n", "'", None, "utf8", ValueError),
(",", ";,", None, "utf8", ValueError),
(",", "'", "escape", "utf8", ValueError),
(",", "'", None, "utf888", LookupError)
("string", "'", None, "utf8", "{}", "{}"),
("\n", "'", None, "utf8", "{}", "{}"),
(",", ";,", None, "utf8", "{}", "{}"),
(",", "'", "escape", "utf8", "{}", "{}"),
(",", "'", None, "utf888", "{}", "{}"),
(",", "'", None, "utf8", "{'compression': true}", "{}"),
(",", "'", None, "utf8", "{}", "{'compression: true}"),
],
ids=[
"long_delimiter",
"forbidden_delimiter_symbol",
"long_quote_char",
"long_escape_char",
"unknown_encoding"
"unknown_encoding",
"invalid read options",
"invalid convert options"
],
)
def test_check_connection_csv_validation_exception(delimiter, quote_char, escape_char, encoding, error_type):
def test_check_connection_csv_validation_exception(delimiter, quote_char, escape_char, encoding, read_options, convert_options):
config = {
"dataset": "test",
"provider": {
Expand All @@ -73,13 +78,15 @@ def test_check_connection_csv_validation_exception(delimiter, quote_char, escape
"quote_char": quote_char,
"escape_char": escape_char,
"encoding": encoding,
"advanced_options": read_options,
"additional_reader_options": convert_options
}
}
ok, error_msg = SourceS3().check_connection(logger, config=config)

assert not ok
assert error_msg
assert isinstance(error_msg, error_type)
assert isinstance(error_msg, AirbyteTracedException)


def test_check_connection(config):
Expand Down
Loading

0 comments on commit e729abc

Please sign in to comment.