-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Source S3: basic structure using file-based CDK #28786
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
from .config import Config | ||
from .stream_reader import SourceS3StreamReader | ||
|
||
__all__ = ["Config", "SourceS3StreamReader"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
from typing import Optional | ||
|
||
from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec | ||
from pydantic import AnyUrl, Field, ValidationError, root_validator | ||
|
||
|
||
class Config(AbstractFileBasedSpec): | ||
config_version: str = "0.1" | ||
|
||
@classmethod | ||
def documentation_url(cls) -> AnyUrl: | ||
return AnyUrl("https://docs.airbyte.com/integrations/sources/s3", scheme="https") | ||
|
||
bucket: str = Field(title="Bucket", description="Name of the S3 bucket where the file(s) exist.", order=0) | ||
|
||
aws_access_key_id: Optional[str] = Field( | ||
title="AWS Access Key ID", | ||
default=None, | ||
description="In order to access private Buckets stored on AWS S3, this connector requires credentials with the proper " | ||
"permissions. If accessing publicly available data, this field is not necessary.", | ||
airbyte_secret=True, | ||
order=1, | ||
) | ||
|
||
aws_secret_access_key: Optional[str] = Field( | ||
title="AWS Secret Access Key", | ||
default=None, | ||
description="In order to access private Buckets stored on AWS S3, this connector requires credentials with the proper " | ||
"permissions. If accessing publicly available data, this field is not necessary.", | ||
airbyte_secret=True, | ||
order=2, | ||
) | ||
|
||
endpoint: Optional[str] = Field( | ||
"", title="Endpoint", description="Endpoint to an S3 compatible service. Leave empty to use AWS.", order=4 | ||
) | ||
|
||
@root_validator | ||
def validate_optional_args(cls, values): | ||
aws_access_key_id = values.get("aws_access_key_id") | ||
aws_secret_access_key = values.get("aws_secret_access_key") | ||
endpoint = values.get("endpoint") | ||
if aws_access_key_id or aws_secret_access_key: | ||
if not (aws_access_key_id and aws_secret_access_key): | ||
raise ValidationError( | ||
"`aws_access_key_id` and `aws_secret_access_key` are both required to authenticate with AWS.", model=Config | ||
) | ||
if endpoint: | ||
raise ValidationError( | ||
"Either `aws_access_key_id` and `aws_secret_access_key` or `endpoint` must be set, but not both.", model=Config | ||
) | ||
return values |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
import logging | ||
from contextlib import contextmanager | ||
from io import IOBase | ||
from typing import Iterable, List, Optional, Set | ||
|
||
import boto3.session | ||
import smart_open | ||
from airbyte_cdk.sources.file_based.exceptions import ErrorListingFiles, FileBasedSourceError | ||
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode | ||
from airbyte_cdk.sources.file_based.remote_file import RemoteFile | ||
from botocore.client import BaseClient | ||
from botocore.client import Config as ClientConfig | ||
from source_s3.v4.config import Config | ||
|
||
|
||
class SourceS3StreamReader(AbstractFileBasedStreamReader): | ||
def __init__(self): | ||
super().__init__() | ||
self._s3_client = None | ||
|
||
@property | ||
def config(self) -> Config: | ||
return self._config | ||
|
||
@config.setter | ||
def config(self, value: Config): | ||
""" | ||
FileBasedSource reads the config from disk and parses it, and once parsed, the source sets the config on its StreamReader. | ||
|
||
Note: FileBasedSource only requires the keys defined in the abstract config, whereas concrete implementations of StreamReader | ||
will require keys that (for example) allow it to authenticate with the 3rd party. | ||
|
||
Therefore, concrete implementations of AbstractFileBasedStreamReader's config setter should assert that `value` is of the correct | ||
config type for that type of StreamReader. | ||
""" | ||
assert isinstance(value, Config) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's add a comment here with our prior conversation to why we have to check this value for type safety. At face value without context it might seem redundant. @clnoll I feel like this is worth discussing with other reviewers in the PR because it's probably one of the more confusing aspects. Since we're trying to force a specific subclass of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, @brianjlai, added a comment. To give some context to other reviewers - the basic background here is that In this PR, I'm allowing FileBasedSource to read the config from disk and parse it (like normal), and once parsed, the source sets the config on the StreamReader. However, the type of config accepted by FileBasedSource is AbstractFileBasedSpec, and it only cares about keys that are source-agnostic, whereas StreamReader cares about keys that are specific to the 3rd party that it's reading from. For example, S3 will be looking for a So that leads us to this situation, where the S3 StreamReader's One alternative route that I went down involved reading in the config prior to the initialization of the Source, so that it could be given to the StreamReader as an argument. Since the Source requires a StreamReader, it could get the config off of that. This is somewhat undesirable because we end up reading the config twice (because AbstractSource still reads it deep within the CDK code), and it also deviates from the pattern of letting the Source validate the config, which may have error handling behavior that we want to keep. For those two reasons I'd prefer to keep the code as-is. But I'm open to other opinions. |
||
self._config = value | ||
|
||
@property | ||
def s3_client(self) -> BaseClient: | ||
if self.config is None: | ||
# We shouldn't hit this; config should always get set before attempting to | ||
# list or read files. | ||
raise ValueError("Source config is missing; cannot create the S3 client.") | ||
if self._s3_client is None: | ||
if self.config.endpoint: | ||
client_kv_args = _get_s3_compatible_client_args(self.config) | ||
self._s3_client = boto3.client("s3", **client_kv_args) | ||
else: | ||
self._s3_client = boto3.client( | ||
"s3", | ||
aws_access_key_id=self.config.aws_access_key_id, | ||
aws_secret_access_key=self.config.aws_secret_access_key, | ||
) | ||
return self._s3_client | ||
|
||
def get_matching_files(self, globs: List[str], logger: logging.Logger) -> Iterable[RemoteFile]: | ||
""" | ||
Get all files matching the specified glob patterns. | ||
""" | ||
s3 = self.s3_client | ||
prefixes = self.get_prefixes_from_globs(globs) | ||
seen = set() | ||
total_n_keys = 0 | ||
|
||
try: | ||
if prefixes: | ||
for prefix in prefixes: | ||
for remote_file in self._page(s3, globs, self.config.bucket, prefix, seen, logger): | ||
total_n_keys += 1 | ||
yield remote_file | ||
else: | ||
for remote_file in self._page(s3, globs, self.config.bucket, None, seen, logger): | ||
total_n_keys += 1 | ||
yield remote_file | ||
|
||
logger.info(f"Finished listing objects from S3. Found {total_n_keys} objects total ({len(seen)} unique objects).") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although What count are we trying to display here and would this extra log be duplicative with the ones in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops yep |
||
except Exception as exc: | ||
raise ErrorListingFiles( | ||
FileBasedSourceError.ERROR_LISTING_FILES, | ||
source="s3", | ||
bucket=self.config.bucket, | ||
globs=globs, | ||
endpoint=self.config.endpoint, | ||
) from exc | ||
|
||
@contextmanager | ||
def open_file(self, file: RemoteFile, mode: FileReadMode, logger: logging.Logger) -> IOBase: | ||
try: | ||
params = {"client": self.s3_client} | ||
except Exception as exc: | ||
raise exc | ||
|
||
logger.debug(f"try to open {file.uri}") | ||
try: | ||
result = smart_open.open(f"s3://{self.config.bucket}/{file.uri}", transport_params=params, mode=mode.value) | ||
except OSError: | ||
logger.warning( | ||
f"We don't have access to {file.uri}. The file appears to have become unreachable during sync." | ||
f"Check whether key {file.uri} exists in `{self.config.bucket}` bucket and/or has proper ACL permissions" | ||
) | ||
# see https://docs.python.org/3/library/contextlib.html#contextlib.contextmanager for why we do this | ||
try: | ||
yield result | ||
finally: | ||
result.close() | ||
|
||
@staticmethod | ||
def _is_folder(file) -> bool: | ||
return file["Key"].endswith("/") | ||
|
||
def _page( | ||
self, s3: BaseClient, globs: List[str], bucket: str, prefix: Optional[str], seen: Set[str], logger: logging.Logger | ||
) -> Iterable[RemoteFile]: | ||
""" | ||
Page through lists of S3 objects. | ||
""" | ||
total_n_keys_for_prefix = 0 | ||
kwargs = {"Bucket": bucket} | ||
while True: | ||
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix) if prefix else s3.list_objects_v2(Bucket=bucket) | ||
key_count = response.get("KeyCount") | ||
total_n_keys_for_prefix += key_count | ||
logger.info(f"Received {key_count} objects from S3 for prefix '{prefix}'.") | ||
|
||
if "Contents" in response: | ||
for file in response["Contents"]: | ||
if self._is_folder(file): | ||
continue | ||
remote_file = RemoteFile(uri=file["Key"], last_modified=file["LastModified"]) | ||
if self.file_matches_globs(remote_file, globs) and remote_file.uri not in seen: | ||
seen.add(remote_file.uri) | ||
yield remote_file | ||
else: | ||
logger.warning(f"Invalid response from S3; missing 'Contents' key. kwargs={kwargs}.") | ||
|
||
if next_token := response.get("NextContinuationToken"): | ||
kwargs["ContinuationToken"] = next_token | ||
else: | ||
logger.info(f"Finished listing objects from S3 for prefix={prefix}. Found {total_n_keys_for_prefix} objects.") | ||
break | ||
|
||
|
||
def _get_s3_compatible_client_args(config: Config) -> dict: | ||
""" | ||
Returns map of args used for creating s3 boto3 client. | ||
""" | ||
client_kv_args = { | ||
"config": ClientConfig(s3={"addressing_style": "auto"}), | ||
"endpoint_url": config.endpoint, | ||
"use_ssl": True, | ||
"verify": True, | ||
} | ||
return client_kv_args |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
|
||
import logging | ||
|
||
import pytest | ||
from pydantic import ValidationError | ||
from source_s3.v4.config import Config | ||
|
||
logger = logging.Logger("") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"kwargs,expected_error", | ||
[ | ||
pytest.param({"bucket": "test", "streams": []}, None, id="required-fields"), | ||
pytest.param({"bucket": "test", "streams": [], "aws_access_key_id": "access_key", "aws_secret_access_key": "secret_access_key"}, None, id="config-created-with-aws-info"), | ||
pytest.param({"bucket": "test", "streams": [], "endpoint": "http://test.com"}, None, id="config-created-with-endpoint"), | ||
pytest.param({"bucket": "test", "streams": [], "aws_access_key_id": "access_key", "aws_secret_access_key": "secret_access_key", "endpoint": "http://test.com"}, ValidationError, id="cannot-have-endpoint-and-aws-info"), | ||
pytest.param({"streams": []}, ValidationError, id="missing-bucket"), | ||
] | ||
) | ||
def test_config(kwargs, expected_error): | ||
if expected_error: | ||
with pytest.raises(expected_error): | ||
Config(**kwargs) | ||
else: | ||
Config(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, does mypy checks not flag this since the interface method expects
AbstractFileBasedSpec
even thoughConfig
is a subclass of itThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah mypy doesn't seem to mind about it.