From f85e06eb860e43a87a34423b4d9fd726c49af80e Mon Sep 17 00:00:00 2001 From: Nikhil Badyal Date: Wed, 21 Aug 2024 18:06:17 +0530 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20Ability=20to=20fetch=20size=20fr?= =?UTF-8?q?om=20query?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- esxport/click_opt/cli_options.py | 5 ++- esxport/esxport.py | 52 +++++++++++++--------- esxport/exceptions.py | 4 ++ esxport/strings.py | 1 + test/conftest.py | 2 +- test/elastic/client_test.py | 16 +++++++ test/esxport/_export_test.py | 13 ++++++ test/esxport/_prepare_search_query_test.py | 10 ++++- 8 files changed, 80 insertions(+), 23 deletions(-) diff --git a/esxport/click_opt/cli_options.py b/esxport/click_opt/cli_options.py index 866f13c..93e99ea 100644 --- a/esxport/click_opt/cli_options.py +++ b/esxport/click_opt/cli_options.py @@ -1,6 +1,7 @@ """CLII options.""" from __future__ import annotations +import ast import json from typing import Any @@ -61,7 +62,9 @@ def __init__(self: Self, myclass_kwargs: dict[str, Any]) -> None: self.fields: list[str] = list(self.fields) self.index_prefixes: list[str] = list(self.index_prefixes) self.meta_fields: list[str] = list(self.meta_fields) - self.max_results = int(self.max_results) + if isinstance(self.query, str): + self.query = ast.literal_eval(self.query) + self.max_results = self.query["size"] if self.query.get("size") else int(self.max_results) self.scroll_size = int(self.scroll_size) self.export_format: str = "csv" diff --git a/esxport/esxport.py b/esxport/esxport.py index 40971da..a2f7f23 100644 --- a/esxport/esxport.py +++ b/esxport/esxport.py @@ -19,10 +19,19 @@ FieldNotFoundError, HealthCheckError, IndexNotFoundError, + InvalidEsQueryError, MetaFieldNotFoundError, ScrollExpiredError, ) -from .strings import index_not_found, meta_field_not_found, output_fields, sorting_by, using_indexes, using_query +from .strings import ( + index_not_found, + meta_field_not_found, + output_fields, + query_key_missing, + sorting_by, + using_indexes, + using_query, +) from .writer import Writer if TYPE_CHECKING: @@ -99,25 +108,28 @@ def _validate_fields(self: Self) -> None: def _prepare_search_query(self: Self) -> None: """Prepares search query from input.""" - self.search_args = { - "index": ",".join(self.opts.index_prefixes), - "scroll": self.scroll_time, - "size": self.opts.scroll_size, - "terminate_after": self.opts.max_results, - "query": Json().convert(self.opts.query, None, None)["query"], - } - if self.opts.sort: - self.search_args["sort"] = self.opts.sort - - if "_all" not in self.opts.fields: - self.search_args["_source_includes"] = ",".join(self.opts.fields) - - if self.opts.debug: - logger.debug(using_indexes.format(indexes={", ".join(self.opts.index_prefixes)})) - query = json.dumps(self.opts.query, default=str) - logger.debug(using_query.format(query={query})) - logger.debug(output_fields.format(fields={", ".join(self.opts.fields)})) - logger.debug(sorting_by.format(sort=self.opts.sort)) + try: + self.search_args = { + "index": ",".join(self.opts.index_prefixes), + "scroll": self.scroll_time, + "size": self.opts.scroll_size, + "terminate_after": self.opts.max_results, + "query": Json().convert(self.opts.query, None, None)["query"], + } + if self.opts.sort: + self.search_args["sort"] = self.opts.sort + + if "_all" not in self.opts.fields: + self.search_args["_source_includes"] = ",".join(self.opts.fields) + + if self.opts.debug: + logger.debug(using_indexes.format(indexes={", ".join(self.opts.index_prefixes)})) + query = json.dumps(self.opts.query, default=str) + logger.debug(using_query.format(query={query})) + logger.debug(output_fields.format(fields={", ".join(self.opts.fields)})) + logger.debug(sorting_by.format(sort=self.opts.sort)) + except KeyError as e: + raise InvalidEsQueryError(query_key_missing) from e @retry( wait=wait_exponential(2), diff --git a/esxport/exceptions.py b/esxport/exceptions.py index 6d3fad3..645e164 100644 --- a/esxport/exceptions.py +++ b/esxport/exceptions.py @@ -27,3 +27,7 @@ class ScrollExpiredError(EsXportError): class HealthCheckError(EsXportError): """Health check error.""" + + +class InvalidEsQueryError(EsXportError): + """Invalid query param.""" diff --git a/esxport/strings.py b/esxport/strings.py index c235c3d..3fcfc36 100644 --- a/esxport/strings.py +++ b/esxport/strings.py @@ -8,3 +8,4 @@ invalid_sort_format = 'Invalid input format: "{value}". Use the format "field:sort_order".' invalid_query_format = "{value} is not a valid json string, caused {exc}" cli_version = "EsXport Cli {__version__}" +query_key_missing = "Query key not found." diff --git a/test/conftest.py b/test/conftest.py index dd67600..2d817fc 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -165,7 +165,7 @@ def index_name() -> str: @pytest.fixture() -def es_index(index_name: str, elasticsearch_proc: Elasticsearch) -> Any: +def es_index(index_name: str, elasticsearch_proc: Elasticsearch) -> str: """Create index.""" elasticsearch_proc.indices.create(index=index_name) return index_name diff --git a/test/elastic/client_test.py b/test/elastic/client_test.py index 0ff590c..6963ce6 100644 --- a/test/elastic/client_test.py +++ b/test/elastic/client_test.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING import pytest +from elastic_transport import ObjectApiResponse from esxport.exceptions import ScrollExpiredError @@ -64,3 +65,18 @@ def test_scroll_expired(self: Self, elastic_client: ElasticsearchClient) -> None """Test client return true when index exists.""" with pytest.raises(ScrollExpiredError): elastic_client.scroll(scroll="5m", scroll_id="brqwdwefwef") + + @pytest.mark.xdist_group(name="elastic") + def test_ping(self: Self, elastic_client: ElasticsearchClient) -> None: + """Test that ping returns valid cluster information.""" + response = elastic_client.ping() + + # Assert that the response is an instance of ObjectApiResponse + assert isinstance(response, ObjectApiResponse), "Ping response should be an ObjectApiResponse." + + # Convert to dictionary and check for cluster information + response_dict = response.raw + assert isinstance(response_dict, dict), "Ping response should be convertible to a dictionary." + assert "cluster_name" in response_dict, "Cluster name should be present in the ping response." + assert "version" in response_dict, "Elasticsearch version should be present in the ping response." + assert "tagline" in response_dict, "Tagline should be present in the ping response." diff --git a/test/esxport/_export_test.py b/test/esxport/_export_test.py index 3a7cd4b..177a6ed 100644 --- a/test/esxport/_export_test.py +++ b/test/esxport/_export_test.py @@ -9,6 +9,7 @@ from typing_extensions import Self from esxport.esxport import EsXport +from esxport.exceptions import HealthCheckError @patch("esxport.esxport.EsXport._validate_fields") @@ -65,3 +66,15 @@ def test_headers_extraction( json.dump(test_json, tmp_file) assert esxport_obj._extract_headers() == list(test_json.keys()) TestExport.rm_export_file(f"{inspect.stack()[0].function}.csv") + + def test_ping_cluster_failure(self: Self, _: Any, esxport_obj: EsXport) -> None: + """Test that _ping_cluster raises HealthCheckError when ping fails.""" + with patch.object(esxport_obj.es_client, "ping", side_effect=ConnectionError("mocked error")), pytest.raises( + HealthCheckError, + ): + esxport_obj._ping_cluster() + + def test_ping_cluster_success(self: Self, _: Any, esxport_obj: EsXport) -> None: + """Test that _ping_cluster succeeds when ping is successful.""" + with patch.object(esxport_obj.es_client, "ping", return_value={}): + esxport_obj._ping_cluster() diff --git a/test/esxport/_prepare_search_query_test.py b/test/esxport/_prepare_search_query_test.py index 1bf9444..402b6ff 100644 --- a/test/esxport/_prepare_search_query_test.py +++ b/test/esxport/_prepare_search_query_test.py @@ -8,7 +8,7 @@ import pytest -from esxport.exceptions import IndexNotFoundError +from esxport.exceptions import IndexNotFoundError, InvalidEsQueryError from esxport.strings import index_not_found, output_fields, sorting_by, using_indexes if TYPE_CHECKING: @@ -147,3 +147,11 @@ def test_custom_output_fields(self: Self, _: Any, esxport_obj: EsXport) -> None: esxport_obj.opts.fields = random_strings esxport_obj._prepare_search_query() assert esxport_obj.search_args["_source_includes"] == ",".join(random_strings) + + def test_error_raised_when_query_key_missing(self: Self, _: Any, esxport_obj: EsXport) -> None: + """Test if selection only some fields for the output works.""" + expected_query: dict[str, Any] = {"size": 10} + esxport_obj.opts.query = expected_query + + with pytest.raises(InvalidEsQueryError): + esxport_obj._prepare_search_query()