Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[airflow] Extend airflow context parameter check for BaseOperator.execute (AIR302) #15713

Merged
merged 9 commits into from
Jan 27, 2025
119 changes: 72 additions & 47 deletions crates/ruff_linter/resources/test/fixtures/airflow/AIR302_context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from datetime import datetime

import pendulum

from airflow.decorators import dag, task
from airflow.models import DAG
from airflow.models.baseoperator import BaseOperator
Expand All @@ -13,30 +14,22 @@

def access_invalid_key_in_context(**context):
print("access invalid key", context["conf"])
print("access invalid key", context.get("conf"))


@task
def access_invalid_key_task_out_of_dag(**context):
print("access invalid key", context["conf"])
print("access invalid key", context.get("conf"))

@dag(
schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
tags=[""],
)
def invalid_dag():
@task()
def access_invalid_key_task(**context):
print("access invalid key", context.get("conf"))

task1 = PythonOperator(
task_id="task1",
python_callable=access_invalid_key_in_context,
)
access_invalid_key_task() >> task1
access_invalid_key_task_out_of_dag()
@task
def access_invalid_argument_task_out_of_dag(
execution_date, tomorrow_ds, logical_date, **context
):
print("execution date", execution_date)
print("access invalid key", context.get("conf"))

invalid_dag()

@task
def print_config(**context):
Expand All @@ -56,6 +49,63 @@ def print_config(**context):
yesterday_ds = context["yesterday_ds"]
yesterday_ds_nodash = context["yesterday_ds_nodash"]


@task
def print_config_with_get_current_context():
context = get_current_context()
execution_date = context["execution_date"]
next_ds = context["next_ds"]
next_ds_nodash = context["next_ds_nodash"]
next_execution_date = context["next_execution_date"]
prev_ds = context["prev_ds"]
prev_ds_nodash = context["prev_ds_nodash"]
prev_execution_date = context["prev_execution_date"]
prev_execution_date_success = context["prev_execution_date_success"]
tomorrow_ds = context["tomorrow_ds"]
yesterday_ds = context["yesterday_ds"]
yesterday_ds_nodash = context["yesterday_ds_nodash"]


@task(task_id="print_the_context")
def print_context(ds=None, **kwargs):
"""Print the Airflow context and ds variable from the context."""
print(ds)
print(kwargs.get("tomorrow_ds"))
c = get_current_context()
c.get("execution_date")


@dag(
schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
tags=[""],
)
def invalid_dag():
@task()
def access_invalid_key_task(**context):
print("access invalid key", context.get("conf"))

@task()
def access_invalid_key_explicit_task(execution_date):
print(execution_date)

task1 = PythonOperator(
task_id="task1",
python_callable=access_invalid_key_in_context,
)

access_invalid_key_task() >> task1
access_invalid_key_explicit_task()
access_invalid_argument_task_out_of_dag()
access_invalid_key_task_out_of_dag()
print_config()
print_config_with_get_current_context()
print_context()


invalid_dag()

with DAG(
dag_id="example_dag",
schedule_interval="@daily",
Expand All @@ -68,34 +118,21 @@ def print_config(**context):
# Removed variables in template
"execution_date": "{{ execution_date }}",
"next_ds": "{{ next_ds }}",
"prev_ds": "{{ prev_ds }}"
"prev_ds": "{{ prev_ds }}",
},
)


class CustomMacrosPlugin(AirflowPlugin):
name = "custom_macros"
macros = {
"execution_date_macro": lambda context: context["execution_date"],
"next_ds_macro": lambda context: context["next_ds"]
"next_ds_macro": lambda context: context["next_ds"],
}

@task
def print_config():
context = get_current_context()
execution_date = context["execution_date"]
next_ds = context["next_ds"]
next_ds_nodash = context["next_ds_nodash"]
next_execution_date = context["next_execution_date"]
prev_ds = context["prev_ds"]
prev_ds_nodash = context["prev_ds_nodash"]
prev_execution_date = context["prev_execution_date"]
prev_execution_date_success = context["prev_execution_date_success"]
tomorrow_ds = context["tomorrow_ds"]
yesterday_ds = context["yesterday_ds"]
yesterday_ds_nodash = context["yesterday_ds_nodash"]

