Skip to content

Commit

Permalink
stateless persistence used for airbyte cursor preservance (#7174)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 805a606ab7d8c1f461bbbba1a792bf30d8beb054
  • Loading branch information
zxqfd555-pw authored and Manul from Pathway committed Sep 10, 2024
1 parent ff0ad37 commit 5e58dec
Show file tree
Hide file tree
Showing 20 changed files with 569 additions and 113 deletions.
17 changes: 15 additions & 2 deletions integration_tests/s3/test_s3_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@
import pytest

import pathway as pw
from pathway.internals import api
from pathway.internals.monitoring import MonitoringLevel
from pathway.internals.parse_graph import G
from pathway.tests.utils import get_aws_s3_settings, write_lines

from .base import create_jsonlines, put_aws_object, read_jsonlines_fields


def test_s3_backfilling(tmp_path: pathlib.Path, s3_path: str):
@pytest.mark.parametrize(
"snapshot_access", [api.SnapshotAccess.FULL, api.SnapshotAccess.OFFSETS_ONLY]
)
def test_s3_backfilling(snapshot_access, tmp_path: pathlib.Path, s3_path: str):
pathway_persistent_storage = tmp_path / "PStorage"
s3_input_path = f"{s3_path}/input.csv"

Expand All @@ -33,6 +37,7 @@ def test_s3_backfilling(tmp_path: pathlib.Path, s3_path: str):
monitoring_level=MonitoringLevel.NONE,
persistence_config=pw.persistence.Config.simple_config(
pw.persistence.Backend.filesystem(pathway_persistent_storage),
snapshot_access=snapshot_access,
),
)
G.clear()
Expand All @@ -52,6 +57,7 @@ def test_s3_backfilling(tmp_path: pathlib.Path, s3_path: str):
monitoring_level=MonitoringLevel.NONE,
persistence_config=pw.persistence.Config.simple_config(
pw.persistence.Backend.filesystem(pathway_persistent_storage),
snapshot_access=snapshot_access,
),
)
G.clear()
Expand All @@ -75,6 +81,7 @@ def test_s3_backfilling(tmp_path: pathlib.Path, s3_path: str):
monitoring_level=MonitoringLevel.NONE,
persistence_config=pw.persistence.Config.simple_config(
pw.persistence.Backend.filesystem(pathway_persistent_storage),
snapshot_access=snapshot_access,
),
)

Expand All @@ -91,7 +98,12 @@ def test_s3_backfilling(tmp_path: pathlib.Path, s3_path: str):
assert result.equals(expected)


def test_s3_json_read_and_recovery(tmp_path: pathlib.Path, s3_path: str):
@pytest.mark.parametrize(
"snapshot_access", [api.SnapshotAccess.FULL, api.SnapshotAccess.OFFSETS_ONLY]
)
def test_s3_json_read_and_recovery(
snapshot_access, tmp_path: pathlib.Path, s3_path: str
):
pstorage_s3_path = f"{s3_path}/PStorage"
input_s3_path = f"{s3_path}/input"
output_path = tmp_path / "output.json"
Expand All @@ -118,6 +130,7 @@ class InputSchema(pw.Schema):
root_path=pstorage_s3_path,
bucket_settings=get_aws_s3_settings(),
),
snapshot_access=snapshot_access,
),
)

Expand Down
2 changes: 1 addition & 1 deletion integration_tests/webserver/test_rest_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class InputSchema(pw.Schema):
pw.io.csv.write(sum, output_path)
pw.io.csv.write(sum_dup, output_path)

with pytest.raises(RuntimeError, match="error while attempting to bind on address"):
with pytest.raises(OSError, match="error while attempting to bind on address"):
pw.run()


Expand Down
5 changes: 4 additions & 1 deletion python/pathway/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def tmp_path_with_airbyte_config(tmp_path):
[
"new_source",
"--image",
"airbyte/source-faker:0.1.4",
"airbyte/source-faker:6.2.10",
],
)
assert result.exit_code == 0
Expand All @@ -102,9 +102,12 @@ def tmp_path_with_airbyte_config(tmp_path):

with open(tmp_path / AIRBYTE_CONNECTION_REL_PATH, "r") as f:
config = yaml.safe_load(f)

# https://docs.airbyte.com/integrations/sources/faker#reference
config["source"]["config"]["records_per_slice"] = 500
config["source"]["config"]["records_per_sync"] = 500
config["source"]["config"]["count"] = 500
config["source"]["config"]["always_updated"] = False
with open(tmp_path / AIRBYTE_CONNECTION_REL_PATH, "w") as f:
yaml.dump(config, f)

Expand Down
10 changes: 6 additions & 4 deletions python/pathway/engine.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -785,11 +785,13 @@ class SnapshotAccess(Enum):
RECORD: SnapshotAccess
REPLAY: SnapshotAccess
FULL: SnapshotAccess
OFFSETS_ONLY: SnapshotAccess

