Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from datetime import datetime
from datetime import datetime, timezone
from unittest import mock

import pytest
Expand All @@ -28,18 +28,16 @@


class TestExponentialBackoffRetry:
@mock.patch("airflow.utils.timezone.utcnow")
def test_exponential_backoff_retry_base_case(self, mock_utcnow):
mock_utcnow.return_value = datetime(2023, 1, 1, 12, 0, 5)
def test_exponential_backoff_retry_base_case(self, time_machine):
time_machine.move_to(datetime(2023, 1, 1, 12, 0, 5))
mock_callable_function = mock.Mock()
exponential_backoff_retry(
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0),
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
attempts_since_last_successful=0,
callable_function=mock_callable_function,
)
mock_callable_function.assert_called_once()

@mock.patch("airflow.utils.timezone.utcnow")
@pytest.mark.parametrize(
"attempt_number, utcnow_value, expected_calls",
[
Expand Down Expand Up @@ -111,92 +109,88 @@ def test_exponential_backoff_retry_base_case(self, mock_utcnow):
],
)
def test_exponential_backoff_retry_parameterized(
self, mock_utcnow, attempt_number, utcnow_value, expected_calls
self, attempt_number, utcnow_value, expected_calls, time_machine
):
time_machine.move_to(utcnow_value)
mock_callable_function = mock.Mock()
mock_callable_function.__name__ = "test_callable_function"
mock_callable_function.side_effect = Exception()
mock_utcnow.return_value = utcnow_value

exponential_backoff_retry(
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0),
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
attempts_since_last_successful=attempt_number,
callable_function=mock_callable_function,
)
assert mock_callable_function.call_count == expected_calls

@mock.patch("airflow.utils.timezone.utcnow")
def test_exponential_backoff_retry_fail_success(self, mock_utcnow, caplog):
def test_exponential_backoff_retry_fail_success(self, time_machine, caplog):
mock_callable_function = mock.Mock()
mock_callable_function.__name__ = "test_callable_function"
mock_callable_function.side_effect = [Exception(), True]
mock_utcnow.return_value = datetime(2023, 1, 1, 12, 0, 2)
time_machine.move_to(datetime(2023, 1, 1, 12, 0, 2))
exponential_backoff_retry(
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0),
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
attempts_since_last_successful=0,
callable_function=mock_callable_function,
)
mock_callable_function.assert_called_once()
assert any("Error calling" in log for log in caplog.messages)
caplog.clear() # clear messages so that we have clean logs for the next call

mock_utcnow.return_value = datetime(2023, 1, 1, 12, 0, 6)
time_machine.move_to(datetime(2023, 1, 1, 12, 0, 6))
exponential_backoff_retry(
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0),
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
attempts_since_last_successful=1,
callable_function=mock_callable_function,
)
assert all("Error calling" not in log for log in caplog.messages)

@mock.patch("airflow.utils.timezone.utcnow")
def test_exponential_backoff_retry_max_delay(self, mock_utcnow):
def test_exponential_backoff_retry_max_delay(self, time_machine):
mock_callable_function = mock.Mock()
mock_callable_function.__name__ = "test_callable_function"
mock_callable_function.return_value = Exception()
mock_utcnow.return_value = datetime(2023, 1, 1, 12, 4, 15)
time_machine.move_to(datetime(2023, 1, 1, 12, 4, 15))
exponential_backoff_retry(
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0),
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
attempts_since_last_successful=4,
callable_function=mock_callable_function,
max_delay=60 * 5,
)
mock_callable_function.assert_not_called() # delay is 256 seconds; no calls made
mock_utcnow.return_value = datetime(2023, 1, 1, 12, 4, 16)
time_machine.move_to(datetime(2023, 1, 1, 12, 4, 16))
exponential_backoff_retry(
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0),
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
attempts_since_last_successful=4,
callable_function=mock_callable_function,
max_delay=60 * 5,
)
mock_callable_function.assert_called_once()

mock_utcnow.return_value = datetime(2023, 1, 1, 12, 5, 0)
time_machine.move_to(datetime(2023, 1, 1, 12, 5, 0))
exponential_backoff_retry(
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0),
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
attempts_since_last_successful=5,
callable_function=mock_callable_function,
max_delay=60 * 5,
)
# delay should be 4^5=1024 seconds, but max_delay is 60*5=300 seconds
assert mock_callable_function.call_count == 2

@mock.patch("airflow.utils.timezone.utcnow")
def test_exponential_backoff_retry_max_attempts(self, mock_utcnow, caplog):
def test_exponential_backoff_retry_max_attempts(self, time_machine, caplog):
mock_callable_function = mock.Mock()
mock_callable_function.__name__ = "test_callable_function"
mock_callable_function.return_value = Exception()
mock_utcnow.return_value = datetime(2023, 1, 1, 12, 55, 0)
time_machine.move_to(datetime(2023, 1, 1, 12, 55, 0))
for i in range(10):
exponential_backoff_retry(
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0),
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
attempts_since_last_successful=i,
callable_function=mock_callable_function,
max_attempts=3,
)
assert any("Max attempts reached." in log for log in caplog.messages)
assert mock_callable_function.call_count == 3