class CustomOperator(BaseOperator):
def execute(self, context):
def execute(self, next_ds, context):
execution_date = context["execution_date"]
next_ds = context["next_ds"]
next_ds_nodash = context["next_ds_nodash"]
Expand All @@ -108,18 +145,6 @@ def execute(self, context):
yesterday_ds = context["yesterday_ds"]
yesterday_ds_nodash = context["yesterday_ds_nodash"]

@task
def access_invalid_argument_task_out_of_dag(execution_date, tomorrow_ds, logical_date, **context):
print("execution date", execution_date)
print("access invalid key", context.get("conf"))

@task(task_id="print_the_context")
def print_context(ds=None, **kwargs):
"""Print the Airflow context and ds variable from the context."""
print(ds)
print(kwargs.get("tomorrow_ds"))
c = get_current_context()
c.get("execution_date")

class CustomOperatorNew(BaseOperator):
def execute(self, context):
Expand Down
10 changes: 5 additions & 5 deletions crates/ruff_linter/src/checkers/ast/analyze/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ pub(crate) fn expression(expr: &Expr, checker: &mut Checker) {
pyupgrade::rules::use_pep646_unpack(checker, subscript);
}
if checker.enabled(Rule::Airflow3Removal) {
airflow::rules::removed_in_3(checker, expr);
airflow::rules::airflow_3_removal_expr(checker, expr);
}
pandas_vet::rules::subscript(checker, value, expr);
}
Expand Down Expand Up @@ -227,7 +227,7 @@ pub(crate) fn expression(expr: &Expr, checker: &mut Checker) {
refurb::rules::regex_flag_alias(checker, expr);
}
if checker.enabled(Rule::Airflow3Removal) {
airflow::rules::removed_in_3(checker, expr);
airflow::rules::airflow_3_removal_expr(checker, expr);
}
if checker.enabled(Rule::Airflow3MovedToProvider) {
airflow::rules::moved_to_provider_in_3(checker, expr);
Expand Down Expand Up @@ -311,7 +311,7 @@ pub(crate) fn expression(expr: &Expr, checker: &mut Checker) {
}
}
if checker.enabled(Rule::Airflow3Removal) {
airflow::rules::removed_in_3(checker, expr);
airflow::rules::airflow_3_removal_expr(checker, expr);
}
if checker.enabled(Rule::MixedCaseVariableInGlobalScope) {
if matches!(checker.semantic.current_scope().kind, ScopeKind::Module) {
Expand Down Expand Up @@ -449,7 +449,7 @@ pub(crate) fn expression(expr: &Expr, checker: &mut Checker) {
flake8_pyi::rules::bytestring_attribute(checker, expr);
}
if checker.enabled(Rule::Airflow3Removal) {
airflow::rules::removed_in_3(checker, expr);
airflow::rules::airflow_3_removal_expr(checker, expr);
}
}
Expr::Call(
Expand Down Expand Up @@ -1150,7 +1150,7 @@ pub(crate) fn expression(expr: &Expr, checker: &mut Checker) {
ruff::rules::unnecessary_regular_expression(checker, call);
}
if checker.enabled(Rule::Airflow3Removal) {
airflow::rules::removed_in_3(checker, expr);
airflow::rules::airflow_3_removal_expr(checker, expr);
}
if checker.enabled(Rule::UnnecessaryCastToInt) {
ruff::rules::unnecessary_cast_to_int(checker, call);
Expand Down
2 changes: 1 addition & 1 deletion crates/ruff_linter/src/checkers/ast/analyze/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) {
flake8_pytest_style::rules::parameter_with_default_argument(checker, function_def);
}
if checker.enabled(Rule::Airflow3Removal) {
airflow::rules::removed_in_3_function_def(checker, function_def);
airflow::rules::airflow_3_removal_function_def(checker, function_def);
}
if checker.enabled(Rule::NonPEP695GenericFunction) {
pyupgrade::rules::non_pep695_generic_function(checker, function_def);
Expand Down
53 changes: 45 additions & 8 deletions crates/ruff_linter/src/rules/airflow/rules/removal_in_3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ enum Replacement {
}

/// AIR302
pub(crate) fn removed_in_3(checker: &mut Checker, expr: &Expr) {
pub(crate) fn airflow_3_removal_expr(checker: &mut Checker, expr: &Expr) {
if !checker.semantic().seen_module(Modules::AIRFLOW) {
return;
}
Expand Down Expand Up @@ -117,7 +117,10 @@ pub(crate) fn removed_in_3(checker: &mut Checker, expr: &Expr) {
}

/// AIR302
pub(crate) fn removed_in_3_function_def(checker: &mut Checker, function_def: &StmtFunctionDef) {
pub(crate) fn airflow_3_removal_function_def(
checker: &mut Checker,
function_def: &StmtFunctionDef,
) {
if !checker.semantic().seen_module(Modules::AIRFLOW) {
return;
}
Expand Down Expand Up @@ -154,7 +157,9 @@ const REMOVED_CONTEXT_KEYS: [&str; 12] = [
/// pass
/// ```
fn check_function_parameters(checker: &mut Checker, function_def: &StmtFunctionDef) {
if !is_airflow_task(function_def, checker.semantic()) {
if !is_airflow_task_function_def(function_def, checker.semantic())
&& !is_execute_method_inherits_from_airflow_operator(function_def, checker.semantic())
{
return;
}

Expand Down Expand Up @@ -346,7 +351,7 @@ fn check_class_attribute(checker: &mut Checker, attribute_expr: &ExprAttribute)
/// context.get("conf") # 'conf' is removed in Airflow 3.0
/// ```
fn check_context_key_usage_in_call(checker: &mut Checker, call_expr: &ExprCall) {
if !in_airflow_task_function(checker.semantic()) {
if !in_airflow_task_function_def(checker.semantic()) {
return;
}

Expand Down Expand Up @@ -395,7 +400,7 @@ fn check_context_key_usage_in_call(checker: &mut Checker, call_expr: &ExprCall)
/// Check if a subscript expression accesses a removed Airflow context variable.
/// If a removed key is found, push a corresponding diagnostic.
fn check_context_key_usage_in_subscript(checker: &mut Checker, subscript: &ExprSubscript) {
if !in_airflow_task_function(checker.semantic()) {
if !in_airflow_task_function_def(checker.semantic()) {
return;
}

Expand Down Expand Up @@ -1059,15 +1064,15 @@ fn is_airflow_builtin_or_provider(segments: &[&str], module: &str, symbol_suffix

/// Returns `true` if the current statement hierarchy has a function that's decorated with
/// `@airflow.decorators.task`.
fn in_airflow_task_function(semantic: &SemanticModel) -> bool {
fn in_airflow_task_function_def(semantic: &SemanticModel) -> bool {
semantic
.current_statements()
.find_map(|stmt| stmt.as_function_def_stmt())
.is_some_and(|function_def| is_airflow_task(function_def, semantic))
.is_some_and(|function_def| is_airflow_task_function_def(function_def, semantic))
}

/// Returns `true` if the given function is decorated with `@airflow.decorators.task`.
fn is_airflow_task(function_def: &StmtFunctionDef, semantic: &SemanticModel) -> bool {
fn is_airflow_task_function_def(function_def: &StmtFunctionDef, semantic: &SemanticModel) -> bool {
Lee-W marked this conversation as resolved.
Show resolved Hide resolved
function_def.decorator_list.iter().any(|decorator| {
semantic
.resolve_qualified_name(map_callable(&decorator.expression))
Expand All @@ -1076,3 +1081,35 @@ fn is_airflow_task(function_def: &StmtFunctionDef, semantic: &SemanticModel) ->
})
})
}

/// Check it's "execute" method inherits from Airflow base operator
///
/// For example:
///
/// ```python
/// from airflow.models.baseoperator import BaseOperator
///
/// class CustomOperator(BaseOperator):
/// def execute(self):
/// pass
/// ```
fn is_execute_method_inherits_from_airflow_operator(
function_def: &StmtFunctionDef,
semantic: &SemanticModel,
) -> bool {
if function_def.name.as_str() != "execute" {
return false;
}

let ScopeKind::Class(class_def) = semantic.current_scope().kind else {
return false;
};

class_def.bases().iter().any(|class_base| {
semantic
.resolve_qualified_name(class_base)
.is_some_and(|qualified_name| {
matches!(qualified_name.segments(), ["airflow", .., "BaseOperator"])
})
})
}
Loading
Loading