class DataEventType(Enum):
INSERT: DataEventType
DELETE: DataEventType
UPSERT: DataEventType
class PythonConnectorEventType(Enum):
INSERT: PythonConnectorEventType
DELETE: PythonConnectorEventType
UPSERT: PythonConnectorEventType
EXTERNAL_OFFSET: PythonConnectorEventType

class SessionType(Enum):
NATIVE: SessionType
Expand Down
3 changes: 3 additions & 0 deletions python/pathway/io/airbyte/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def read(
gcp_job_name: str | None = None,
enforce_method: str | None = None,
refresh_interval_ms: int = 60000,
persistent_id: int | None = None,
):
"""
Reads a table with an Airbyte connector that supports the \
Expand Down Expand Up @@ -315,6 +316,7 @@ def read(

subject = _PathwayAirbyteSubject(
source=source,
streams=streams,
mode=mode,
refresh_interval_ms=refresh_interval_ms,
)
Expand All @@ -324,4 +326,5 @@ def read(
schema=_AirbyteRecordSchema,
autocommit_duration_ms=max(refresh_interval_ms, 1),
name="airbyte",
persistent_id=persistent_id,
)
72 changes: 64 additions & 8 deletions python/pathway/io/airbyte/logic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import time
from collections.abc import Sequence

from pathway.io._utils import STATIC_MODE_NAME
from pathway.io.python import ConnectorSubject
Expand All @@ -18,13 +19,39 @@


class _PathwayAirbyteDestination(BaseAirbyteDestination):
def __init__(self, on_event, *args, **kwargs):
def __init__(self, on_event, on_state, *args, **kwargs):
super().__init__(*args, **kwargs)
self.on_event = on_event
self.on_state = on_state
self._state = {}
self._shared_state = None

def get_state(self):
return self._state
stream_states = []
for stream_name, stream_state in self._state.items():
stream_states.append(
{
"stream_descriptor": {
"name": stream_name,
},
"stream_state": stream_state,
}
)
global_state = {
"stream_states": stream_states,
}
if self._shared_state is not None:
global_state["shared_state"] = self._shared_state
result = {
"type": "GLOBAL",
"global": global_state,
}
return result

def set_state(self, state):
self._state = {}
self._shared_state = None
self._handle_global_state(state)

def _write(self, record_type, records):
if record_type.startswith(AIRBYTE_STREAM_RECORD_PREFIX):
Expand All @@ -48,6 +75,7 @@ def _write(self, record_type, records):
logging.warning(
f"Unknown state type: {state_type}. Ignoring state: {full_state}"
)
self.on_state(self.get_state())

def _handle_stream_state(self, full_state):
stream = full_state.get("stream")
Expand Down Expand Up @@ -83,12 +111,21 @@ def _handle_legacy_state(self, full_state):
logging.warning("Legacy state doesn't contain 'data' field")

def _handle_global_state(self, full_state):
stream_states = full_state.get("stream_states")
global_ = full_state.get("global")
if global_ is None:
logging.warning("Global state doesn't contain 'global' section")
return
stream_states = global_.get("stream_states")
if stream_states is None:
logging.warning("Global state doesn't contain 'stream_states' section")
return
for stream_state in stream_states:
self._handle_stream_state_inner(stream_state)
shared_state = global_.get("shared_state")
if shared_state is not None:
self._shared_state = shared_state
else:
self._shared_state = None


class _PathwayAirbyteSubject(ConnectorSubject):
Expand All @@ -98,23 +135,32 @@ def __init__(
source: AbstractAirbyteSource,
mode: str,
refresh_interval_ms: int,
streams: Sequence[str],
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.source = source
self.mode = mode
self.refresh_interval = refresh_interval_ms / 1000.0
self.destination = _PathwayAirbyteDestination(
on_event=lambda payload: self.next_json({"data": payload}),
on_state=self.on_state,
)
self.streams = streams

def on_state(self, state):
self._report_offset(json.dumps(state).encode("utf-8"))
self._enable_commits() # A commit is done here
self._disable_commits() # Disable commits till the next state

def run(self):
destination = _PathwayAirbyteDestination(
on_event=lambda payload: self.next_json({"data": payload})
)
self._disable_commits()
n_times_failed = 0
while True:
time_before_start = time.time()
try:
messages = self.source.extract(destination.get_state())
messages = self.source.extract([self.destination.get_state()])
except Exception:
logging.exception(
"Failed to query airbyte-serverless source, retrying..."
Expand All @@ -129,7 +175,7 @@ def run(self):
continue

n_times_failed = 0
destination.load(messages)
self.destination.load(messages)

if self.mode == STATIC_MODE_NAME:
break
Expand All @@ -140,3 +186,13 @@ def run(self):

def on_stop(self):
self.source.on_stop()

def _seek(self, state):
self.destination.set_state(json.loads(state.decode("utf-8")))

def on_persisted_run(self):
if len(self.streams) != 1:
raise RuntimeError(
"Persistence in airbyte connector is supported only for the case of a single stream. "
"Please use several airbyte connectors with one stream per connector to persist the state."
)
Loading

0 comments on commit 5e58dec

Please sign in to comment.