Skip to content

Commit

Permalink
Feat: Upgrade to Pydantic 2.0 and CDK 2.0 (#291)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers authored Jul 7, 2024
1 parent 42e9c85 commit 8676125
Show file tree
Hide file tree
Showing 18 changed files with 278 additions and 145 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/python_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ jobs:
fail-fast: false

runs-on: "${{ matrix.os }}-latest"
env:
# Enforce UTF-8 encoding so Windows runners don't fail inside the connector code.
# TODO: See if we can fully enforce this within PyAirbyte itself.
PYTHONIOENCODING: utf-8
steps:
# Common steps:
- name: Checkout code
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/test-pr-command.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ jobs:
Windows,
]
fail-fast: false

runs-on: "${{ matrix.os }}-latest"
env:
# Enforce UTF-8 encoding so Windows runners don't fail inside the connector code.
# TODO: See if we can fully enforce this within PyAirbyte itself.
PYTHONIOENCODING: utf-8
steps:

# Custom steps to fetch the PR and checkout the code:
Expand Down
1 change: 1 addition & 0 deletions airbyte/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _stream_from_subprocess(args: list[str]) -> Generator[Iterable[str], None, N
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
encoding="utf-8",
)

def _stream_from_file(file: IO[str]) -> Generator[str, Any, None]:
Expand Down
2 changes: 1 addition & 1 deletion airbyte/_future_cdk/record_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _airbyte_messages_from_buffer(
buffer: io.TextIOBase,
) -> Iterator[AirbyteMessage]:
"""Yield messages from a buffer."""
yield from (AirbyteMessage.parse_raw(line) for line in buffer)
yield from (AirbyteMessage.model_validate_json(line) for line in buffer)

@final
def process_input_stream(
Expand Down
5 changes: 4 additions & 1 deletion airbyte/_future_cdk/state_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def to_state_input_file_text(self) -> str:
return (
"["
+ "\n, ".join(
[state_artifact.json() for state_artifact in (self._state_message_artifacts or [])]
[
state_artifact.model_dump_json()
for state_artifact in (self._state_message_artifacts or [])
]
)
+ "]"
)
Expand Down
2 changes: 1 addition & 1 deletion airbyte/_future_cdk/state_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ def write_state(
state_message: AirbyteStateMessage,
) -> None:
"""Save or 'write' a state artifact."""
print(state_message.json())
print(state_message.model_dump_json())
8 changes: 4 additions & 4 deletions airbyte/_util/document_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _to_title_case(name: str, /) -> str:
class CustomRenderingInstructions(BaseModel):
"""Instructions for rendering a stream's records as documents."""

title_property: Optional[str]
title_property: Optional[str] = None
content_properties: list[str]
frontmatter_properties: list[str]
metadata_properties: list[str]
Expand All @@ -37,9 +37,9 @@ class CustomRenderingInstructions(BaseModel):
class DocumentRenderer(BaseModel):
"""Instructions for rendering a stream's records as documents."""

title_property: Optional[str]
content_properties: Optional[list[str]]
metadata_properties: Optional[list[str]]
title_property: Optional[str] = None
content_properties: Optional[list[str]] = None
metadata_properties: Optional[list[str]] = None
render_metadata: bool = False

# TODO: Add primary key and cursor key support:
Expand Down
4 changes: 2 additions & 2 deletions airbyte/caches/_state_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def write_state(
source_name=self.source_name,
stream_name=stream_name,
table_name=table_prefix + stream_name,
state_json=state_message.json(),
state_json=state_message.model_dump_json(),
)
)
session.commit()
Expand Down Expand Up @@ -170,7 +170,7 @@ def get_state_provider(

return StaticInputState(
from_state_messages=[
AirbyteStateMessage.parse_raw(state.state_json) for state in states
AirbyteStateMessage.model_validate_json(state.state_json) for state in states
]
)

Expand Down
2 changes: 1 addition & 1 deletion airbyte/caches/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@
class SnowflakeCache(SnowflakeConfig, CacheBase):
"""Configuration for the Snowflake cache."""

dedupe_mode = RecordDedupeMode.APPEND
dedupe_mode: RecordDedupeMode = RecordDedupeMode.APPEND

_sql_processor_class = PrivateAttr(default=SnowflakeSqlProcessor)
52 changes: 51 additions & 1 deletion airbyte/secrets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@
import json
from abc import ABC, abstractmethod
from enum import Enum
from typing import cast
from typing import TYPE_CHECKING, Any, cast

from pydantic_core import CoreSchema, core_schema

from airbyte import exceptions as exc


if TYPE_CHECKING:
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler, ValidationInfo
from pydantic.json_schema import JsonSchemaValue


class SecretSourceEnum(str, Enum):
ENV = "env"
DOTENV = "dotenv"
Expand Down Expand Up @@ -65,6 +72,49 @@ def parse_json(self) -> dict:
},
) from None

# Pydantic compatibility

@classmethod
def validate(
cls,
v: Any, # noqa: ANN401 # Must allow `Any` to match Pydantic signature
info: ValidationInfo,
) -> SecretString:
"""Validate the input value is valid as a secret string."""
_ = info # Unused
if not isinstance(v, str):
raise exc.PyAirbyteInputError(
message="A valid `str` or `SecretString` object is required.",
)
return cls(v)

