Skip to content

Commit

Permalink
Mask secrets in stdout for 'airflow tasks test' (#24362)
Browse files Browse the repository at this point in the history
A stdout redirector is implemented to mask all values to stdout and
redact any secrets in it with the secrets masker. This redirector is
applied to the 'airflow.task' logger.

Co-authored-by: Alex Kennedy <alex.kennedy@astronomer.io>
  • Loading branch information
uranusjr and alex-astronomer authored Jun 12, 2022
1 parent 770ee07 commit 3007159
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 28 deletions.
10 changes: 6 additions & 4 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)
from airflow.utils.dates import timezone
from airflow.utils.log.logging_mixin import StreamLogWriter
from airflow.utils.log.secrets_masker import RedactedIO
from airflow.utils.net import get_hostname
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState
Expand Down Expand Up @@ -546,10 +547,11 @@ def task_test(args, dag=None):
ti, dr_created = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="db")

try:
if args.dry_run:
ti.dry_run()
else:
ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True)
with redirect_stdout(RedactedIO()):
if args.dry_run:
ti.dry_run()
else:
ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True)
except Exception:
if args.post_mortem:
debugger = _guess_debugger()
Expand Down
38 changes: 28 additions & 10 deletions airflow/utils/log/secrets_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,17 @@
import collections
import logging
import re
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import sys
from typing import Any, Dict, Iterable, List, Optional, Set, TextIO, Tuple, TypeVar, Union

from airflow import settings
from airflow.compat.functools import cache, cached_property

if TYPE_CHECKING:
RedactableItem = Union[str, Dict[Any, Any], Tuple[Any, ...], List[Any]]

Redactable = TypeVar("Redactable", str, Dict[Any, Any], Tuple[Any, ...], List[Any])
Redacted = Union[Redactable, str]

log = logging.getLogger(__name__)


