Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitize doc_md for jinja templating while initialising DAG #40520

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ def _create_orm_dagrun(
"dag_display_name": str,
}

SANITIZED_DOC_MD_TEXT = "[[ Riksy Jinja template detected & removed ]]"


@functools.total_ordering
class DAG(LoggingMixin):
Expand Down Expand Up @@ -768,10 +770,18 @@ def __init__(

validate_instance_args(self, DAG_ARGS_EXPECTED_TYPES)

def sanitize_doc_md(self, doc_md: str) -> str:
import re

jinja_regex = r"\{\{.*?\}\}"
sanitized_doc_md = re.sub(jinja_regex, SANITIZED_DOC_MD_TEXT, doc_md)
return sanitized_doc_md

def get_doc_md(self, doc_md: str | None) -> str | None:
if doc_md is None:
return doc_md

doc_md = self.sanitize_doc_md(doc_md)
env = self.get_template_env(force_sandboxed=True)

if not doc_md.endswith(".md"):
Expand Down
48 changes: 47 additions & 1 deletion tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from airflow.hooks.base import BaseHook
from airflow.models.baseoperator import BaseOperator
from airflow.models.connection import Connection
from airflow.models.dag import DAG
from airflow.models.dag import DAG, SANITIZED_DOC_MD_TEXT
from airflow.models.dagbag import DagBag
from airflow.models.expandinput import EXPAND_INPUT_EMPTY
from airflow.models.mappedoperator import MappedOperator
Expand Down Expand Up @@ -2822,3 +2822,49 @@ def operator_extra_links(self):
}
deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR])
assert deserialized_dag.task_dict["task"].operator_extra_links == [AirflowLink2()]


class TestDagDocMd:
def test_dag_doc_md_safe(self):
doc_md = "### Simple DAG"
expected = doc_md
dag = DAG(
dag_id="simple_dag",
default_args={
"retries": 1,
"retry_delay": timedelta(minutes=5),
"max_retry_delay": timedelta(minutes=10),
"depends_on_past": False,
"sla": timedelta(seconds=100),
},
start_date=datetime(2019, 8, 1),
is_paused_upon_creation=False,
doc_md=doc_md,
)

actual = dag.doc_md
assert actual == expected

def test_dag_doc_md_unsafe(self):
command = "hostname"
doc_md = (
f"{{ ''.__class__.__mro__[-1].__subclasses__()[138].__init__.__globals__['__builtins__']["
"'__import__']('subprocess').check_output('%s') }}" % command
)
expected = SANITIZED_DOC_MD_TEXT
dag = DAG(
dag_id="simple_dag",
default_args={
"retries": 1,
"retry_delay": timedelta(minutes=5),
"max_retry_delay": timedelta(minutes=10),
"depends_on_past": False,
"sla": timedelta(seconds=100),
},
start_date=datetime(2019, 8, 1),
is_paused_upon_creation=False,
doc_md=doc_md,
)

actual = dag.doc_md
assert actual == expected
Loading