From 8ef287286d6671f25ba95c2f97fa0b4974e9f860 Mon Sep 17 00:00:00 2001 From: Catherine Noll Date: Wed, 11 Jan 2023 10:22:11 -0500 Subject: [PATCH] [Connector-builder server] Allow client to specify record limit and enforce max of 1000 (#20575) --- .../connector_builder/entrypoint.py | 8 +- .../models/stream_read_request_body.py | 12 + .../connector_builder/impl/adapter.py | 34 +++ .../connector_builder/impl/default_api.py | 30 +- .../impl/low_code_cdk_adapter.py | 20 +- airbyte-connector-builder-server/setup.py | 2 +- .../src/main/openapi/openapi.yaml | 5 + .../impl/test_default_api.py | 288 +++++++++++++----- 8 files changed, 297 insertions(+), 102 deletions(-) create mode 100644 airbyte-connector-builder-server/connector_builder/impl/adapter.py diff --git a/airbyte-connector-builder-server/connector_builder/entrypoint.py b/airbyte-connector-builder-server/connector_builder/entrypoint.py index 642a742b534b7..c277501551785 100644 --- a/airbyte-connector-builder-server/connector_builder/entrypoint.py +++ b/airbyte-connector-builder-server/connector_builder/entrypoint.py @@ -2,11 +2,11 @@ # Copyright (c) 2022 Airbyte, Inc., all rights reserved. # -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware - from connector_builder.generated.apis.default_api_interface import initialize_router from connector_builder.impl.default_api import DefaultApiImpl +from connector_builder.impl.low_code_cdk_adapter import LowCodeSourceAdapter +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware app = FastAPI( title="Connector Builder Server API", @@ -22,4 +22,4 @@ allow_headers=["*"], ) -app.include_router(initialize_router(DefaultApiImpl())) +app.include_router(initialize_router(DefaultApiImpl(LowCodeSourceAdapter))) diff --git a/airbyte-connector-builder-server/connector_builder/generated/models/stream_read_request_body.py b/airbyte-connector-builder-server/connector_builder/generated/models/stream_read_request_body.py index e57be491017d8..78bddf225b007 100644 --- a/airbyte-connector-builder-server/connector_builder/generated/models/stream_read_request_body.py +++ b/airbyte-connector-builder-server/connector_builder/generated/models/stream_read_request_body.py @@ -20,11 +20,23 @@ class StreamReadRequestBody(BaseModel): stream: The stream of this StreamReadRequestBody. config: The config of this StreamReadRequestBody. state: The state of this StreamReadRequestBody [Optional]. + record_limit: The record_limit of this StreamReadRequestBody [Optional]. """ manifest: Dict[str, Any] stream: str config: Dict[str, Any] state: Optional[Dict[str, Any]] = None + record_limit: Optional[int] = None + + @validator("record_limit") + def record_limit_max(cls, value): + assert value <= 1000 + return value + + @validator("record_limit") + def record_limit_min(cls, value): + assert value >= 1 + return value StreamReadRequestBody.update_forward_refs() diff --git a/airbyte-connector-builder-server/connector_builder/impl/adapter.py b/airbyte-connector-builder-server/connector_builder/impl/adapter.py new file mode 100644 index 0000000000000..840e0996fbaaf --- /dev/null +++ b/airbyte-connector-builder-server/connector_builder/impl/adapter.py @@ -0,0 +1,34 @@ +# +# Copyright (c) 2022 Airbyte, Inc., all rights reserved. +# + +from abc import ABC, abstractmethod +from typing import Any, Dict, Iterator, List + +from airbyte_cdk.models import AirbyteMessage +from airbyte_cdk.sources.streams.http import HttpStream + + +class CdkAdapter(ABC): + """ + Abstract base class for the connector builder's CDK adapter. + """ + + @abstractmethod + def get_http_streams(self, config: Dict[str, Any]) -> List[HttpStream]: + """ + Gets a list of HTTP streams. + + :param config: The user-provided configuration as specified by the source's spec. + :return: A list of `HttpStream`s. + """ + + @abstractmethod + def read_stream(self, stream: str, config: Dict[str, Any]) -> Iterator[AirbyteMessage]: + """ + Reads data from the specified stream. + + :param stream: stream + :param config: The user-provided configuration as specified by the source's spec. + :return: An iterator over `AirbyteMessage` objects. + """ diff --git a/airbyte-connector-builder-server/connector_builder/impl/default_api.py b/airbyte-connector-builder-server/connector_builder/impl/default_api.py index e80a4cd02da96..0dc86ecb51384 100644 --- a/airbyte-connector-builder-server/connector_builder/impl/default_api.py +++ b/airbyte-connector-builder-server/connector_builder/impl/default_api.py @@ -6,7 +6,7 @@ import logging import traceback from json import JSONDecodeError -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Union from urllib.parse import parse_qs, urljoin, urlparse from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Type @@ -22,7 +22,7 @@ from connector_builder.generated.models.streams_list_read import StreamsListRead from connector_builder.generated.models.streams_list_read_streams import StreamsListReadStreams from connector_builder.generated.models.streams_list_request_body import StreamsListRequestBody -from connector_builder.impl.low_code_cdk_adapter import LowCodeSourceAdapter +from connector_builder.impl.adapter import CdkAdapter from fastapi import Body, HTTPException from jsonschema import ValidationError @@ -30,6 +30,11 @@ class DefaultApiImpl(DefaultApi): logger = logging.getLogger("airbyte.connector-builder") + def __init__(self, adapter_cls: Callable[[Dict[str, Any]], CdkAdapter], max_record_limit: int = 1000): + self.adapter_cls = adapter_cls + self.max_record_limit = max_record_limit + super().__init__() + async def get_manifest_template(self) -> str: return """version: "0.1.0" definitions: @@ -107,17 +112,24 @@ async def read_stream(self, stream_read_request_body: StreamReadRequestBody = Bo Using the provided manifest and config, invokes a sync for the specified stream and returns groups of Airbyte messages that are produced during the read operation :param stream_read_request_body: Input parameters to trigger the read operation for a stream + :param limit: The maximum number of records requested by the client (must be within the range [1, self.max_record_limit]) :return: Airbyte record messages produced by the sync grouped by slice and page """ adapter = self._create_low_code_adapter(manifest=stream_read_request_body.manifest) schema_inferrer = SchemaInferrer() + if stream_read_request_body.record_limit is None: + record_limit = self.max_record_limit + else: + record_limit = min(stream_read_request_body.record_limit, self.max_record_limit) + single_slice = StreamReadSlices(pages=[]) log_messages = [] try: for message_group in self._get_message_groups( adapter.read_stream(stream_read_request_body.stream, stream_read_request_body.config), - schema_inferrer + schema_inferrer, + record_limit, ): if isinstance(message_group, AirbyteLogMessage): log_messages.append({"message": message_group.message}) @@ -132,11 +144,11 @@ async def read_stream(self, stream_read_request_body: StreamReadRequestBody = Bo return StreamRead(logs=log_messages, slices=[single_slice], inferred_schema=schema_inferrer.get_stream_schema(stream_read_request_body.stream)) - def _get_message_groups(self, messages: Iterable[AirbyteMessage], schema_inferrer: SchemaInferrer) -> Iterable[Union[StreamReadPages, AirbyteLogMessage]]: + def _get_message_groups(self, messages: Iterator[AirbyteMessage], schema_inferrer: SchemaInferrer, limit: int) -> Iterable[Union[StreamReadPages, AirbyteLogMessage]]: """ Message groups are partitioned according to when request log messages are received. Subsequent response log messages and record messages belong to the prior request log message and when we encounter another request, append the latest - message group. + message group, until records have been read. Messages received from the CDK read operation will always arrive in the following order: {type: LOG, log: {message: "request: ..."}} @@ -152,7 +164,8 @@ def _get_message_groups(self, messages: Iterable[AirbyteMessage], schema_inferre current_records = [] current_page_request: Optional[HttpRequest] = None current_page_response: Optional[HttpResponse] = None - for message in messages: + + while len(current_records) < limit and (message := next(messages, None)): if first_page and message.type == Type.LOG and message.log.message.startswith("request:"): first_page = False request = self._create_request_from_log_message(message.log) @@ -209,10 +222,9 @@ def _create_response_from_log_message(self, log_message: AirbyteLogMessage) -> O self.logger.warning(f"Failed to parse log message into response object with error: {error}") return None - @staticmethod - def _create_low_code_adapter(manifest: Dict[str, Any]) -> LowCodeSourceAdapter: + def _create_low_code_adapter(self, manifest: Dict[str, Any]) -> CdkAdapter: try: - return LowCodeSourceAdapter(manifest=manifest) + return self.adapter_cls(manifest=manifest) except ValidationError as error: # TODO: We're temporarily using FastAPI's default exception model. Ideally we should use exceptions defined in the OpenAPI spec raise HTTPException( diff --git a/airbyte-connector-builder-server/connector_builder/impl/low_code_cdk_adapter.py b/airbyte-connector-builder-server/connector_builder/impl/low_code_cdk_adapter.py index 6c3ce29fbee35..580d0d7c42e72 100644 --- a/airbyte-connector-builder-server/connector_builder/impl/low_code_cdk_adapter.py +++ b/airbyte-connector-builder-server/connector_builder/impl/low_code_cdk_adapter.py @@ -2,19 +2,17 @@ # Copyright (c) 2022 Airbyte, Inc., all rights reserved. # -from typing import Any, Dict, Iterable, List +from typing import Any, Dict, Iterator, List -from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level -from airbyte_cdk.models import ConfiguredAirbyteCatalog +from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, ConfiguredAirbyteCatalog, Level from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream -from airbyte_cdk.sources.declarative.yaml_declarative_source import \ - ManifestDeclarativeSource +from airbyte_cdk.sources.declarative.yaml_declarative_source import ManifestDeclarativeSource from airbyte_cdk.sources.streams.http import HttpStream +from connector_builder.impl.adapter import CdkAdapter -class LowCodeSourceAdapter: - +class LowCodeSourceAdapter(CdkAdapter): def __init__(self, manifest: Dict[str, Any]): # Request and response messages are only emitted for a sources that have debug turned on self._source = ManifestDeclarativeSource(manifest, debug=True) @@ -27,13 +25,15 @@ def get_http_streams(self, config: Dict[str, Any]) -> List[HttpStream]: http_streams.append(stream.retriever) else: raise TypeError( - f"A declarative stream should only have a retriever of type HttpStream, but received: {stream.retriever.__class__}") + f"A declarative stream should only have a retriever of type HttpStream, but received: {stream.retriever.__class__}" + ) else: raise TypeError( - f"A declarative source should only contain streams of type DeclarativeStream, but received: {stream.__class__}") + f"A declarative source should only contain streams of type DeclarativeStream, but received: {stream.__class__}" + ) return http_streams - def read_stream(self, stream: str, config: Dict[str, Any]) -> Iterable[AirbyteMessage]: + def read_stream(self, stream: str, config: Dict[str, Any]) -> Iterator[AirbyteMessage]: configured_catalog = ConfiguredAirbyteCatalog.parse_obj( { "streams": [ diff --git a/airbyte-connector-builder-server/setup.py b/airbyte-connector-builder-server/setup.py index 7b9e3d5eb8258..e54489923c8f7 100644 --- a/airbyte-connector-builder-server/setup.py +++ b/airbyte-connector-builder-server/setup.py @@ -41,7 +41,7 @@ }, packages=find_packages(exclude=("unit_tests", "integration_tests", "docs")), package_data={}, - install_requires=["airbyte-cdk~=0.8", "fastapi", "uvicorn"], + install_requires=["airbyte-cdk~=0.15", "fastapi", "uvicorn"], python_requires=">=3.9.11", extras_require={ "tests": [ diff --git a/airbyte-connector-builder-server/src/main/openapi/openapi.yaml b/airbyte-connector-builder-server/src/main/openapi/openapi.yaml index 7c2b663da3fdd..61a828e3fbf0e 100644 --- a/airbyte-connector-builder-server/src/main/openapi/openapi.yaml +++ b/airbyte-connector-builder-server/src/main/openapi/openapi.yaml @@ -147,6 +147,11 @@ components: type: object description: The AirbyteStateMessage object to use as the starting state for this read # $ref: "#/components/schemas/AirbyteProtocol/definitions/AirbyteStateMessage" + record_limit: + type: integer + minimum: 1 + maximum: 1000 + description: Number of records that will be returned to the client from the connector builder (max of 1000) # --- Potential addition for a later phase --- # numPages: # type: integer diff --git a/airbyte-connector-builder-server/unit_tests/connector_builder/impl/test_default_api.py b/airbyte-connector-builder-server/unit_tests/connector_builder/impl/test_default_api.py index 1a01ed57b7c23..d9bbf6a07fa17 100644 --- a/airbyte-connector-builder-server/unit_tests/connector_builder/impl/test_default_api.py +++ b/airbyte-connector-builder-server/unit_tests/connector_builder/impl/test_default_api.py @@ -4,7 +4,8 @@ import asyncio import json -from unittest.mock import MagicMock, patch +from typing import Iterator +from unittest.mock import MagicMock import pytest from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteRecordMessage, Level, Type @@ -17,7 +18,9 @@ from connector_builder.generated.models.streams_list_read_streams import StreamsListReadStreams from connector_builder.generated.models.streams_list_request_body import StreamsListRequestBody from connector_builder.impl.default_api import DefaultApiImpl +from connector_builder.impl.low_code_cdk_adapter import LowCodeSourceAdapter from fastapi import HTTPException +from pydantic.error_wrappers import ValidationError MANIFEST = { "version": "0.1.0", @@ -96,7 +99,7 @@ def test_list_streams(): StreamsListReadStreams(name="breathing-techniques", url="https://demonslayers.com/api/v1/breathing_techniques"), ] - api = DefaultApiImpl() + api = DefaultApiImpl(LowCodeSourceAdapter) streams_list_request_body = StreamsListRequestBody(manifest=MANIFEST, config=CONFIG) loop = asyncio.get_event_loop() actual_streams = loop.run_until_complete(api.list_streams(streams_list_request_body)) @@ -130,7 +133,7 @@ def test_list_streams_with_interpolated_urls(): expected_streams = StreamsListRead(streams=[StreamsListReadStreams(name="demons", url="https://upper-six.muzan.com/api/v1/demons")]) - api = DefaultApiImpl() + api = DefaultApiImpl(LowCodeSourceAdapter) streams_list_request_body = StreamsListRequestBody(manifest=manifest, config=CONFIG) loop = asyncio.get_event_loop() actual_streams = loop.run_until_complete(api.list_streams(streams_list_request_body)) @@ -164,7 +167,7 @@ def test_list_streams_with_unresolved_interpolation(): # The interpolated string {{ config['not_in_config'] }} doesn't resolve to anything so it ends up blank during interpolation expected_streams = StreamsListRead(streams=[StreamsListReadStreams(name="demons", url="https://.muzan.com/api/v1/demons")]) - api = DefaultApiImpl() + api = DefaultApiImpl(LowCodeSourceAdapter) streams_list_request_body = StreamsListRequestBody(manifest=manifest, config=CONFIG) loop = asyncio.get_event_loop() @@ -180,7 +183,7 @@ def test_read_stream(): "http_method": "GET", "body": {"custom": "field"}, } - response = {"status_code": 200, "headers": {"field": "value"}, "body": '{"name": "field"}'} + response = {"status_code": 200, "headers": {"field": "value"}, "body": '{"name": "field"}', "http_method": "GET"} expected_schema = {"$schema": "http://json-schema.org/schema#", "properties": {"name": {"type": "string"}}, "type": "object"} expected_pages = [ StreamReadPages( @@ -207,30 +210,31 @@ def test_read_stream(): ), ] - mock_source_adapter = MagicMock() - mock_source_adapter.read_stream.return_value = [ - request_log_message(request), - response_log_message(response), - record_message("hashiras", {"name": "Shinobu Kocho"}), - record_message("hashiras", {"name": "Muichiro Tokito"}), - request_log_message(request), - response_log_message(response), - record_message("hashiras", {"name": "Mitsuri Kanroji"}), - ] - - with patch.object(DefaultApiImpl, "_create_low_code_adapter", return_value=mock_source_adapter): - api = DefaultApiImpl() - - loop = asyncio.get_event_loop() - actual_response: StreamRead = loop.run_until_complete( - api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras")) + mock_source_adapter_cls = make_mock_adapter_cls( + iter( + [ + request_log_message(request), + response_log_message(response), + record_message("hashiras", {"name": "Shinobu Kocho"}), + record_message("hashiras", {"name": "Muichiro Tokito"}), + request_log_message(request), + response_log_message(response), + record_message("hashiras", {"name": "Mitsuri Kanroji"}), + ] ) + ) + + api = DefaultApiImpl(mock_source_adapter_cls) - assert actual_response.inferred_schema == expected_schema + loop = asyncio.get_event_loop() + actual_response: StreamRead = loop.run_until_complete( + api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras")) + ) + assert actual_response.inferred_schema == expected_schema - single_slice = actual_response.slices[0] - for i, actual_page in enumerate(single_slice.pages): - assert actual_page == expected_pages[i] + single_slice = actual_response.slices[0] + for i, actual_page in enumerate(single_slice.pages): + assert actual_page == expected_pages[i] def test_read_stream_with_logs(): @@ -271,31 +275,147 @@ def test_read_stream_with_logs(): {"message": "log message after the response"}, ] - mock_source_adapter = MagicMock() - mock_source_adapter.read_stream.return_value = [ - AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="log message before the request")), - request_log_message(request), - response_log_message(response), - record_message("hashiras", {"name": "Shinobu Kocho"}), - AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="log message during the page")), - record_message("hashiras", {"name": "Muichiro Tokito"}), - AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="log message after the response")), - ] + mock_source_adapter_cls = make_mock_adapter_cls( + iter( + [ + AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="log message before the request")), + request_log_message(request), + response_log_message(response), + record_message("hashiras", {"name": "Shinobu Kocho"}), + AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="log message during the page")), + record_message("hashiras", {"name": "Muichiro Tokito"}), + AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="log message after the response")), + ] + ) + ) + + api = DefaultApiImpl(mock_source_adapter_cls) + + loop = asyncio.get_event_loop() + actual_response: StreamRead = loop.run_until_complete( + api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras")) + ) + + single_slice = actual_response.slices[0] + for i, actual_page in enumerate(single_slice.pages): + assert actual_page == expected_pages[i] + + for i, actual_log in enumerate(actual_response.logs): + assert actual_log == expected_logs[i] + + +@pytest.mark.parametrize( + "request_record_limit, max_record_limit", + [ + pytest.param(1, 3, id="test_create_request_with_record_limit"), + pytest.param(3, 1, id="test_create_request_record_limit_exceeds_max"), + ], +) +def test_read_stream_record_limit(request_record_limit, max_record_limit): + request = { + "url": "https://demonslayers.com/api/v1/hashiras?era=taisho", + "headers": {"Content-Type": "application/json"}, + "body": {"custom": "field"}, + } + response = {"status_code": 200, "headers": {"field": "value"}, "body": '{"name": "field"}'} + mock_source_adapter_cls = make_mock_adapter_cls( + iter( + [ + request_log_message(request), + response_log_message(response), + record_message("hashiras", {"name": "Shinobu Kocho"}), + record_message("hashiras", {"name": "Muichiro Tokito"}), + request_log_message(request), + response_log_message(response), + record_message("hashiras", {"name": "Mitsuri Kanroji"}), + response_log_message(response), + ] + ) + ) + n_records = 2 + record_limit = min(request_record_limit, max_record_limit) + + api = DefaultApiImpl(mock_source_adapter_cls, max_record_limit=max_record_limit) + loop = asyncio.get_event_loop() + actual_response: StreamRead = loop.run_until_complete( + api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras", record_limit=request_record_limit)) + ) + single_slice = actual_response.slices[0] + total_records = 0 + for i, actual_page in enumerate(single_slice.pages): + total_records += len(actual_page.records) + assert total_records == min([record_limit, n_records]) - with patch.object(DefaultApiImpl, "_create_low_code_adapter", return_value=mock_source_adapter): - api = DefaultApiImpl() - loop = asyncio.get_event_loop() - actual_response: StreamRead = loop.run_until_complete( - api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras")) +@pytest.mark.parametrize( + "max_record_limit", + [ + pytest.param(2, id="test_create_request_no_record_limit"), + pytest.param(1, id="test_create_request_no_record_limit_n_records_exceed_max"), + ], +) +def test_read_stream_default_record_limit(max_record_limit): + request = { + "url": "https://demonslayers.com/api/v1/hashiras?era=taisho", + "headers": {"Content-Type": "application/json"}, + "body": {"custom": "field"}, + } + response = {"status_code": 200, "headers": {"field": "value"}, "body": '{"name": "field"}'} + mock_source_adapter_cls = make_mock_adapter_cls( + iter( + [ + request_log_message(request), + response_log_message(response), + record_message("hashiras", {"name": "Shinobu Kocho"}), + record_message("hashiras", {"name": "Muichiro Tokito"}), + request_log_message(request), + response_log_message(response), + record_message("hashiras", {"name": "Mitsuri Kanroji"}), + response_log_message(response), + ] ) + ) + n_records = 2 - single_slice = actual_response.slices[0] - for i, actual_page in enumerate(single_slice.pages): - assert actual_page == expected_pages[i] + api = DefaultApiImpl(mock_source_adapter_cls, max_record_limit=max_record_limit) + loop = asyncio.get_event_loop() + actual_response: StreamRead = loop.run_until_complete( + api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras")) + ) + single_slice = actual_response.slices[0] + total_records = 0 + for i, actual_page in enumerate(single_slice.pages): + total_records += len(actual_page.records) + assert total_records == min([max_record_limit, n_records]) + + +def test_read_stream_limit_0(): + request = { + "url": "https://demonslayers.com/api/v1/hashiras?era=taisho", + "headers": {"Content-Type": "application/json"}, + "body": {"custom": "field"}, + } + response = {"status_code": 200, "headers": {"field": "value"}, "body": '{"name": "field"}'} + mock_source_adapter_cls = make_mock_adapter_cls( + iter( + [ + request_log_message(request), + response_log_message(response), + record_message("hashiras", {"name": "Shinobu Kocho"}), + record_message("hashiras", {"name": "Muichiro Tokito"}), + request_log_message(request), + response_log_message(response), + record_message("hashiras", {"name": "Mitsuri Kanroji"}), + response_log_message(response), + ] + ) + ) + api = DefaultApiImpl(mock_source_adapter_cls) + loop = asyncio.get_event_loop() - for i, actual_log in enumerate(actual_response.logs): - assert actual_log == expected_logs[i] + with pytest.raises(ValidationError): + loop.run_until_complete(api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras", record_limit=0))) + loop.run_until_complete(api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras"))) def test_read_stream_no_records(): @@ -331,25 +451,27 @@ def test_read_stream_no_records(): ), ] - mock_source_adapter = MagicMock() - mock_source_adapter.read_stream.return_value = [ - request_log_message(request), - response_log_message(response), - request_log_message(request), - response_log_message(response), - ] + mock_source_adapter_cls = make_mock_adapter_cls( + iter( + [ + request_log_message(request), + response_log_message(response), + request_log_message(request), + response_log_message(response), + ] + ) + ) - with patch.object(DefaultApiImpl, "_create_low_code_adapter", return_value=mock_source_adapter): - api = DefaultApiImpl() + api = DefaultApiImpl(mock_source_adapter_cls) - loop = asyncio.get_event_loop() - actual_response: StreamRead = loop.run_until_complete( - api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras")) - ) + loop = asyncio.get_event_loop() + actual_response: StreamRead = loop.run_until_complete( + api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras")) + ) - single_slice = actual_response.slices[0] - for i, actual_page in enumerate(single_slice.pages): - assert actual_page == expected_pages[i] + single_slice = actual_response.slices[0] + for i, actual_page in enumerate(single_slice.pages): + assert actual_page == expected_pages[i] def test_invalid_manifest(): @@ -377,7 +499,7 @@ def test_invalid_manifest(): expected_status_code = 400 - api = DefaultApiImpl() + api = DefaultApiImpl(LowCodeSourceAdapter) loop = asyncio.get_event_loop() with pytest.raises(HTTPException) as actual_exception: loop.run_until_complete(api.read_stream(StreamReadRequestBody(manifest=invalid_manifest, config={}, stream="hashiras"))) @@ -388,27 +510,29 @@ def test_invalid_manifest(): def test_read_stream_invalid_group_format(): response = {"status_code": 200, "headers": {"field": "value"}, "body": '{"name": "field"}'} - mock_source_adapter = MagicMock() - mock_source_adapter.read_stream.return_value = [ - response_log_message(response), - record_message("hashiras", {"name": "Shinobu Kocho"}), - record_message("hashiras", {"name": "Muichiro Tokito"}), - ] + mock_source_adapter_cls = make_mock_adapter_cls( + iter( + [ + response_log_message(response), + record_message("hashiras", {"name": "Shinobu Kocho"}), + record_message("hashiras", {"name": "Muichiro Tokito"}), + ] + ) + ) - with patch.object(DefaultApiImpl, "_create_low_code_adapter", return_value=mock_source_adapter): - api = DefaultApiImpl() + api = DefaultApiImpl(mock_source_adapter_cls) - loop = asyncio.get_event_loop() - with pytest.raises(HTTPException) as actual_exception: - loop.run_until_complete(api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras"))) + loop = asyncio.get_event_loop() + with pytest.raises(HTTPException) as actual_exception: + loop.run_until_complete(api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config=CONFIG, stream="hashiras"))) - assert actual_exception.value.status_code == 400 + assert actual_exception.value.status_code == 400 def test_read_stream_returns_error_if_stream_does_not_exist(): expected_status_code = 400 - api = DefaultApiImpl() + api = DefaultApiImpl(LowCodeSourceAdapter) loop = asyncio.get_event_loop() with pytest.raises(HTTPException) as actual_exception: loop.run_until_complete(api.read_stream(StreamReadRequestBody(manifest=MANIFEST, config={}, stream="not_in_manifest"))) @@ -458,7 +582,7 @@ def test_read_stream_returns_error_if_stream_does_not_exist(): ) def test_create_request_from_log_message(log_message, expected_request): airbyte_log_message = AirbyteLogMessage(level=Level.INFO, message=log_message) - api = DefaultApiImpl() + api = DefaultApiImpl(LowCodeSourceAdapter) actual_request = api._create_request_from_log_message(airbyte_log_message) assert actual_request == expected_request @@ -493,7 +617,15 @@ def test_create_response_from_log_message(log_message, expected_response): response_message = f"response:{json.dumps(log_message)}" airbyte_log_message = AirbyteLogMessage(level=Level.INFO, message=response_message) - api = DefaultApiImpl() + api = DefaultApiImpl(LowCodeSourceAdapter) actual_response = api._create_response_from_log_message(airbyte_log_message) assert actual_response == expected_response + + +def make_mock_adapter_cls(return_value: Iterator) -> MagicMock: + mock_source_adapter_cls = MagicMock() + mock_source_adapter = MagicMock() + mock_source_adapter.read_stream.return_value = return_value + mock_source_adapter_cls.return_value = mock_source_adapter + return mock_source_adapter_cls