diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py b/airflow/providers/elasticsearch/log/es_task_handler.py index 03bfe247c5821..0a85f178badac 100644 --- a/airflow/providers/elasticsearch/log/es_task_handler.py +++ b/airflow/providers/elasticsearch/log/es_task_handler.py @@ -25,7 +25,7 @@ from operator import attrgetter from time import time from typing import TYPE_CHECKING, Any, Callable, List, Tuple -from urllib.parse import quote +from urllib.parse import quote, urlparse # Using `from elasticsearch import *` would break elasticsearch mocking used in unit test. import elasticsearch @@ -98,6 +98,12 @@ def __init__( log_id_template: str | None = None, ): es_kwargs = es_kwargs or {} + # For elasticsearch>8,arguments like retry_timeout have changed for elasticsearch to retry_on_timeout + # in Elasticsearch() compared to previous versions. + # Read more at: https://elasticsearch-py.readthedocs.io/en/v8.8.2/api.html#module-elasticsearch + if es_kwargs.get("retry_timeout"): + es_kwargs["retry_on_timeout"] = es_kwargs.pop("retry_timeout") + host = self.format_url(host) super().__init__(base_log_folder, filename_template) self.closed = False @@ -126,6 +132,27 @@ def __init__( self._doc_type_map: dict[Any, Any] = {} self._doc_type: list[Any] = [] + @staticmethod + def format_url(host: str) -> str: + """ + Formats the given host string to ensure it starts with 'http'. + Checks if the host string represents a valid URL. + + :params host: The host string to format and check. + """ + parsed_url = urlparse(host) + + # Check if the scheme is either http or https + if not parsed_url.scheme: + host = "http://" + host + parsed_url = urlparse(host) + + # Basic validation for a valid URL + if not parsed_url.netloc: + raise ValueError(f"'{host}' is not a valid URL.") + + return host + def _render_log_id(self, ti: TaskInstance, try_number: int) -> str: with create_session() as session: dag_run = ti.get_dagrun(session=session) diff --git a/tests/providers/elasticsearch/log/test_es_task_handler.py b/tests/providers/elasticsearch/log/test_es_task_handler.py index 7ae894f22a94c..4ffa958819666 100644 --- a/tests/providers/elasticsearch/log/test_es_task_handler.py +++ b/tests/providers/elasticsearch/log/test_es_task_handler.py @@ -125,6 +125,69 @@ def concat_logs(lines): "on 2023-07-09 07:47:32+00:00" ) + @pytest.mark.parametrize( + "host, expected", + [ + ("http://localhost:9200", "http://localhost:9200"), + ("https://localhost:9200", "https://localhost:9200"), + ("localhost:9200", "http://localhost:9200"), + ("someurl", "http://someurl"), + ("https://", "ValueError"), + ], + ) + def test_format_url(self, host, expected): + """ + Test the format_url method of the ElasticsearchTaskHandler class. + """ + if expected == "ValueError": + with pytest.raises(ValueError): + assert ElasticsearchTaskHandler.format_url(host) == expected + else: + assert ElasticsearchTaskHandler.format_url(host) == expected + + def test_elasticsearch_constructor_retry_timeout_handling(self): + """ + Test if the ElasticsearchTaskHandler constructor properly handles the retry_timeout argument. + """ + # Mock the Elasticsearch client + with mock.patch( + "airflow.providers.elasticsearch.log.es_task_handler.elasticsearch.Elasticsearch" + ) as mock_es: + # Test when 'retry_timeout' is present in es_kwargs + es_kwargs = {"retry_timeout": 10} + ElasticsearchTaskHandler( + base_log_folder="dummy_folder", + end_of_log_mark="end_of_log_mark", + write_stdout=False, + json_format=False, + json_fields="fields", + host_field="host", + offset_field="offset", + es_kwargs=es_kwargs, + ) + + # Check the arguments with which the Elasticsearch client is instantiated + mock_es.assert_called_once_with("http://localhost:9200", retry_on_timeout=10) + + # Reset the mock for the next test + mock_es.reset_mock() + + # Test when 'retry_timeout' is not present in es_kwargs + es_kwargs = {} + ElasticsearchTaskHandler( + base_log_folder="dummy_folder", + end_of_log_mark="end_of_log_mark", + write_stdout=False, + json_format=False, + json_fields="fields", + host_field="host", + offset_field="offset", + es_kwargs=es_kwargs, + ) + + # Check that the Elasticsearch client is instantiated without the 'retry_on_timeout' argument + mock_es.assert_called_once_with("http://localhost:9200") + def test_client(self): assert isinstance(self.es_task_handler.client, elasticsearch.Elasticsearch) assert self.es_task_handler.index_patterns == "_all"