Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions airflow-core/tests/unit/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2715,7 +2715,7 @@ def test_inlet_asset_extra(self, dag_maker, session, mock_supervisor_comms):
AssetEventResponse(
id=1,
created_dagruns=[],
timestamp=datetime.datetime.now(),
timestamp=timezone.utcnow(),
extra={"from": f"write{i}"},
asset=AssetResponse(
name="test_inlet_asset_extra", uri="test_inlet_asset_extra", group="asset"
Expand Down Expand Up @@ -2791,7 +2791,7 @@ def test_inlet_asset_alias_extra(self, dag_maker, session, mock_supervisor_comms
AssetEventResponse(
id=1,
created_dagruns=[],
timestamp=datetime.datetime.now(),
timestamp=timezone.utcnow(),
extra={"from": f"write{i}"},
asset=AssetResponse(
name="test_inlet_asset_extra_ds", uri="test_inlet_asset_extra_ds", group="asset"
Expand Down Expand Up @@ -2914,7 +2914,7 @@ def test_inlet_asset_extra_slice(self, dag_maker, session, slicer, expected, moc
AssetEventResponse(
id=1,
created_dagruns=[],
timestamp=datetime.datetime.now(),
timestamp=timezone.utcnow(),
extra={"from": i},
asset=AssetResponse(name=asset_uri, uri=asset_uri, group="asset"),
)
Expand Down Expand Up @@ -2981,7 +2981,7 @@ def test_inlet_asset_alias_extra_slice(self, dag_maker, session, slicer, expecte
AssetEventResponse(
id=1,
created_dagruns=[],
timestamp=datetime.datetime.now(),
timestamp=timezone.utcnow(),
extra={"from": i},
asset=AssetResponse(name=asset_uri, uri=asset_uri, group="asset"),
)
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ def choose_branch(self, context: Context) -> str | Iterable[str]:
now = context.get("logical_date")
if not now:
dag_run = context.get("dag_run")
now = dag_run.run_after # type: ignore[union-attr]
now = dag_run.run_after # type: ignore[union-attr, assignment]
else:
now = timezone.coerce_datetime(timezone.utcnow())
if TYPE_CHECKING:
assert isinstance(now, datetime.datetime)
lower, upper = target_times_as_dates(now, self.target_lower, self.target_upper)
lower = timezone.coerce_datetime(lower, self.dag.timezone)
upper = timezone.coerce_datetime(upper, self.dag.timezone)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def choose_branch(self, context: Context) -> str | Iterable[str]:
now = context.get("logical_date")
if not now:
dag_run = context.get("dag_run")
now = dag_run.run_after # type: ignore[union-attr]
now = dag_run.run_after # type: ignore[union-attr, assignment]
else:
now = timezone.make_naive(timezone.utcnow(), self.dag.timezone)

Expand Down
2 changes: 1 addition & 1 deletion task-sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ enable-version-header=true
enum-field-as-literal='one' # When a single enum member, make it output a `Literal["..."]`
input-file-type='openapi'
output-model-type='pydantic_v2.BaseModel'
output-datetime-class='datetime'
output-datetime-class='AwareDatetime'
target-python-version='3.9'
use-annotated=true
use-default=true
Expand Down
50 changes: 25 additions & 25 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
# under the License.
from __future__ import annotations

from datetime import datetime, timedelta
from datetime import timedelta
from enum import Enum
from typing import Annotated, Any, Final, Literal
from uuid import UUID

from pydantic import BaseModel, ConfigDict, Field, JsonValue
from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue

API_VERSION: Final[str] = "2025-03-26"

Expand Down Expand Up @@ -88,12 +88,12 @@ class DagRunAssetReference(BaseModel):
)
run_id: Annotated[str, Field(title="Run Id")]
dag_id: Annotated[str, Field(title="Dag Id")]
logical_date: Annotated[datetime | None, Field(title="Logical Date")] = None
start_date: Annotated[datetime, Field(title="Start Date")]
end_date: Annotated[datetime | None, Field(title="End Date")] = None
logical_date: Annotated[AwareDatetime | None, Field(title="Logical Date")] = None
start_date: Annotated[AwareDatetime, Field(title="Start Date")]
end_date: Annotated[AwareDatetime | None, Field(title="End Date")] = None
state: Annotated[str, Field(title="State")]
data_interval_start: Annotated[datetime | None, Field(title="Data Interval Start")] = None
data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None
data_interval_start: Annotated[AwareDatetime | None, Field(title="Data Interval Start")] = None
data_interval_end: Annotated[AwareDatetime | None, Field(title="Data Interval End")] = None


class DagRunState(str, Enum):
Expand Down Expand Up @@ -149,10 +149,10 @@ class PrevSuccessfulDagRunResponse(BaseModel):
Schema for response with previous successful DagRun information for Task Template Context.
"""

data_interval_start: Annotated[datetime | None, Field(title="Data Interval Start")] = None
data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None
start_date: Annotated[datetime | None, Field(title="Start Date")] = None
end_date: Annotated[datetime | None, Field(title="End Date")] = None
data_interval_start: Annotated[AwareDatetime | None, Field(title="Data Interval Start")] = None
data_interval_end: Annotated[AwareDatetime | None, Field(title="Data Interval End")] = None
start_date: Annotated[AwareDatetime | None, Field(title="Start Date")] = None
end_date: Annotated[AwareDatetime | None, Field(title="End Date")] = None


class TIDeferredStatePayload(BaseModel):
Expand Down Expand Up @@ -183,7 +183,7 @@ class TIEnterRunningPayload(BaseModel):
hostname: Annotated[str, Field(title="Hostname")]
unixname: Annotated[str, Field(title="Unixname")]
pid: Annotated[int, Field(title="Pid")]
start_date: Annotated[datetime, Field(title="Start Date")]
start_date: Annotated[AwareDatetime, Field(title="Start Date")]


class TIHeartbeatInfo(BaseModel):
Expand All @@ -207,8 +207,8 @@ class TIRescheduleStatePayload(BaseModel):
extra="forbid",
)
state: Annotated[Literal["up_for_reschedule"] | None, Field(title="State")] = "up_for_reschedule"
reschedule_date: Annotated[datetime, Field(title="Reschedule Date")]
end_date: Annotated[datetime, Field(title="End Date")]
reschedule_date: Annotated[AwareDatetime, Field(title="Reschedule Date")]
end_date: Annotated[AwareDatetime, Field(title="End Date")]


class TIRetryStatePayload(BaseModel):
Expand All @@ -220,7 +220,7 @@ class TIRetryStatePayload(BaseModel):
extra="forbid",
)
state: Annotated[Literal["up_for_retry"] | None, Field(title="State")] = "up_for_retry"
end_date: Annotated[datetime, Field(title="End Date")]
end_date: Annotated[AwareDatetime, Field(title="End Date")]


class TISkippedDownstreamTasksStatePayload(BaseModel):
Expand All @@ -243,7 +243,7 @@ class TISuccessStatePayload(BaseModel):
extra="forbid",
)
state: Annotated[Literal["success"] | None, Field(title="State")] = "success"
end_date: Annotated[datetime, Field(title="End Date")]
end_date: Annotated[AwareDatetime, Field(title="End Date")]
task_outlets: Annotated[list[AssetProfile] | None, Field(title="Task Outlets")] = None
outlet_events: Annotated[list[dict[str, Any]] | None, Field(title="Outlet Events")] = None

Expand Down Expand Up @@ -277,7 +277,7 @@ class TriggerDAGRunPayload(BaseModel):
model_config = ConfigDict(
extra="forbid",
)
logical_date: Annotated[datetime | None, Field(title="Logical Date")] = None
logical_date: Annotated[AwareDatetime | None, Field(title="Logical Date")] = None
conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None
reset_dag_run: Annotated[bool | None, Field(title="Reset Dag Run")] = False

Expand Down Expand Up @@ -360,7 +360,7 @@ class AssetEventResponse(BaseModel):
"""

id: Annotated[int, Field(title="Id")]
timestamp: Annotated[datetime, Field(title="Timestamp")]
timestamp: Annotated[AwareDatetime, Field(title="Timestamp")]
extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None
asset: AssetResponse
created_dagruns: Annotated[list[DagRunAssetReference], Field(title="Created Dagruns")]
Expand Down Expand Up @@ -388,12 +388,12 @@ class DagRun(BaseModel):
)
dag_id: Annotated[str, Field(title="Dag Id")]
run_id: Annotated[str, Field(title="Run Id")]
logical_date: Annotated[datetime | None, Field(title="Logical Date")] = None
data_interval_start: Annotated[datetime | None, Field(title="Data Interval Start")] = None
data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None
run_after: Annotated[datetime, Field(title="Run After")]
start_date: Annotated[datetime, Field(title="Start Date")]
end_date: Annotated[datetime | None, Field(title="End Date")] = None
logical_date: Annotated[AwareDatetime | None, Field(title="Logical Date")] = None
data_interval_start: Annotated[AwareDatetime | None, Field(title="Data Interval Start")] = None
data_interval_end: Annotated[AwareDatetime | None, Field(title="Data Interval End")] = None
run_after: Annotated[AwareDatetime, Field(title="Run After")]
start_date: Annotated[AwareDatetime, Field(title="Start Date")]
end_date: Annotated[AwareDatetime | None, Field(title="End Date")] = None
clear_number: Annotated[int | None, Field(title="Clear Number")] = 0
run_type: DagRunType
conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None
Expand Down Expand Up @@ -429,4 +429,4 @@ class TITerminalStatePayload(BaseModel):
extra="forbid",
)
state: TerminalStateNonSuccess
end_date: Annotated[datetime, Field(title="End Date")]
end_date: Annotated[AwareDatetime, Field(title="End Date")]
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/bases/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
if TYPE_CHECKING:
import jinja2

from airflow import DAG
from airflow.sdk import DAG
from airflow.sdk.definitions.context import Context


Expand Down
19 changes: 9 additions & 10 deletions task-sdk/src/airflow/sdk/definitions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from typing import TYPE_CHECKING, Any, NamedTuple, TypedDict

if TYPE_CHECKING:
# TODO: Should we use pendulum.DateTime instead of datetime like AF 2.x?
from datetime import datetime
from pendulum import DateTime

from airflow.models.operator import Operator
from airflow.sdk.bases.operator import BaseOperator
Expand All @@ -41,27 +40,27 @@ class Context(TypedDict, total=False):
conn: Any
dag: DAG
dag_run: DagRunProtocol
data_interval_end: datetime | None
data_interval_start: datetime | None
data_interval_end: DateTime | None
data_interval_start: DateTime | None
outlet_events: OutletEventAccessorsProtocol
ds: str
ds_nodash: str
expanded_ti_count: int | None
exception: None | str | BaseException
inlets: list
inlet_events: InletEventsAccessors
logical_date: datetime
logical_date: DateTime
macros: Any
map_index_template: str | None
outlets: list
params: dict[str, Any]
prev_data_interval_start_success: datetime | None
prev_data_interval_end_success: datetime | None
prev_start_date_success: datetime | None
prev_end_date_success: datetime | None
prev_data_interval_start_success: DateTime | None
prev_data_interval_end_success: DateTime | None
prev_start_date_success: DateTime | None
prev_end_date_success: DateTime | None
reason: str | None
run_id: str
start_date: datetime
start_date: DateTime
# TODO: Remove Operator from below once we have MappedOperator to the Task SDK
# and once we can remove context related code from the Scheduler/models.TaskInstance
task: BaseOperator | Operator
Expand Down
4 changes: 2 additions & 2 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from uuid import UUID

from fastapi import Body
from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_serializer
from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, field_serializer

from airflow.sdk.api.datamodels._generated import (
AssetEventsResponse,
Expand Down Expand Up @@ -218,7 +218,7 @@ def from_dagrun_response(cls, prev_dag_run: PrevSuccessfulDagRunResponse) -> Pre
class TaskRescheduleStartDate(BaseModel):
"""Response containing the first reschedule date for a task instance."""

start_date: datetime | None
start_date: AwareDatetime | None
type: Literal["TaskRescheduleStartDate"] = "TaskRescheduleStartDate"


Expand Down
25 changes: 14 additions & 11 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import attrs
import lazy_object_proxy
import structlog
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter
from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter

from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock
from airflow.dag_processing.bundles.manager import DagBundlesManager
Expand Down Expand Up @@ -87,9 +87,11 @@
from airflow.sdk.execution_time.xcom import XCom
from airflow.utils.net import get_hostname
from airflow.utils.state import TaskInstanceState
from airflow.utils.timezone import coerce_datetime

if TYPE_CHECKING:
import jinja2
from pendulum.datetime import DateTime
from structlog.typing import FilteringBoundLogger as Logger

from airflow.exceptions import DagRunTriggerException
Expand All @@ -116,7 +118,7 @@ class RuntimeTaskInstance(TaskInstance):
max_tries: int = 0
"""The maximum number of retries for the task."""

start_date: datetime
start_date: AwareDatetime
"""Start date of the task instance."""

def __rich_repr__(self):
Expand Down Expand Up @@ -144,7 +146,6 @@ def get_template_context(self) -> Context:

validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False)

# TODO: Assess if we need to it through airflow.utils.timezone.coerce_datetime()
context: Context = {
# From the Task Execution interface
"dag": self.task.dag,
Expand Down Expand Up @@ -179,15 +180,17 @@ def get_template_context(self) -> Context:
"task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}",
"task_reschedule_count": self._ti_context_from_server.task_reschedule_count or 0,
"prev_start_date_success": lazy_object_proxy.Proxy(
lambda: get_previous_dagrun_success(self.id).start_date
lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date)
),
"prev_end_date_success": lazy_object_proxy.Proxy(
lambda: get_previous_dagrun_success(self.id).end_date
lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date)
),
}
context.update(context_from_server)

if logical_date := dag_run.logical_date:
if logical_date := coerce_datetime(dag_run.logical_date):
if TYPE_CHECKING:
assert isinstance(logical_date, DateTime)
ds = logical_date.strftime("%Y-%m-%d")
ds_nodash = ds.replace("-", "")
ts = logical_date.isoformat()
Expand All @@ -205,13 +208,13 @@ def get_template_context(self) -> Context:
"ts_nodash": ts_nodash,
"ts_nodash_with_tz": ts_nodash_with_tz,
# keys that depend on data_interval
"data_interval_end": dag_run.data_interval_end,
"data_interval_start": dag_run.data_interval_start,
"data_interval_end": coerce_datetime(dag_run.data_interval_end),
"data_interval_start": coerce_datetime(dag_run.data_interval_start),
"prev_data_interval_start_success": lazy_object_proxy.Proxy(
lambda: get_previous_dagrun_success(self.id).data_interval_start
lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start)
),
"prev_data_interval_end_success": lazy_object_proxy.Proxy(
lambda: get_previous_dagrun_success(self.id).data_interval_end
lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end)
),
}
)
Expand Down Expand Up @@ -368,7 +371,7 @@ def get_relevant_upstream_map_indexes(
# TODO: Implement this method
return None

def get_first_reschedule_date(self, context: Context) -> datetime | None:
def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None:
"""Get the first reschedule date for the task instance if found, none otherwise."""
if context.get("task_reschedule_count", 0) == 0:
# If the task has not been rescheduled, there is no need to ask the supervisor
Expand Down
Loading