Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions airflow-core/src/airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,7 @@ def trigger_tasks(self, open_slots: int) -> None:

# If it's None, then the span for the current TaskInstanceKey hasn't been started.
if self.active_spans is not None and self.active_spans.get(key) is None:
from airflow.models.taskinstance import SimpleTaskInstance

if isinstance(ti, (SimpleTaskInstance, workloads.TaskInstance)):
if isinstance(ti, workloads.TaskInstance):
parent_context = Trace.extract(ti.parent_context_carrier)
else:
parent_context = Trace.extract(ti.dag_run.context_carrier)
Expand Down
91 changes: 0 additions & 91 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,97 +2510,6 @@ def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool:
return False


# State of the task instance.
# Stores string version of the task state.
TaskInstanceStateType = tuple[TaskInstanceKey, TaskInstanceState]


class SimpleTaskInstance:
"""
Simplified Task Instance.

Used to send data between processes via Queues.
"""

def __init__(
self,
dag_id: str,
task_id: str,
run_id: str,
queued_dttm: datetime | None,
start_date: datetime | None,
end_date: datetime | None,
try_number: int,
map_index: int,
state: str,
executor: str | None,
executor_config: Any,
pool: str,
queue: str,
key: TaskInstanceKey,
run_as_user: str | None = None,
priority_weight: int | None = None,
parent_context_carrier: dict | None = None,
context_carrier: dict | None = None,
span_status: str | None = None,
):
self.dag_id = dag_id
self.task_id = task_id
self.run_id = run_id
self.map_index = map_index
self.queued_dttm = queued_dttm
self.start_date = start_date
self.end_date = end_date
self.try_number = try_number
self.state = state
self.executor = executor
self.executor_config = executor_config
self.run_as_user = run_as_user
self.pool = pool
self.priority_weight = priority_weight
self.queue = queue
self.key = key
self.parent_context_carrier = parent_context_carrier
self.context_carrier = context_carrier
self.span_status = span_status

def __repr__(self) -> str:
attrs = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items())
return f"SimpleTaskInstance({attrs})"

def __eq__(self, other) -> bool:
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return NotImplemented

@classmethod
def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance:
return cls(
dag_id=ti.dag_id,
task_id=ti.task_id,
run_id=ti.run_id,
map_index=ti.map_index,
queued_dttm=ti.queued_dttm,
start_date=ti.start_date,
end_date=ti.end_date,
try_number=ti.try_number,
state=ti.state,
executor=ti.executor,
executor_config=ti.executor_config,
pool=ti.pool,
queue=ti.queue,
key=ti.key,
run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None,
priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None,
# Inspect the ti, to check if the 'dag_run' relationship is loaded.
parent_context_carrier=ti.dag_run.context_carrier
if "dag_run" not in inspect(ti).unloaded
else None,
context_carrier=ti.context_carrier if hasattr(ti, "context_carrier") else None,
span_status=ti.span_status,
)


class TaskInstanceNote(Base):
"""For storage of arbitrary notes concerning the task instance."""

Expand Down
1 change: 0 additions & 1 deletion airflow-core/src/airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ class DagAttributeTypes(str, Enum):
ASSET_REF = "asset_ref"
ASSET_UNIQUE_KEY = "asset_unique_key"
ASSET_ALIAS_UNIQUE_KEY = "asset_alias_unique_key"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
BASE_JOB = "Job"
TASK_INSTANCE = "task_instance"
DAG_RUN = "dag_run"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from airflow.models.expandinput import (
create_expand_input,
)
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.xcom import XComModel
from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg
Expand Down Expand Up @@ -101,6 +100,7 @@

from airflow.models import DagRun
from airflow.models.expandinput import SchedulerExpandInput
from airflow.models.taskinstance import TaskInstance
from airflow.sdk import BaseOperatorLink
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.types import Operator
Expand Down Expand Up @@ -822,11 +822,6 @@ def serialize(
return cls._encode(serialized_asset, type_=serialized_asset.pop("__type"))
elif isinstance(var, AssetRef):
return cls._encode(attrs.asdict(var), type_=DAT.ASSET_REF)
elif isinstance(var, SimpleTaskInstance):
return cls._encode(
cls.serialize(var.__dict__, strict=strict),
type_=DAT.SIMPLE_TASK_INSTANCE,
)
elif isinstance(var, Connection):
return cls._encode(var.to_dict(validate=True), type_=DAT.CONNECTION)
elif isinstance(var, TaskCallbackRequest):
Expand Down Expand Up @@ -939,8 +934,6 @@ def deserialize(cls, encoded_var: Any) -> Any:
return AssetAll(*(decode_asset_condition(x) for x in var["objects"]))
elif type_ == DAT.ASSET_REF:
return Asset.ref(**var)
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
elif type_ == DAT.CONNECTION:
return Connection(**var)
elif type_ == DAT.TASK_CALLBACK_REQUEST:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from airflow.models.connection import Connection
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.taskinstance import TaskInstance
from airflow.models.xcom_arg import XComArg
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.empty import EmptyOperator
Expand Down Expand Up @@ -327,7 +327,6 @@ def __len__(self) -> int:
DAT.ASSET,
equals,
),
(SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals),
(
Connection(conn_id="TEST_ID", uri="mysql://"),
DAT.CONNECTION,
Expand Down