@classmethod
def __get_pydantic_core_schema__( # noqa: PLW3201 # Pydantic dunder
cls,
source_type: Any, # noqa: ANN401 # Must allow `Any` to match Pydantic signature
handler: GetCoreSchemaHandler,
) -> CoreSchema:
return core_schema.with_info_after_validator_function(
function=cls.validate, schema=handler(str), field_name=handler.field_name
)

@classmethod
def __get_pydantic_json_schema__( # noqa: PLW3201 # Pydantic dunder method
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
"""
Return a modified JSON schema for the secret string.
- `writeOnly=True` is the official way to prevent secrets from being exposed inadvertently.
- `Format=password` is a popular and readable convention to indicate the field is sensitive.
"""
_ = _core_schema, handler # Unused
return {
"type": "string",
"format": "password",
"writeOnly": True,
}


class SecretManager(ABC):
"""Abstract base class for secret managers.
Expand Down
14 changes: 7 additions & 7 deletions airbyte/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ def _discover(self) -> AirbyteCatalog:
"""Call discover on the connector.
This involves the following steps:
* Write the config to a temporary file
* execute the connector with discover --config <config_file>
* Listen to the messages and return the first AirbyteCatalog that comes along.
* Make sure the subprocess is killed when the function returns.
- Write the config to a temporary file
- execute the connector with discover --config <config_file>
- Listen to the messages and return the first AirbyteCatalog that comes along.
- Make sure the subprocess is killed when the function returns.
"""
with as_temp_files([self._config]) as [config_file]:
for msg in self._execute(["discover", "--config", config_file]):
Expand Down Expand Up @@ -298,7 +298,7 @@ def _yaml_spec(self) -> str:
for each connector.
"""
spec_obj: ConnectorSpecification = self._get_spec()
spec_dict = spec_obj.dict(exclude_unset=True)
spec_dict: dict[str, Any] = spec_obj.model_dump(exclude_unset=True)
# convert to a yaml string
return yaml.dump(spec_dict)

Expand Down Expand Up @@ -537,7 +537,7 @@ def _read_with_catalog(
with as_temp_files(
[
self._config,
catalog.json(),
catalog.model_dump_json(),
state.to_state_input_file_text() if state else "[]",
]
) as [
Expand Down Expand Up @@ -579,7 +579,7 @@ def _execute(self, args: list[str]) -> Iterator[AirbyteMessage]:
self._last_log_messages = []
for line in self.executor.execute(args):
try:
message = AirbyteMessage.parse_raw(line)
message: AirbyteMessage = AirbyteMessage.model_validate_json(json_data=line)
if message.type is Type.RECORD:
self._processed_records += 1
if message.type == Type.LOG:
Expand Down
18 changes: 18 additions & 0 deletions airbyte/sources/declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
from __future__ import annotations

import json
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, cast

import pydantic

from airbyte_cdk.entrypoint import AirbyteEntrypoint
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource

Expand All @@ -19,6 +22,18 @@
from collections.abc import Iterator


def _suppress_cdk_pydantic_deprecation_warnings() -> None:
"""Suppress deprecation warnings from Pydantic in the CDK.
CDK has deprecated uses of `json()` and `parse_obj()`, and we don't want users
to see these warnings.
"""
warnings.filterwarnings(
"ignore",
category=pydantic.warnings.PydanticDeprecatedSince20,
)


class DeclarativeExecutor(Executor):
"""An executor for declarative sources."""

Expand All @@ -32,6 +47,7 @@ def __init__(
- If `manifest` is a string, it will be parsed as an HTTP path.
- If `manifest` is a dict, it will be used as is.
"""
_suppress_cdk_pydantic_deprecation_warnings()
self._manifest_dict: dict
if isinstance(manifest, Path):
self._manifest_dict = cast(dict, json.loads(manifest.read_text()))
Expand Down Expand Up @@ -91,6 +107,8 @@ def __init__(
manifest: The manifest for the declarative source. This can be a path to a yaml file, a
yaml string, or a dict.
"""
_suppress_cdk_pydantic_deprecation_warnings()

# TODO: Conform manifest to a dict or str (TBD)
self.manifest = manifest

Expand Down
1 change: 1 addition & 0 deletions airbyte/sources/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"source-salesloft",
"source-slack",
"source-surveymonkey",
"source-tiktok-marketing",
"source-the-guardian-api",
"source-trello",
"source-typeform",
Expand Down
5 changes: 1 addition & 4 deletions examples/run_faker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
import airbyte as ab


SCALE = 50_000 # Number of records to generate between users and purchases.
SCALE = 200_000 # Number of records to generate between users and purchases.
FORCE_FULL_REFRESH = True # Whether to force a full refresh on the source.

# This is a dummy secret, just to test functionality.
DUMMY_SECRET = ab.get_secret("DUMMY_SECRET")


print("Initializing cache...")
cache = ab.get_default_cache()
Expand Down
5 changes: 3 additions & 2 deletions examples/run_pokeapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from __future__ import annotations

import airbyte as ab
from airbyte.experimental import get_source


source = ab.get_source(
source = get_source(
"source-pokeapi",
config={"pokemon_name": "bulbasaur"},
install_if_missing=True,
source_manifest=True,
)
source.check()

Expand Down
Loading

0 comments on commit 8676125

Please sign in to comment.