Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

File-based CDK + Source S3 (v4): Pass configured file encoding to stream reader #29110

Merged
merged 24 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def validate_quote_char(cls, v: str) -> str:

@validator("escape_char")
def validate_escape_char(cls, v: str) -> str:
if len(v) != 1:
if v is not None and len(v) != 1:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

escape_char is an optional field

raise ValueError("escape_char should only be one character")
return v

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def config(self, value: AbstractFileBasedSpec) -> None:
...

@abstractmethod
def open_file(self, file: RemoteFile, mode: FileReadMode, logger: logging.Logger) -> IOBase:
def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase:
"""
Return a file handle for reading.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@


class AvroParser(FileTypeParser):
ENCODING = None

async def infer_schema(
self,
config: FileBasedStreamConfig,
Expand All @@ -50,7 +52,7 @@ async def infer_schema(
if not isinstance(avro_format, AvroFormat):
raise ValueError(f"Expected ParquetFormat, got {avro_format}")

with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
avro_reader = fastavro.reader(fp)
avro_schema = avro_reader.writer_schema
if not avro_schema["type"] == "record":
Expand Down Expand Up @@ -135,7 +137,7 @@ def parse_records(
if not isinstance(avro_format, AvroFormat):
raise ValueError(f"Expected ParquetFormat, got {avro_format}")

with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
avro_reader = fastavro.reader(fp)
schema = avro_reader.writer_schema
schema_field_name_to_type = {field["name"]: field["type"] for field in schema["fields"]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def infer_schema(
doublequote=config_format.double_quote,
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL),
)
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, config_format.encoding, logger) as fp:
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
headers = self._get_headers(fp, config_format, dialect_name)
Expand Down Expand Up @@ -78,7 +78,7 @@ def parse_records(
doublequote=config_format.double_quote,
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL),
)
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, config_format.encoding, logger) as fp:
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
self._skip_rows_before_header(fp, config_format.skip_rows_before_header)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class JsonlParser(FileTypeParser):

MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE = 1_000_000
ENCODING = "utf8"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encoding isn't configurable in legacy S3 source. We can move this to a config if needed


async def infer_schema(
self,
Expand All @@ -31,7 +32,7 @@ async def infer_schema(
inferred_schema: Dict[str, Any] = {}
read_bytes = 0

with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
for line in fp:
if read_bytes < self.MAX_BYTES_PER_FILE_FOR_SCHEMA_INFERENCE:
line_schema = self.infer_schema_for_record(json.loads(line))
Expand All @@ -53,7 +54,7 @@ def parse_records(
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> Iterable[Dict[str, Any]]:
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
for line in fp:
yield json.loads(line)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@


class ParquetParser(FileTypeParser):

ENCODING = None

async def infer_schema(
self,
config: FileBasedStreamConfig,
Expand All @@ -29,7 +32,7 @@ async def infer_schema(
if not isinstance(parquet_format, ParquetFormat):
raise ValueError(f"Expected ParquetFormat, got {parquet_format}")

with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
parquet_file = pq.ParquetFile(fp)
parquet_schema = parquet_file.schema_arrow

Expand All @@ -50,8 +53,8 @@ def parse_records(
) -> Iterable[Dict[str, Any]]:
parquet_format = config.format[config.file_type] if config.format else ParquetFormat()
if not isinstance(parquet_format, ParquetFormat):
raise ValueError(f"Expected ParquetFormat, got {parquet_format}") # FIXME test this branch!
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
raise ValueError(f"Expected ParquetFormat, got {parquet_format}")
with stream_reader.open_file(file, self.file_read_mode, self.ENCODING, logger) as fp:
reader = pq.ParquetFile(fp)
partition_columns = {x.split("=")[0]: x.split("=")[1] for x in self._extract_partitions(file.uri)}
for row_group in range(reader.num_row_groups):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import asyncio
import io
import logging
from datetime import datetime
from typing import Iterable, List, Optional
from unittest.mock import MagicMock, Mock

import pytest
from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec
from airbyte_cdk.sources.file_based.config.csv_format import DEFAULT_FALSE_VALUES, DEFAULT_TRUE_VALUES, CsvFormat
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.exceptions import RecordParseError
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
from airbyte_cdk.sources.file_based.file_types.csv_parser import CsvParser, _cast_types
from airbyte_cdk.sources.file_based.remote_file import RemoteFile

PROPERTY_TYPES = {
"col1": "null",
Expand Down Expand Up @@ -96,3 +103,40 @@ def test_read_and_cast_types(reader_values, expected_rows):
list(parser._read_and_cast_types(reader, schema, config_format, logger))
else:
assert expected_rows == list(parser._read_and_cast_types(reader, schema, config_format, logger))


class MockFileBasedStreamReader(AbstractFileBasedStreamReader):
def __init__(self, expected_encoding: Optional[str]):
self._expected_encoding = expected_encoding

def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> io.IOBase:
assert encoding == self._expected_encoding
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a great test, but the actual decoding is done outside of the CDK

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we defining a class MockFileBasedStreamReader while we could just "Mock(spec=AbstractFileBasedStreamReader)`? Having the mock would allow us to assert calls (and therefore validate arguments like we're doing with this assert

return io.StringIO("c1,c2\nv1,v2")

def get_matching_files(self, globs: List[str], logger: logging.Logger) -> Iterable[RemoteFile]:
pass

@property
def config(self) -> Optional[AbstractFileBasedSpec]:
return None

@config.setter
def config(self, value: AbstractFileBasedSpec) -> None:
pass


def test_encoding_is_passed_to_stream_reader():
parser = CsvParser()
encoding = "ascii"
stream_reader = MockFileBasedStreamReader(encoding)
file = RemoteFile(uri="s3://bucket/key.csv", last_modified=datetime.now())
config = FileBasedStreamConfig(
name="test",
validation_policy="emit_record",
file_type="csv",
format={"csv": CsvFormat(encoding=encoding)}
)
list(parser.parse_records(config, file, stream_reader, logger))

loop = asyncio.get_event_loop()
loop.run_until_complete(parser.infer_schema(config, file, stream_reader, logger))
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_matching_files(


class TestErrorOpenFileInMemoryFilesStreamReader(InMemoryFilesStreamReader):
def open_file(self, file: RemoteFile, file_read_mode: FileReadMode, logger: logging.Logger) -> IOBase:
def open_file(self, file: RemoteFile, file_read_mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase:
raise Exception("Error opening file")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_matching_files(
for f, data in self.files.items()
], globs)

def open_file(self, file: RemoteFile, mode: FileReadMode, logger: logging.Logger) -> IOBase:
def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase:
if self.file_type == "csv":
return self._make_csv_file_contents(file.uri)
elif self.file_type == "jsonl":
Expand Down Expand Up @@ -153,7 +153,7 @@ class TemporaryParquetFilesStreamReader(InMemoryFilesStreamReader):
A file reader that writes RemoteFiles to a temporary file and then reads them back.
"""

