Skip to content

Commit

Permalink
Avoid using Pendulum classes directly
Browse files Browse the repository at this point in the history
Pendulum does not implement caching in the constructor, but with a
wrapper function. This causes issues in tests due to classes such as
Timezone does not implement equality, making asserts difficult.

By using the wrapper functions instead, we can help Pendulum produces
equal values more consistently.
  • Loading branch information
uranusjr committed Dec 7, 2023
1 parent ca75462 commit 8944682
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 20 deletions.
5 changes: 3 additions & 2 deletions airflow/timetables/_cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@
import datetime
from typing import TYPE_CHECKING, Any

import pendulum
from cron_descriptor import CasingTypeEnum, ExpressionDescriptor, FormatException, MissingFieldException
from croniter import CroniterBadCronError, CroniterBadDateError, croniter
from pendulum.tz.timezone import Timezone

from airflow.exceptions import AirflowTimetableInvalid
from airflow.utils.dates import cron_presets
from airflow.utils.timezone import convert_to_utc, make_aware, make_naive

if TYPE_CHECKING:
from pendulum import DateTime
from pendulum.tz.timezone import Timezone


def _covers_every_hour(cron: croniter) -> bool:
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(self, cron: str, timezone: str | Timezone) -> None:
self._expression = cron_presets.get(cron, cron)

if isinstance(timezone, str):
timezone = Timezone(timezone)
timezone = pendulum.tz.timezone(timezone)
self._timezone = timezone

try:
Expand Down
4 changes: 2 additions & 2 deletions kubernetes_tests/test_kubernetes_pod_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
from unittest.mock import ANY, MagicMock
from uuid import uuid4

import pendulum
import pytest
from kubernetes import client
from kubernetes.client import V1EnvVar, V1PodSecurityContext, V1SecurityContext, models as k8s
from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException
from pendulum.tz.timezone import Timezone

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models.connection import Connection
Expand All @@ -53,7 +53,7 @@