@mock.patch("airflow.utils.timezone.utcnow")
@pytest.mark.parametrize(
"attempt_number, utcnow_value, expected_calls",
[
Expand Down Expand Up @@ -268,15 +262,15 @@ def test_exponential_backoff_retry_max_attempts(self, mock_utcnow, caplog):
],
)
def test_exponential_backoff_retry_exponent_base_parameterized(
self, mock_utcnow, attempt_number, utcnow_value, expected_calls
self, time_machine, attempt_number, utcnow_value, expected_calls
):
mock_callable_function = mock.Mock()
mock_callable_function.__name__ = "test_callable_function"
mock_callable_function.side_effect = Exception()
mock_utcnow.return_value = utcnow_value
time_machine.move_to(utcnow_value)

exponential_backoff_retry(
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0),
last_attempt_time=datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
attempts_since_last_successful=attempt_number,
callable_function=mock_callable_function,
exponent_base=3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,12 @@ def test_create_ol_event_pair_success(mock_generate_uuid, is_successful):

@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
@mock.patch("airflow.utils.timezone.utcnow")
def test_emit_openlineage_events_for_databricks_queries(mock_now, mock_generate_uuid, mock_version):
def test_emit_openlineage_events_for_databricks_queries(mock_generate_uuid, mock_version, time_machine):
fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
mock_generate_uuid.return_value = fake_uuid

default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
mock_now.return_value = default_event_time
time_machine.move_to(default_event_time, tick=False)

query_ids = ["query1", "query2", "query3"]
original_query_ids = copy.deepcopy(query_ids)
Expand Down Expand Up @@ -523,15 +522,14 @@ def test_emit_openlineage_events_for_databricks_queries(mock_now, mock_generate_

@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
@mock.patch("airflow.utils.timezone.utcnow")
def test_emit_openlineage_events_for_databricks_queries_without_metadata(
mock_now, mock_generate_uuid, mock_version
mock_generate_uuid, mock_version, time_machine
):
fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
mock_generate_uuid.return_value = fake_uuid

default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
mock_now.return_value = default_event_time
time_machine.move_to(default_event_time, tick=False)

query_ids = ["query1"]
original_query_ids = copy.deepcopy(query_ids)
Expand Down Expand Up @@ -642,15 +640,14 @@ def test_emit_openlineage_events_for_databricks_queries_without_metadata(

@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
@mock.patch("airflow.utils.timezone.utcnow")
def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids(
mock_now, mock_generate_uuid, mock_version
mock_generate_uuid, mock_version, time_machine
):
fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
mock_generate_uuid.return_value = fake_uuid

default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
mock_now.return_value = default_event_time
time_machine.move_to(default_event_time, tick=False)

query_ids = ["query1"]
hook = mock.MagicMock()
Expand Down Expand Up @@ -765,15 +762,14 @@ def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_i
)
@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
@mock.patch("airflow.utils.timezone.utcnow")
def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace(
mock_now, mock_generate_uuid, mock_version, mock_parser
mock_generate_uuid, mock_version, mock_parser, time_machine
):
fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
mock_generate_uuid.return_value = fake_uuid

default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
mock_now.return_value = default_event_time
time_machine.move_to(default_event_time, tick=False)

query_ids = ["query1"]
hook = mock.MagicMock()
Expand Down Expand Up @@ -884,15 +880,14 @@ def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_i

@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
@mock.patch("airflow.utils.timezone.utcnow")
def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace_raw_ns(
mock_now, mock_generate_uuid, mock_version
mock_generate_uuid, mock_version, time_machine
):
fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
mock_generate_uuid.return_value = fake_uuid

default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
mock_now.return_value = default_event_time
time_machine.move_to(default_event_time, tick=False)

query_ids = ["query1"]
hook = DatabricksHook()
Expand Down Expand Up @@ -1004,15 +999,14 @@ def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_i

@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
@mock.patch("airflow.utils.timezone.utcnow")
def test_emit_openlineage_events_for_databricks_queries_ith_query_ids_and_hook_query_ids(
mock_now, mock_generate_uuid, mock_version
mock_generate_uuid, mock_version, time_machine
):
fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
mock_generate_uuid.return_value = fake_uuid

default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
mock_now.return_value = default_event_time
time_machine.move_to(default_event_time, tick=False)

hook = DatabricksSqlHook()
hook.query_ids = ["query2", "query3"]
Expand Down
Loading