def open_file(self, file: RemoteFile, mode: FileReadMode, logger: logging.Logger) -> IOBase:
def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase:
return io.BytesIO(self._create_file(file.uri))

def _create_file(self, file_name: str) -> bytes:
Expand All @@ -174,7 +174,7 @@ class TemporaryAvroFilesStreamReader(InMemoryFilesStreamReader):
A file reader that writes RemoteFiles to a temporary file and then reads them back.
"""

def open_file(self, file: RemoteFile, mode: FileReadMode, logger: logging.Logger) -> IOBase:
def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase:
return io.BytesIO(self._make_file_contents(file.uri))

def _make_file_contents(self, file_name: str) -> bytes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import decimal

import pyarrow as pa
from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError
from unit_tests.sources.file_based.in_memory_files_source import TemporaryParquetFilesStreamReader
from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder

Expand Down Expand Up @@ -629,3 +630,34 @@
}
)
).build()

parquet_with_invalid_config_scenario = (
TestScenarioBuilder()
.set_name("parquet_with_invalid_config_scenario")
.set_config(
{
"streams": [
{
"name": "stream1",
"file_type": "parquet",
"globs": ["*"],
"validation_policy": "emit_record",
"format": {
"parquet": {
"filetype": "csv",
}
}
}
]
}
)
.set_stream_reader(TemporaryParquetFilesStreamReader(files=_single_parquet_file, file_type="parquet"))
.set_file_type("parquet")
.set_expected_read_error(ConfigValidationError, "Error creating stream config object. Contact Support if you need assistance.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is invalid because of the csv, right? Could we have a better error message? Also, could this be a unit test?

.set_expected_discover_error(ConfigValidationError, "Error creating stream config object. Contact Support if you need assistance.")
.set_expected_records(
[
# No records were read
]
)
).build()
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
parquet_file_with_decimal_legacy_config_scenario,
parquet_file_with_decimal_no_config_scenario,
parquet_various_types_scenario,
parquet_with_invalid_config_scenario,
single_parquet_scenario,
)
from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenario
Expand Down Expand Up @@ -197,6 +198,7 @@
avro_file_with_decimal_as_float_scenario,
csv_newline_in_values_not_quoted_scenario,
csv_autogenerate_column_names_scenario,
parquet_with_invalid_config_scenario
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,15 @@ def get_matching_files(self, globs: List[str], logger: logging.Logger) -> Iterab
) from exc

@contextmanager
def open_file(self, file: RemoteFile, mode: FileReadMode, logger: logging.Logger) -> IOBase:
def open_file(self, file: RemoteFile, mode: FileReadMode, encoding: Optional[str], logger: logging.Logger) -> IOBase:
try:
params = {"client": self.s3_client}
except Exception as exc:
raise exc

logger.debug(f"try to open {file.uri}")
try:
result = smart_open.open(f"s3://{self.config.bucket}/{file.uri}", transport_params=params, mode=mode.value)
result = smart_open.open(f"s3://{self.config.bucket}/{file.uri}", transport_params=params, mode=mode.value, encoding=encoding)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a little tricky to test this until we have the adapter. I tested by merging the codebase with this branch to run the v4 source and reading s3://airbyte-acceptance-test-source-s3/csv_tests/csv_encoded_as_cp1252.csv, which is encoded as cp1252

except OSError:
logger.warning(
f"We don't have access to {file.uri}. The file appears to have become unreachable during sync."
Expand Down