Skip to content

Commit

Permalink
File-based CDK + Source S3 (v4): Pass configured file encoding to str…
Browse files Browse the repository at this point in the history
…eam reader (airbytehq#29110)

* Add encoding to open_file interface

* pass the encoding set in the config

* cleanup

* cleanup

* Automated Commit - Formatting Changes

* Add missing test

* Automated Commit - Formatting Changes

* Update infer_schema too

* Automated Commit - Formatting Changes

* Update unit test

* add a unit test

* fix

* format

* format

* remove newline

* use a mock

* fix

* format

---------

Co-authored-by: girarda <girarda@users.noreply.github.com>
  • Loading branch information
2 people authored and harrytou committed Sep 1, 2023
1 parent 0057193 commit 0a6dda0
Show file tree
Hide file tree
Showing 13 changed files with 213 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def validate_quote_char(cls, v: str) -> str:

@validator("escape_char")
def validate_escape_char(cls, v: str) -> str:
if len(v) != 1:
if v is not None and len(v) != 1:
raise ValueError("escape_char should only be one character")
return v

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def config(self, value: AbstractFileBasedSpec) -> None:
...

@abstractmethod
def open_file(self, file: RemoteFile, mode: FileReadMode, logger: logging.Logger) -> IOBase:
def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase:
"""
Return a file handle for reading.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@


class AvroParser(FileTypeParser):
ENCODING = None

async def infer_schema(
self,
config: FileBasedStreamConfig,
Expand All @@ -50,7 +52,7 @@ async def infer_schema(
if not isinstance(avro_format, AvroFormat):
raise ValueError(f"Expected ParquetFormat, got {avro_format}")

with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
avro_reader = fastavro.reader(fp)
avro_schema = avro_reader.writer_schema
if not avro_schema["type"] == "record":
Expand Down Expand Up @@ -135,7 +137,7 @@ def parse_records(
if not isinstance(avro_format, AvroFormat):
raise ValueError(f"Expected ParquetFormat, got {avro_format}")

with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
avro_reader = fastavro.reader(fp)
schema = avro_reader.writer_schema
schema_field_name_to_type = {field["name"]: field["type"] for field in schema["fields"]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def infer_schema(
doublequote=config_format.double_quote,
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL),
)
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, config_format.encoding, logger) as fp:
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
headers = self._get_headers(fp, config_format, dialect_name)
Expand Down Expand Up @@ -78,7 +78,7 @@ def parse_records(
doublequote=config_format.double_quote,
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL),
)
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, config_format.encoding, logger) as fp:
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
self._skip_rows_before_header(fp, config_format.skip_rows_before_header)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class JsonlParser(FileTypeParser):

MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE = 1_000_000
ENCODING = "utf8"

async def infer_schema(
self,
Expand All @@ -31,7 +32,7 @@ async def infer_schema(
inferred_schema: Dict[str, Any] = {}
read_bytes = 0

with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
for line in fp:
if read_bytes < self.MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE:
line_schema = self.infer_schema_for_record(json.loads(line))
Expand All @@ -53,7 +54,7 @@ def parse_records(
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> Iterable[Dict[str, Any]]:
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
for line in fp:
yield json.loads(line)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
import pyarrow as pa
import pyarrow.parquet as pq
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, ParquetFormat
from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from pyarrow import Scalar


class ParquetParser(FileTypeParser):

ENCODING = None

async def infer_schema(
self,
config: FileBasedStreamConfig,
Expand All @@ -29,7 +33,7 @@ async def infer_schema(
if not isinstance(parquet_format, ParquetFormat):
raise ValueError(f"Expected ParquetFormat, got {parquet_format}")

with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
parquet_file = pq.ParquetFile(fp)
parquet_schema = parquet_file.schema_arrow

Expand All @@ -50,8 +54,9 @@ def parse_records(
) -> Iterable[Dict[str, Any]]:
parquet_format = config.format or ParquetFormat()
if not isinstance(parquet_format, ParquetFormat):
raise ValueError(f"Expected ParquetFormat, got {parquet_format}") # FIXME test this branch!
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
logger.info(f"Expected ParquetFormat, got {parquet_format}")
raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR)
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
reader = pq.ParquetFile(fp)
partition_columns = {x.split("=")[0]: x.split("=")[1] for x in self._extract_partitions(file.uri)}
for row_group in range(reader.num_row_groups):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import asyncio
import io
import logging
from datetime import datetime
from unittest import mock
from unittest.mock import MagicMock, Mock

import pytest
from airbyte_cdk.sources.file_based.config.csv_format import DEFAULT_FALSE_VALUES, DEFAULT_TRUE_VALUES, CsvFormat
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.exceptions import RecordParseError
from airbyte_cdk.sources.file_based.file_based_stream_reader import FileReadMode
from airbyte_cdk.sources.file_based.file_types.csv_parser import CsvParser, _cast_types
from airbyte_cdk.sources.file_based.remote_file import RemoteFile

PROPERTY_TYPES = {
"col1": "null",
Expand Down Expand Up @@ -96,3 +103,38 @@ def test_read_and_cast_types(reader_values, expected_rows):
list(parser._read_and_cast_types(reader, schema, config_format, logger))
else:
assert expected_rows == list(parser._read_and_cast_types(reader, schema, config_format, logger))


def test_encoding_is_passed_to_stream_reader():
parser = CsvParser()
encoding = "ascii"
stream_reader = Mock()
mock_obj = stream_reader.open_file.return_value
mock_obj.__enter__ = Mock(return_value=io.StringIO("c1,c2\nv1,v2"))
mock_obj.__exit__ = Mock(return_value=None)
file = RemoteFile(uri="s3://bucket/key.csv", last_modified=datetime.now())
config = FileBasedStreamConfig(
name="test",
validation_policy="Emit Record",
file_type="csv",
format=CsvFormat(encoding=encoding)
)
list(parser.parse_records(config, file, stream_reader, logger))
stream_reader.open_file.assert_has_calls([
mock.call(file, FileReadMode.READ, encoding, logger),
mock.call().__enter__(),
mock.call().__exit__(None, None, None),
])

mock_obj.__enter__ = Mock(return_value=io.StringIO("c1,c2\nv1,v2"))
loop = asyncio.get_event_loop()
loop.run_until_complete(parser.infer_schema(config, file, stream_reader, logger))
stream_reader.open_file.assert_called_with(file, FileReadMode.READ, encoding, logger)
stream_reader.open_file.assert_has_calls([
mock.call(file, FileReadMode.READ, encoding, logger),
mock.call().__enter__(),
mock.call().__exit__(None, None, None),
mock.call(file, FileReadMode.READ, encoding, logger),
mock.call().__enter__(),
mock.call().__exit__(None, None, None),
])
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_matching_files(


class TestErrorOpenFileInMemoryFilesStreamReader(InMemoryFilesStreamReader):
def open_file(self, file: RemoteFile, file_read_mode: FileReadMode, logger: logging.Logger) -> IOBase:
def open_file(self, file: RemoteFile, file_read_mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase:
raise Exception("Error opening file")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_matching_files(
for f, data in self.files.items()
], globs)

def open_file(self, file: RemoteFile, mode: FileReadMode, logger: logging.Logger) -> IOBase:
def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase:
if self.file_type == "csv":
return self._make_csv_file_contents(file.uri)
elif self.file_type == "jsonl":
Expand Down Expand Up @@ -144,7 +144,7 @@ class TemporaryParquetFilesStreamReader(InMemoryFilesStreamReader):
A file reader that writes RemoteFiles to a temporary file and then reads them back.
"""

def open_file(self, file: RemoteFile, mode: FileReadMode, logger: logging.Logger) -> IOBase:
def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase:
return io.BytesIO(self._create_file(file.uri))

def _create_file(self, file_name: str) -> bytes:
Expand All @@ -165,7 +165,7 @@ class TemporaryAvroFilesStreamReader(InMemoryFilesStreamReader):
A file reader that writes RemoteFiles to a temporary file and then reads them back.
"""

def open_file(self, file: RemoteFile, mode: FileReadMode, logger: logging.Logger) -> IOBase:
def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase:
return io.BytesIO(self._make_file_contents(file.uri))

def _make_file_contents(self, file_name: str) -> bytes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import decimal

import pyarrow as pa
from airbyte_cdk.sources.file_based.exceptions import SchemaInferenceError
from unit_tests.sources.file_based.in_memory_files_source import TemporaryParquetFilesStreamReader
from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder

Expand Down Expand Up @@ -644,3 +645,119 @@
}
)
).build()

parquet_file_with_decimal_legacy_config_scenario = (
TestScenarioBuilder()
.set_name("parquet_file_with_decimal_legacy_config")
.set_config(
{
"streams": [
{
"name": "stream1",
"file_type": "parquet",
"format": {
"filetype": "parquet",
},
"globs": ["*"],
"validation_policy": "emit_record",
}
]
}
)
.set_stream_reader(TemporaryParquetFilesStreamReader(files=_parquet_file_with_decimal, file_type="parquet"))
.set_file_type("parquet")
.set_expected_records(
[
{"data": {"col1": 13.00, "_ab_source_file_last_modified": "2023-06-05T03:54:07.000000Z",
"_ab_source_file_url": "a.parquet"}, "stream": "stream1"},
]
)
.set_expected_catalog(
{
"streams": [
{
"default_cursor_field": ["_ab_source_file_last_modified"],
"json_schema": {
"type": "object",
"properties": {
"col1": {
"type": ["null", "number"]
},
"_ab_source_file_last_modified": {
"type": "string"
},
"_ab_source_file_url": {
"type": "string"
},
}
},
"name": "stream1",
"source_defined_cursor": True,
"supported_sync_modes": ["full_refresh", "incremental"],
}
]
}
)
).build()

parquet_with_invalid_config_scenario = (
TestScenarioBuilder()
.set_name("parquet_with_invalid_config")
.set_config(
{
"streams": [
{
"name": "stream1",
"file_type": "parquet",
"globs": ["*"],
"validation_policy": "Emit Record",
"format": {
"filetype": "csv"
}
}
]
}
)
.set_stream_reader(TemporaryParquetFilesStreamReader(files=_single_parquet_file, file_type="parquet"))
.set_file_type("parquet")
.set_expected_records(
[
]
)
.set_expected_logs({"read": [
{
"level": "ERROR",
"message": "Error parsing record"
}
]})
.set_expected_discover_error(SchemaInferenceError, "Error inferring schema from files")
.set_expected_catalog(
{
"streams": [
{
"default_cursor_field": ["_ab_source_file_last_modified"],
"json_schema": {
"type": "object",
"properties": {
"col1": {
"type": ["null", "string"]
},
"col2": {
"type": ["null", "string"]
},
"_ab_source_file_last_modified": {
"type": "string"
},
"_ab_source_file_url": {
"type": "string"
},
}
},
"name": "stream1",
"source_defined_cursor": True,
"supported_sync_modes": ["full_refresh", "incremental"],
}
]
}
)
).build()
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
parquet_file_with_decimal_as_string_scenario,
parquet_file_with_decimal_no_config_scenario,
parquet_various_types_scenario,
parquet_with_invalid_config_scenario,
single_parquet_scenario,
single_partitioned_parquet_scenario,
)
Expand Down Expand Up @@ -194,7 +195,8 @@
avro_file_with_double_as_number_scenario,
csv_newline_in_values_not_quoted_scenario,
csv_autogenerate_column_names_scenario,
single_partitioned_parquet_scenario
parquet_with_invalid_config_scenario,
single_partitioned_parquet_scenario,
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,15 @@ def get_matching_files(self, globs: List[str], logger: logging.Logger) -> Iterab
) from exc

@contextmanager
def open_file(self, file: RemoteFile, mode: FileReadMode, logger: logging.Logger) -> IOBase:
def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase:
try:
params = {"client": self.s3_client}
except Exception as exc:
raise exc

logger.debug(f"try to open {file.uri}")
try:
result = smart_open.open(f"s3://{self.config.bucket}/{file.uri}", transport_params=params, mode=mode.value)
result = smart_open.open(f"s3://{self.config.bucket}/{file.uri}", transport_params=params, mode=mode.value, encoding=encoding)
except OSError:
logger.warning(
f"We don't have access to {file.uri}. The file appears to have become unreachable during sync."
Expand Down
Loading

0 comments on commit 0a6dda0

Please sign in to comment.