diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 650c593919c7c..dce990ba22ebf 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -2715,7 +2715,7 @@ def test_inlet_asset_extra(self, dag_maker, session, mock_supervisor_comms): AssetEventResponse( id=1, created_dagruns=[], - timestamp=datetime.datetime.now(), + timestamp=timezone.utcnow(), extra={"from": f"write{i}"}, asset=AssetResponse( name="test_inlet_asset_extra", uri="test_inlet_asset_extra", group="asset" @@ -2791,7 +2791,7 @@ def test_inlet_asset_alias_extra(self, dag_maker, session, mock_supervisor_comms AssetEventResponse( id=1, created_dagruns=[], - timestamp=datetime.datetime.now(), + timestamp=timezone.utcnow(), extra={"from": f"write{i}"}, asset=AssetResponse( name="test_inlet_asset_extra_ds", uri="test_inlet_asset_extra_ds", group="asset" @@ -2914,7 +2914,7 @@ def test_inlet_asset_extra_slice(self, dag_maker, session, slicer, expected, moc AssetEventResponse( id=1, created_dagruns=[], - timestamp=datetime.datetime.now(), + timestamp=timezone.utcnow(), extra={"from": i}, asset=AssetResponse(name=asset_uri, uri=asset_uri, group="asset"), ) @@ -2981,7 +2981,7 @@ def test_inlet_asset_alias_extra_slice(self, dag_maker, session, slicer, expecte AssetEventResponse( id=1, created_dagruns=[], - timestamp=datetime.datetime.now(), + timestamp=timezone.utcnow(), extra={"from": i}, asset=AssetResponse(name=asset_uri, uri=asset_uri, group="asset"), ) diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py index 436907571b579..7da47ee59655e 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import datetime as dt import uuid from collections import defaultdict from concurrent.futures import Future @@ -38,7 +37,7 @@ from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet from airflow.providers.openlineage.plugins.listener import OpenLineageListener from airflow.providers.openlineage.utils.selective_enable import disable_lineage, enable_lineage -from airflow.utils import types +from airflow.utils import timezone, types from airflow.utils.state import DagRunState, State from tests_common.test_utils.compat import EmptyOperator, PythonOperator @@ -105,7 +104,7 @@ class TestOpenLineageListenerAirflow2: @patch("airflow.models.BaseOperator.render_template") def test_listener_does_not_change_task_instance(self, render_mock, xcom_push_mock): render_mock.return_value = render_df() - date = dt.datetime(2022, 1, 1) + date = timezone.datetime(2022, 1, 1) dag = DAG( "test", schedule=None, @@ -184,7 +183,7 @@ def sample_callable(**kwargs): task_instance = _create_test_dag_and_task(sample_callable, "sample_scenario") # Use task_instance to simulate running a task in a test. """ - date = dt.datetime(2022, 1, 1) + date = timezone.datetime(2022, 1, 1) dag = DAG( f"test_{scenario_name}", schedule=None, @@ -244,7 +243,7 @@ def mock_task_id(dag_id, task_id, try_number, logical_date, map_index): task_instance.dag_run.data_interval_start = None task_instance.dag_run.data_interval_end = None task_instance.dag_run.clear_number = 0 - task_instance.dag_run.execution_date = dt.datetime(2020, 1, 1, 1, 1, 1) + task_instance.dag_run.execution_date = timezone.datetime(2020, 1, 1, 1, 1, 1) task_instance.task = mock.Mock() task_instance.task.task_id = "task_id" task_instance.task.dag = mock.Mock() @@ -257,9 +256,9 @@ def mock_task_id(dag_id, task_id, try_number, logical_date, map_index): task_instance.run_id = "dag_run_run_id" task_instance.try_number = 1 task_instance.state = State.RUNNING - task_instance.start_date = dt.datetime(2023, 1, 1, 13, 1, 1) - task_instance.end_date = dt.datetime(2023, 1, 3, 13, 1, 1) - task_instance.logical_date = dt.datetime(2020, 1, 1, 1, 1, 1) + task_instance.start_date = timezone.datetime(2023, 1, 1, 13, 1, 1) + task_instance.end_date = timezone.datetime(2023, 1, 3, 13, 1, 1) + task_instance.logical_date = timezone.datetime(2020, 1, 1, 1, 1, 1) task_instance.map_index = -1 task_instance.next_method = None # Ensure this is None to reach start_task task_instance.get_template_context = mock.MagicMock() # type: ignore[method-assign] @@ -303,12 +302,12 @@ def test_adapter_start_task_is_called_with_proper_arguments( listener.on_task_instance_running(None, task_instance, None) listener.adapter.start_task.assert_called_once_with( - run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1", + run_id="2020-01-01T01:01:01+00:00.dag_id.task_id.1.-1", job_name="job_name", job_description="Test DAG Description", - event_time="2023-01-01T13:01:01", + event_time="2023-01-01T13:01:01+00:00", parent_job_name="dag_id", - parent_run_id="2020-01-01T01:01:01.dag_id.0", + parent_run_id="2020-01-01T01:01:01+00:00.dag_id.0", code_location=None, nominal_start_time=None, nominal_end_time=None, @@ -330,7 +329,7 @@ def test_adapter_start_task_is_called_with_proper_arguments( @mock.patch( "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call ) - @mock.patch("airflow.utils.timezone.utcnow", return_value=dt.datetime(2023, 1, 3, 13, 1, 1)) + @mock.patch("airflow.utils.timezone.utcnow", return_value=timezone.datetime(2023, 1, 3, 13, 1, 1)) def test_adapter_fail_task_is_called_with_proper_arguments( self, mock_utcnow, @@ -349,7 +348,7 @@ def test_adapter_fail_task_is_called_with_proper_arguments( """ listener, task_instance = self._create_listener_and_task_instance() - task_instance.logical_date = dt.datetime(2020, 1, 1, 1, 1, 1) + task_instance.logical_date = timezone.datetime(2020, 1, 1, 1, 1, 1) mock_get_job_name.return_value = "job_name" mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2} mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}} @@ -363,11 +362,11 @@ def test_adapter_fail_task_is_called_with_proper_arguments( previous_state=None, task_instance=task_instance, **on_task_failed_listener_kwargs, session=None ) listener.adapter.fail_task.assert_called_once_with( - end_time="2023-01-03T13:01:01", + end_time="2023-01-03T13:01:01+00:00", job_name="job_name", parent_job_name="dag_id", - parent_run_id="2020-01-01T01:01:01.dag_id.0", - run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1", + parent_run_id="2020-01-01T01:01:01+00:00.dag_id.0", + run_id="2020-01-01T01:01:01+00:00.dag_id.task_id.1.-1", task=listener.extractor_manager.extract_metadata(), run_facets={ "custom_user_facet": 2, @@ -385,7 +384,7 @@ def test_adapter_fail_task_is_called_with_proper_arguments( @mock.patch( "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call ) - @mock.patch("airflow.utils.timezone.utcnow", return_value=dt.datetime(2023, 1, 3, 13, 1, 1)) + @mock.patch("airflow.utils.timezone.utcnow", return_value=timezone.datetime(2023, 1, 3, 13, 1, 1)) def test_adapter_complete_task_is_called_with_proper_arguments( self, mock_utcnow, @@ -416,11 +415,11 @@ def test_adapter_complete_task_is_called_with_proper_arguments( calls = listener.adapter.complete_task.call_args_list assert len(calls) == 1 assert calls[0][1] == dict( - end_time="2023-01-03T13:01:01", + end_time="2023-01-03T13:01:01+00:00", job_name="job_name", parent_job_name="dag_id", - parent_run_id="2020-01-01T01:01:01.dag_id.0", - run_id=f"2020-01-01T01:01:01.dag_id.task_id.{EXPECTED_TRY_NUMBER_1}.-1", + parent_run_id="2020-01-01T01:01:01+00:00.dag_id.0", + run_id=f"2020-01-01T01:01:01+00:00.dag_id.task_id.{EXPECTED_TRY_NUMBER_1}.-1", task=listener.extractor_manager.extract_metadata(), run_facets={ "custom_user_facet": 2, @@ -444,7 +443,7 @@ def test_on_task_instance_running_correctly_calls_openlineage_adapter_run_id_met listener.adapter.build_task_instance_run_id.assert_called_once_with( dag_id="dag_id", task_id="task_id", - logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + logical_date=timezone.datetime(2020, 1, 1, 1, 1, 1), try_number=1, map_index=-1, ) @@ -468,7 +467,7 @@ def test_on_task_instance_failed_correctly_calls_openlineage_adapter_run_id_meth listener.adapter.build_task_instance_run_id.assert_called_once_with( dag_id="dag_id", task_id="task_id", - logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + logical_date=timezone.datetime(2020, 1, 1, 1, 1, 1), try_number=1, map_index=-1, ) @@ -488,7 +487,7 @@ def test_on_task_instance_success_correctly_calls_openlineage_adapter_run_id_met listener.adapter.build_task_instance_run_id.assert_called_once_with( dag_id="dag_id", task_id="task_id", - logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + logical_date=timezone.datetime(2020, 1, 1, 1, 1, 1), try_number=EXPECTED_TRY_NUMBER_1, map_index=-1, ) @@ -664,7 +663,7 @@ def set_result(*args, **kwargs): listener.adapter = OpenLineageAdapter( client=OpenLineageClient(transport=ConsoleTransport(config=ConsoleConfig())) ) - event_time = dt.datetime.now() + event_time = timezone.utcnow() fut = listener.submit_callable( listener.adapter.dag_failed, dag_id="", @@ -704,7 +703,7 @@ def test_listener_does_not_change_task_instance(self, render_mock, mock_supervis render_mock.return_value = render_df() - date = dt.datetime(2022, 1, 1) + date = timezone.datetime(2022, 1, 1) dag = DAG( "test", schedule=None, @@ -800,7 +799,7 @@ def sample_callable(**kwargs): task_instance = _create_test_dag_and_task(sample_callable, "sample_scenario") # Use task_instance to simulate running a task in a test. """ - date = dt.datetime(2022, 1, 1) + date = timezone.datetime(2022, 1, 1) dag = DAG( f"test_{scenario_name}", schedule=None, @@ -892,21 +891,21 @@ def mock_task_id(dag_id, task_id, try_number, logical_date, map_index): dag_run=SdkDagRun( dag_id="dag_id", run_id="dag_run_run_id", - logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + logical_date=timezone.datetime(2020, 1, 1, 1, 1, 1), data_interval_start=None, data_interval_end=None, - start_date=dt.datetime(2023, 1, 1, 13, 1, 1), - end_date=dt.datetime(2023, 1, 3, 13, 1, 1), + start_date=timezone.datetime(2023, 1, 1, 13, 1, 1), + end_date=timezone.datetime(2023, 1, 3, 13, 1, 1), clear_number=0, run_type=DagRunType.MANUAL, - run_after=dt.datetime(2023, 1, 3, 13, 1, 1), + run_after=timezone.datetime(2023, 1, 3, 13, 1, 1), conf=None, ), task_reschedule_count=0, max_tries=1, should_retry=False, ), - start_date=dt.datetime(2023, 1, 1, 13, 1, 1), + start_date=timezone.datetime(2023, 1, 1, 13, 1, 1), ) return listener, runtime_ti @@ -946,12 +945,12 @@ def test_adapter_start_task_is_called_with_proper_arguments( listener.on_task_instance_running(None, task_instance) listener.adapter.start_task.assert_called_once_with( - run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1", + run_id="2020-01-01T01:01:01+00:00.dag_id.task_id.1.-1", job_name="job_name", job_description="Test DAG Description", - event_time="2023-01-01T13:01:01", + event_time="2023-01-01T13:01:01+00:00", parent_job_name="dag_id", - parent_run_id="2020-01-01T01:01:01.dag_id.0", + parent_run_id="2020-01-01T01:01:01+00:00.dag_id.0", code_location=None, nominal_start_time=None, nominal_end_time=None, @@ -973,7 +972,7 @@ def test_adapter_start_task_is_called_with_proper_arguments( @mock.patch( "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call ) - @mock.patch("airflow.utils.timezone.utcnow", return_value=dt.datetime(2023, 1, 3, 13, 1, 1)) + @mock.patch("airflow.utils.timezone.utcnow", return_value=timezone.datetime(2023, 1, 3, 13, 1, 1)) def test_adapter_fail_task_is_called_with_proper_arguments( self, mock_utcnow, @@ -992,7 +991,7 @@ def test_adapter_fail_task_is_called_with_proper_arguments( """ listener, task_instance = self._create_listener_and_task_instance() - task_instance.get_template_context()["dag_run"].logical_date = dt.datetime(2020, 1, 1, 1, 1, 1) + task_instance.get_template_context()["dag_run"].logical_date = timezone.datetime(2020, 1, 1, 1, 1, 1) mock_get_job_name.return_value = "job_name" mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2} mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}} @@ -1006,11 +1005,11 @@ def test_adapter_fail_task_is_called_with_proper_arguments( previous_state=None, task_instance=task_instance, **on_task_failed_listener_kwargs ) listener.adapter.fail_task.assert_called_once_with( - end_time="2023-01-03T13:01:01", + end_time="2023-01-03T13:01:01+00:00", job_name="job_name", parent_job_name="dag_id", - parent_run_id="2020-01-01T01:01:01.dag_id.0", - run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1", + parent_run_id="2020-01-01T01:01:01+00:00.dag_id.0", + run_id="2020-01-01T01:01:01+00:00.dag_id.task_id.1.-1", task=listener.extractor_manager.extract_metadata(), run_facets={ "custom_user_facet": 2, @@ -1028,7 +1027,7 @@ def test_adapter_fail_task_is_called_with_proper_arguments( @mock.patch( "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call ) - @mock.patch("airflow.utils.timezone.utcnow", return_value=dt.datetime(2023, 1, 3, 13, 1, 1)) + @mock.patch("airflow.utils.timezone.utcnow", return_value=timezone.datetime(2023, 1, 3, 13, 1, 1)) def test_adapter_complete_task_is_called_with_proper_arguments( self, mock_utcnow, @@ -1059,11 +1058,11 @@ def test_adapter_complete_task_is_called_with_proper_arguments( calls = listener.adapter.complete_task.call_args_list assert len(calls) == 1 assert calls[0][1] == dict( - end_time="2023-01-03T13:01:01", + end_time="2023-01-03T13:01:01+00:00", job_name="job_name", parent_job_name="dag_id", - parent_run_id="2020-01-01T01:01:01.dag_id.0", - run_id=f"2020-01-01T01:01:01.dag_id.task_id.{EXPECTED_TRY_NUMBER_1}.-1", + parent_run_id="2020-01-01T01:01:01+00:00.dag_id.0", + run_id=f"2020-01-01T01:01:01+00:00.dag_id.task_id.{EXPECTED_TRY_NUMBER_1}.-1", task=listener.extractor_manager.extract_metadata(), run_facets={ "custom_user_facet": 2, @@ -1087,7 +1086,7 @@ def test_on_task_instance_running_correctly_calls_openlineage_adapter_run_id_met listener.adapter.build_task_instance_run_id.assert_called_once_with( dag_id="dag_id", task_id="task_id", - logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + logical_date=timezone.datetime(2020, 1, 1, 1, 1, 1), try_number=1, map_index=-1, ) @@ -1111,7 +1110,7 @@ def test_on_task_instance_failed_correctly_calls_openlineage_adapter_run_id_meth listener.adapter.build_task_instance_run_id.assert_called_once_with( dag_id="dag_id", task_id="task_id", - logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + logical_date=timezone.datetime(2020, 1, 1, 1, 1, 1), try_number=1, map_index=-1, ) @@ -1131,7 +1130,7 @@ def test_on_task_instance_success_correctly_calls_openlineage_adapter_run_id_met listener.adapter.build_task_instance_run_id.assert_called_once_with( dag_id="dag_id", task_id="task_id", - logical_date=dt.datetime(2020, 1, 1, 1, 1, 1), + logical_date=timezone.datetime(2020, 1, 1, 1, 1, 1), try_number=EXPECTED_TRY_NUMBER_1, map_index=-1, ) @@ -1248,7 +1247,7 @@ def set_result(*args, **kwargs): listener.adapter = OpenLineageAdapter( client=OpenLineageClient(transport=ConsoleTransport(config=ConsoleConfig())) ) - event_time = dt.datetime.now() + event_time = timezone.utcnow() fut = listener.submit_callable( listener.adapter.dag_failed, dag_id="", @@ -1270,7 +1269,7 @@ def set_result(*args, **kwargs): @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Airflow 2 tests") class TestOpenLineageSelectiveEnableAirflow2: def setup_method(self): - date = dt.datetime(2022, 1, 1) + date = timezone.datetime(2022, 1, 1) self.dag = DAG( "test_selective_enable", schedule=None, diff --git a/providers/standard/src/airflow/providers/standard/operators/datetime.py b/providers/standard/src/airflow/providers/standard/operators/datetime.py index 18c6b0a7a9aa5..a549b13354ced 100644 --- a/providers/standard/src/airflow/providers/standard/operators/datetime.py +++ b/providers/standard/src/airflow/providers/standard/operators/datetime.py @@ -80,9 +80,11 @@ def choose_branch(self, context: Context) -> str | Iterable[str]: now = context.get("logical_date") if not now: dag_run = context.get("dag_run") - now = dag_run.run_after # type: ignore[union-attr] + now = dag_run.run_after # type: ignore[union-attr, assignment] else: now = timezone.coerce_datetime(timezone.utcnow()) + if TYPE_CHECKING: + assert isinstance(now, datetime.datetime) lower, upper = target_times_as_dates(now, self.target_lower, self.target_upper) lower = timezone.coerce_datetime(lower, self.dag.timezone) upper = timezone.coerce_datetime(upper, self.dag.timezone) diff --git a/providers/standard/src/airflow/providers/standard/operators/weekday.py b/providers/standard/src/airflow/providers/standard/operators/weekday.py index bcae0b746c524..6d42ff2f3e0a0 100644 --- a/providers/standard/src/airflow/providers/standard/operators/weekday.py +++ b/providers/standard/src/airflow/providers/standard/operators/weekday.py @@ -119,7 +119,7 @@ def choose_branch(self, context: Context) -> str | Iterable[str]: now = context.get("logical_date") if not now: dag_run = context.get("dag_run") - now = dag_run.run_after # type: ignore[union-attr] + now = dag_run.run_after # type: ignore[union-attr, assignment] else: now = timezone.make_naive(timezone.utcnow(), self.dag.timezone) diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index d29bae0a51746..54ee6b7203276 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -128,7 +128,7 @@ enable-version-header=true enum-field-as-literal='one' # When a single enum member, make it output a `Literal["..."]` input-file-type='openapi' output-model-type='pydantic_v2.BaseModel' -output-datetime-class='datetime' +output-datetime-class='AwareDatetime' target-python-version='3.9' use-annotated=true use-default=true diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index af24f756675a3..d0c7ae17a37ed 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -20,12 +20,12 @@ # under the License. from __future__ import annotations -from datetime import datetime, timedelta +from datetime import timedelta from enum import Enum from typing import Annotated, Any, Final, Literal from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field, JsonValue +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue API_VERSION: Final[str] = "2025-03-26" @@ -88,12 +88,12 @@ class DagRunAssetReference(BaseModel): ) run_id: Annotated[str, Field(title="Run Id")] dag_id: Annotated[str, Field(title="Dag Id")] - logical_date: Annotated[datetime | None, Field(title="Logical Date")] = None - start_date: Annotated[datetime, Field(title="Start Date")] - end_date: Annotated[datetime | None, Field(title="End Date")] = None + logical_date: Annotated[AwareDatetime | None, Field(title="Logical Date")] = None + start_date: Annotated[AwareDatetime, Field(title="Start Date")] + end_date: Annotated[AwareDatetime | None, Field(title="End Date")] = None state: Annotated[str, Field(title="State")] - data_interval_start: Annotated[datetime | None, Field(title="Data Interval Start")] = None - data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None + data_interval_start: Annotated[AwareDatetime | None, Field(title="Data Interval Start")] = None + data_interval_end: Annotated[AwareDatetime | None, Field(title="Data Interval End")] = None class DagRunState(str, Enum): @@ -149,10 +149,10 @@ class PrevSuccessfulDagRunResponse(BaseModel): Schema for response with previous successful DagRun information for Task Template Context. """ - data_interval_start: Annotated[datetime | None, Field(title="Data Interval Start")] = None - data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None - start_date: Annotated[datetime | None, Field(title="Start Date")] = None - end_date: Annotated[datetime | None, Field(title="End Date")] = None + data_interval_start: Annotated[AwareDatetime | None, Field(title="Data Interval Start")] = None + data_interval_end: Annotated[AwareDatetime | None, Field(title="Data Interval End")] = None + start_date: Annotated[AwareDatetime | None, Field(title="Start Date")] = None + end_date: Annotated[AwareDatetime | None, Field(title="End Date")] = None class TIDeferredStatePayload(BaseModel): @@ -183,7 +183,7 @@ class TIEnterRunningPayload(BaseModel): hostname: Annotated[str, Field(title="Hostname")] unixname: Annotated[str, Field(title="Unixname")] pid: Annotated[int, Field(title="Pid")] - start_date: Annotated[datetime, Field(title="Start Date")] + start_date: Annotated[AwareDatetime, Field(title="Start Date")] class TIHeartbeatInfo(BaseModel): @@ -207,8 +207,8 @@ class TIRescheduleStatePayload(BaseModel): extra="forbid", ) state: Annotated[Literal["up_for_reschedule"] | None, Field(title="State")] = "up_for_reschedule" - reschedule_date: Annotated[datetime, Field(title="Reschedule Date")] - end_date: Annotated[datetime, Field(title="End Date")] + reschedule_date: Annotated[AwareDatetime, Field(title="Reschedule Date")] + end_date: Annotated[AwareDatetime, Field(title="End Date")] class TIRetryStatePayload(BaseModel): @@ -220,7 +220,7 @@ class TIRetryStatePayload(BaseModel): extra="forbid", ) state: Annotated[Literal["up_for_retry"] | None, Field(title="State")] = "up_for_retry" - end_date: Annotated[datetime, Field(title="End Date")] + end_date: Annotated[AwareDatetime, Field(title="End Date")] class TISkippedDownstreamTasksStatePayload(BaseModel): @@ -243,7 +243,7 @@ class TISuccessStatePayload(BaseModel): extra="forbid", ) state: Annotated[Literal["success"] | None, Field(title="State")] = "success" - end_date: Annotated[datetime, Field(title="End Date")] + end_date: Annotated[AwareDatetime, Field(title="End Date")] task_outlets: Annotated[list[AssetProfile] | None, Field(title="Task Outlets")] = None outlet_events: Annotated[list[dict[str, Any]] | None, Field(title="Outlet Events")] = None @@ -277,7 +277,7 @@ class TriggerDAGRunPayload(BaseModel): model_config = ConfigDict( extra="forbid", ) - logical_date: Annotated[datetime | None, Field(title="Logical Date")] = None + logical_date: Annotated[AwareDatetime | None, Field(title="Logical Date")] = None conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None reset_dag_run: Annotated[bool | None, Field(title="Reset Dag Run")] = False @@ -360,7 +360,7 @@ class AssetEventResponse(BaseModel): """ id: Annotated[int, Field(title="Id")] - timestamp: Annotated[datetime, Field(title="Timestamp")] + timestamp: Annotated[AwareDatetime, Field(title="Timestamp")] extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None asset: AssetResponse created_dagruns: Annotated[list[DagRunAssetReference], Field(title="Created Dagruns")] @@ -388,12 +388,12 @@ class DagRun(BaseModel): ) dag_id: Annotated[str, Field(title="Dag Id")] run_id: Annotated[str, Field(title="Run Id")] - logical_date: Annotated[datetime | None, Field(title="Logical Date")] = None - data_interval_start: Annotated[datetime | None, Field(title="Data Interval Start")] = None - data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None - run_after: Annotated[datetime, Field(title="Run After")] - start_date: Annotated[datetime, Field(title="Start Date")] - end_date: Annotated[datetime | None, Field(title="End Date")] = None + logical_date: Annotated[AwareDatetime | None, Field(title="Logical Date")] = None + data_interval_start: Annotated[AwareDatetime | None, Field(title="Data Interval Start")] = None + data_interval_end: Annotated[AwareDatetime | None, Field(title="Data Interval End")] = None + run_after: Annotated[AwareDatetime, Field(title="Run After")] + start_date: Annotated[AwareDatetime, Field(title="Start Date")] + end_date: Annotated[AwareDatetime | None, Field(title="End Date")] = None clear_number: Annotated[int | None, Field(title="Clear Number")] = 0 run_type: DagRunType conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None @@ -429,4 +429,4 @@ class TITerminalStatePayload(BaseModel): extra="forbid", ) state: TerminalStateNonSuccess - end_date: Annotated[datetime, Field(title="End Date")] + end_date: Annotated[AwareDatetime, Field(title="End Date")] diff --git a/task-sdk/src/airflow/sdk/bases/notifier.py b/task-sdk/src/airflow/sdk/bases/notifier.py index db79a0db4482c..457abc7f1407c 100644 --- a/task-sdk/src/airflow/sdk/bases/notifier.py +++ b/task-sdk/src/airflow/sdk/bases/notifier.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: import jinja2 - from airflow import DAG + from airflow.sdk import DAG from airflow.sdk.definitions.context import Context diff --git a/task-sdk/src/airflow/sdk/definitions/context.py b/task-sdk/src/airflow/sdk/definitions/context.py index c54a5edd6a878..6580b8bcf5e81 100644 --- a/task-sdk/src/airflow/sdk/definitions/context.py +++ b/task-sdk/src/airflow/sdk/definitions/context.py @@ -21,8 +21,7 @@ from typing import TYPE_CHECKING, Any, NamedTuple, TypedDict if TYPE_CHECKING: - # TODO: Should we use pendulum.DateTime instead of datetime like AF 2.x? - from datetime import datetime + from pendulum import DateTime from airflow.models.operator import Operator from airflow.sdk.bases.operator import BaseOperator @@ -41,8 +40,8 @@ class Context(TypedDict, total=False): conn: Any dag: DAG dag_run: DagRunProtocol - data_interval_end: datetime | None - data_interval_start: datetime | None + data_interval_end: DateTime | None + data_interval_start: DateTime | None outlet_events: OutletEventAccessorsProtocol ds: str ds_nodash: str @@ -50,18 +49,18 @@ class Context(TypedDict, total=False): exception: None | str | BaseException inlets: list inlet_events: InletEventsAccessors - logical_date: datetime + logical_date: DateTime macros: Any map_index_template: str | None outlets: list params: dict[str, Any] - prev_data_interval_start_success: datetime | None - prev_data_interval_end_success: datetime | None - prev_start_date_success: datetime | None - prev_end_date_success: datetime | None + prev_data_interval_start_success: DateTime | None + prev_data_interval_end_success: DateTime | None + prev_start_date_success: DateTime | None + prev_end_date_success: DateTime | None reason: str | None run_id: str - start_date: datetime + start_date: DateTime # TODO: Remove Operator from below once we have MappedOperator to the Task SDK # and once we can remove context related code from the Scheduler/models.TaskInstance task: BaseOperator | Operator diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index fa649e7d76a13..d579aa8fb4866 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -48,7 +48,7 @@ from uuid import UUID from fastapi import Body -from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_serializer +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, field_serializer from airflow.sdk.api.datamodels._generated import ( AssetEventsResponse, @@ -218,7 +218,7 @@ def from_dagrun_response(cls, prev_dag_run: PrevSuccessfulDagRunResponse) -> Pre class TaskRescheduleStartDate(BaseModel): """Response containing the first reschedule date for a task instance.""" - start_date: datetime | None + start_date: AwareDatetime | None type: Literal["TaskRescheduleStartDate"] = "TaskRescheduleStartDate" diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 4d612a03708ae..9388cab18e155 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -35,7 +35,7 @@ import attrs import lazy_object_proxy import structlog -from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager @@ -87,9 +87,11 @@ from airflow.sdk.execution_time.xcom import XCom from airflow.utils.net import get_hostname from airflow.utils.state import TaskInstanceState +from airflow.utils.timezone import coerce_datetime if TYPE_CHECKING: import jinja2 + from pendulum.datetime import DateTime from structlog.typing import FilteringBoundLogger as Logger from airflow.exceptions import DagRunTriggerException @@ -116,7 +118,7 @@ class RuntimeTaskInstance(TaskInstance): max_tries: int = 0 """The maximum number of retries for the task.""" - start_date: datetime + start_date: AwareDatetime """Start date of the task instance.""" def __rich_repr__(self): @@ -144,7 +146,6 @@ def get_template_context(self) -> Context: validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False) - # TODO: Assess if we need to it through airflow.utils.timezone.coerce_datetime() context: Context = { # From the Task Execution interface "dag": self.task.dag, @@ -179,15 +180,17 @@ def get_template_context(self) -> Context: "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}", "task_reschedule_count": self._ti_context_from_server.task_reschedule_count or 0, "prev_start_date_success": lazy_object_proxy.Proxy( - lambda: get_previous_dagrun_success(self.id).start_date + lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date) ), "prev_end_date_success": lazy_object_proxy.Proxy( - lambda: get_previous_dagrun_success(self.id).end_date + lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date) ), } context.update(context_from_server) - if logical_date := dag_run.logical_date: + if logical_date := coerce_datetime(dag_run.logical_date): + if TYPE_CHECKING: + assert isinstance(logical_date, DateTime) ds = logical_date.strftime("%Y-%m-%d") ds_nodash = ds.replace("-", "") ts = logical_date.isoformat() @@ -205,13 +208,13 @@ def get_template_context(self) -> Context: "ts_nodash": ts_nodash, "ts_nodash_with_tz": ts_nodash_with_tz, # keys that depend on data_interval - "data_interval_end": dag_run.data_interval_end, - "data_interval_start": dag_run.data_interval_start, + "data_interval_end": coerce_datetime(dag_run.data_interval_end), + "data_interval_start": coerce_datetime(dag_run.data_interval_start), "prev_data_interval_start_success": lazy_object_proxy.Proxy( - lambda: get_previous_dagrun_success(self.id).data_interval_start + lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start) ), "prev_data_interval_end_success": lazy_object_proxy.Proxy( - lambda: get_previous_dagrun_success(self.id).data_interval_end + lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end) ), } ) @@ -368,7 +371,7 @@ def get_relevant_upstream_map_indexes( # TODO: Implement this method return None - def get_first_reschedule_date(self, context: Context) -> datetime | None: + def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: """Get the first reschedule date for the task instance if found, none otherwise.""" if context.get("task_reschedule_count", 0) == 0: # If the task has not been rescheduled, there is no need to ask the supervisor diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index cd31c4e71c88f..0f37c532660e3 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -27,7 +27,8 @@ if TYPE_CHECKING: from collections.abc import Iterator - from datetime import datetime + + from pydantic import AwareDatetime from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetRef, BaseAssetUniqueKey @@ -42,13 +43,13 @@ class DagRunProtocol(Protocol): dag_id: str run_id: str - logical_date: datetime | None - data_interval_start: datetime | None - data_interval_end: datetime | None - start_date: datetime - end_date: datetime | None + logical_date: AwareDatetime | None + data_interval_start: AwareDatetime | None + data_interval_end: AwareDatetime | None + start_date: AwareDatetime + end_date: AwareDatetime | None run_type: Any - run_after: datetime + run_after: AwareDatetime conf: dict[str, Any] | None @@ -64,7 +65,7 @@ class RuntimeTaskInstanceProtocol(Protocol): map_index: int | None max_tries: int hostname: str | None = None - start_date: datetime + start_date: AwareDatetime def xcom_pull( self, @@ -83,7 +84,7 @@ def xcom_push(self, key: str, value: Any) -> None: ... def get_template_context(self) -> Context: ... - def get_first_reschedule_date(self, first_try_number) -> datetime | None: ... + def get_first_reschedule_date(self, first_try_number) -> AwareDatetime | None: ... class OutletEventAccessorProtocol(Protocol, attrs.AttrsInstance): diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index 0a6e30946c95c..7402ff3e7ef3a 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -53,6 +53,7 @@ context_to_airflow_vars, set_current_context, ) +from airflow.utils import timezone def test_convert_connection_result_conn(): @@ -471,7 +472,7 @@ def test__get_item__(self, key, sample_inlet_evnets_accessor, mock_supervisor_co asset_event_resp = AssetEventResponse( id=1, created_dagruns=[], - timestamp=datetime.now(), + timestamp=timezone.utcnow(), asset=AssetResponse(name="test", uri="test", group="asset"), ) events_result = AssetEventsResult(asset_events=[asset_event_resp]) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index a41989d92791b..150b691fc179d 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -816,7 +816,7 @@ def test_run_with_asset_inlets(create_runtime_ti, mock_supervisor_comms): asset_event_resp = AssetEventResponse( id=1, created_dagruns=[], - timestamp=datetime.now(), + timestamp=timezone.utcnow(), asset=AssetResponse(name="test", uri="test", group="asset"), ) events_result = AssetEventsResult(asset_events=[asset_event_resp])