DEFAULT_SENSITIVE_FIELDS = frozenset(
{
'access_token',
Expand Down Expand Up @@ -91,14 +90,13 @@ def mask_secret(secret: Union[str, dict, Iterable], name: Optional[str] = None)
_secrets_masker().add_mask(secret, name)


def redact(value: "RedactableItem", name: Optional[str] = None) -> "RedactableItem":
def redact(value: Redactable, name: Optional[str] = None) -> Redacted:
"""Redact any secrets found in ``value``."""
return _secrets_masker().redact(value, name)


@cache
def _secrets_masker() -> "SecretsMasker":

for flt in logging.getLogger('airflow.task').filters:
if isinstance(flt, SecretsMasker):
return flt
Expand Down Expand Up @@ -177,7 +175,7 @@ def filter(self, record) -> bool:

return True

def _redact_all(self, item: "RedactableItem", depth: int) -> "RedactableItem":
def _redact_all(self, item: Redactable, depth: int) -> Redacted:
if depth > self.MAX_RECURSION_DEPTH or isinstance(item, str):
return '***'
if isinstance(item, dict):
Expand All @@ -190,7 +188,7 @@ def _redact_all(self, item: "RedactableItem", depth: int) -> "RedactableItem":
else:
return item

def _redact(self, item: "RedactableItem", name: Optional[str], depth: int) -> "RedactableItem":
def _redact(self, item: Redactable, name: Optional[str], depth: int) -> Redacted:
# Avoid spending too much effort on redacting on deeply nested
# structures. This also avoid infinite recursion if a structure has
# reference to self.
Expand Down Expand Up @@ -231,7 +229,7 @@ def _redact(self, item: "RedactableItem", name: Optional[str], depth: int) -> "R
)
return item

def redact(self, item: "RedactableItem", name: Optional[str] = None) -> "RedactableItem":
def redact(self, item: Redactable, name: Optional[str] = None) -> Redacted:
"""Redact an any secrets found in ``item``, if it is a string.
If ``name`` is given, and it's a "sensitive" name (see
Expand All @@ -258,3 +256,23 @@ def add_mask(self, secret: Union[str, dict, Iterable], name: Optional[str] = Non
elif isinstance(secret, collections.abc.Iterable):
for v in secret:
self.add_mask(v, name)


class RedactedIO(TextIO):
"""IO class that redacts values going into stdout.
Expected usage::
with contextlib.redirect_stdout(RedactedIO()):
... # Writes to stdout will be redacted.
"""

def __init__(self):
self.target = sys.stdout

def write(self, s: str) -> int:
s = redact(s)
return self.target.write(s)

def flush(self) -> None:
return self.target.flush()
40 changes: 27 additions & 13 deletions tests/cli/commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def reset(dag_id):


# TODO: Check if tests needs side effects - locally there's missing DAG
class TestCliTasks(unittest.TestCase):
class TestCliTasks:
run_id = 'TEST_RUN_ID'
dag_id = 'example_python_operator'
parser: ArgumentParser
Expand All @@ -67,7 +67,7 @@ class TestCliTasks(unittest.TestCase):
dag_run: DagRun

@classmethod
def setUpClass(cls):
def setup_class(cls):
cls.dagbag = DagBag(include_examples=True)
cls.parser = cli_parser.get_parser()
clear_db_runs()
Expand All @@ -78,7 +78,7 @@ def setUpClass(cls):
)

@classmethod
def tearDownClass(cls) -> None:
def teardown_class(cls) -> None:
clear_db_runs()

def test_cli_list_tasks(self):
Expand All @@ -103,20 +103,34 @@ def test_test(self):
assert "'example_python_operator__print_the_context__20180101'" in stdout.getvalue()

@pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
def test_test_with_existing_dag_run(self):
def test_test_with_existing_dag_run(self, caplog):
"""Test the `airflow test` command"""
task_id = 'print_the_context'

args = self.parser.parse_args(["tasks", "test", self.dag_id, task_id, DEFAULT_DATE.isoformat()])
with caplog.at_level("INFO", logger="airflow.task"):
task_command.task_test(args)
assert f"Marking task as SUCCESS. dag_id={self.dag_id}, task_id={task_id}" in caplog.text

@pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
def test_test_filters_secrets(self, capsys):
"""Test ``airflow test`` does not print secrets to stdout.
Output should be filtered by SecretsMasker.
"""
password = "somepassword1234!"
logging.getLogger("airflow.task").filters[0].add_mask(password)
args = self.parser.parse_args(
["tasks", "test", "example_python_operator", "print_the_context", "2018-01-01"],
)

with self.assertLogs('airflow.task', level='INFO') as cm:
with mock.patch("airflow.models.TaskInstance.run", new=lambda *_, **__: print(password)):
task_command.task_test(args)
assert any(
[
f"Marking task as SUCCESS. dag_id={self.dag_id}, task_id={task_id}" in log
for log in cm.output
]
)
assert capsys.readouterr().out.endswith("***\n")

not_password = "!4321drowssapemos"
with mock.patch("airflow.models.TaskInstance.run", new=lambda *_, **__: print(not_password)):
task_command.task_test(args)
assert capsys.readouterr().out.endswith(f"{not_password}\n")

@mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization")
@mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
Expand Down Expand Up @@ -229,7 +243,7 @@ def test_run_raises_when_theres_no_dagrun(self, mock_local_job):
task0_id,
run_id,
]
with self.assertRaises(DagRunNotFound):
with pytest.raises(DagRunNotFound):
task_command.task_run(self.parser.parse_args(args0), dag=dag)

def test_cli_test_with_params(self):
Expand Down
23 changes: 22 additions & 1 deletion tests/utils/log/test_secrets_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import inspect
import logging
import logging.config
Expand All @@ -23,7 +24,7 @@
import pytest

from airflow import settings
from airflow.utils.log.secrets_masker import SecretsMasker, should_hide_value_for_key
from airflow.utils.log.secrets_masker import RedactedIO, SecretsMasker, should_hide_value_for_key
from tests.test_utils.config import conf_vars

settings.MASK_SECRETS_IN_LOGS = True
Expand Down Expand Up @@ -340,3 +341,23 @@ def formatException(self, exc_info):
def lineno():
"""Returns the current line number in our program."""
return inspect.currentframe().f_back.f_lineno


class TestRedactedIO:
def test_redacts_from_print(self, capsys):
# Without redacting, password is printed.
print(p)
stdout = capsys.readouterr().out
assert stdout == f"{p}\n"
assert "***" not in stdout

# With context manager, password is redacted.
with contextlib.redirect_stdout(RedactedIO()):
print(p)
stdout = capsys.readouterr().out
assert stdout == "***\n"

def test_write(self, capsys):
RedactedIO().write(p)
stdout = capsys.readouterr().out
assert stdout == "***"

0 comments on commit 3007159

Please sign in to comment.