Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def get_count(
states: list[str] | None = None,
) -> TICount:
"""Get count of task instances matching the given criteria."""
params: dict[str, Any]
params = {
"dag_id": dag_id,
"task_ids": task_ids,
Expand All @@ -246,7 +247,7 @@ def get_count(
params = {k: v for k, v in params.items() if v is not None}

if map_index is not None and map_index >= 0:
params.update({"map_index": map_index}) # type: ignore[dict-item]
params.update({"map_index": map_index})

resp = self.client.get("task-instances/count", params=params)
return TICount(count=resp.json())
Expand All @@ -261,6 +262,7 @@ def get_task_states(
run_ids: list[str] | None = None,
) -> TaskStatesResponse:
"""Get task states given criteria."""
params: dict[str, Any]
params = {
"dag_id": dag_id,
"task_ids": task_ids,
Expand All @@ -273,7 +275,7 @@ def get_task_states(
params = {k: v for k, v in params.items() if v is not None}

if map_index is not None and map_index >= 0:
params.update({"map_index": map_index}) # type: ignore[dict-item]
params.update({"map_index": map_index})

resp = self.client.get("task-instances/states", params=params)
return TaskStatesResponse.model_validate_json(resp.read())
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/bases/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
("resources", coerce_resources),
):
if (v := partial_kwargs.get(fld, NOTSET)) is not NOTSET:
partial_kwargs[fld] = convert(v) # type: ignore[operator]
partial_kwargs[fld] = convert(v)

partial_kwargs.setdefault("executor_config", {})
partial_kwargs.setdefault("op_args", [])
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/bases/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ def add_outlets(self, outlets: Iterable[Any]):
def get_dag(self) -> DAG | None:
return self._dag

@property # type: ignore[override]
@property
def dag(self) -> DAG:
"""Returns the Operator's DAG if set, otherwise raises an error."""
if dag := self._dag:
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/bases/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def delete(
run_id=run_id,
map_index=map_index,
)
cls.purge(xcom_result) # type: ignore[call-arg]
cls.purge(xcom_result)
SUPERVISOR_COMMS.send(
DeleteXCom(
key=key,
Expand Down
4 changes: 2 additions & 2 deletions task-sdk/src/airflow/sdk/definitions/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_uri(self) -> str:
uri = f"{self.conn_type.lower().replace('_', '-')}://"
else:
uri = "//"

host_to_use: str | None
if self.host and "://" in self.host:
protocol, host = self.host.split("://", 1)
# If the protocol in host matches the connection type, don't add it again
Expand All @@ -84,7 +84,7 @@ def get_uri(self) -> str:
host_to_use = host
protocol_to_add = protocol
else:
host_to_use = self.host # type: ignore[assignment]
host_to_use = self.host
protocol_to_add = None

if protocol_to_add:
Expand Down
8 changes: 4 additions & 4 deletions task-sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def partial_subset(
# deep-copying self.task_dict and self.task_group takes a long time, and we don't want all
# the tasks anyway, so we copy the tasks manually later
memo = {id(self.task_dict): None, id(self.task_group): None}
dag = copy.deepcopy(self, memo) # type: ignore
dag = copy.deepcopy(self, memo)

if isinstance(task_ids, str):
matched_tasks = [t for t in self.tasks if task_ids in t.task_id]
Expand Down Expand Up @@ -935,8 +935,8 @@ def add_task(self, task: Operator) -> None:
) or task_id in self.task_group.used_group_ids:
raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG")
self.task_dict[task_id] = task
# TODO: Task-SDK: this type ignore shouldn't be needed!
task.dag = self # type: ignore[assignment]

task.dag = self
# Add task_id to used_group_ids to prevent group_id and task_id collisions.
self.task_group.used_group_ids.add(task_id)

Expand Down Expand Up @@ -1089,7 +1089,7 @@ def add_logger_if_needed(ti: TaskInstance):
dags=[self],
start_date=logical_date,
end_date=logical_date,
dag_run_state=False, # type: ignore
dag_run_state=False,
)

log.debug("Getting dagrun for dag %s", self.dag_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def initialize_context(...):
if isinstance(func, _TaskGroupFactory):
raise AirflowException("Task groups cannot be marked as setup or teardown.")
func = cast("_TaskDecorator", func)
func.is_setup = True # type: ignore[attr-defined] # TODO: Remove this once mypy is bump to 1.16.1
func.is_setup = True
return func


Expand All @@ -80,9 +80,8 @@ def teardown(func: Callable) -> Callable:
raise AirflowException("Task groups cannot be marked as setup or teardown.")
func = cast("_TaskDecorator", func)

# TODO: Remove below attr-defined once mypy is bump to 1.16.1
func.is_teardown = True # type: ignore[attr-defined]
func.on_failure_fail_dagrun = on_failure_fail_dagrun # type: ignore[attr-defined]
func.is_teardown = True
func.on_failure_fail_dagrun = on_failure_fail_dagrun
return func

if _func is None:
Expand Down
18 changes: 9 additions & 9 deletions task-sdk/src/airflow/sdk/definitions/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def task_display_name(self) -> str:
return self.partial_kwargs.get("task_display_name") or self.task_id

@property
def owner(self) -> str: # type: ignore[override]
def owner(self) -> str:
return self.partial_kwargs.get("owner", DEFAULT_OWNER)

@owner.setter
Expand Down Expand Up @@ -537,15 +537,15 @@ def retry_exponential_backoff(self, value: bool) -> None:
self.partial_kwargs["retry_exponential_backoff"] = value

@property
def priority_weight(self) -> int: # type: ignore[override]
def priority_weight(self) -> int:
return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT)

@priority_weight.setter
def priority_weight(self, value: int) -> None:
self.partial_kwargs["priority_weight"] = value

@property
def weight_rule(self) -> PriorityWeightStrategy: # type: ignore[override]
def weight_rule(self) -> PriorityWeightStrategy:
return validate_and_load_priority_weight_strategy(
self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
)
Expand Down Expand Up @@ -626,20 +626,20 @@ def executor(self) -> str | None:
def executor_config(self) -> dict:
return self.partial_kwargs.get("executor_config", {})

@property # type: ignore[override]
def inlets(self) -> list[Any]: # type: ignore[override]
@property
def inlets(self) -> list[Any]:
return self.partial_kwargs.get("inlets", [])

@inlets.setter
def inlets(self, value: list[Any]) -> None: # type: ignore[override]
def inlets(self, value: list[Any]) -> None:
self.partial_kwargs["inlets"] = value

@property # type: ignore[override]
def outlets(self) -> list[Any]: # type: ignore[override]
@property
def outlets(self) -> list[Any]:
return self.partial_kwargs.get("outlets", [])

@outlets.setter
def outlets(self, value: list[Any]) -> None: # type: ignore[override]
def outlets(self, value: list[Any]) -> None:
self.partial_kwargs["outlets"] = value

@property
Expand Down
3 changes: 1 addition & 2 deletions task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,11 @@ def _get_variable(key: str, deserialize_json: bool) -> Any:
# enabled only if SecretCache.init() has been called first
from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded

var_val = None
backends = ensure_secrets_backend_loaded()
# iterate over backends if not in cache (or expired)
for secrets_backend in backends:
try:
var_val = secrets_backend.get_variable(key=key) # type: ignore[assignment]
var_val = secrets_backend.get_variable(key=key)
if var_val is not None:
if deserialize_json:
import json
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int
if isinstance(item, Enum):
return self._redact(item=item.value, name=name, depth=depth, max_depth=max_depth)
if _is_v1_env_var(item) and hasattr(item, "to_dict"):
tmp: dict = item.to_dict() # type: ignore[attr-defined] # V1EnvVar has a to_dict method
tmp: dict = item.to_dict()
if should_hide_value_for_key(tmp.get("name", "")) and "value" in tmp:
tmp["value"] = "***"
else:
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,7 +1414,7 @@ def _api_client(dag=None):

client = Client(base_url=None, token="", dry_run=True, transport=api.transport)
# Mypy is wrong -- the setter accepts a string on the property setter! `URLType = URL | str`
client.base_url = "http://in-process.invalid./" # type: ignore[assignment]
client.base_url = "http://in-process.invalid./"
return client

def send_msg(
Expand Down
Loading