def create_context(task) -> Context:
dag = DAG(dag_id="dag")
execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=Timezone("Europe/Amsterdam"))
execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=pendulum.tz.timezone("Europe/Amsterdam"))
dag_run = DagRun(
dag_id=dag.dag_id,
execution_date=execution_date,
Expand Down
5 changes: 2 additions & 3 deletions tests/api_connexion/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@

from unittest import mock

import pendulum
import pytest
from pendulum import DateTime
from pendulum.tz.timezone import Timezone

from airflow.api_connexion.exceptions import BadRequest
from airflow.api_connexion.parameters import (
Expand Down Expand Up @@ -106,7 +105,7 @@ def test_should_works_with_datetime_formatter(self):

decorated_endpoint(param_a="2020-01-01T0:0:00+00:00")

endpoint.assert_called_once_with(param_a=DateTime(2020, 1, 1, 0, tzinfo=Timezone("UTC")))
endpoint.assert_called_once_with(param_a=pendulum.datetime(2020, 1, 1, 0, tz="UTC"))

def test_should_propagate_exceptions(self):
decorator = format_parameters({"param_a": format_datetime})
Expand Down
17 changes: 9 additions & 8 deletions tests/providers/cncf/kubernetes/utils/test_pod_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@
from datetime import datetime
from json.decoder import JSONDecodeError
from types import SimpleNamespace
from typing import cast
from typing import TYPE_CHECKING, cast
from unittest import mock
from unittest.mock import MagicMock

import pendulum
import pytest
import time_machine
from kubernetes.client.rest import ApiException
from pendulum import DateTime
from pendulum.tz.timezone import Timezone
from urllib3.exceptions import HTTPError as BaseHTTPError

from airflow.exceptions import AirflowException
Expand All @@ -43,6 +41,9 @@
)
from airflow.utils.timezone import utc

if TYPE_CHECKING:
from pendulum import DateTime


class TestPodManager:
def setup_method(self):
Expand Down Expand Up @@ -269,7 +270,7 @@ def test_fetch_container_logs_returning_last_timestamp(

status = self.pod_manager.fetch_container_logs(mock.MagicMock(), mock.MagicMock(), follow=True)

assert status.last_log_time == cast(DateTime, pendulum.parse(timestamp_string))
assert status.last_log_time == cast("DateTime", pendulum.parse(timestamp_string))

@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs")
Expand Down Expand Up @@ -306,7 +307,7 @@ def consumer_iter():
mock_consumer_iter.side_effect = consumer_iter
mock_container_is_running.side_effect = [True, True, False]
status = self.pod_manager.fetch_container_logs(mock.MagicMock(), mock.MagicMock(), follow=True)
assert status.last_log_time == cast(DateTime, pendulum.parse(last_timestamp_string))
assert status.last_log_time == cast("DateTime", pendulum.parse(last_timestamp_string))
assert self.mock_progress_callback.call_count == expected_call_count

@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
Expand Down Expand Up @@ -461,13 +462,13 @@ def test_fetch_requested_container_logs_invalid(self, container_running, contain
def test_fetch_container_since_time(self, logs_available, container_running, mock_now):
"""If given since_time, should be used."""
mock_pod = MagicMock()
mock_now.return_value = DateTime(2020, 1, 1, 0, 0, 5, tzinfo=Timezone("UTC"))
mock_now.return_value = pendulum.datetime(2020, 1, 1, 0, 0, 5, tz="UTC")
logs_available.return_value = True
container_running.return_value = False
self.mock_kube_client.read_namespaced_pod_log.return_value = mock.MagicMock(
stream=mock.MagicMock(return_value=[b"2021-01-01 hi"])
)
since_time = DateTime(2020, 1, 1, tzinfo=Timezone("UTC"))
since_time = pendulum.datetime(2020, 1, 1, tz="UTC")
self.pod_manager.fetch_container_logs(pod=mock_pod, container_name="base", since_time=since_time)
args, kwargs = self.mock_kube_client.read_namespaced_pod_log.call_args_list[0]
assert kwargs["since_seconds"] == 5
Expand All @@ -488,7 +489,7 @@ def test_fetch_container_running_follow(
)
ret = self.pod_manager.fetch_container_logs(pod=mock_pod, container_name="base", follow=follow)
assert len(container_running_mock.call_args_list) == is_running_calls
assert ret.last_log_time == DateTime(2021, 1, 1, tzinfo=Timezone("UTC"))
assert ret.last_log_time == pendulum.datetime(2021, 1, 1, tz="UTC")
assert ret.running is exp_running

@pytest.mark.parametrize(
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/openlineage/plugins/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from json import JSONEncoder
from typing import Any

import pendulum
import pytest
from attrs import define
from openlineage.client.utils import RedactMixin
from pendulum.tz.timezone import Timezone
from pkg_resources import parse_version

from airflow.models import DAG as AIRFLOW_DAG, DagModel
Expand Down Expand Up @@ -86,8 +86,8 @@ def test_get_dagrun_start_end():
state=State.NONE, run_id=run_id, data_interval=dag.get_next_data_interval(dag_model)
)
assert dagrun.data_interval_start is not None
start_date_tz = datetime.datetime(2022, 1, 1, tzinfo=Timezone("UTC"))
end_date_tz = datetime.datetime(2022, 1, 1, hour=2, tzinfo=Timezone("UTC"))
start_date_tz = datetime.datetime(2022, 1, 1, tzinfo=pendulum.tz.timezone("UTC"))
end_date_tz = datetime.datetime(2022, 1, 1, hour=2, tzinfo=pendulum.tz.timezone("UTC"))
assert dagrun.data_interval_start, dagrun.data_interval_end == (start_date_tz, end_date_tz)


Expand Down
4 changes: 2 additions & 2 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
import json
from datetime import datetime, timedelta

import pendulum
import pytest
from dateutil import relativedelta
from kubernetes.client import models as k8s
from pendulum.tz.timezone import Timezone

from airflow.datasets import Dataset
from airflow.exceptions import SerializationError
Expand Down Expand Up @@ -142,7 +142,7 @@ def equal_time(a: datetime, b: datetime) -> bool:
(1, None, equals),
(datetime.utcnow(), DAT.DATETIME, equal_time),
(timedelta(minutes=2), DAT.TIMEDELTA, equals),
(Timezone("UTC"), DAT.TIMEZONE, lambda a, b: a.name == b.name),
(pendulum.tz.timezone("UTC"), DAT.TIMEZONE, lambda a, b: a.name == b.name),
(relativedelta.relativedelta(hours=+1), DAT.RELATIVEDELTA, lambda a, b: a.hours == b.hours),
({"test": "dict", "test-1": 1}, None, equals),
(["array_item", 2], None, equals),
Expand Down

0 comments on commit 8944682

Please sign in to comment.