Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stream read given stream doesn't have any slice #28746

Merged
merged 3 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
AirbyteMessage,
AirbyteTraceMessage,
ConfiguredAirbyteCatalog,
Level,
OrchestratorType,
TraceType,
)
Expand Down Expand Up @@ -126,15 +125,14 @@ def _get_message_groups(
current_slice_pages: List[StreamReadPages] = []
current_page_request: Optional[HttpRequest] = None
current_page_response: Optional[HttpResponse] = None
had_error = False

while records_count < limit and (message := next(messages, None)):
json_object = self._parse_json(message.log) if message.type == MessageType.LOG else None
if json_object is not None and not isinstance(json_object, dict):
raise ValueError(f"Expected log message to be a dict, got {json_object} of type {type(json_object)}")
json_message: Optional[Dict[str, JsonType]] = json_object
if self._need_to_close_page(at_least_one_page_in_group, message, json_message):
self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records, True)
self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records)
current_page_request = None
current_page_response = None

Expand Down Expand Up @@ -172,12 +170,9 @@ def _get_message_groups(
current_page_request = self._create_request_from_log_message(json_message)
current_page_response = self._create_response_from_log_message(json_message)
else:
if message.log.level == Level.ERROR:
had_error = True
yield message.log
elif message.type == MessageType.TRACE:
if message.trace.type == TraceType.ERROR:
had_error = True
yield message.trace
elif message.type == MessageType.RECORD:
current_page_records.append(message.record.data)
Expand All @@ -187,7 +182,7 @@ def _get_message_groups(
elif message.type == MessageType.CONTROL and message.control.type == OrchestratorType.CONNECTOR_CONFIG:
yield message.control
else:
self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records, validate_page_complete=not had_error)
self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records)
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor)

@staticmethod
Expand Down Expand Up @@ -224,15 +219,10 @@ def _is_auxiliary_http_request(message: Optional[Dict[str, Any]]) -> bool:
return is_http and message.get("http", {}).get("is_auxiliary", False)

@staticmethod
def _close_page(current_page_request: Optional[HttpRequest], current_page_response: Optional[HttpResponse], current_slice_pages: List[StreamReadPages], current_page_records: List[Mapping[str, Any]], validate_page_complete: bool) -> None:
def _close_page(current_page_request: Optional[HttpRequest], current_page_response: Optional[HttpResponse], current_slice_pages: List[StreamReadPages], current_page_records: List[Mapping[str, Any]]) -> None:
"""
Close a page when parsing message groups
@param validate_page_complete: in some cases, we expect the CDK to not return a response. As of today, this will only happen before
an uncaught exception and therefore, the assumption is that `validate_page_complete=True` only on the last page that is being closed
"""
if validate_page_complete and (not current_page_request or not current_page_response):
raise ValueError("Every message grouping should have at least one request and response")

current_slice_pages.append(
StreamReadPages(request=current_page_request, response=current_page_response, records=deepcopy(current_page_records)) # type: ignore
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,25 +367,6 @@ def test_get_grouped_messages_no_records(mock_entrypoint_read: Mock) -> None:
assert actual_page == expected_pages[i]


@patch('airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read')
def test_get_grouped_messages_invalid_group_format(mock_entrypoint_read: Mock) -> None:
response = {"status_code": 200, "headers": {"field": "value"}, "body": '{"name": "field"}'}

mock_source = make_mock_source(mock_entrypoint_read, iter(
[
response_log_message(response),
record_message("hashiras", {"name": "Shinobu Kocho"}),
record_message("hashiras", {"name": "Muichiro Tokito"}),
]
)
)

api = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES)

with pytest.raises(ValueError):
api.get_message_groups(source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"))


@pytest.mark.parametrize(
"log_message, expected_response",
[
Expand Down