diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py b/airflow/providers/elasticsearch/log/es_task_handler.py index 82cc887553d6..c397e2b3585f 100644 --- a/airflow/providers/elasticsearch/log/es_task_handler.py +++ b/airflow/providers/elasticsearch/log/es_task_handler.py @@ -41,6 +41,7 @@ from airflow.utils import timezone from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import ExternalLoggingMixin, LoggingMixin +from airflow.utils.module_loading import import_string from airflow.utils.session import create_session if TYPE_CHECKING: @@ -152,7 +153,8 @@ def __init__( offset_field: str = "offset", host: str = "http://localhost:9200", frontend: str = "localhost:5601", - index_patterns: str | None = conf.get("elasticsearch", "index_patterns", fallback="_all"), + index_patterns: str = conf.get("elasticsearch", "index_patterns"), + index_patterns_callable: str = conf.get("elasticsearch", "index_patterns_callable", fallback=""), es_kwargs: dict | None | Literal["default_es_kwargs"] = "default_es_kwargs", *, filename_template: str | None = None, @@ -184,6 +186,7 @@ def __init__( self.host_field = host_field self.offset_field = offset_field self.index_patterns = index_patterns + self.index_patterns_callable = index_patterns_callable self.context_set = False self.formatter: logging.Formatter @@ -213,6 +216,19 @@ def format_url(host: str) -> str: return host + def _get_index_patterns(self, ti: TaskInstance | None) -> str: + """ + Get index patterns by calling index_patterns_callable, if provided, or the configured index_patterns. + + :param ti: A TaskInstance object or None. + """ + if self.index_patterns_callable: + self.log.debug("Using index_patterns_callable: %s", self.index_patterns_callable) + index_pattern_callable_obj = import_string(self.index_patterns_callable) + return index_pattern_callable_obj(ti) + self.log.debug("Using index_patterns: %s", self.index_patterns) + return self.index_patterns + def _render_log_id(self, ti: TaskInstance | TaskInstanceKey, try_number: int) -> str: from airflow.models.taskinstance import TaskInstanceKey @@ -302,7 +318,7 @@ def _read( offset = metadata["offset"] log_id = self._render_log_id(ti, try_number) - response = self._es_read(log_id, offset) + response = self._es_read(log_id, offset, ti) if response is not None and response.hits: logs_by_host = self._group_logs_by_host(response) next_offset = attrgetter(self.offset_field)(response[-1]) @@ -372,12 +388,13 @@ def _format_msg(self, hit: Hit): # Just a safe-guard to preserve backwards-compatibility return hit.message - def _es_read(self, log_id: str, offset: int | str) -> ElasticSearchResponse | None: + def _es_read(self, log_id: str, offset: int | str, ti: TaskInstance) -> ElasticSearchResponse | None: """ Return the logs matching log_id in Elasticsearch and next offset or ''. :param log_id: the log_id of the log to read. :param offset: the offset start to read log from. + :param ti: the task instance object :meta private: """ @@ -388,16 +405,17 @@ def _es_read(self, log_id: str, offset: int | str) -> ElasticSearchResponse | No } } + index_patterns = self._get_index_patterns(ti) try: - max_log_line = self.client.count(index=self.index_patterns, query=query)["count"] # type: ignore + max_log_line = self.client.count(index=index_patterns, query=query)["count"] # type: ignore except NotFoundError as e: - self.log.exception("The target index pattern %s does not exist", self.index_patterns) + self.log.exception("The target index pattern %s does not exist", index_patterns) raise e if max_log_line != 0: try: res = self.client.search( - index=self.index_patterns, + index=index_patterns, query=query, sort=[self.offset_field], size=self.MAX_LINE_PER_PAGE, diff --git a/airflow/providers/elasticsearch/provider.yaml b/airflow/providers/elasticsearch/provider.yaml index 73d6ae2fac6f..46ae0530673d 100644 --- a/airflow/providers/elasticsearch/provider.yaml +++ b/airflow/providers/elasticsearch/provider.yaml @@ -160,10 +160,19 @@ config: index_patterns: description: | Comma separated list of index patterns to use when searching for logs (default: `_all`). + The index_patterns_callable takes precedence over this. version_added: 2.6.0 type: string example: something-* default: "_all" + index_patterns_callable: + description: | + A string representing the full path to the Python callable path which accept TI object and + return comma separated list of index patterns. This will takes precedence over index_patterns. + version_added: 5.5.0 + type: string + example: module.callable + default: "" elasticsearch_configs: description: ~ options: diff --git a/tests/providers/elasticsearch/log/test_es_task_handler.py b/tests/providers/elasticsearch/log/test_es_task_handler.py index da9848291615..c920f0d46687 100644 --- a/tests/providers/elasticsearch/log/test_es_task_handler.py +++ b/tests/providers/elasticsearch/log/test_es_task_handler.py @@ -25,6 +25,7 @@ from io import StringIO from pathlib import Path from unittest import mock +from unittest.mock import Mock, patch from urllib.parse import quote import elasticsearch @@ -49,7 +50,6 @@ pytestmark = pytest.mark.db_test - AIRFLOW_SOURCES_ROOT_DIR = Path(__file__).parents[4].resolve() ES_PROVIDER_YAML_FILE = AIRFLOW_SOURCES_ROOT_DIR / "airflow" / "providers" / "elasticsearch" / "provider.yaml" @@ -643,6 +643,18 @@ def test_dynamic_offset(self, stdout_mock, ti, time_machine): assert second_log["asctime"] == t2.format("YYYY-MM-DDTHH:mm:ss.SSSZZ") assert third_log["asctime"] == t3.format("YYYY-MM-DDTHH:mm:ss.SSSZZ") + def test_get_index_patterns_with_callable(self): + with patch("airflow.providers.elasticsearch.log.es_task_handler.import_string") as mock_import_string: + mock_callable = Mock(return_value="callable_index_pattern") + mock_import_string.return_value = mock_callable + + self.es_task_handler.index_patterns_callable = "path.to.index_pattern_callable" + result = self.es_task_handler._get_index_patterns({}) + + mock_import_string.assert_called_once_with("path.to.index_pattern_callable") + mock_callable.assert_called_once_with({}) + assert result == "callable_index_pattern" + def test_safe_attrgetter(): class A: ...