Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, NamedTuple

from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.utils.trigger_rule import TriggerRule

if TYPE_CHECKING:
from collections.abc import Sized

from airflow.models import DagRun
from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef


class AirflowException(Exception):
Expand Down Expand Up @@ -121,6 +121,8 @@ def __init__(self, inactive_asset_keys: Collection[AssetUniqueKey | AssetNameRef

@staticmethod
def _render_asset_key(key: AssetUniqueKey | AssetNameRef | AssetUriRef) -> str:
from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef

if isinstance(key, AssetUniqueKey):
return f"Asset(name={key.name!r}, uri={key.uri!r})"
elif isinstance(key, AssetNameRef):
Expand Down
11 changes: 10 additions & 1 deletion airflow/models/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from sqlalchemy.orm import relationship

from airflow.models.base import Base, StringID
from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.settings import json
from airflow.utils import timezone
from airflow.utils.sqlalchemy import UtcDateTime
Expand All @@ -47,6 +46,8 @@

from sqlalchemy.orm import Session

from airflow.sdk.definitions.asset import Asset, AssetAlias


def fetch_active_assets_by_name(names: Iterable[str], session: Session) -> dict[str, Asset]:
return {
Expand Down Expand Up @@ -187,12 +188,16 @@ def __hash__(self):
return hash(self.name)

def __eq__(self, other):
from airflow.sdk.definitions.asset import AssetAlias

if isinstance(other, (self.__class__, AssetAlias)):
return self.name == other.name
else:
return NotImplemented

def to_public(self) -> AssetAlias:
from airflow.sdk.definitions.asset import AssetAlias

return AssetAlias(name=self.name)


Expand Down Expand Up @@ -280,6 +285,8 @@ def __init__(self, name: str = "", uri: str = "", **kwargs):
super().__init__(name=name, uri=uri, **kwargs)

def __eq__(self, other):
from airflow.sdk.definitions.asset import Asset

if isinstance(other, (self.__class__, Asset)):
return self.name == other.name and self.uri == other.uri
return NotImplemented
Expand All @@ -291,6 +298,8 @@ def __repr__(self):
return f"{self.__class__.__name__}(name={self.name!r}, uri={self.uri!r}, extra={self.extra!r})"

def to_public(self) -> Asset:
from airflow.sdk.definitions.asset import Asset

return Asset(name=self.name, uri=self.uri, group=self.group, extra=self.extra)


Expand Down
2 changes: 1 addition & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@
from airflow.models.dag_version import DagVersion
from airflow.models.dagrun import RUN_ID_REGEX, DagRun
from airflow.models.taskinstance import (
Context,
TaskInstance,
TaskInstanceKey,
clear_task_instances,
Expand All @@ -105,6 +104,7 @@
OnceTimetable,
)
from airflow.utils import timezone
from airflow.utils.context import Context
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down
37 changes: 24 additions & 13 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,26 +104,13 @@
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sdk.definitions._internal.templater import SandboxedEnvironment
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.param import process_params
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.sdk.execution_time.context import InletEventsAccessors
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
from airflow.stats import Stats
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
from airflow.traces.tracer import Trace
from airflow.utils import timezone
from airflow.utils.context import (
ConnectionAccessor,
Context,
OutletEventAccessors,
VariableAccessor,
context_get_outlet_events,
context_merge,
)
from airflow.utils.email import send_email
from airflow.utils.helpers import prune_dict, render_template_to_string
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -160,9 +147,12 @@
from airflow.models.dagrun import DagRun
from airflow.sdk.api.datamodels._generated import AssetProfile
from airflow.sdk.definitions._internal.abstractoperator import Operator
from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.sdk.types import RuntimeTaskInstanceProtocol
from airflow.typing_compat import Literal
from airflow.utils.context import Context
from airflow.utils.task_group import TaskGroup


Expand Down Expand Up @@ -261,6 +251,8 @@ def _run_raw_task(

try:
if ti.task:
from airflow.sdk.definitions.asset import Asset

inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)]
outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)]
TaskInstance.validate_inlet_outlet_assets_activeness(inlets, outlets, session=session)
Expand Down Expand Up @@ -678,6 +670,8 @@ def _execute_task(task_instance: TaskInstance, context: Context, task_orig: Oper
)

def _execute_callable(context: Context, **execute_callable_kwargs):
from airflow.utils.context import context_get_outlet_events

try:
# Print a marker for log grouping of details before task execution
log.info("::endgroup::")
Expand Down Expand Up @@ -903,6 +897,13 @@ def _get_template_context(
PrevSuccessfulDagRunResponse,
TIRunContext,
)
from airflow.sdk.definitions.param import process_params
from airflow.sdk.execution_time.context import InletEventsAccessors
from airflow.utils.context import (
ConnectionAccessor,
OutletEventAccessors,
VariableAccessor,
)

integrate_macros_plugins()

Expand Down Expand Up @@ -1347,6 +1348,9 @@ def _get_email_subject_content(
html_content_err = jinja_env.from_string(default_html_content_err).render(**default_context)

else:
from airflow.sdk.definitions._internal.templater import SandboxedEnvironment
from airflow.utils.context import context_merge

if TYPE_CHECKING:
assert task_instance.task

Expand Down Expand Up @@ -2736,6 +2740,8 @@ def register_asset_changes_in_db(
outlet_events: list[dict[str, Any]],
session: Session = NEW_SESSION,
) -> None:
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef

asset_keys = {
AssetUniqueKey(o.name, o.uri)
for o in task_outlets
Expand Down Expand Up @@ -3682,6 +3688,8 @@ def duration_expression_update(
def validate_inlet_outlet_assets_activeness(
inlets: list[AssetProfile], outlets: list[AssetProfile], session: Session
) -> None:
from airflow.sdk.definitions.asset import AssetUniqueKey

if not (inlets or outlets):
return

Expand All @@ -3699,6 +3707,8 @@ def validate_inlet_outlet_assets_activeness(
def _get_inactive_asset_unique_keys(
asset_unique_keys: set[AssetUniqueKey], session: Session
) -> set[AssetUniqueKey]:
from airflow.sdk.definitions.asset import AssetUniqueKey

active_asset_unique_keys = {
AssetUniqueKey(name, uri)
for name, uri in session.execute(
Expand All @@ -3724,6 +3734,7 @@ def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> Mapp
def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool:
"""Whether given operator is *further* mapped inside a task group."""
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.taskgroup import MappedTaskGroup

if isinstance(operator, MappedOperator):
return True
Expand Down
3 changes: 2 additions & 1 deletion airflow/utils/operator_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar

from airflow import settings
from airflow.sdk.definitions.asset.metadata import Metadata
from airflow.typing_compat import ParamSpec
from airflow.utils.types import NOTSET

Expand Down Expand Up @@ -257,6 +256,8 @@ def ExecutionCallableRunner(
class _ExecutionCallableRunnerImpl:
@staticmethod
def run(*args: P.args, **kwargs: P.kwargs) -> R:
from airflow.sdk.definitions.asset.metadata import Metadata

if not inspect.isgeneratorfunction(func):
return func(*args, **kwargs)

Expand Down