Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add backward compatibility for elasticsearch<8 #33281

Merged
merged 4 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion airflow/providers/elasticsearch/log/es_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from operator import attrgetter
from time import time
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
from urllib.parse import quote
from urllib.parse import quote, urlparse

# Using `from elasticsearch import *` would break elasticsearch mocking used in unit test.
import elasticsearch
Expand Down Expand Up @@ -98,6 +98,12 @@ def __init__(
log_id_template: str | None = None,
):
es_kwargs = es_kwargs or {}
# For elasticsearch>8,arguments like retry_timeout have changed for elasticsearch to retry_on_timeout
# in Elasticsearch() compared to previous versions.
# Read more at: https://elasticsearch-py.readthedocs.io/en/v8.8.2/api.html#module-elasticsearch
if es_kwargs.get("retry_timeout"):
es_kwargs["retry_on_timeout"] = es_kwargs.pop("retry_timeout")
host = self.format_url(host)
super().__init__(base_log_folder, filename_template)
self.closed = False

Expand Down Expand Up @@ -126,6 +132,27 @@ def __init__(
self._doc_type_map: dict[Any, Any] = {}
self._doc_type: list[Any] = []

@staticmethod
def format_url(host: str) -> str:
"""
Formats the given host string to ensure it starts with 'http'.
Checks if the host string represents a valid URL.

:params host: The host string to format and check.
"""
parsed_url = urlparse(host)

# Check if the scheme is either http or https
if not parsed_url.scheme:
host = "http://" + host
parsed_url = urlparse(host)

# Basic validation for a valid URL
if not parsed_url.netloc:
raise ValueError(f"'{host}' is not a valid URL.")

return host

def _render_log_id(self, ti: TaskInstance, try_number: int) -> str:
with create_session() as session:
dag_run = ti.get_dagrun(session=session)
Expand Down
63 changes: 63 additions & 0 deletions tests/providers/elasticsearch/log/test_es_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,69 @@ def concat_logs(lines):
"on 2023-07-09 07:47:32+00:00"
)

@pytest.mark.parametrize(
"host, expected",
[
("http://localhost:9200", "http://localhost:9200"),
("https://localhost:9200", "https://localhost:9200"),
("localhost:9200", "http://localhost:9200"),
("someurl", "http://someurl"),
("https://", "ValueError"),
],
)
def test_format_url(self, host, expected):
"""
Test the format_url method of the ElasticsearchTaskHandler class.
"""
if expected == "ValueError":
with pytest.raises(ValueError):
assert ElasticsearchTaskHandler.format_url(host) == expected
else:
assert ElasticsearchTaskHandler.format_url(host) == expected

def test_elasticsearch_constructor_retry_timeout_handling(self):
"""
Test if the ElasticsearchTaskHandler constructor properly handles the retry_timeout argument.
"""
# Mock the Elasticsearch client
with mock.patch(
"airflow.providers.elasticsearch.log.es_task_handler.elasticsearch.Elasticsearch"
) as mock_es:
# Test when 'retry_timeout' is present in es_kwargs
es_kwargs = {"retry_timeout": 10}
ElasticsearchTaskHandler(
base_log_folder="dummy_folder",
end_of_log_mark="end_of_log_mark",
write_stdout=False,
json_format=False,
json_fields="fields",
host_field="host",
offset_field="offset",
es_kwargs=es_kwargs,
)

# Check the arguments with which the Elasticsearch client is instantiated
mock_es.assert_called_once_with("http://localhost:9200", retry_on_timeout=10)

# Reset the mock for the next test
mock_es.reset_mock()

# Test when 'retry_timeout' is not present in es_kwargs
es_kwargs = {}
ElasticsearchTaskHandler(
base_log_folder="dummy_folder",
end_of_log_mark="end_of_log_mark",
write_stdout=False,
json_format=False,
json_fields="fields",
host_field="host",
offset_field="offset",
es_kwargs=es_kwargs,
)

# Check that the Elasticsearch client is instantiated without the 'retry_on_timeout' argument
mock_es.assert_called_once_with("http://localhost:9200")

def test_client(self):
assert isinstance(self.es_task_handler.client, elasticsearch.Elasticsearch)
assert self.es_task_handler.index_patterns == "_all"
Expand Down