From 57eb6bb5b84876849cc09f26df4c1cc86f827ee7 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Wed, 25 Aug 2021 10:29:57 -0500 Subject: [PATCH] Remove 'elasticsearch.helpers.test', switch to pytest --- .ci/Dockerfile | 6 +- .ci/run-nox.sh | 6 + .ci/run-repository.sh | 2 +- .ci/test-matrix.yml | 3 - .github/workflows/ci.yml | 23 +- elasticsearch/helpers/test.py | 99 -- noxfile.py | 23 +- setup.py | 19 +- .../conftest.py | 31 +- test_elasticsearch/run_tests.py | 133 -- .../test_async/test_server/conftest.py | 37 +- .../test_server/test_rest_api_spec.py | 8 +- test_elasticsearch/test_cases.py | 31 +- test_elasticsearch/test_client/__init__.py | 87 +- .../test_client/test_cluster.py | 4 +- .../test_client/test_indices.py | 15 +- .../test_client/test_overrides.py | 4 +- test_elasticsearch/test_client/test_utils.py | 128 +- test_elasticsearch/test_connection.py | 813 +++++------ test_elasticsearch/test_connection_pool.py | 68 +- test_elasticsearch/test_exceptions.py | 14 +- test_elasticsearch/test_helpers.py | 139 +- test_elasticsearch/test_serializer.py | 259 ++-- test_elasticsearch/test_server/__init__.py | 43 - test_elasticsearch/test_server/conftest.py | 22 +- .../test_server/test_clients.py | 54 +- .../test_server/test_helpers.py | 1292 +++++++++-------- .../test_server/test_rest_api_spec.py | 18 +- test_elasticsearch/test_transport.py | 216 +-- test_elasticsearch/utils.py | 76 +- 30 files changed, 1687 insertions(+), 1986 deletions(-) create mode 100755 .ci/run-nox.sh delete mode 100644 elasticsearch/helpers/test.py rename elasticsearch/helpers/test.pyi => test_elasticsearch/conftest.py (60%) delete mode 100755 test_elasticsearch/run_tests.py diff --git a/.ci/Dockerfile b/.ci/Dockerfile index 091819863..e50db6dd6 100644 --- a/.ci/Dockerfile +++ b/.ci/Dockerfile @@ -6,11 +6,7 @@ COPY dev-requirements.txt . RUN python -m pip install \ -U --no-cache-dir \ --disable-pip-version-check \ - pip \ - && python -m pip install \ - --no-cache-dir \ - --disable-pip-version-check \ - -r dev-requirements.txt + nox -rdev-requirements.txt COPY . . RUN python -m pip install -e . diff --git a/.ci/run-nox.sh b/.ci/run-nox.sh new file mode 100755 index 000000000..ab8f4be01 --- /dev/null +++ b/.ci/run-nox.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +if [[ -z "$NOX_SESSION" ]]; then + NOX_SESSION=test-${PYTHON_VERSION%-dev} +fi +nox -s $NOX_SESSION diff --git a/.ci/run-repository.sh b/.ci/run-repository.sh index 969009ff9..a18df1bf1 100755 --- a/.ci/run-repository.sh +++ b/.ci/run-repository.sh @@ -42,4 +42,4 @@ docker run \ --name elasticsearch-py \ --rm \ elastic/elasticsearch-py \ - python setup.py test + nox -s test diff --git a/.ci/test-matrix.yml b/.ci/test-matrix.yml index 77b8416c6..f8d8559bc 100644 --- a/.ci/test-matrix.yml +++ b/.ci/test-matrix.yml @@ -5,9 +5,6 @@ TEST_SUITE: - platinum PYTHON_VERSION: - - "2.7" - - "3.4" - - "3.5" - "3.6" - "3.7" - "3.8" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3c2320ad6..276fca543 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,13 +9,13 @@ jobs: steps: - name: Checkout Repository uses: actions/checkout@v1 - - name: Set up Python 3.7 + - name: Set up Python 3.x uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.x - name: Install dependencies run: | - python3.7 -m pip install nox + python3 -m pip install nox - name: Lint the code run: nox -s lint @@ -27,10 +27,10 @@ jobs: - name: Set up Python 3.7 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.x - name: Install dependencies run: | - python3.7 -m pip install nox + python3 -m pip install nox - name: Build the docs run: nox -s docs @@ -38,11 +38,13 @@ jobs: strategy: fail-fast: false matrix: - python-version: [2.7, 3.5, 3.6, 3.7, 3.8, 3.9] + python-version: [3.6, 3.7, 3.8, 3.9] experimental: [false] + nox-session: [""] include: - python-version: 3.10-dev experimental: true + nox-session: test-3.10 runs-on: ubuntu-latest name: test-${{ matrix.python-version }} @@ -56,7 +58,10 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install Dependencies run: | - python -m pip install -r dev-requirements.txt + python -m pip install nox - name: Run Tests - run: | - python setup.py test + shell: bash + run: .ci/run-nox.sh + env: + PYTHON_VERSION: ${{ matrix.python-version }} + NOX_SESSION: ${{ matrix.nox-session }} diff --git a/elasticsearch/helpers/test.py b/elasticsearch/helpers/test.py deleted file mode 100644 index e87139b2d..000000000 --- a/elasticsearch/helpers/test.py +++ /dev/null @@ -1,99 +0,0 @@ -# Licensed to Elasticsearch B.V. under one or more contributor -# license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright -# ownership. Elasticsearch B.V. licenses this file to you under -# the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# type: ignore - -import os -import time -from os.path import abspath, dirname, join -from unittest import SkipTest, TestCase - -from elasticsearch import Elasticsearch -from elasticsearch.exceptions import ConnectionError - -if "ELASTICSEARCH_URL" in os.environ: - ELASTICSEARCH_URL = os.environ["ELASTICSEARCH_URL"] -else: - ELASTICSEARCH_URL = "https://elastic:changeme@localhost:9200" - -CA_CERTS = join(dirname(dirname(dirname(abspath(__file__)))), ".ci/certs/ca.pem") - - -def get_test_client(nowait=False, **kwargs): - # construct kwargs from the environment - kw = {"timeout": 30, "ca_certs": CA_CERTS} - if "PYTHON_CONNECTION_CLASS" in os.environ: - from elasticsearch import connection - - kw["connection_class"] = getattr( - connection, os.environ["PYTHON_CONNECTION_CLASS"] - ) - - kw.update(kwargs) - client = Elasticsearch(ELASTICSEARCH_URL, **kw) - - # wait for yellow status - for _ in range(1 if nowait else 100): - try: - client.cluster.health(wait_for_status="yellow") - return client - except ConnectionError: - time.sleep(0.1) - else: - # timeout - raise SkipTest("Elasticsearch failed to start.") - - -class ElasticsearchTestCase(TestCase): - @staticmethod - def _get_client(): - return get_test_client() - - @classmethod - def setup_class(cls): - cls.client = cls._get_client() - - def teardown_method(self, _): - # Hidden indices expanded in wildcards in ES 7.7 - expand_wildcards = ["open", "closed"] - if self.es_version() >= (7, 7): - expand_wildcards.append("hidden") - - self.client.indices.delete_data_stream( - name="*", ignore=404, expand_wildcards=expand_wildcards - ) - self.client.indices.delete( - index="*", ignore=404, expand_wildcards=expand_wildcards - ) - self.client.indices.delete_template(name="*", ignore=404) - self.client.indices.delete_index_template(name="*", ignore=404) - - def es_version(self): - if not hasattr(self, "_es_version"): - self._es_version = es_version(self.client) - return self._es_version - - -def _get_version(version_string): - if "." not in version_string: - return () - version = version_string.strip().split(".") - return tuple(int(v) if v.isdigit() else 999 for v in version) - - -def es_version(client): - return _get_version(client.info()["version"]["number"]) diff --git a/noxfile.py b/noxfile.py index 6d28cad73..cc2e397f1 100644 --- a/noxfile.py +++ b/noxfile.py @@ -15,8 +15,11 @@ # specific language governing permissions and limitations # under the License. +import os + import nox +SOURCE_DIR = os.path.dirname(os.path.abspath(__file__)) SOURCE_FILES = ( "setup.py", "noxfile.py", @@ -26,12 +29,26 @@ ) -@nox.session(python=["2.7", "3.4", "3.5", "3.6", "3.7", "3.8", "3.9"]) +@nox.session(python=["2.7", "3.4", "3.5", "3.6", "3.7", "3.8", "3.9", "3.10"]) def test(session): session.install(".") session.install("-r", "dev-requirements.txt") - session.run("python", "setup.py", "test") + python_version = tuple(int(x) for x in session.python.split(".")) + junit_xml = os.path.join(SOURCE_DIR, "junit", "elasticsearch-py-junit.xml") + pytest_argv = [ + "pytest", + "--cov=elasticsearch", + "--junitxml=%s" % junit_xml, + "--log-level=DEBUG", + "--cache-clear", + "-vv", + ] + # Python 3.6+ is required for async + if python_version < (3, 6): + pytest_argv.append("--ignore=test_elasticsearch/test_async/") + + session.run(*pytest_argv) @nox.session() @@ -55,7 +72,7 @@ def lint(session): session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) # Workaround to make '-r' to still work despite uninstalling aiohttp below. - session.run("python", "-m", "pip", "install", "aiohttp") + session.install("aiohttp") # Run mypy on the package and then the type examples separately for # the two different mypy use-cases, ourselves and our users. diff --git a/setup.py b/setup.py index 89dd43e24..66363ef30 100644 --- a/setup.py +++ b/setup.py @@ -55,18 +55,7 @@ "urllib3>=1.21.1, <2", "certifi", ] -tests_require = [ - "requests>=2.0.0, <3.0.0", - "coverage", - "mock", - "pyyaml", - "pytest", - "pytest-cov", -] -async_require = ["aiohttp>=3,<4"] - -docs_require = ["sphinx<1.7", "sphinx_rtd_theme"] -generate_require = ["black", "jinja2"] +async_requires = ["aiohttp>=3,<4"] setup( name=package_name, @@ -109,12 +98,8 @@ ], python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, <4", install_requires=install_requires, - test_suite="test_elasticsearch.run_tests.run_all", - tests_require=tests_require, extras_require={ - "develop": tests_require + docs_require + generate_require, - "docs": docs_require, "requests": ["requests>=2.4.0, <3.0.0"], - "async": async_require, + "async": async_requires, }, ) diff --git a/elasticsearch/helpers/test.pyi b/test_elasticsearch/conftest.py similarity index 60% rename from elasticsearch/helpers/test.pyi rename to test_elasticsearch/conftest.py index 209d26ed7..97c87fba1 100644 --- a/elasticsearch/helpers/test.pyi +++ b/test_elasticsearch/conftest.py @@ -15,21 +15,24 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Tuple -from unittest import TestCase +from typing import Tuple -from ..client import Elasticsearch +import pytest -ELASTICSEARCH_URL: str -CA_CERTS: str +from elasticsearch import Elasticsearch -def get_test_client(nowait: bool = ..., **kwargs: Any) -> Elasticsearch: ... -def _get_version(version_string: str) -> Tuple[int, ...]: ... +from .utils import CA_CERTS, es_url, es_version -class ElasticsearchTestCase(TestCase): - @staticmethod - def _get_client() -> Elasticsearch: ... - @classmethod - def setup_class(cls) -> None: ... - def teardown_method(self, _: Any) -> None: ... - def es_version(self) -> Tuple[int, ...]: ... + +@pytest.fixture(scope="session") +def elasticsearch_url(): + try: + return es_url() + except RuntimeError as e: + pytest.skip(str(e)) + + +@pytest.fixture(scope="session") +def elasticsearch_version(elasticsearch_url) -> Tuple[int, ...]: + """Returns the version of the current Elasticsearch cluster""" + return es_version(Elasticsearch(elasticsearch_url, ca_certs=CA_CERTS)) diff --git a/test_elasticsearch/run_tests.py b/test_elasticsearch/run_tests.py deleted file mode 100755 index 5200093ba..000000000 --- a/test_elasticsearch/run_tests.py +++ /dev/null @@ -1,133 +0,0 @@ -#!/usr/bin/env python -# Licensed to Elasticsearch B.V. under one or more contributor -# license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright -# ownership. Elasticsearch B.V. licenses this file to you under -# the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import print_function - -import atexit -import subprocess -import sys -from os import environ -from os.path import abspath, dirname, exists, join, pardir - - -def fetch_es_repo(): - # user is manually setting YAML dir, don't tamper with it - if "TEST_ES_YAML_DIR" in environ: - return - - repo_path = environ.get( - "TEST_ES_REPO", - abspath(join(dirname(__file__), pardir, pardir, "elasticsearch")), - ) - - # no repo - if not exists(repo_path) or not exists(join(repo_path, ".git")): - subprocess.check_call( - "git clone https://github.com/elastic/elasticsearch %s" % repo_path, - shell=True, - ) - - # set YAML test dir - environ["TEST_ES_YAML_DIR"] = join( - repo_path, "rest-api-spec", "src", "main", "resources", "rest-api-spec", "test" - ) - - # fetching of yaml tests disabled, we'll run with what's there - if environ.get("TEST_ES_NOFETCH", False): - return - - from test_elasticsearch.test_cases import SkipTest - from test_elasticsearch.test_server import get_client - - # find out the sha of the running es - try: - es = get_client() - sha = es.info()["version"]["build_hash"] - except (SkipTest, KeyError): - print("No running elasticsearch >1.X server...") - return - - # fetch new commits to be sure... - print("Fetching elasticsearch repo...") - subprocess.check_call( - "cd %s && git fetch https://github.com/elastic/elasticsearch.git" % repo_path, - shell=True, - ) - # reset to the version fron info() - subprocess.check_call("cd %s && git fetch" % repo_path, shell=True) - subprocess.check_call("cd %s && git reset --hard %s" % (repo_path, sha), shell=True) - - -def run_all(argv=None): - atexit.register(lambda: sys.stderr.write("Shutting down....\n")) - - # fetch yaml tests anywhere that's not GitHub Actions - if "GITHUB_ACTION" not in environ: - fetch_es_repo() - - # always insert coverage when running tests - if argv is None: - junit_xml = join( - abspath(dirname(dirname(__file__))), "junit", "elasticsearch-py-junit.xml" - ) - argv = [ - "pytest", - "--cov=elasticsearch", - "--junitxml=%s" % junit_xml, - "--log-level=DEBUG", - "--cache-clear", - "-vv", - ] - - ignores = [] - # Python 3.6+ is required for async - if sys.version_info < (3, 6): - ignores.append("test_elasticsearch/test_async/") - - # GitHub Actions, run non-server tests - if "GITHUB_ACTION" in environ: - ignores.extend( - [ - "test_elasticsearch/test_server/", - "test_elasticsearch/test_async/test_server/", - ] - ) - if ignores: - argv.extend(["--ignore=%s" % ignore for ignore in ignores]) - - # Jenkins, only run server tests - if environ.get("TEST_TYPE") == "server": - test_dir = abspath(dirname(__file__)) - argv.append(join(test_dir, "test_server")) - if sys.version_info >= (3, 6): - argv.append(join(test_dir, "test_async/test_server")) - - # Not in CI, run all tests specified. - else: - argv.append(abspath(dirname(__file__))) - - exit_code = 0 - try: - subprocess.check_call(argv, stdout=sys.stdout, stderr=sys.stderr) - except subprocess.CalledProcessError as e: - exit_code = e.returncode - sys.exit(exit_code) - - -if __name__ == "__main__": - run_all(sys.argv) diff --git a/test_elasticsearch/test_async/test_server/conftest.py b/test_elasticsearch/test_async/test_server/conftest.py index fd77027e2..3a90fd17d 100644 --- a/test_elasticsearch/test_async/test_server/conftest.py +++ b/test_elasticsearch/test_async/test_server/conftest.py @@ -15,41 +15,32 @@ # specific language governing permissions and limitations # under the License. -import asyncio - import pytest import elasticsearch -from elasticsearch.helpers.test import CA_CERTS, ELASTICSEARCH_URL -from ...utils import wipe_cluster +from ...utils import CA_CERTS, wipe_cluster pytestmark = pytest.mark.asyncio @pytest.fixture(scope="function") -async def async_client(): +@pytest.mark.usefixtures("sync_client") +async def async_client(elasticsearch_url): + # 'sync_client' fixture is used for the guaranteed wipe_cluster() call. + + if not hasattr(elasticsearch, "AsyncElasticsearch"): + pytest.skip("test requires 'AsyncElasticsearch' and aiohttp to be installed") + + # Unfortunately the asyncio client needs to be rebuilt every + # test execution due to how pytest-asyncio manages + # event loops (one per test!) client = None try: - if not hasattr(elasticsearch, "AsyncElasticsearch"): - pytest.skip("test requires 'AsyncElasticsearch'") - - kw = {"timeout": 3, "ca_certs": CA_CERTS} - client = elasticsearch.AsyncElasticsearch(ELASTICSEARCH_URL, **kw) - - # wait for yellow status - for _ in range(100): - try: - await client.cluster.health(wait_for_status="yellow") - break - except ConnectionError: - await asyncio.sleep(0.1) - else: - # timeout - pytest.skip("Elasticsearch failed to start.") - + client = elasticsearch.AsyncElasticsearch( + elasticsearch_url, timeout=3, ca_certs=CA_CERTS + ) yield client - finally: if client: wipe_cluster(client) diff --git a/test_elasticsearch/test_async/test_server/test_rest_api_spec.py b/test_elasticsearch/test_async/test_server/test_rest_api_spec.py index 4101d4360..de8adde9c 100644 --- a/test_elasticsearch/test_async/test_server/test_rest_api_spec.py +++ b/test_elasticsearch/test_async/test_server/test_rest_api_spec.py @@ -26,7 +26,6 @@ import pytest from elasticsearch import ElasticsearchWarning, RequestError -from elasticsearch.helpers.test import _get_version from ...test_server.test_rest_api_spec import ( IMPLEMENTED_FEATURES, @@ -35,6 +34,7 @@ YAML_TEST_SPECS, YamlRunner, ) +from ...utils import parse_version pytestmark = pytest.mark.asyncio @@ -188,9 +188,9 @@ async def run_skip(self, skip): version, reason = skip["version"], skip["reason"] if version == "all": pytest.skip(reason) - min_version, max_version = version.split("-") - min_version = _get_version(min_version) or (0,) - max_version = _get_version(max_version) or (999,) + min_version, _, max_version = version.partition("-") + min_version = parse_version(min_version.strip()) or (0,) + max_version = parse_version(max_version.strip()) or (999,) if min_version <= (await self.es_version()) <= max_version: pytest.skip(reason) diff --git a/test_elasticsearch/test_cases.py b/test_elasticsearch/test_cases.py index f3fcc025a..c97fbb88d 100644 --- a/test_elasticsearch/test_cases.py +++ b/test_elasticsearch/test_cases.py @@ -16,14 +16,12 @@ # under the License. from collections import defaultdict -from unittest import SkipTest # noqa: F401 -from unittest import TestCase from elasticsearch import Elasticsearch class DummyTransport(object): - def __init__(self, hosts, responses=None, **kwargs): + def __init__(self, hosts, responses=None, **_): self.hosts = hosts self.responses = responses self.call_count = 0 @@ -38,32 +36,15 @@ def perform_request(self, method, url, params=None, headers=None, body=None): return resp -class ElasticsearchTestCase(TestCase): - def setUp(self): - super(ElasticsearchTestCase, self).setUp() +class DummyTransportTestCase: + def setup_method(self, _): self.client = Elasticsearch(transport_class=DummyTransport) def assert_call_count_equals(self, count): - self.assertEqual(count, self.client.transport.call_count) + assert count == self.client.transport.call_count def assert_url_called(self, method, url, count=1): - self.assertIn((method, url), self.client.transport.calls) + assert (method, url) in self.client.transport.calls calls = self.client.transport.calls[(method, url)] - self.assertEqual(count, len(calls)) + assert count == len(calls) return calls - - -class TestElasticsearchTestCase(ElasticsearchTestCase): - def test_our_transport_used(self): - self.assertIsInstance(self.client.transport, DummyTransport) - - def test_start_with_0_call(self): - self.assert_call_count_equals(0) - - def test_each_call_is_recorded(self): - self.client.transport.perform_request("GET", "/") - self.client.transport.perform_request("DELETE", "/42", params={}, body="body") - self.assert_call_count_equals(2) - self.assertEqual( - [({}, None, "body")], self.assert_url_called("DELETE", "/42", 1) - ) diff --git a/test_elasticsearch/test_client/__init__.py b/test_elasticsearch/test_client/__init__.py index c0d47aa70..82834986c 100644 --- a/test_elasticsearch/test_client/__init__.py +++ b/test_elasticsearch/test_client/__init__.py @@ -19,57 +19,50 @@ from elasticsearch.client import Elasticsearch, _normalize_hosts -from ..test_cases import ElasticsearchTestCase, TestCase +from ..test_cases import DummyTransportTestCase -class TestNormalizeHosts(TestCase): +class TestNormalizeHosts: def test_none_uses_defaults(self): - self.assertEqual([{}], _normalize_hosts(None)) + assert [{}] == _normalize_hosts(None) def test_strings_are_used_as_hostnames(self): - self.assertEqual([{"host": "elastic.co"}], _normalize_hosts(["elastic.co"])) + assert [{"host": "elastic.co"}] == _normalize_hosts(["elastic.co"]) def test_strings_are_parsed_for_port_and_user(self): - self.assertEqual( - [ - {"host": "elastic.co", "port": 42}, - {"host": "elastic.co", "http_auth": "user:secre]"}, - ], - _normalize_hosts(["elastic.co:42", "user:secre%5D@elastic.co"]), - ) + assert [ + {"host": "elastic.co", "port": 42}, + {"host": "elastic.co", "http_auth": "user:secre]"}, + ] == _normalize_hosts(["elastic.co:42", "user:secre%5D@elastic.co"]) def test_strings_are_parsed_for_scheme(self): - self.assertEqual( - [ - {"host": "elastic.co", "port": 42, "use_ssl": True}, - { - "host": "elastic.co", - "http_auth": "user:secret", - "use_ssl": True, - "port": 443, - "url_prefix": "/prefix", - }, - ], - _normalize_hosts( - ["https://elastic.co:42", "https://user:secret@elastic.co/prefix"] - ), + assert [ + {"host": "elastic.co", "port": 42, "use_ssl": True}, + { + "host": "elastic.co", + "http_auth": "user:secret", + "use_ssl": True, + "port": 443, + "url_prefix": "/prefix", + }, + ] == _normalize_hosts( + ["https://elastic.co:42", "https://user:secret@elastic.co/prefix"] ) def test_dicts_are_left_unchanged(self): - self.assertEqual( - [{"host": "local", "extra": 123}], - _normalize_hosts([{"host": "local", "extra": 123}]), + assert [{"host": "local", "extra": 123}] == _normalize_hosts( + [{"host": "local", "extra": 123}] ) def test_single_string_is_wrapped_in_list(self): - self.assertEqual([{"host": "elastic.co"}], _normalize_hosts("elastic.co")) + assert [{"host": "elastic.co"}] == _normalize_hosts("elastic.co") -class TestClient(ElasticsearchTestCase): +class TestClient(DummyTransportTestCase): def test_request_timeout_is_passed_through_unescaped(self): self.client.ping(request_timeout=0.1) calls = self.assert_url_called("HEAD", "/") - self.assertEqual([({"request_timeout": 0.1}, {}, None)], calls) + assert [({"request_timeout": 0.1}, {}, None)] == calls def test_params_is_copied_when(self): rt = object() @@ -77,11 +70,11 @@ def test_params_is_copied_when(self): self.client.ping(params=params) self.client.ping(params=params) calls = self.assert_url_called("HEAD", "/", 2) - self.assertEqual( - [({"request_timeout": rt}, {}, None), ({"request_timeout": rt}, {}, None)], - calls, - ) - self.assertFalse(calls[0][0] is calls[1][0]) + assert [ + ({"request_timeout": rt}, {}, None), + ({"request_timeout": rt}, {}, None), + ] == calls + assert not (calls[0][0] is calls[1][0]) def test_headers_is_copied_when(self): hv = "value" @@ -89,34 +82,34 @@ def test_headers_is_copied_when(self): self.client.ping(headers=headers) self.client.ping(headers=headers) calls = self.assert_url_called("HEAD", "/", 2) - self.assertEqual( - [({}, {"authentication": hv}, None), ({}, {"authentication": hv}, None)], - calls, - ) - self.assertFalse(calls[0][0] is calls[1][0]) + assert [ + ({}, {"authentication": hv}, None), + ({}, {"authentication": hv}, None), + ] == calls + assert not (calls[0][0] is calls[1][0]) def test_from_in_search(self): self.client.search(index="i", from_=10) calls = self.assert_url_called("POST", "/i/_search") - self.assertEqual([({"from": "10"}, {}, None)], calls) + assert [({"from": "10"}, {}, None)] == calls def test_repr_contains_hosts(self): - self.assertEqual("", repr(self.client)) + assert "" == repr(self.client) def test_repr_subclass(self): class OtherElasticsearch(Elasticsearch): pass - self.assertEqual("", repr(OtherElasticsearch())) + assert "" == repr(OtherElasticsearch()) def test_repr_contains_hosts_passed_in(self): - self.assertIn("es.org", repr(Elasticsearch(["es.org:123"]))) + assert "es.org" in repr(Elasticsearch(["es.org:123"])) def test_repr_truncates_host_to_5(self): hosts = [{"host": "es" + str(i)} for i in range(10)] es = Elasticsearch(hosts) - self.assertNotIn("es5", repr(es)) - self.assertIn("...", repr(es)) + assert "es5" not in repr(es) + assert "..." in repr(es) def test_index_uses_post_if_id_is_empty(self): self.client.index(index="my-index", id="", body={}) diff --git a/test_elasticsearch/test_client/test_cluster.py b/test_elasticsearch/test_client/test_cluster.py index d92a648a7..a623f03ea 100644 --- a/test_elasticsearch/test_client/test_cluster.py +++ b/test_elasticsearch/test_client/test_cluster.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. -from test_elasticsearch.test_cases import ElasticsearchTestCase +from test_elasticsearch.test_cases import DummyTransportTestCase -class TestCluster(ElasticsearchTestCase): +class TestCluster(DummyTransportTestCase): def test_stats_without_node_id(self): self.client.cluster.stats() self.assert_url_called("GET", "/_cluster/stats") diff --git a/test_elasticsearch/test_client/test_indices.py b/test_elasticsearch/test_client/test_indices.py index 152764399..2d71593bd 100644 --- a/test_elasticsearch/test_client/test_indices.py +++ b/test_elasticsearch/test_client/test_indices.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. -from test_elasticsearch.test_cases import ElasticsearchTestCase +import pytest +from test_elasticsearch.test_cases import DummyTransportTestCase -class TestIndices(ElasticsearchTestCase): + +class TestIndices(DummyTransportTestCase): def test_create_one_index(self): self.client.indices.create("test-index") self.assert_url_called("PUT", "/test-index") @@ -32,6 +34,9 @@ def test_exists_index(self): self.assert_url_called("HEAD", "/second.index,third%2Findex") def test_passing_empty_value_for_required_param_raises_exception(self): - self.assertRaises(ValueError, self.client.indices.exists, index=None) - self.assertRaises(ValueError, self.client.indices.exists, index=[]) - self.assertRaises(ValueError, self.client.indices.exists, index="") + with pytest.raises(ValueError): + self.client.indices.exists(index=None) + with pytest.raises(ValueError): + self.client.indices.exists(index=[]) + with pytest.raises(ValueError): + self.client.indices.exists(index="") diff --git a/test_elasticsearch/test_client/test_overrides.py b/test_elasticsearch/test_client/test_overrides.py index 88702841a..bc3ea2ef3 100644 --- a/test_elasticsearch/test_client/test_overrides.py +++ b/test_elasticsearch/test_client/test_overrides.py @@ -16,10 +16,10 @@ # specific language governing permissions and limitations # under the License. -from test_elasticsearch.test_cases import ElasticsearchTestCase +from test_elasticsearch.test_cases import DummyTransportTestCase -class TestOverriddenUrlTargets(ElasticsearchTestCase): +class TestOverriddenUrlTargets(DummyTransportTestCase): def test_create(self): self.client.create(index="test-index", id="test-id", body={}) self.assert_url_called("PUT", "/test-index/_create/test-id") diff --git a/test_elasticsearch/test_client/test_utils.py b/test_elasticsearch/test_client/test_utils.py index 064d720a7..f02969776 100644 --- a/test_elasticsearch/test_client/test_utils.py +++ b/test_elasticsearch/test_client/test_utils.py @@ -18,13 +18,13 @@ from __future__ import unicode_literals +import pytest + from elasticsearch.client.utils import _bulk_body, _escape, _make_path, query_params from elasticsearch.compat import PY2 -from ..test_cases import SkipTest, TestCase - -class TestQueryParams(TestCase): +class TestQueryParams: def setup_method(self, _): self.calls = [] @@ -34,140 +34,130 @@ def func_to_wrap(self, *args, **kwargs): def test_handles_params(self): self.func_to_wrap(params={"simple_param_2": "2"}, simple_param="3") - self.assertEqual( - self.calls, - [ - ( - (), - { - "params": {"simple_param": b"3", "simple_param_2": "2"}, - "headers": {}, - }, - ) - ], - ) + assert self.calls == [ + ( + (), + { + "params": {"simple_param": b"3", "simple_param_2": "2"}, + "headers": {}, + }, + ) + ] def test_handles_headers(self): self.func_to_wrap(headers={"X-Opaque-Id": "app-1"}) - self.assertEqual( - self.calls, [((), {"params": {}, "headers": {"x-opaque-id": "app-1"}})] - ) + assert self.calls == [((), {"params": {}, "headers": {"x-opaque-id": "app-1"}})] def test_handles_opaque_id(self): self.func_to_wrap(opaque_id="request-id") - self.assertEqual( - self.calls, [((), {"params": {}, "headers": {"x-opaque-id": "request-id"}})] - ) + assert self.calls == [ + ((), {"params": {}, "headers": {"x-opaque-id": "request-id"}}) + ] def test_handles_empty_none_and_normalization(self): self.func_to_wrap(params=None) - self.assertEqual(self.calls[-1], ((), {"params": {}, "headers": {}})) + assert self.calls[-1] == ((), {"params": {}, "headers": {}}) self.func_to_wrap(headers=None) - self.assertEqual(self.calls[-1], ((), {"params": {}, "headers": {}})) + assert self.calls[-1] == ((), {"params": {}, "headers": {}}) self.func_to_wrap(headers=None, params=None) - self.assertEqual(self.calls[-1], ((), {"params": {}, "headers": {}})) + assert self.calls[-1] == ((), {"params": {}, "headers": {}}) self.func_to_wrap(headers={}, params={}) - self.assertEqual(self.calls[-1], ((), {"params": {}, "headers": {}})) + assert self.calls[-1] == ((), {"params": {}, "headers": {}}) self.func_to_wrap(headers={"X": "y"}) - self.assertEqual(self.calls[-1], ((), {"params": {}, "headers": {"x": "y"}})) + assert self.calls[-1] == ((), {"params": {}, "headers": {"x": "y"}}) def test_per_call_authentication(self): self.func_to_wrap(api_key=("name", "key")) - self.assertEqual( - self.calls[-1], - ((), {"headers": {"authorization": "ApiKey bmFtZTprZXk="}, "params": {}}), + assert self.calls[-1] == ( + (), + {"headers": {"authorization": "ApiKey bmFtZTprZXk="}, "params": {}}, ) self.func_to_wrap(http_auth=("user", "password")) - self.assertEqual( - self.calls[-1], - ( - (), - { - "headers": {"authorization": "Basic dXNlcjpwYXNzd29yZA=="}, - "params": {}, - }, - ), + assert self.calls[-1] == ( + (), + { + "headers": {"authorization": "Basic dXNlcjpwYXNzd29yZA=="}, + "params": {}, + }, ) self.func_to_wrap(http_auth="abcdef") - self.assertEqual( - self.calls[-1], - ((), {"headers": {"authorization": "Basic abcdef"}, "params": {}}), + assert self.calls[-1] == ( + (), + {"headers": {"authorization": "Basic abcdef"}, "params": {}}, ) # If one or the other is 'None' it's all good! self.func_to_wrap(http_auth=None, api_key=None) - self.assertEqual(self.calls[-1], ((), {"headers": {}, "params": {}})) + assert self.calls[-1] == ((), {"headers": {}, "params": {}}) self.func_to_wrap(http_auth="abcdef", api_key=None) - self.assertEqual( - self.calls[-1], - ((), {"headers": {"authorization": "Basic abcdef"}, "params": {}}), + assert self.calls[-1] == ( + (), + {"headers": {"authorization": "Basic abcdef"}, "params": {}}, ) # If both are given values an error is raised. - with self.assertRaises(ValueError) as e: + with pytest.raises(ValueError) as e: self.func_to_wrap(http_auth="key", api_key=("1", "2")) - self.assertEqual( - str(e.exception), - "Only one of 'http_auth' and 'api_key' may be passed at a time", + assert ( + str(e.value) + == "Only one of 'http_auth' and 'api_key' may be passed at a time" ) -class TestMakePath(TestCase): +class TestMakePath: def test_handles_unicode(self): id = "中文" - self.assertEqual( - "/some-index/type/%E4%B8%AD%E6%96%87", _make_path("some-index", "type", id) + assert "/some-index/type/%E4%B8%AD%E6%96%87" == _make_path( + "some-index", "type", id ) + @pytest.mark.skipif(not PY2, reason="Only relevant for Python 2") def test_handles_utf_encoded_string(self): - if not PY2: - raise SkipTest("Only relevant for py2") id = "中文".encode("utf-8") - self.assertEqual( - "/some-index/type/%E4%B8%AD%E6%96%87", _make_path("some-index", "type", id) + assert "/some-index/type/%E4%B8%AD%E6%96%87" == _make_path( + "some-index", "type", id ) -class TestEscape(TestCase): +class TestEscape: def test_handles_ascii(self): string = "abc123" - self.assertEqual(b"abc123", _escape(string)) + assert b"abc123" == _escape(string) def test_handles_unicode(self): string = "中文" - self.assertEqual(b"\xe4\xb8\xad\xe6\x96\x87", _escape(string)) + assert b"\xe4\xb8\xad\xe6\x96\x87" == _escape(string) def test_handles_bytestring(self): string = b"celery-task-meta-c4f1201f-eb7b-41d5-9318-a75a8cfbdaa0" - self.assertEqual(string, _escape(string)) + assert string == _escape(string) -class TestBulkBody(TestCase): +class TestBulkBody: def test_proper_bulk_body_as_string_is_not_modified(self): string_body = '"{"index":{ "_index" : "test"}}\n{"field1": "value1"}"\n' - self.assertEqual(string_body, _bulk_body(None, string_body)) + assert string_body == _bulk_body(None, string_body) def test_proper_bulk_body_as_bytestring_is_not_modified(self): bytestring_body = b'"{"index":{ "_index" : "test"}}\n{"field1": "value1"}"\n' - self.assertEqual(bytestring_body, _bulk_body(None, bytestring_body)) + assert bytestring_body == _bulk_body(None, bytestring_body) def test_bulk_body_as_string_adds_trailing_newline(self): string_body = '"{"index":{ "_index" : "test"}}\n{"field1": "value1"}"' - self.assertEqual( - '"{"index":{ "_index" : "test"}}\n{"field1": "value1"}"\n', - _bulk_body(None, string_body), + assert '"{"index":{ "_index" : "test"}}\n{"field1": "value1"}"\n' == _bulk_body( + None, string_body ) def test_bulk_body_as_bytestring_adds_trailing_newline(self): bytestring_body = b'"{"index":{ "_index" : "test"}}\n{"field1": "value1"}"' - self.assertEqual( - b'"{"index":{ "_index" : "test"}}\n{"field1": "value1"}"\n', - _bulk_body(None, bytestring_body), + assert ( + b'"{"index":{ "_index" : "test"}}\n{"field1": "value1"}"\n' + == _bulk_body(None, bytestring_body) ) diff --git a/test_elasticsearch/test_connection.py b/test_elasticsearch/test_connection.py index 0baa4f456..c4cc90733 100644 --- a/test_elasticsearch/test_connection.py +++ b/test_elasticsearch/test_connection.py @@ -46,8 +46,6 @@ TransportError, ) -from .test_cases import SkipTest, TestCase - CLOUD_ID_PORT_443 = "cluster:d2VzdGV1cm9wZS5henVyZS5lbGFzdGljLWNsb3VkLmNvbTo0NDMkZTdkZTlmMTM0NWU0NDkwMjgzZDkwM2JlNWI2ZjkxOWUk" CLOUD_ID_KIBANA = "cluster:d2VzdGV1cm9wZS5henVyZS5lbGFzdGljLWNsb3VkLmNvbSQ4YWY3ZWUzNTQyMGY0NThlOTAzMDI2YjQwNjQwODFmMiQyMDA2MTU1NmM1NDA0OTg2YmZmOTU3ZDg0YTZlYjUxZg==" CLOUD_ID_PORT_AND_KIBANA = "cluster:d2VzdGV1cm9wZS5henVyZS5lbGFzdGljLWNsb3VkLmNvbTo5MjQzJGM2NjM3ZjMxMmM1MjQzY2RhN2RlZDZlOTllM2QyYzE5JA==" @@ -59,77 +57,69 @@ def gzip_decompress(data): return buf.read() -class TestBaseConnection(TestCase): +class TestBaseConnection: def test_parse_cloud_id(self): # Embedded port in cloud_id - con = Connection(cloud_id=CLOUD_ID_PORT_AND_KIBANA) - self.assertEqual( - con.host, - "https://c6637f312c5243cda7ded6e99e3d2c19.westeurope.azure.elastic-cloud.com:9243", + conn = Connection(cloud_id=CLOUD_ID_PORT_AND_KIBANA) + assert ( + conn.host + == "https://c6637f312c5243cda7ded6e99e3d2c19.westeurope.azure.elastic-cloud.com:9243" ) - self.assertEqual(con.port, 9243) - self.assertEqual( - con.hostname, - "c6637f312c5243cda7ded6e99e3d2c19.westeurope.azure.elastic-cloud.com", + assert conn.port == 9243 + assert ( + conn.hostname + == "c6637f312c5243cda7ded6e99e3d2c19.westeurope.azure.elastic-cloud.com" ) - - # Embedded port but overridden - con = Connection( + conn = Connection( cloud_id=CLOUD_ID_PORT_AND_KIBANA, port=443, ) - self.assertEqual( - con.host, - "https://c6637f312c5243cda7ded6e99e3d2c19.westeurope.azure.elastic-cloud.com:443", + assert ( + conn.host + == "https://c6637f312c5243cda7ded6e99e3d2c19.westeurope.azure.elastic-cloud.com:443" ) - self.assertEqual(con.port, 443) - self.assertEqual( - con.hostname, - "c6637f312c5243cda7ded6e99e3d2c19.westeurope.azure.elastic-cloud.com", + assert conn.port == 443 + assert ( + conn.hostname + == "c6637f312c5243cda7ded6e99e3d2c19.westeurope.azure.elastic-cloud.com" ) - - # Port is 443, removed by default. - con = Connection(cloud_id=CLOUD_ID_PORT_443) - self.assertEqual( - con.host, - "https://e7de9f1345e4490283d903be5b6f919e.westeurope.azure.elastic-cloud.com", + conn = Connection(cloud_id=CLOUD_ID_PORT_443) + assert ( + conn.host + == "https://e7de9f1345e4490283d903be5b6f919e.westeurope.azure.elastic-cloud.com" ) - self.assertEqual(con.port, None) - self.assertEqual( - con.hostname, - "e7de9f1345e4490283d903be5b6f919e.westeurope.azure.elastic-cloud.com", + assert conn.port is None + assert ( + conn.hostname + == "e7de9f1345e4490283d903be5b6f919e.westeurope.azure.elastic-cloud.com" ) - - # No port, contains Kibana UUID - con = Connection(cloud_id=CLOUD_ID_KIBANA) - self.assertEqual( - con.host, - "https://8af7ee35420f458e903026b4064081f2.westeurope.azure.elastic-cloud.com", + conn = Connection(cloud_id=CLOUD_ID_KIBANA) + assert ( + conn.host + == "https://8af7ee35420f458e903026b4064081f2.westeurope.azure.elastic-cloud.com" ) - self.assertEqual(con.port, None) - self.assertEqual( - con.hostname, - "8af7ee35420f458e903026b4064081f2.westeurope.azure.elastic-cloud.com", + assert conn.port is None + assert ( + conn.hostname + == "8af7ee35420f458e903026b4064081f2.westeurope.azure.elastic-cloud.com" ) def test_empty_warnings(self): - con = Connection() + conn = Connection() with warnings.catch_warnings(record=True) as w: - con._raise_warnings(()) - con._raise_warnings([]) + conn._raise_warnings(()) + conn._raise_warnings([]) - self.assertEqual(w, []) + assert w == [] def test_raises_warnings(self): - con = Connection() - + conn = Connection() with warnings.catch_warnings(record=True) as warn: - con._raise_warnings(['299 Elasticsearch-7.6.1-aa751 "this is deprecated"']) - - self.assertEqual([str(w.message) for w in warn], ["this is deprecated"]) + conn._raise_warnings(['299 Elasticsearch-7.6.1-aa751 "this is deprecated"']) + assert [str(w.message) for w in warn] == ["this is deprecated"] with warnings.catch_warnings(record=True) as warn: - con._raise_warnings( + conn._raise_warnings( [ '299 Elasticsearch-7.6.1-aa751 "this is also deprecated"', '299 Elasticsearch-7.6.1-aa751 "this is also deprecated"', @@ -137,22 +127,22 @@ def test_raises_warnings(self): ] ) - self.assertEqual( - [str(w.message) for w in warn], - ["this is also deprecated", "guess what? deprecated"], - ) + assert [str(w.message) for w in warn] == [ + "this is also deprecated", + "guess what? deprecated", + ] def test_raises_warnings_when_folded(self): - con = Connection() + conn = Connection() with warnings.catch_warnings(record=True) as warn: - con._raise_warnings( + conn._raise_warnings( [ '299 Elasticsearch-7.6.1-aa751 "warning",' '299 Elasticsearch-7.6.1-aa751 "folded"', ] ) - self.assertEqual([str(w.message) for w in warn], ["warning", "folded"]) + assert [str(w.message) for w in warn] == ["warning", "folded"] def test_ipv6_host_and_port(self): for kwargs, expected_host in [ @@ -170,7 +160,6 @@ def test_meta_header(self): assert conn.meta_header is True conn = Connection(meta_header=False) assert conn.meta_header is False - with pytest.raises(TypeError) as e: Connection(meta_header=1) assert str(e.value) == "meta_header must be of type bool" @@ -179,14 +168,10 @@ def test_compatibility_accept_header(self): try: conn = Connection() assert "accept" not in conn.headers - os.environ["ELASTIC_CLIENT_APIVERSIONING"] = "0" - conn = Connection() assert "accept" not in conn.headers - os.environ["ELASTIC_CLIENT_APIVERSIONING"] = "1" - conn = Connection() assert ( conn.headers["accept"] @@ -196,9 +181,9 @@ def test_compatibility_accept_header(self): os.environ.pop("ELASTIC_CLIENT_APIVERSIONING") -class TestUrllib3Connection(TestCase): - def _get_mock_connection(self, connection_params={}, response_body=b"{}"): - con = Urllib3HttpConnection(**connection_params) +class TestUrllib3Connection: + def get_mock_urllib3_connection(self, connection_params={}, response_body=b"{}"): + conn = Urllib3HttpConnection(**connection_params) def _dummy_urlopen(*args, **kwargs): dummy_response = Mock() @@ -208,226 +193,196 @@ def _dummy_urlopen(*args, **kwargs): _dummy_urlopen.call_args = (args, kwargs) return dummy_response - con.pool.urlopen = _dummy_urlopen - return con + conn.pool.urlopen = _dummy_urlopen + return conn def test_ssl_context(self): try: context = ssl.create_default_context() except AttributeError: - # if create_default_context raises an AttributeError Exception + # if create_default_context raises an AttributeError exception # it means SSLContext is not available for that version of python # and we should skip this test. - raise SkipTest( - "Test test_ssl_context is skipped cause SSLContext is not available for this version of ptyhon" + pytest.skip( + "test_ssl_context is skipped cause SSLContext is not available for this version of python" ) - con = Urllib3HttpConnection(use_ssl=True, ssl_context=context) - self.assertEqual(len(con.pool.conn_kw.keys()), 1) - self.assertIsInstance(con.pool.conn_kw["ssl_context"], ssl.SSLContext) - self.assertTrue(con.use_ssl) + conn = Urllib3HttpConnection(use_ssl=True, ssl_context=context) + assert len(conn.pool.conn_kw.keys()) == 1 + assert isinstance(conn.pool.conn_kw["ssl_context"], ssl.SSLContext) + assert conn.use_ssl def test_opaque_id(self): - con = Urllib3HttpConnection(opaque_id="app-1") - self.assertEqual(con.headers["x-opaque-id"], "app-1") + conn = Urllib3HttpConnection(opaque_id="app-1") + assert conn.headers["x-opaque-id"] == "app-1" def test_http_cloud_id(self): - con = Urllib3HttpConnection( + conn = Urllib3HttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==" ) - self.assertTrue(con.use_ssl) - self.assertEqual( - con.host, "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" + assert conn.use_ssl + assert ( + conn.host + == "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" ) - self.assertEqual(con.port, None) - self.assertEqual( - con.hostname, "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" + assert conn.port is None + assert ( + conn.hostname == "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" ) - self.assertTrue(con.http_compress) - - con = Urllib3HttpConnection( + assert conn.http_compress + conn = Urllib3HttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", port=9243, ) - self.assertEqual( - con.host, - "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io:9243", + assert ( + conn.host + == "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io:9243" ) - self.assertEqual(con.port, 9243) - self.assertEqual( - con.hostname, "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" + assert conn.port == 9243 + assert ( + conn.hostname == "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" ) def test_api_key_auth(self): # test with tuple - con = Urllib3HttpConnection( + conn = Urllib3HttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", api_key=("elastic", "changeme1"), ) - self.assertEqual( - con.headers["authorization"], "ApiKey ZWxhc3RpYzpjaGFuZ2VtZTE=" + assert conn.headers["authorization"] == "ApiKey ZWxhc3RpYzpjaGFuZ2VtZTE=" + assert ( + conn.host + == "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" ) - self.assertEqual( - con.host, "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" - ) - - # test with base64 encoded string - con = Urllib3HttpConnection( + conn = Urllib3HttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", api_key="ZWxhc3RpYzpjaGFuZ2VtZTI=", ) - self.assertEqual( - con.headers["authorization"], "ApiKey ZWxhc3RpYzpjaGFuZ2VtZTI=" - ) - self.assertEqual( - con.host, "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" + assert conn.headers["authorization"] == "ApiKey ZWxhc3RpYzpjaGFuZ2VtZTI=" + assert ( + conn.host + == "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" ) def test_no_http_compression(self): - con = self._get_mock_connection() - self.assertFalse(con.http_compress) - self.assertNotIn("accept-encoding", con.headers) - - con.perform_request("GET", "/") - - (_, _, req_body), kwargs = con.pool.urlopen.call_args - - self.assertFalse(req_body) - self.assertNotIn("accept-encoding", kwargs["headers"]) - self.assertNotIn("content-encoding", kwargs["headers"]) + conn = self.get_mock_urllib3_connection() + assert not conn.http_compress + assert "accept-encoding" not in conn.headers + conn.perform_request("GET", "/") + (_, _, req_body), kwargs = conn.pool.urlopen.call_args + assert not req_body + assert "accept-encoding" not in kwargs["headers"] + assert "content-encoding" not in kwargs["headers"] def test_http_compression(self): - con = self._get_mock_connection({"http_compress": True}) - self.assertTrue(con.http_compress) - self.assertEqual(con.headers["accept-encoding"], "gzip,deflate") - - # 'content-encoding' shouldn't be set at a connection level. - # Should be applied only if the request is sent with a body. - self.assertNotIn("content-encoding", con.headers) - - con.perform_request("GET", "/", body=b"{}") - - (_, _, req_body), kwargs = con.pool.urlopen.call_args - - self.assertEqual(gzip_decompress(req_body), b"{}") - self.assertEqual(kwargs["headers"]["accept-encoding"], "gzip,deflate") - self.assertEqual(kwargs["headers"]["content-encoding"], "gzip") - - con.perform_request("GET", "/") - - (_, _, req_body), kwargs = con.pool.urlopen.call_args - - self.assertFalse(req_body) - self.assertEqual(kwargs["headers"]["accept-encoding"], "gzip,deflate") - self.assertNotIn("content-encoding", kwargs["headers"]) + conn = self.get_mock_urllib3_connection({"http_compress": True}) + assert conn.http_compress + assert conn.headers["accept-encoding"] == "gzip,deflate" + assert "content-encoding" not in conn.headers + conn.perform_request("GET", "/", body=b"{}") + (_, _, req_body), kwargs = conn.pool.urlopen.call_args + assert gzip_decompress(req_body) == b"{}" + assert kwargs["headers"]["accept-encoding"] == "gzip,deflate" + assert kwargs["headers"]["content-encoding"] == "gzip" + conn.perform_request("GET", "/") + (_, _, req_body), kwargs = conn.pool.urlopen.call_args + assert not req_body + assert kwargs["headers"]["accept-encoding"] == "gzip,deflate" + assert "content-encoding" not in kwargs["headers"] def test_cloud_id_http_compress_override(self): # 'http_compress' will be 'True' by default for connections with # 'cloud_id' set but should prioritize user-defined values. - con = Urllib3HttpConnection( + conn = Urllib3HttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", ) - self.assertEqual(con.http_compress, True) - - con = Urllib3HttpConnection( + assert conn.http_compress is True + conn = Urllib3HttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", http_compress=False, ) - self.assertEqual(con.http_compress, False) - - con = Urllib3HttpConnection( + assert conn.http_compress is False + conn = Urllib3HttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", http_compress=True, ) - self.assertEqual(con.http_compress, True) + assert conn.http_compress is True def test_default_user_agent(self): - con = Urllib3HttpConnection() - self.assertEqual( - con._get_default_user_agent(), - "elasticsearch-py/%s (Python %s)" % (__versionstr__, python_version()), + conn = Urllib3HttpConnection() + assert conn._get_default_user_agent() == "elasticsearch-py/%s (Python %s)" % ( + __versionstr__, + python_version(), ) def test_timeout_set(self): - con = Urllib3HttpConnection(timeout=42) - self.assertEqual(42, con.timeout) + conn = Urllib3HttpConnection(timeout=42) + assert 42 == conn.timeout def test_keep_alive_is_on_by_default(self): - con = Urllib3HttpConnection() - self.assertEqual( - { - "connection": "keep-alive", - "content-type": "application/json", - "user-agent": con._get_default_user_agent(), - }, - con.headers, - ) + conn = Urllib3HttpConnection() + assert { + "connection": "keep-alive", + "content-type": "application/json", + "user-agent": conn._get_default_user_agent(), + } == conn.headers def test_http_auth(self): - con = Urllib3HttpConnection(http_auth="username:secret") - self.assertEqual( - { - "authorization": "Basic dXNlcm5hbWU6c2VjcmV0", - "connection": "keep-alive", - "content-type": "application/json", - "user-agent": con._get_default_user_agent(), - }, - con.headers, - ) + conn = Urllib3HttpConnection(http_auth="username:secret") + assert { + "authorization": "Basic dXNlcm5hbWU6c2VjcmV0", + "connection": "keep-alive", + "content-type": "application/json", + "user-agent": conn._get_default_user_agent(), + } == conn.headers def test_http_auth_tuple(self): - con = Urllib3HttpConnection(http_auth=("username", "secret")) - self.assertEqual( - { - "authorization": "Basic dXNlcm5hbWU6c2VjcmV0", - "content-type": "application/json", - "connection": "keep-alive", - "user-agent": con._get_default_user_agent(), - }, - con.headers, - ) + conn = Urllib3HttpConnection(http_auth=("username", "secret")) + assert { + "authorization": "Basic dXNlcm5hbWU6c2VjcmV0", + "content-type": "application/json", + "connection": "keep-alive", + "user-agent": conn._get_default_user_agent(), + } == conn.headers def test_http_auth_list(self): - con = Urllib3HttpConnection(http_auth=["username", "secret"]) - self.assertEqual( - { - "authorization": "Basic dXNlcm5hbWU6c2VjcmV0", - "content-type": "application/json", - "connection": "keep-alive", - "user-agent": con._get_default_user_agent(), - }, - con.headers, - ) + conn = Urllib3HttpConnection(http_auth=["username", "secret"]) + assert { + "authorization": "Basic dXNlcm5hbWU6c2VjcmV0", + "content-type": "application/json", + "connection": "keep-alive", + "user-agent": conn._get_default_user_agent(), + } == conn.headers def test_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: - con = Urllib3HttpConnection(use_ssl=True, verify_certs=False) - self.assertEqual(1, len(w)) - self.assertEqual( - "Connecting to https://localhost:9200 using SSL with verify_certs=False is insecure.", - str(w[0].message), + conn = Urllib3HttpConnection(use_ssl=True, verify_certs=False) + assert 1 == len(w) + assert ( + "Connecting to https://localhost:9200 using SSL with verify_certs=False is insecure." + == str(w[0].message) ) - self.assertIsInstance(con.pool, urllib3.HTTPSConnectionPool) + assert isinstance(conn.pool, urllib3.HTTPSConnectionPool) def test_nowarn_when_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: - con = Urllib3HttpConnection( + conn = Urllib3HttpConnection( use_ssl=True, verify_certs=False, ssl_show_warn=False ) - self.assertEqual(0, len(w)) + assert 0 == len(w) - self.assertIsInstance(con.pool, urllib3.HTTPSConnectionPool) + assert isinstance(conn.pool, urllib3.HTTPSConnectionPool) def test_doesnt_use_https_if_not_specified(self): - con = Urllib3HttpConnection() - self.assertIsInstance(con.pool, urllib3.HTTPConnectionPool) + conn = Urllib3HttpConnection() + assert isinstance(conn.pool, urllib3.HTTPConnectionPool) def test_no_warning_when_using_ssl_context(self): ctx = ssl.create_default_context() with warnings.catch_warnings(record=True) as w: Urllib3HttpConnection(ssl_context=ctx) - self.assertEqual(0, len(w)) + assert 0 == len(w) def test_warns_if_using_non_default_ssl_kwargs_with_ssl_context(self): for kwargs in ( @@ -439,34 +394,31 @@ def test_warns_if_using_non_default_ssl_kwargs_with_ssl_context(self): {"ssl_show_warn": True, "ca_certs": "/path/to/certs"}, ): kwargs["ssl_context"] = ssl.create_default_context() - with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - Urllib3HttpConnection(**kwargs) - - self.assertEqual(1, len(w)) - self.assertEqual( - "When using `ssl_context`, all other SSL related kwargs are ignored", - str(w[0].message), - ) + assert 1 == len(w) + assert ( + "When using `ssl_context`, all other SSL related kwargs are ignored" + == str(w[0].message) + ) @patch("elasticsearch.connection.base.logger") def test_uncompressed_body_logged(self, logger): - con = self._get_mock_connection(connection_params={"http_compress": True}) - con.perform_request("GET", "/", body=b'{"example": "body"}') - - self.assertEqual(2, logger.debug.call_count) + conn = self.get_mock_urllib3_connection( + connection_params={"http_compress": True} + ) + conn.perform_request("GET", "/", body=b'{"example": "body"}') + assert 2 == logger.debug.call_count req, resp = logger.debug.call_args_list - - self.assertEqual('> {"example": "body"}', req[0][0] % req[0][1:]) - self.assertEqual("< {}", resp[0][0] % resp[0][1:]) + assert '> {"example": "body"}' == req[0][0] % req[0][1:] + assert "< {}" == resp[0][0] % resp[0][1:] def test_surrogatepass_into_bytes(self): buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa" - con = self._get_mock_connection(response_body=buf) - status, headers, data = con.perform_request("GET", "/") - self.assertEqual(u"你好\uda6a", data) + conn = self.get_mock_urllib3_connection(response_body=buf) + status, headers, data = conn.perform_request("GET", "/") + assert u"你好\uda6a" == data @pytest.mark.skipif( not reraise_exceptions, reason="RecursionError isn't defined in Python <3.5" @@ -478,17 +430,16 @@ def urlopen_raise(*_, **__): raise RecursionError("Wasn't modified!") conn.pool.urlopen = urlopen_raise - with pytest.raises(RecursionError) as e: conn.perform_request("GET", "/") assert str(e.value) == "Wasn't modified!" -class TestRequestsConnection(TestCase): - def _get_mock_connection( +class TestRequestsConnection: + def get_mock_requests_connection( self, connection_params={}, status_code=200, response_body=b"{}" ): - con = RequestsHttpConnection(**connection_params) + conn = RequestsHttpConnection(**connection_params) def _dummy_send(*args, **kwargs): dummy_response = Mock() @@ -500,164 +451,151 @@ def _dummy_send(*args, **kwargs): _dummy_send.call_args = (args, kwargs) return dummy_response - con.session.send = _dummy_send - return con + conn.session.send = _dummy_send + return conn def _get_request(self, connection, *args, **kwargs): if "body" in kwargs: kwargs["body"] = kwargs["body"].encode("utf-8") status, headers, data = connection.perform_request(*args, **kwargs) - self.assertEqual(200, status) - self.assertEqual("{}", data) - + assert 200 == status + assert "{}" == data timeout = kwargs.pop("timeout", connection.timeout) args, kwargs = connection.session.send.call_args - self.assertEqual(timeout, kwargs["timeout"]) - self.assertEqual(1, len(args)) + assert timeout == kwargs["timeout"] + assert 1 == len(args) return args[0] def test_custom_http_auth_is_allowed(self): auth = AuthBase() c = RequestsHttpConnection(http_auth=auth) - - self.assertEqual(auth, c.session.auth) + assert auth == c.session.auth def test_timeout_set(self): - con = RequestsHttpConnection(timeout=42) - self.assertEqual(42, con.timeout) + conn = RequestsHttpConnection(timeout=42) + assert 42 == conn.timeout def test_opaque_id(self): - con = RequestsHttpConnection(opaque_id="app-1") - self.assertEqual(con.headers["x-opaque-id"], "app-1") + conn = RequestsHttpConnection(opaque_id="app-1") + assert conn.headers["x-opaque-id"] == "app-1" def test_http_cloud_id(self): - con = RequestsHttpConnection( + conn = RequestsHttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==" ) - self.assertTrue(con.use_ssl) - self.assertEqual( - con.host, "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" + assert conn.use_ssl + assert ( + conn.host + == "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" ) - self.assertEqual(con.port, None) - self.assertEqual( - con.hostname, "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" + assert conn.port is None + assert ( + conn.hostname == "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" ) - self.assertTrue(con.http_compress) - - con = RequestsHttpConnection( + assert conn.http_compress + conn = RequestsHttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", port=9243, ) - self.assertEqual( - con.host, - "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io:9243", + assert ( + conn.host + == "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io:9243" ) - self.assertEqual(con.port, 9243) - self.assertEqual( - con.hostname, "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" + assert conn.port == 9243 + assert ( + conn.hostname == "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" ) def test_api_key_auth(self): # test with tuple - con = RequestsHttpConnection( + conn = RequestsHttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", api_key=("elastic", "changeme1"), ) - self.assertEqual( - con.session.headers["authorization"], "ApiKey ZWxhc3RpYzpjaGFuZ2VtZTE=" + assert ( + conn.session.headers["authorization"] == "ApiKey ZWxhc3RpYzpjaGFuZ2VtZTE=" ) - self.assertEqual( - con.host, "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" + assert ( + conn.host + == "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" ) - - # test with base64 encoded string - con = RequestsHttpConnection( + conn = RequestsHttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", api_key="ZWxhc3RpYzpjaGFuZ2VtZTI=", ) - self.assertEqual( - con.session.headers["authorization"], "ApiKey ZWxhc3RpYzpjaGFuZ2VtZTI=" + assert ( + conn.session.headers["authorization"] == "ApiKey ZWxhc3RpYzpjaGFuZ2VtZTI=" ) - self.assertEqual( - con.host, "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" + assert ( + conn.host + == "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io" ) def test_no_http_compression(self): - con = self._get_mock_connection() - - self.assertFalse(con.http_compress) - self.assertNotIn("content-encoding", con.session.headers) + conn = self.get_mock_requests_connection() + assert not conn.http_compress + assert "content-encoding" not in conn.session.headers + conn.perform_request("GET", "/") - con.perform_request("GET", "/") - - req = con.session.send.call_args[0][0] - self.assertNotIn("content-encoding", req.headers) - self.assertNotIn("accept-encoding", req.headers) + req = conn.session.send.call_args[0][0] + assert "content-encoding" not in req.headers + assert "accept-encoding" not in req.headers def test_http_compression(self): - con = self._get_mock_connection( + conn = self.get_mock_requests_connection( {"http_compress": True}, ) + assert conn.http_compress + assert "content-encoding" not in conn.session.headers + conn.perform_request("GET", "/", body=b"{}") - self.assertTrue(con.http_compress) - - # 'content-encoding' shouldn't be set at a session level. - # Should be applied only if the request is sent with a body. - self.assertNotIn("content-encoding", con.session.headers) - - con.perform_request("GET", "/", body=b"{}") + req = conn.session.send.call_args[0][0] + assert req.headers["content-encoding"] == "gzip" + assert req.headers["accept-encoding"] == "gzip,deflate" + conn.perform_request("GET", "/") - req = con.session.send.call_args[0][0] - self.assertEqual(req.headers["content-encoding"], "gzip") - self.assertEqual(req.headers["accept-encoding"], "gzip,deflate") - - con.perform_request("GET", "/") - - req = con.session.send.call_args[0][0] - self.assertNotIn("content-encoding", req.headers) - self.assertEqual(req.headers["accept-encoding"], "gzip,deflate") + req = conn.session.send.call_args[0][0] + assert "content-encoding" not in req.headers + assert req.headers["accept-encoding"] == "gzip,deflate" def test_cloud_id_http_compress_override(self): # 'http_compress' will be 'True' by default for connections with # 'cloud_id' set but should prioritize user-defined values. - con = RequestsHttpConnection( + conn = RequestsHttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", ) - self.assertEqual(con.http_compress, True) - - con = RequestsHttpConnection( + assert conn.http_compress is True + conn = RequestsHttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", http_compress=False, ) - self.assertEqual(con.http_compress, False) - - con = RequestsHttpConnection( + assert conn.http_compress is False + conn = RequestsHttpConnection( cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", http_compress=True, ) - self.assertEqual(con.http_compress, True) + assert conn.http_compress is True def test_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: - con = self._get_mock_connection( + conn = self.get_mock_requests_connection( {"use_ssl": True, "url_prefix": "url", "verify_certs": False} ) - self.assertEqual(1, len(w)) - self.assertEqual( - "Connecting to https://localhost:9200 using SSL with verify_certs=False is insecure.", - str(w[0].message), + assert 1 == len(w) + assert ( + "Connecting to https://localhost:9200 using SSL with verify_certs=False is insecure." + == str(w[0].message) ) - request = self._get_request(con, "GET", "/") - - self.assertEqual("https://localhost:9200/url/", request.url) - self.assertEqual("GET", request.method) - self.assertEqual(None, request.body) + request = self._get_request(conn, "GET", "/") + assert "https://localhost:9200/url/" == request.url + assert "GET" == request.method + assert request.body is None def test_nowarn_when_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: - con = self._get_mock_connection( + conn = self.get_mock_requests_connection( { "use_ssl": True, "url_prefix": "url", @@ -665,33 +603,32 @@ def test_nowarn_when_uses_https_if_verify_certs_is_off(self): "ssl_show_warn": False, } ) - self.assertEqual(0, len(w)) + assert 0 == len(w) - request = self._get_request(con, "GET", "/") - - self.assertEqual("https://localhost:9200/url/", request.url) - self.assertEqual("GET", request.method) - self.assertEqual(None, request.body) + request = self._get_request(conn, "GET", "/") + assert "https://localhost:9200/url/" == request.url + assert "GET" == request.method + assert request.body is None def test_merge_headers(self): - con = self._get_mock_connection( + conn = self.get_mock_requests_connection( connection_params={"headers": {"h1": "v1", "h2": "v2"}} ) - req = self._get_request(con, "GET", "/", headers={"h2": "v2p", "h3": "v3"}) - self.assertEqual(req.headers["h1"], "v1") - self.assertEqual(req.headers["h2"], "v2p") - self.assertEqual(req.headers["h3"], "v3") + req = self._get_request(conn, "GET", "/", headers={"h2": "v2p", "h3": "v3"}) + assert req.headers["h1"] == "v1" + assert req.headers["h2"] == "v2p" + assert req.headers["h3"] == "v3" def test_default_headers(self): - con = self._get_mock_connection() - req = self._get_request(con, "GET", "/") - self.assertEqual(req.headers["content-type"], "application/json") - self.assertEqual(req.headers["user-agent"], con._get_default_user_agent()) + conn = self.get_mock_requests_connection() + req = self._get_request(conn, "GET", "/") + assert req.headers["content-type"] == "application/json" + assert req.headers["user-agent"] == conn._get_default_user_agent() def test_custom_headers(self): - con = self._get_mock_connection() + conn = self.get_mock_requests_connection() req = self._get_request( - con, + conn, "GET", "/", headers={ @@ -699,191 +636,173 @@ def test_custom_headers(self): "user-agent": "custom-agent/1.2.3", }, ) - self.assertEqual(req.headers["content-type"], "application/x-ndjson") - self.assertEqual(req.headers["user-agent"], "custom-agent/1.2.3") + assert req.headers["content-type"] == "application/x-ndjson" + assert req.headers["user-agent"] == "custom-agent/1.2.3" def test_http_auth(self): - con = RequestsHttpConnection(http_auth="username:secret") - self.assertEqual(("username", "secret"), con.session.auth) + conn = RequestsHttpConnection(http_auth="username:secret") + assert ("username", "secret") == conn.session.auth def test_http_auth_tuple(self): - con = RequestsHttpConnection(http_auth=("username", "secret")) - self.assertEqual(("username", "secret"), con.session.auth) + conn = RequestsHttpConnection(http_auth=("username", "secret")) + assert ("username", "secret") == conn.session.auth def test_http_auth_list(self): - con = RequestsHttpConnection(http_auth=["username", "secret"]) - self.assertEqual(("username", "secret"), con.session.auth) + conn = RequestsHttpConnection(http_auth=["username", "secret"]) + assert ("username", "secret") == conn.session.auth def test_repr(self): - con = self._get_mock_connection({"host": "elasticsearch.com", "port": 443}) - self.assertEqual( - "", repr(con) + conn = self.get_mock_requests_connection( + {"host": "elasticsearch.com", "port": 443} ) + assert "" == repr(conn) def test_conflict_error_is_returned_on_409(self): - con = self._get_mock_connection(status_code=409) - self.assertRaises(ConflictError, con.perform_request, "GET", "/", {}, "") + conn = self.get_mock_requests_connection(status_code=409) + with pytest.raises(ConflictError): + conn.perform_request("GET", "/", {}, "") def test_not_found_error_is_returned_on_404(self): - con = self._get_mock_connection(status_code=404) - self.assertRaises(NotFoundError, con.perform_request, "GET", "/", {}, "") + conn = self.get_mock_requests_connection(status_code=404) + with pytest.raises(NotFoundError): + conn.perform_request("GET", "/", {}, "") def test_request_error_is_returned_on_400(self): - con = self._get_mock_connection(status_code=400) - self.assertRaises(RequestError, con.perform_request, "GET", "/", {}, "") + conn = self.get_mock_requests_connection(status_code=400) + with pytest.raises(RequestError): + conn.perform_request("GET", "/", {}, "") @patch("elasticsearch.connection.base.logger") def test_head_with_404_doesnt_get_logged(self, logger): - con = self._get_mock_connection(status_code=404) - self.assertRaises(NotFoundError, con.perform_request, "HEAD", "/", {}, "") - self.assertEqual(0, logger.warning.call_count) + conn = self.get_mock_requests_connection(status_code=404) + with pytest.raises(NotFoundError): + conn.perform_request("HEAD", "/", {}, "") + assert 0 == logger.warning.call_count @patch("elasticsearch.connection.base.tracer") @patch("elasticsearch.connection.base.logger") def test_failed_request_logs_and_traces(self, logger, tracer): - con = self._get_mock_connection( + conn = self.get_mock_requests_connection( response_body=b'{"answer": 42}', status_code=500 ) - self.assertRaises( - TransportError, - con.perform_request, - "GET", - "/", - {"param": 42}, - "{}".encode("utf-8"), - ) - - # trace request - self.assertEqual(1, tracer.info.call_count) - # trace response - self.assertEqual(1, tracer.debug.call_count) - # log url and duration - self.assertEqual(1, logger.warning.call_count) - self.assertTrue( - re.match( - r"^GET http://localhost:9200/\?param=42 \[status:500 request:0.[0-9]{3}s\]", - logger.warning.call_args[0][0] % logger.warning.call_args[0][1:], + with pytest.raises(TransportError): + conn.perform_request( + "GET", + "/", + {"param": 42}, + "{}".encode("utf-8"), ) + assert 1 == tracer.info.call_count + assert 1 == tracer.debug.call_count + assert 1 == logger.warning.call_count + assert re.match( + r"^GET http://localhost:9200/\?param=42 \[status:500 request:0.[0-9]{3}s\]", + logger.warning.call_args[0][0] % logger.warning.call_args[0][1:], ) @patch("elasticsearch.connection.base.tracer") @patch("elasticsearch.connection.base.logger") def test_success_logs_and_traces(self, logger, tracer): - con = self._get_mock_connection(response_body=b"""{"answer": "that's it!"}""") - status, headers, data = con.perform_request( + conn = self.get_mock_requests_connection( + response_body=b"""{"answer": "that's it!"}""" + ) + status, headers, data = conn.perform_request( "GET", "/", {"param": 42}, """{"question": "what's that?"}""".encode("utf-8"), ) - - # trace request - self.assertEqual(1, tracer.info.call_count) - self.assertEqual( - """curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/?pretty¶m=42' -d '{\n "question": "what\\u0027s that?"\n}'""", - tracer.info.call_args[0][0] % tracer.info.call_args[0][1:], + assert 1 == tracer.info.call_count + assert ( + """curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/?pretty¶m=42' -d '{\n "question": "what\\u0027s that?"\n}'""" + == tracer.info.call_args[0][0] % tracer.info.call_args[0][1:] ) - # trace response - self.assertEqual(1, tracer.debug.call_count) - self.assertTrue( - re.match( - r'#\[200\] \(0.[0-9]{3}s\)\n#{\n# "answer": "that\\u0027s it!"\n#}', - tracer.debug.call_args[0][0] % tracer.debug.call_args[0][1:], - ) + assert 1 == tracer.debug.call_count + assert re.match( + r'#\[200\] \(0.[0-9]{3}s\)\n#{\n# "answer": "that\\u0027s it!"\n#}', + tracer.debug.call_args[0][0] % tracer.debug.call_args[0][1:], ) - - # log url and duration - self.assertEqual(1, logger.info.call_count) - self.assertTrue( - re.match( - r"GET http://localhost:9200/\?param=42 \[status:200 request:0.[0-9]{3}s\]", - logger.info.call_args[0][0] % logger.info.call_args[0][1:], - ) + assert 1 == logger.info.call_count + assert re.match( + r"GET http://localhost:9200/\?param=42 \[status:200 request:0.[0-9]{3}s\]", + logger.info.call_args[0][0] % logger.info.call_args[0][1:], ) - # log request body and response - self.assertEqual(2, logger.debug.call_count) + assert 2 == logger.debug.call_count req, resp = logger.debug.call_args_list - self.assertEqual('> {"question": "what\'s that?"}', req[0][0] % req[0][1:]) - self.assertEqual('< {"answer": "that\'s it!"}', resp[0][0] % resp[0][1:]) + assert '> {"question": "what\'s that?"}' == req[0][0] % req[0][1:] + assert '< {"answer": "that\'s it!"}' == resp[0][0] % resp[0][1:] @patch("elasticsearch.connection.base.logger") def test_uncompressed_body_logged(self, logger): - con = self._get_mock_connection(connection_params={"http_compress": True}) - con.perform_request("GET", "/", body=b'{"example": "body"}') - - self.assertEqual(2, logger.debug.call_count) + conn = self.get_mock_requests_connection( + connection_params={"http_compress": True} + ) + conn.perform_request("GET", "/", body=b'{"example": "body"}') + assert 2 == logger.debug.call_count req, resp = logger.debug.call_args_list - self.assertEqual('> {"example": "body"}', req[0][0] % req[0][1:]) - self.assertEqual("< {}", resp[0][0] % resp[0][1:]) - - con = self._get_mock_connection( + assert '> {"example": "body"}' == req[0][0] % req[0][1:] + assert "< {}" == resp[0][0] % resp[0][1:] + conn = self.get_mock_requests_connection( connection_params={"http_compress": True}, status_code=500, response_body=b'{"hello":"world"}', ) with pytest.raises(TransportError): - con.perform_request("GET", "/", body=b'{"example": "body2"}') + conn.perform_request("GET", "/", body=b'{"example": "body2"}') - self.assertEqual(4, logger.debug.call_count) + assert 4 == logger.debug.call_count _, _, req, resp = logger.debug.call_args_list - self.assertEqual('> {"example": "body2"}', req[0][0] % req[0][1:]) - self.assertEqual('< {"hello":"world"}', resp[0][0] % resp[0][1:]) + assert '> {"example": "body2"}' == req[0][0] % req[0][1:] + assert '< {"hello":"world"}' == resp[0][0] % resp[0][1:] def test_defaults(self): - con = self._get_mock_connection() - request = self._get_request(con, "GET", "/") - - self.assertEqual("http://localhost:9200/", request.url) - self.assertEqual("GET", request.method) - self.assertEqual(None, request.body) + conn = self.get_mock_requests_connection() + request = self._get_request(conn, "GET", "/") + assert "http://localhost:9200/" == request.url + assert "GET" == request.method + assert request.body is None def test_params_properly_encoded(self): - con = self._get_mock_connection() + conn = self.get_mock_requests_connection() request = self._get_request( - con, "GET", "/", params={"param": "value with spaces"} + conn, "GET", "/", params={"param": "value with spaces"} ) - - self.assertEqual("http://localhost:9200/?param=value+with+spaces", request.url) - self.assertEqual("GET", request.method) - self.assertEqual(None, request.body) + assert "http://localhost:9200/?param=value+with+spaces" == request.url + assert "GET" == request.method + assert request.body is None def test_body_attached(self): - con = self._get_mock_connection() - request = self._get_request(con, "GET", "/", body='{"answer": 42}') - - self.assertEqual("http://localhost:9200/", request.url) - self.assertEqual("GET", request.method) - self.assertEqual('{"answer": 42}'.encode("utf-8"), request.body) + conn = self.get_mock_requests_connection() + request = self._get_request(conn, "GET", "/", body='{"answer": 42}') + assert "http://localhost:9200/" == request.url + assert "GET" == request.method + assert '{"answer": 42}'.encode("utf-8") == request.body def test_http_auth_attached(self): - con = self._get_mock_connection({"http_auth": "username:secret"}) - request = self._get_request(con, "GET", "/") - - self.assertEqual(request.headers["authorization"], "Basic dXNlcm5hbWU6c2VjcmV0") + conn = self.get_mock_requests_connection({"http_auth": "username:secret"}) + request = self._get_request(conn, "GET", "/") + assert request.headers["authorization"] == "Basic dXNlcm5hbWU6c2VjcmV0" @patch("elasticsearch.connection.base.tracer") def test_url_prefix(self, tracer): - con = self._get_mock_connection({"url_prefix": "/some-prefix/"}) + conn = self.get_mock_requests_connection({"url_prefix": "/some-prefix/"}) request = self._get_request( - con, "GET", "/_search", body='{"answer": 42}', timeout=0.1 + conn, "GET", "/_search", body='{"answer": 42}', timeout=0.1 ) - - self.assertEqual("http://localhost:9200/some-prefix/_search", request.url) - self.assertEqual("GET", request.method) - self.assertEqual('{"answer": 42}'.encode("utf-8"), request.body) - - # trace request - self.assertEqual(1, tracer.info.call_count) - self.assertEqual( - "curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/_search?pretty' -d '{\n \"answer\": 42\n}'", - tracer.info.call_args[0][0] % tracer.info.call_args[0][1:], + assert "http://localhost:9200/some-prefix/_search" == request.url + assert "GET" == request.method + assert '{"answer": 42}'.encode("utf-8") == request.body + assert 1 == tracer.info.call_count + assert ( + "curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/_search?pretty' -d '{\n \"answer\": 42\n}'" + == tracer.info.call_args[0][0] % tracer.info.call_args[0][1:] ) def test_surrogatepass_into_bytes(self): buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa" - con = self._get_mock_connection(response_body=buf) - status, headers, data = con.perform_request("GET", "/") - self.assertEqual(u"你好\uda6a", data) + conn = self.get_mock_requests_connection(response_body=buf) + status, headers, data = conn.perform_request("GET", "/") + assert u"你好\uda6a" == data @pytest.mark.skipif( not reraise_exceptions, reason="RecursionError isn't defined in Python <3.5" diff --git a/test_elasticsearch/test_connection_pool.py b/test_elasticsearch/test_connection_pool.py index adb11bea3..35ea1bc39 100644 --- a/test_elasticsearch/test_connection_pool.py +++ b/test_elasticsearch/test_connection_pool.py @@ -17,6 +17,8 @@ import time +import pytest + from elasticsearch.connection import Connection from elasticsearch.connection_pool import ( ConnectionPool, @@ -25,18 +27,17 @@ ) from elasticsearch.exceptions import ImproperlyConfigured -from .test_cases import TestCase - -class TestConnectionPool(TestCase): +class TestConnectionPool: def test_dummy_cp_raises_exception_on_more_connections(self): - self.assertRaises(ImproperlyConfigured, DummyConnectionPool, []) - self.assertRaises( - ImproperlyConfigured, DummyConnectionPool, [object(), object()] - ) + with pytest.raises(ImproperlyConfigured): + DummyConnectionPool([]) + with pytest.raises(ImproperlyConfigured): + DummyConnectionPool([object(), object()]) def test_raises_exception_when_no_connections_defined(self): - self.assertRaises(ImproperlyConfigured, ConnectionPool, []) + with pytest.raises(ImproperlyConfigured): + ConnectionPool([]) def test_default_round_robin(self): pool = ConnectionPool([(x, {}) for x in range(100)]) @@ -44,7 +45,7 @@ def test_default_round_robin(self): connections = set() for _ in range(100): connections.add(pool.get_connection()) - self.assertEqual(connections, set(range(100))) + assert connections == set(range(100)) def test_disable_shuffling(self): pool = ConnectionPool([(x, {}) for x in range(100)], randomize_hosts=False) @@ -52,7 +53,7 @@ def test_disable_shuffling(self): connections = [] for _ in range(100): connections.append(pool.get_connection()) - self.assertEqual(connections, list(range(100))) + assert connections == list(range(100)) def test_selectors_have_access_to_connection_opts(self): class MySelector(RoundRobinSelector): @@ -70,25 +71,26 @@ def select(self, connections): connections = [] for _ in range(100): connections.append(pool.get_connection()) - self.assertEqual(connections, [x * x for x in range(100)]) + assert connections == [x * x for x in range(100)] def test_dead_nodes_are_removed_from_active_connections(self): pool = ConnectionPool([(x, {}) for x in range(100)]) now = time.time() pool.mark_dead(42, now=now) - self.assertEqual(99, len(pool.connections)) - self.assertEqual(1, pool.dead.qsize()) - self.assertEqual((now + 60, 42), pool.dead.get()) + assert 99 == len(pool.connections) + assert 1 == pool.dead.qsize() + assert (now + 60, 42) == pool.dead.get() def test_connection_is_skipped_when_dead(self): pool = ConnectionPool([(x, {}) for x in range(2)]) pool.mark_dead(0) - self.assertEqual( - [1, 1, 1], - [pool.get_connection(), pool.get_connection(), pool.get_connection()], - ) + assert [1, 1, 1] == [ + pool.get_connection(), + pool.get_connection(), + pool.get_connection(), + ] def test_new_connection_is_not_marked_dead(self): # Create 10 connections @@ -99,7 +101,7 @@ def test_new_connection_is_not_marked_dead(self): pool.mark_dead(new_connection) # Nothing should be marked dead - self.assertEqual(0, len(pool.dead_count)) + assert 0 == len(pool.dead_count) def test_connection_is_forcibly_resurrected_when_no_live_ones_are_availible(self): pool = ConnectionPool([(x, {}) for x in range(2)]) @@ -107,9 +109,9 @@ def test_connection_is_forcibly_resurrected_when_no_live_ones_are_availible(self pool.mark_dead(0) # failed twice, longer timeout pool.mark_dead(1) # failed the first time, first to be resurrected - self.assertEqual([], pool.connections) - self.assertEqual(1, pool.get_connection()) - self.assertEqual([1], pool.connections) + assert [] == pool.connections + assert 1 == pool.get_connection() + assert [1] == pool.connections def test_connection_is_resurrected_after_its_timeout(self): pool = ConnectionPool([(x, {}) for x in range(100)]) @@ -117,16 +119,16 @@ def test_connection_is_resurrected_after_its_timeout(self): now = time.time() pool.mark_dead(42, now=now - 61) pool.get_connection() - self.assertEqual(42, pool.connections[-1]) - self.assertEqual(100, len(pool.connections)) + assert 42 == pool.connections[-1] + assert 100 == len(pool.connections) def test_force_resurrect_always_returns_a_connection(self): pool = ConnectionPool([(0, {})]) pool.connections = [] - self.assertEqual(0, pool.get_connection()) - self.assertEqual([], pool.connections) - self.assertTrue(pool.dead.empty()) + assert 0 == pool.get_connection() + assert [] == pool.connections + assert pool.dead.empty() def test_already_failed_connection_has_longer_timeout(self): pool = ConnectionPool([(x, {}) for x in range(100)]) @@ -134,8 +136,8 @@ def test_already_failed_connection_has_longer_timeout(self): pool.dead_count[42] = 2 pool.mark_dead(42, now=now) - self.assertEqual(3, pool.dead_count[42]) - self.assertEqual((now + 4 * 60, 42), pool.dead.get()) + assert 3 == pool.dead_count[42] + assert (now + 4 * 60, 42) == pool.dead.get() def test_timeout_for_failed_connections_is_limitted(self): pool = ConnectionPool([(x, {}) for x in range(100)]) @@ -143,8 +145,8 @@ def test_timeout_for_failed_connections_is_limitted(self): pool.dead_count[42] = 245 pool.mark_dead(42, now=now) - self.assertEqual(246, pool.dead_count[42]) - self.assertEqual((now + 32 * 60, 42), pool.dead.get()) + assert 246 == pool.dead_count[42] + assert (now + 32 * 60, 42) == pool.dead.get() def test_dead_count_is_wiped_clean_for_connection_if_marked_live(self): pool = ConnectionPool([(x, {}) for x in range(100)]) @@ -152,6 +154,6 @@ def test_dead_count_is_wiped_clean_for_connection_if_marked_live(self): pool.dead_count[42] = 2 pool.mark_dead(42, now=now) - self.assertEqual(3, pool.dead_count[42]) + assert 3 == pool.dead_count[42] pool.mark_live(42) - self.assertNotIn(42, pool.dead_count) + assert 42 not in pool.dead_count diff --git a/test_elasticsearch/test_exceptions.py b/test_elasticsearch/test_exceptions.py index 98b466b7a..6e1875b53 100644 --- a/test_elasticsearch/test_exceptions.py +++ b/test_elasticsearch/test_exceptions.py @@ -17,10 +17,8 @@ from elasticsearch.exceptions import TransportError -from .test_cases import TestCase - -class TestTransformError(TestCase): +class TestTransformError: def test_transform_error_parse_with_error_reason(self): e = TransportError( 500, @@ -28,16 +26,14 @@ def test_transform_error_parse_with_error_reason(self): {"error": {"root_cause": [{"type": "error", "reason": "error reason"}]}}, ) - self.assertEqual( - str(e), "TransportError(500, 'InternalServerError', 'error reason')" - ) + assert str(e) == "TransportError(500, 'InternalServerError', 'error reason')" def test_transform_error_parse_with_error_string(self): e = TransportError( 500, "InternalServerError", {"error": "something error message"} ) - self.assertEqual( - str(e), - "TransportError(500, 'InternalServerError', 'something error message')", + assert ( + str(e) + == "TransportError(500, 'InternalServerError', 'something error message')" ) diff --git a/test_elasticsearch/test_helpers.py b/test_elasticsearch/test_helpers.py index e6aee4fe7..7c824a100 100644 --- a/test_elasticsearch/test_helpers.py +++ b/test_elasticsearch/test_helpers.py @@ -26,8 +26,6 @@ from elasticsearch.helpers import actions from elasticsearch.serializer import JSONSerializer -from .test_cases import TestCase - lock_side_effect = threading.Lock() @@ -46,7 +44,7 @@ def mock_process_bulk_chunk(*args, **kwargs): mock_process_bulk_chunk.call_count = 0 -class TestParallelBulk(TestCase): +class TestParallelBulk: @mock.patch( "elasticsearch.helpers.actions._process_bulk_chunk", side_effect=mock_process_bulk_chunk, @@ -55,7 +53,7 @@ def test_all_chunks_sent(self, _process_bulk_chunk): actions = ({"x": i} for i in range(100)) list(helpers.parallel_bulk(Elasticsearch(), actions, chunk_size=2)) - self.assertEqual(50, mock_process_bulk_chunk.call_count) + assert 50 == mock_process_bulk_chunk.call_count @pytest.mark.skip @mock.patch( @@ -72,39 +70,28 @@ def test_chunk_sent_from_different_threads(self, _process_bulk_chunk): Elasticsearch(), actions, thread_count=10, chunk_size=2 ) ) - self.assertTrue(len(set([r[1] for r in results])) > 1) + assert len(set([r[1] for r in results])) > 1 -class TestChunkActions(TestCase): +class TestChunkActions: def setup_method(self, _): self.actions = [({"index": {}}, {"some": u"datá", "i": i}) for i in range(100)] def test_expand_action(self): - self.assertEqual(helpers.expand_action({}), ({"index": {}}, {})) - self.assertEqual( - helpers.expand_action({"key": "val"}), ({"index": {}}, {"key": "val"}) - ) + assert helpers.expand_action({}) == ({"index": {}}, {}) + assert helpers.expand_action({"key": "val"}) == ({"index": {}}, {"key": "val"}) def test_expand_action_actions(self): - self.assertEqual( - helpers.expand_action( - {"_op_type": "delete", "_id": "id", "_index": "index"} - ), - ({"delete": {"_id": "id", "_index": "index"}}, None), - ) - self.assertEqual( - helpers.expand_action( - {"_op_type": "update", "_id": "id", "_index": "index", "key": "val"} - ), - ({"update": {"_id": "id", "_index": "index"}}, {"key": "val"}), - ) - self.assertEqual( - helpers.expand_action( - {"_op_type": "create", "_id": "id", "_index": "index", "key": "val"} - ), - ({"create": {"_id": "id", "_index": "index"}}, {"key": "val"}), - ) - self.assertEqual( + assert helpers.expand_action( + {"_op_type": "delete", "_id": "id", "_index": "index"} + ) == ({"delete": {"_id": "id", "_index": "index"}}, None) + assert helpers.expand_action( + {"_op_type": "update", "_id": "id", "_index": "index", "key": "val"} + ) == ({"update": {"_id": "id", "_index": "index"}}, {"key": "val"}) + assert helpers.expand_action( + {"_op_type": "create", "_id": "id", "_index": "index", "key": "val"} + ) == ({"create": {"_id": "id", "_index": "index"}}, {"key": "val"}) + assert ( helpers.expand_action( { "_op_type": "create", @@ -112,8 +99,8 @@ def test_expand_action_actions(self): "_index": "index", "_source": {"key": "val"}, } - ), - ({"create": {"_id": "id", "_index": "index"}}, {"key": "val"}), + ) + == ({"create": {"_id": "id", "_index": "index"}}, {"key": "val"}) ) def test_expand_action_options(self): @@ -143,55 +130,38 @@ def test_expand_action_options(self): action_option = option else: option, action_option = option - self.assertEqual( - helpers.expand_action({"key": "val", option: 0}), - ({"index": {action_option: 0}}, {"key": "val"}), + assert helpers.expand_action({"key": "val", option: 0}) == ( + {"index": {action_option: 0}}, + {"key": "val"}, ) def test__source_metadata_or_source(self): - self.assertEqual( - helpers.expand_action({"_source": {"key": "val"}}), - ({"index": {}}, {"key": "val"}), + assert helpers.expand_action({"_source": {"key": "val"}}) == ( + {"index": {}}, + {"key": "val"}, ) - self.assertEqual( - helpers.expand_action( - {"_source": ["key"], "key": "val", "_op_type": "update"} - ), - ({"update": {"_source": ["key"]}}, {"key": "val"}), - ) + assert helpers.expand_action( + {"_source": ["key"], "key": "val", "_op_type": "update"} + ) == ({"update": {"_source": ["key"]}}, {"key": "val"}) - self.assertEqual( - helpers.expand_action( - {"_source": True, "key": "val", "_op_type": "update"} - ), - ({"update": {"_source": True}}, {"key": "val"}), - ) + assert helpers.expand_action( + {"_source": True, "key": "val", "_op_type": "update"} + ) == ({"update": {"_source": True}}, {"key": "val"}) # This case is only to ensure backwards compatibility with old functionality. - self.assertEqual( - helpers.expand_action( - {"_source": {"key2": "val2"}, "key": "val", "_op_type": "update"} - ), - ({"update": {}}, {"key2": "val2"}), - ) + assert helpers.expand_action( + {"_source": {"key2": "val2"}, "key": "val", "_op_type": "update"} + ) == ({"update": {}}, {"key2": "val2"}) def test_chunks_are_chopped_by_byte_size(self): - self.assertEqual( - 100, - len( - list(helpers._chunk_actions(self.actions, 100000, 1, JSONSerializer())) - ), + assert 100 == len( + list(helpers._chunk_actions(self.actions, 100000, 1, JSONSerializer())) ) def test_chunks_are_chopped_by_chunk_size(self): - self.assertEqual( - 10, - len( - list( - helpers._chunk_actions(self.actions, 10, 99999999, JSONSerializer()) - ) - ), + assert 10 == len( + list(helpers._chunk_actions(self.actions, 10, 99999999, JSONSerializer())) ) def test_chunks_are_chopped_by_byte_size_properly(self): @@ -201,29 +171,24 @@ def test_chunks_are_chopped_by_byte_size_properly(self): self.actions, 100000, max_byte_size, JSONSerializer() ) ) - self.assertEqual(25, len(chunks)) + assert 25 == len(chunks) for chunk_data, chunk_actions in chunks: chunk = u"".join(chunk_actions) chunk = chunk if isinstance(chunk, str) else chunk.encode("utf-8") - self.assertLessEqual(len(chunk), max_byte_size) + assert len(chunk) <= max_byte_size def test_add_helper_meta_to_kwargs(self): - self.assertEqual( - actions._add_helper_meta_to_kwargs({}, "b"), - {"params": {"__elastic_client_meta": (("h", "b"),)}}, - ) - self.assertEqual( - actions._add_helper_meta_to_kwargs({"params": {}}, "b"), - {"params": {"__elastic_client_meta": (("h", "b"),)}}, - ) - self.assertEqual( - actions._add_helper_meta_to_kwargs({"params": {"key": "value"}}, "b"), - {"params": {"__elastic_client_meta": (("h", "b"),), "key": "value"}}, - ) - - -class TestExpandActions(TestCase): + assert actions._add_helper_meta_to_kwargs({}, "b") == { + "params": {"__elastic_client_meta": (("h", "b"),)} + } + assert actions._add_helper_meta_to_kwargs({"params": {}}, "b") == { + "params": {"__elastic_client_meta": (("h", "b"),)} + } + assert actions._add_helper_meta_to_kwargs( + {"params": {"key": "value"}}, "b" + ) == {"params": {"__elastic_client_meta": (("h", "b"),), "key": "value"}} + + +class TestExpandActions: def test_string_actions_are_marked_as_simple_inserts(self): - self.assertEqual( - ('{"index":{}}', "whatever"), helpers.expand_action("whatever") - ) + assert ('{"index":{}}', "whatever") == helpers.expand_action("whatever") diff --git a/test_elasticsearch/test_serializer.py b/test_elasticsearch/test_serializer.py index 5f9002d52..4bf229b1f 100644 --- a/test_elasticsearch/test_serializer.py +++ b/test_elasticsearch/test_serializer.py @@ -21,12 +21,16 @@ from datetime import datetime from decimal import Decimal +import pytest + try: import numpy as np import pandas as pd except ImportError: np = pd = None +import re + from elasticsearch.exceptions import ImproperlyConfigured, SerializationError from elasticsearch.serializer import ( DEFAULT_SERIALIZERS, @@ -35,186 +39,161 @@ TextSerializer, ) -from .test_cases import SkipTest, TestCase +requires_numpy_and_pandas = pytest.mark.skipif( + np is None or pd is None, reason="Test requires numpy or pandas to be available" +) -def requires_numpy_and_pandas(): - if np is None or pd is None: - raise SkipTest("Test requires numpy or pandas to be available") +def test_datetime_serialization(): + assert '{"d":"2010-10-01T02:30:00"}' == JSONSerializer().dumps( + {"d": datetime(2010, 10, 1, 2, 30)} + ) -class TestJSONSerializer(TestCase): - def test_datetime_serialization(self): - self.assertEqual( - '{"d":"2010-10-01T02:30:00"}', - JSONSerializer().dumps({"d": datetime(2010, 10, 1, 2, 30)}), - ) +def test_decimal_serialization(): + requires_numpy_and_pandas() - def test_decimal_serialization(self): - requires_numpy_and_pandas() + if sys.version_info[:2] == (2, 6): + pytest.skip("Float rounding is broken in 2.6.") + assert '{"d":3.8}' == JSONSerializer().dumps({"d": Decimal("3.8")}) - if sys.version_info[:2] == (2, 6): - raise SkipTest("Float rounding is broken in 2.6.") - self.assertEqual('{"d":3.8}', JSONSerializer().dumps({"d": Decimal("3.8")})) - def test_uuid_serialization(self): - self.assertEqual( - '{"d":"00000000-0000-0000-0000-000000000003"}', - JSONSerializer().dumps( - {"d": uuid.UUID("00000000-0000-0000-0000-000000000003")} - ), - ) +def test_uuid_serialization(): + assert '{"d":"00000000-0000-0000-0000-000000000003"}' == JSONSerializer().dumps( + {"d": uuid.UUID("00000000-0000-0000-0000-000000000003")} + ) - def test_serializes_numpy_bool(self): - requires_numpy_and_pandas() - self.assertEqual('{"d":true}', JSONSerializer().dumps({"d": np.bool_(True)})) +@requires_numpy_and_pandas +def test_serializes_numpy_bool(): + assert '{"d":true}' == JSONSerializer().dumps({"d": np.bool_(True)}) - def test_serializes_numpy_integers(self): - requires_numpy_and_pandas() - ser = JSONSerializer() - for np_type in ( - np.int_, - np.int8, - np.int16, - np.int32, - np.int64, - ): - self.assertEqual(ser.dumps({"d": np_type(-1)}), '{"d":-1}') +@requires_numpy_and_pandas +def test_serializes_numpy_integers(): + ser = JSONSerializer() + for np_type in ( + np.int_, + np.int8, + np.int16, + np.int32, + np.int64, + ): + assert ser.dumps({"d": np_type(-1)}) == '{"d":-1}' - for np_type in ( - np.uint8, - np.uint16, - np.uint32, - np.uint64, - ): - self.assertEqual(ser.dumps({"d": np_type(1)}), '{"d":1}') + for np_type in ( + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ): + assert ser.dumps({"d": np_type(1)}) == '{"d":1}' - def test_serializes_numpy_floats(self): - requires_numpy_and_pandas() - ser = JSONSerializer() - for np_type in ( - np.float_, - np.float32, - np.float64, - ): - self.assertRegexpMatches( - ser.dumps({"d": np_type(1.2)}), r'^\{"d":1\.2[\d]*}$' - ) +@requires_numpy_and_pandas +def test_serializes_numpy_floats(): + ser = JSONSerializer() + for np_type in ( + np.float_, + np.float32, + np.float64, + ): + assert re.search(r'^\{"d":1\.2[\d]*}$', ser.dumps({"d": np_type(1.2)})) - def test_serializes_numpy_datetime(self): - requires_numpy_and_pandas() - self.assertEqual( - '{"d":"2010-10-01T02:30:00"}', - JSONSerializer().dumps({"d": np.datetime64("2010-10-01T02:30:00")}), - ) +@requires_numpy_and_pandas +def test_serializes_numpy_datetime(): + assert '{"d":"2010-10-01T02:30:00"}' == JSONSerializer().dumps( + {"d": np.datetime64("2010-10-01T02:30:00")} + ) - def test_serializes_numpy_ndarray(self): - requires_numpy_and_pandas() - self.assertEqual( - '{"d":[0,0,0,0,0]}', - JSONSerializer().dumps({"d": np.zeros((5,), dtype=np.uint8)}), - ) - # This isn't useful for Elasticsearch, just want to make sure it works. - self.assertEqual( - '{"d":[[0,0],[0,0]]}', - JSONSerializer().dumps({"d": np.zeros((2, 2), dtype=np.uint8)}), - ) +@requires_numpy_and_pandas +def test_serializes_numpy_ndarray(): + assert '{"d":[0,0,0,0,0]}' == JSONSerializer().dumps( + {"d": np.zeros((5,), dtype=np.uint8)} + ) + # This isn't useful for Elasticsearch, just want to make sure it works. + assert '{"d":[[0,0],[0,0]]}' == JSONSerializer().dumps( + {"d": np.zeros((2, 2), dtype=np.uint8)} + ) - def test_serializes_numpy_nan_to_nan(self): - requires_numpy_and_pandas() - self.assertEqual( - '{"d":NaN}', - JSONSerializer().dumps({"d": np.nan}), - ) +@requires_numpy_and_pandas +def test_serializes_numpy_nan_to_nan(): + assert '{"d":NaN}' == JSONSerializer().dumps({"d": np.nan}) - def test_serializes_pandas_timestamp(self): - requires_numpy_and_pandas() - self.assertEqual( - '{"d":"2010-10-01T02:30:00"}', - JSONSerializer().dumps({"d": pd.Timestamp("2010-10-01T02:30:00")}), - ) +@requires_numpy_and_pandas +def test_serializes_pandas_timestamp(): + assert '{"d":"2010-10-01T02:30:00"}' == JSONSerializer().dumps( + {"d": pd.Timestamp("2010-10-01T02:30:00")} + ) - def test_serializes_pandas_series(self): - requires_numpy_and_pandas() - self.assertEqual( - '{"d":["a","b","c","d"]}', - JSONSerializer().dumps({"d": pd.Series(["a", "b", "c", "d"])}), - ) +@requires_numpy_and_pandas +def test_serializes_pandas_series(): + assert '{"d":["a","b","c","d"]}' == JSONSerializer().dumps( + {"d": pd.Series(["a", "b", "c", "d"])} + ) - def test_serializes_pandas_na(self): - requires_numpy_and_pandas() - if not hasattr(pd, "NA"): # pandas.NA added in v1 - raise SkipTest("pandas.NA required") - self.assertEqual( - '{"d":null}', - JSONSerializer().dumps({"d": pd.NA}), - ) +@requires_numpy_and_pandas +@pytest.mark.skipif(not hasattr(pd, "NA"), reason="pandas.NA is required") +def test_serializes_pandas_na(): + assert '{"d":null}' == JSONSerializer().dumps({"d": pd.NA}) - def test_raises_serialization_error_pandas_nat(self): - requires_numpy_and_pandas() - if not hasattr(pd, "NaT"): - raise SkipTest("pandas.NaT required") - self.assertRaises(SerializationError, JSONSerializer().dumps, {"d": pd.NaT}) +@requires_numpy_and_pandas +@pytest.mark.skipif(not hasattr(pd, "NaT"), reason="pandas.NaT required") +def test_raises_serialization_error_pandas_nat(): + with pytest.raises(SerializationError): + JSONSerializer().dumps({"d": pd.NaT}) - def test_serializes_pandas_category(self): - requires_numpy_and_pandas() - cat = pd.Categorical(["a", "c", "b", "a"], categories=["a", "b", "c"]) - self.assertEqual( - '{"d":["a","c","b","a"]}', - JSONSerializer().dumps({"d": cat}), - ) +@requires_numpy_and_pandas +def test_serializes_pandas_category(): + cat = pd.Categorical(["a", "c", "b", "a"], categories=["a", "b", "c"]) + assert '{"d":["a","c","b","a"]}' == JSONSerializer().dumps({"d": cat}) + + cat = pd.Categorical([1, 2, 3], categories=[1, 2, 3]) + assert '{"d":[1,2,3]}' == JSONSerializer().dumps({"d": cat}) - cat = pd.Categorical([1, 2, 3], categories=[1, 2, 3]) - self.assertEqual( - '{"d":[1,2,3]}', - JSONSerializer().dumps({"d": cat}), - ) - def test_raises_serialization_error_on_dump_error(self): - self.assertRaises(SerializationError, JSONSerializer().dumps, object()) +def test_json_raises_serialization_error_on_dump_error(): + with pytest.raises(SerializationError): + JSONSerializer().dumps(object()) - def test_raises_serialization_error_on_load_error(self): - self.assertRaises(SerializationError, JSONSerializer().loads, object()) - self.assertRaises(SerializationError, JSONSerializer().loads, "") - self.assertRaises(SerializationError, JSONSerializer().loads, "{{") - def test_strings_are_left_untouched(self): - self.assertEqual("你好", JSONSerializer().dumps("你好")) +def test_raises_serialization_error_on_load_error(): + with pytest.raises(SerializationError): + JSONSerializer().loads(object()) + with pytest.raises(SerializationError): + JSONSerializer().loads("") + with pytest.raises(SerializationError): + JSONSerializer().loads("{{") -class TestTextSerializer(TestCase): - def test_strings_are_left_untouched(self): - self.assertEqual("你好", TextSerializer().dumps("你好")) +def test_strings_are_left_untouched(): + assert "你好" == TextSerializer().dumps("你好") - def test_raises_serialization_error_on_dump_error(self): - self.assertRaises(SerializationError, TextSerializer().dumps, {}) +def test_text_raises_serialization_error_on_dump_error(): + with pytest.raises(SerializationError): + TextSerializer().dumps({}) -class TestDeserializer(TestCase): + +class TestDeserializer: def setup_method(self, _): self.de = Deserializer(DEFAULT_SERIALIZERS) def test_deserializes_json_by_default(self): - self.assertEqual({"some": "data"}, self.de.loads('{"some":"data"}')) + assert {"some": "data"} == self.de.loads('{"some":"data"}') def test_deserializes_text_with_correct_ct(self): - self.assertEqual( - '{"some":"data"}', self.de.loads('{"some":"data"}', "text/plain") - ) - self.assertEqual( - '{"some":"data"}', - self.de.loads('{"some":"data"}', "text/plain; charset=whatever"), + assert '{"some":"data"}' == self.de.loads('{"some":"data"}', "text/plain") + assert '{"some":"data"}' == self.de.loads( + '{"some":"data"}', "text/plain; charset=whatever" ) def test_deserialize_compatibility_header(self): @@ -224,14 +203,14 @@ def test_deserialize_compatibility_header(self): "application/vnd.elasticsearch+json;compatible-with=8", "application/vnd.elasticsearch+json; compatible-with=8", ): - self.assertEqual( - {"some": "data"}, self.de.loads('{"some":"data"}', content_type) - ) + assert {"some": "data"} == self.de.loads('{"some":"data"}', content_type) def test_raises_serialization_error_on_unknown_mimetype(self): - self.assertRaises(SerializationError, self.de.loads, "{}", "text/html") + with pytest.raises(SerializationError): + self.de.loads("{}", "text/html") def test_raises_improperly_configured_when_default_mimetype_cannot_be_deserialized( self, ): - self.assertRaises(ImproperlyConfigured, Deserializer, {}) + with pytest.raises(ImproperlyConfigured): + Deserializer({}) diff --git a/test_elasticsearch/test_server/__init__.py b/test_elasticsearch/test_server/__init__.py index c9ac5ad56..2a87d183f 100644 --- a/test_elasticsearch/test_server/__init__.py +++ b/test_elasticsearch/test_server/__init__.py @@ -14,46 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -from unittest import SkipTest - -from elasticsearch.helpers import test -from elasticsearch.helpers.test import ElasticsearchTestCase as BaseTestCase - -client = None - - -def get_client(**kwargs): - global client - if client is False: - raise SkipTest("No client is available") - if client is not None and not kwargs: - return client - - # try and locate manual override in the local environment - try: - from test_elasticsearch.local import get_client as local_get_client - - new_client = local_get_client(**kwargs) - except ImportError: - # fallback to using vanilla client - try: - new_client = test.get_test_client(**kwargs) - except SkipTest: - client = False - raise - - if not kwargs: - client = new_client - - return new_client - - -def setup_module(): - get_client() - - -class ElasticsearchTestCase(BaseTestCase): - @staticmethod - def _get_client(**kwargs): - return get_client(**kwargs) diff --git a/test_elasticsearch/test_server/conftest.py b/test_elasticsearch/test_server/conftest.py index 8f98ab449..cd95312f7 100644 --- a/test_elasticsearch/test_server/conftest.py +++ b/test_elasticsearch/test_server/conftest.py @@ -16,14 +16,12 @@ # under the License. import os -import time import pytest import elasticsearch -from elasticsearch.helpers.test import CA_CERTS, ELASTICSEARCH_URL -from ..utils import wipe_cluster +from ..utils import CA_CERTS, wipe_cluster # Information about the Elasticsearch instance running, if any # Used for @@ -33,7 +31,7 @@ @pytest.fixture(scope="session") -def sync_client_factory(): +def sync_client_factory(elasticsearch_url): client = None try: # Configure the client with certificates and optionally @@ -53,20 +51,13 @@ def sync_client_factory(): # We do this little dance with the URL to force # Requests to respect 'headers: None' within rest API spec tests. client = elasticsearch.Elasticsearch( - ELASTICSEARCH_URL.replace("elastic:changeme@", ""), **kw + elasticsearch_url.replace("elastic:changeme@", ""), **kw ) - # Wait for the cluster to report a status of 'yellow' - for _ in range(100): - try: - client.cluster.health(wait_for_status="yellow") - break - except ConnectionError: - time.sleep(0.1) - else: - pytest.skip("Elasticsearch wasn't running at %r" % (ELASTICSEARCH_URL,)) - + # Wipe the cluster before we start testing just in case it wasn't wiped + # cleanly from the previous run of pytest? wipe_cluster(client) + yield client finally: if client: @@ -78,4 +69,5 @@ def sync_client(sync_client_factory): try: yield sync_client_factory finally: + # Wipe the cluster clean after every test execution. wipe_cluster(sync_client_factory) diff --git a/test_elasticsearch/test_server/test_clients.py b/test_elasticsearch/test_server/test_clients.py index 56f24b73f..facafe752 100644 --- a/test_elasticsearch/test_server/test_clients.py +++ b/test_elasticsearch/test_server/test_clients.py @@ -18,25 +18,49 @@ from __future__ import unicode_literals -from . import ElasticsearchTestCase +def test_indices_analyze_unicode(sync_client): + resp = sync_client.indices.analyze(body='{"text": "привет"}') + assert resp == { + "tokens": [ + { + "end_offset": 6, + "position": 0, + "start_offset": 0, + "token": "привет", + "type": "", + } + ] + } -class TestUnicode(ElasticsearchTestCase): - def test_indices_analyze(self): - self.client.indices.analyze(body='{"text": "привет"}') +def test_bulk_works_with_string_body(sync_client): + docs = '{ "index" : { "_index" : "bulk_test_index", "_id" : "1" } }\n{"answer": 42}' + resp = sync_client.bulk(body=docs) -class TestBulk(ElasticsearchTestCase): - def test_bulk_works_with_string_body(self): - docs = '{ "index" : { "_index" : "bulk_test_index", "_id" : "1" } }\n{"answer": 42}' - response = self.client.bulk(body=docs) + assert resp["errors"] is False + assert 1 == len(resp["items"]) - self.assertFalse(response["errors"]) - self.assertEqual(1, len(response["items"])) - def test_bulk_works_with_bytestring_body(self): - docs = b'{ "index" : { "_index" : "bulk_test_index", "_id" : "2" } }\n{"answer": 42}' - response = self.client.bulk(body=docs) +def test_bulk_works_with_bytestring_body(sync_client): + docs = ( + b'{ "index" : { "_index" : "bulk_test_index", "_id" : "2" } }\n{"answer": 42}' + ) + resp = sync_client.bulk(body=docs) - self.assertFalse(response["errors"]) - self.assertEqual(1, len(response["items"])) + assert resp["errors"] is False + assert 1 == len(resp["items"]) + + # Pop inconsistent items before asserting + resp["items"][0]["index"].pop("_id") + resp["items"][0]["index"].pop("_version") + assert resp["items"][0] == { + "index": { + "_index": "bulk_test_index", + "result": "created", + "_shards": {"total": 2, "successful": 1, "failed": 0}, + "_seq_no": 0, + "_primary_term": 1, + "status": 201, + } + } diff --git a/test_elasticsearch/test_server/test_helpers.py b/test_elasticsearch/test_server/test_helpers.py index 6536fd5b1..141eec2ed 100644 --- a/test_elasticsearch/test_server/test_helpers.py +++ b/test_elasticsearch/test_server/test_helpers.py @@ -24,9 +24,6 @@ from elasticsearch import TransportError, helpers from elasticsearch.helpers import ScanError -from ..test_cases import SkipTest -from . import ElasticsearchTestCase - class FailingBulkClient(object): def __init__( @@ -45,719 +42,776 @@ def bulk(self, *args, **kwargs): return self.client.bulk(*args, **kwargs) -class TestStreamingBulk(ElasticsearchTestCase): - def test_actions_remain_unchanged(self): - actions = [{"_id": 1}, {"_id": 2}] - for ok, item in helpers.streaming_bulk( - self.client, actions, index="test-index" - ): - self.assertTrue(ok) - self.assertEqual([{"_id": 1}, {"_id": 2}], actions) +def test_bulk_actions_remain_unchanged(sync_client): + actions = [{"_id": 1}, {"_id": 2}] + for ok, item in helpers.streaming_bulk(sync_client, actions, index="test-index"): + assert ok + assert [{"_id": 1}, {"_id": 2}] == actions + + +def test_bulk_all_documents_get_inserted(sync_client): + docs = [{"answer": x, "_id": x} for x in range(100)] + for ok, item in helpers.streaming_bulk( + sync_client, docs, index="test-index", refresh=True + ): + assert ok + + assert 100 == sync_client.count(index="test-index")["count"] + assert {"answer": 42} == sync_client.get(index="test-index", id=42)["_source"] + + +def test_bulk_all_errors_from_chunk_are_raised_on_failure(sync_client): + sync_client.indices.create( + "i", + { + "mappings": {"properties": {"a": {"type": "integer"}}}, + "settings": {"number_of_shards": 1, "number_of_replicas": 0}, + }, + ) + sync_client.cluster.health(wait_for_status="yellow") - def test_all_documents_get_inserted(self): - docs = [{"answer": x, "_id": x} for x in range(100)] + try: for ok, item in helpers.streaming_bulk( - self.client, docs, index="test-index", refresh=True + sync_client, [{"a": "b"}, {"a": "c"}], index="i", raise_on_error=True ): - self.assertTrue(ok) + assert ok + except helpers.BulkIndexError as e: + assert 2 == len(e.errors) + else: + assert False, "exception should have been raised" + + +def test_bulk_different_op_types(sync_client): + sync_client.index(index="i", id=45, body={}) + sync_client.index(index="i", id=42, body={}) + docs = [ + {"_index": "i", "_id": 47, "f": "v"}, + {"_op_type": "delete", "_index": "i", "_id": 45}, + {"_op_type": "update", "_index": "i", "_id": 42, "doc": {"answer": 42}}, + ] + for ok, item in helpers.streaming_bulk(sync_client, docs): + assert ok - self.assertEqual(100, self.client.count(index="test-index")["count"]) - self.assertEqual( - {"answer": 42}, self.client.get(index="test-index", id=42)["_source"] - ) + assert not sync_client.exists(index="i", id=45) + assert {"answer": 42} == sync_client.get(index="i", id=42)["_source"] + assert {"f": "v"} == sync_client.get(index="i", id=47)["_source"] - def test_all_errors_from_chunk_are_raised_on_failure(self): - self.client.indices.create( - "i", - { - "mappings": {"properties": {"a": {"type": "integer"}}}, - "settings": {"number_of_shards": 1, "number_of_replicas": 0}, - }, - ) - self.client.cluster.health(wait_for_status="yellow") - try: - for ok, item in helpers.streaming_bulk( - self.client, [{"a": "b"}, {"a": "c"}], index="i", raise_on_error=True - ): - self.assertTrue(ok) - except helpers.BulkIndexError as e: - self.assertEqual(2, len(e.errors)) - else: - assert False, "exception should have been raised" - - def test_different_op_types(self): - if self.es_version() < (0, 90, 1): - raise SkipTest("update supported since 0.90.1") - self.client.index(index="i", id=45, body={}) - self.client.index(index="i", id=42, body={}) - docs = [ - {"_index": "i", "_id": 47, "f": "v"}, - {"_op_type": "delete", "_index": "i", "_id": 45}, - {"_op_type": "update", "_index": "i", "_id": 42, "doc": {"answer": 42}}, - ] - for ok, item in helpers.streaming_bulk(self.client, docs): - self.assertTrue(ok) - - self.assertFalse(self.client.exists(index="i", id=45)) - self.assertEqual({"answer": 42}, self.client.get(index="i", id=42)["_source"]) - self.assertEqual({"f": "v"}, self.client.get(index="i", id=47)["_source"]) - - def test_transport_error_can_becaught(self): - failing_client = FailingBulkClient(self.client) - docs = [ - {"_index": "i", "_id": 47, "f": "v"}, - {"_index": "i", "_id": 45, "f": "v"}, - {"_index": "i", "_id": 42, "f": "v"}, - ] +def test_bulk_transport_error_can_becaught(sync_client): + failing_client = FailingBulkClient(sync_client) + docs = [ + {"_index": "i", "_id": 47, "f": "v"}, + {"_index": "i", "_id": 45, "f": "v"}, + {"_index": "i", "_id": 42, "f": "v"}, + ] - results = list( - helpers.streaming_bulk( - failing_client, - docs, - raise_on_exception=False, - raise_on_error=False, - chunk_size=1, - ) + results = list( + helpers.streaming_bulk( + failing_client, + docs, + raise_on_exception=False, + raise_on_error=False, + chunk_size=1, ) - self.assertEqual(3, len(results)) - self.assertEqual([True, False, True], [r[0] for r in results]) + ) + assert 3 == len(results) + assert [True, False, True] == [r[0] for r in results] + + exc = results[1][1]["index"].pop("exception") + assert isinstance(exc, TransportError) + assert 599 == exc.status_code + assert { + "index": { + "_index": "i", + "_id": 45, + "data": {"f": "v"}, + "error": "TransportError(599, 'Error!')", + "status": 599, + } + } == results[1][1] - exc = results[1][1]["index"].pop("exception") - self.assertIsInstance(exc, TransportError) - self.assertEqual(599, exc.status_code) - self.assertEqual( - { - "index": { - "_index": "i", - "_id": 45, - "data": {"f": "v"}, - "error": "TransportError(599, 'Error!')", - "status": 599, - } - }, - results[1][1], - ) - def test_rejected_documents_are_retried(self): - failing_client = FailingBulkClient( - self.client, fail_with=TransportError(429, "Rejected!", {}) - ) - docs = [ - {"_index": "i", "_id": 47, "f": "v"}, - {"_index": "i", "_id": 45, "f": "v"}, - {"_index": "i", "_id": 42, "f": "v"}, - ] - results = list( - helpers.streaming_bulk( - failing_client, - docs, - raise_on_exception=False, - raise_on_error=False, - chunk_size=1, - max_retries=1, - initial_backoff=0, - ) +def test_bulk_rejected_documents_are_retried(sync_client): + failing_client = FailingBulkClient( + sync_client, fail_with=TransportError(429, "Rejected!", {}) + ) + docs = [ + {"_index": "i", "_id": 47, "f": "v"}, + {"_index": "i", "_id": 45, "f": "v"}, + {"_index": "i", "_id": 42, "f": "v"}, + ] + results = list( + helpers.streaming_bulk( + failing_client, + docs, + raise_on_exception=False, + raise_on_error=False, + chunk_size=1, + max_retries=1, + initial_backoff=0, ) - self.assertEqual(3, len(results)) - self.assertEqual([True, True, True], [r[0] for r in results]) - self.client.indices.refresh(index="i") - res = self.client.search(index="i") - self.assertEqual({"value": 3, "relation": "eq"}, res["hits"]["total"]) - self.assertEqual(4, failing_client._called) - - def test_rejected_documents_are_retried_at_most_max_retries_times(self): - failing_client = FailingBulkClient( - self.client, fail_at=(1, 2), fail_with=TransportError(429, "Rejected!", {}) + ) + assert 3 == len(results) + assert [True, True, True] == [r[0] for r in results] + sync_client.indices.refresh(index="i") + res = sync_client.search(index="i") + assert {"value": 3, "relation": "eq"} == res["hits"]["total"] + assert 4 == failing_client._called + + +def test_bulk_rejected_documents_are_retried_at_most_max_retries_times(sync_client): + failing_client = FailingBulkClient( + sync_client, fail_at=(1, 2), fail_with=TransportError(429, "Rejected!", {}) + ) + + docs = [ + {"_index": "i", "_id": 47, "f": "v"}, + {"_index": "i", "_id": 45, "f": "v"}, + {"_index": "i", "_id": 42, "f": "v"}, + ] + results = list( + helpers.streaming_bulk( + failing_client, + docs, + raise_on_exception=False, + raise_on_error=False, + chunk_size=1, + max_retries=1, + initial_backoff=0, ) + ) + assert 3 == len(results) + assert [False, True, True] == [r[0] for r in results] + sync_client.indices.refresh(index="i") + res = sync_client.search(index="i") + assert {"value": 2, "relation": "eq"} == res["hits"]["total"] + assert 4 == failing_client._called + + +def test_bulk_transport_error_is_raised_with_max_retries(sync_client): + failing_client = FailingBulkClient( + sync_client, + fail_at=(1, 2, 3, 4), + fail_with=TransportError(429, "Rejected!", {}), + ) - docs = [ - {"_index": "i", "_id": 47, "f": "v"}, - {"_index": "i", "_id": 45, "f": "v"}, - {"_index": "i", "_id": 42, "f": "v"}, - ] + def streaming_bulk(): results = list( helpers.streaming_bulk( failing_client, - docs, - raise_on_exception=False, - raise_on_error=False, - chunk_size=1, - max_retries=1, + [{"a": 42}, {"a": 39}], + raise_on_exception=True, + max_retries=3, initial_backoff=0, ) ) - self.assertEqual(3, len(results)) - self.assertEqual([False, True, True], [r[0] for r in results]) - self.client.indices.refresh(index="i") - res = self.client.search(index="i") - self.assertEqual({"value": 2, "relation": "eq"}, res["hits"]["total"]) - self.assertEqual(4, failing_client._called) - - def test_transport_error_is_raised_with_max_retries(self): - failing_client = FailingBulkClient( - self.client, - fail_at=(1, 2, 3, 4), - fail_with=TransportError(429, "Rejected!", {}), - ) + return results - def streaming_bulk(): - results = list( - helpers.streaming_bulk( - failing_client, - [{"a": 42}, {"a": 39}], - raise_on_exception=True, - max_retries=3, - initial_backoff=0, - ) - ) - return results + with pytest.raises(TransportError): + streaming_bulk() + assert 4 == failing_client._called - self.assertRaises(TransportError, streaming_bulk) - self.assertEqual(4, failing_client._called) +def test_bulk_works_with_single_item(sync_client): + docs = [{"answer": 42, "_id": 1}] + success, failed = helpers.bulk(sync_client, docs, index="test-index", refresh=True) -class TestBulk(ElasticsearchTestCase): - def test_bulk_works_with_single_item(self): - docs = [{"answer": 42, "_id": 1}] - success, failed = helpers.bulk( - self.client, docs, index="test-index", refresh=True - ) - - self.assertEqual(1, success) - self.assertFalse(failed) - self.assertEqual(1, self.client.count(index="test-index")["count"]) - self.assertEqual( - {"answer": 42}, self.client.get(index="test-index", id=1)["_source"] - ) + assert 1 == success + assert not failed + assert 1 == sync_client.count(index="test-index")["count"] + assert {"answer": 42} == sync_client.get(index="test-index", id=1)["_source"] - def test_all_documents_get_inserted(self): - docs = [{"answer": x, "_id": x} for x in range(100)] - success, failed = helpers.bulk( - self.client, docs, index="test-index", refresh=True - ) - self.assertEqual(100, success) - self.assertFalse(failed) - self.assertEqual(100, self.client.count(index="test-index")["count"]) - self.assertEqual( - {"answer": 42}, self.client.get(index="test-index", id=42)["_source"] - ) +def test_all_documents_get_inserted(sync_client): + docs = [{"answer": x, "_id": x} for x in range(100)] + success, failed = helpers.bulk(sync_client, docs, index="test-index", refresh=True) - def test_stats_only_reports_numbers(self): - docs = [{"answer": x} for x in range(100)] - success, failed = helpers.bulk( - self.client, docs, index="test-index", refresh=True, stats_only=True - ) + assert 100 == success + assert not failed + assert 100 == sync_client.count(index="test-index")["count"] + assert {"answer": 42} == sync_client.get(index="test-index", id=42)["_source"] - self.assertEqual(100, success) - self.assertEqual(0, failed) - self.assertEqual(100, self.client.count(index="test-index")["count"]) - def test_errors_are_reported_correctly(self): - self.client.indices.create( - "i", - { - "mappings": {"properties": {"a": {"type": "integer"}}}, - "settings": {"number_of_shards": 1, "number_of_replicas": 0}, - }, - ) - self.client.cluster.health(wait_for_status="yellow") +def test_stats_only_reports_numbers(sync_client): + docs = [{"answer": x} for x in range(100)] + success, failed = helpers.bulk( + sync_client, docs, index="test-index", refresh=True, stats_only=True + ) - success, failed = helpers.bulk( - self.client, - [{"a": 42}, {"a": "c", "_id": 42}], - index="i", - raise_on_error=False, - ) - self.assertEqual(1, success) - self.assertEqual(1, len(failed)) - error = failed[0] - self.assertEqual("42", error["index"]["_id"]) - self.assertEqual("i", error["index"]["_index"]) - print(error["index"]["error"]) - self.assertTrue( - "MapperParsingException" in repr(error["index"]["error"]) - or "mapper_parsing_exception" in repr(error["index"]["error"]) - ) + assert 100 == success + assert 0 == failed + assert 100 == sync_client.count(index="test-index")["count"] - def test_error_is_raised(self): - self.client.indices.create( - "i", - { - "mappings": {"properties": {"a": {"type": "integer"}}}, - "settings": {"number_of_shards": 1, "number_of_replicas": 0}, - }, - ) - self.client.cluster.health(wait_for_status="yellow") - self.assertRaises( - helpers.BulkIndexError, - helpers.bulk, - self.client, - [{"a": 42}, {"a": "c"}], - index="i", - ) +def test_errors_are_reported_correctly(sync_client): + sync_client.indices.create( + "i", + { + "mappings": {"properties": {"a": {"type": "integer"}}}, + "settings": {"number_of_shards": 1, "number_of_replicas": 0}, + }, + ) + sync_client.cluster.health(wait_for_status="yellow") - def test_ignore_error_if_raised(self): - # ignore the status code 400 in tuple - helpers.bulk( - self.client, [{"a": 42}, {"a": "c"}], index="i", ignore_status=(400,) - ) + success, failed = helpers.bulk( + sync_client, + [{"a": 42}, {"a": "c", "_id": 42}], + index="i", + raise_on_error=False, + ) + assert 1 == success + assert 1 == len(failed) + error = failed[0] + assert "42" == error["index"]["_id"] + assert "i" == error["index"]["_index"] + print(error["index"]["error"]) + assert "MapperParsingException" in repr( + error["index"]["error"] + ) or "mapper_parsing_exception" in repr(error["index"]["error"]) + + +def test_error_is_raised(sync_client): + sync_client.indices.create( + "i", + { + "mappings": {"properties": {"a": {"type": "integer"}}}, + "settings": {"number_of_shards": 1, "number_of_replicas": 0}, + }, + ) + sync_client.cluster.health(wait_for_status="yellow") - # ignore the status code 400 in list + with pytest.raises(helpers.BulkIndexError): helpers.bulk( - self.client, + sync_client, [{"a": 42}, {"a": "c"}], index="i", - ignore_status=[ - 400, - ], ) - # ignore the status code 400 - helpers.bulk(self.client, [{"a": 42}, {"a": "c"}], index="i", ignore_status=400) - # ignore only the status code in the `ignore_status` argument - self.assertRaises( - helpers.BulkIndexError, - helpers.bulk, - self.client, - [{"a": 42}, {"a": "c"}], - index="i", - ignore_status=(444,), - ) +def test_ignore_error_if_raised(sync_client): + # ignore the status code 400 in tuple + helpers.bulk(sync_client, [{"a": 42}, {"a": "c"}], index="i", ignore_status=(400,)) - # ignore transport error exception - failing_client = FailingBulkClient(self.client) - helpers.bulk(failing_client, [{"a": 42}], index="i", ignore_status=(599,)) + # ignore the status code 400 in list + helpers.bulk( + sync_client, + [{"a": 42}, {"a": "c"}], + index="i", + ignore_status=[ + 400, + ], + ) - def test_errors_are_collected_properly(self): - self.client.indices.create( - "i", - { - "mappings": {"properties": {"a": {"type": "integer"}}}, - "settings": {"number_of_shards": 1, "number_of_replicas": 0}, - }, - ) - self.client.cluster.health(wait_for_status="yellow") + # ignore the status code 400 + helpers.bulk(sync_client, [{"a": 42}, {"a": "c"}], index="i", ignore_status=400) - success, failed = helpers.bulk( - self.client, + # ignore only the status code in the `ignore_status` argument + with pytest.raises(helpers.BulkIndexError): + helpers.bulk( + sync_client, [{"a": 42}, {"a": "c"}], index="i", - stats_only=True, - raise_on_error=False, + ignore_status=(444,), ) - self.assertEqual(1, success) - self.assertEqual(1, failed) + # ignore transport error exception + failing_client = FailingBulkClient(sync_client) + helpers.bulk(failing_client, [{"a": 42}], index="i", ignore_status=(599,)) -class TestScan(ElasticsearchTestCase): - mock_scroll_responses = [ - { - "_scroll_id": "dummy_id", - "_shards": {"successful": 4, "total": 5, "skipped": 0}, - "hits": {"hits": [{"scroll_data": 42}]}, - }, + +def test_errors_are_collected_properly(sync_client): + sync_client.indices.create( + "i", { - "_scroll_id": "dummy_id", - "_shards": {"successful": 4, "total": 5, "skipped": 0}, - "hits": {"hits": []}, + "mappings": {"properties": {"a": {"type": "integer"}}}, + "settings": {"number_of_shards": 1, "number_of_replicas": 0}, }, - ] + ) + sync_client.cluster.health(wait_for_status="yellow") + + success, failed = helpers.bulk( + sync_client, + [{"a": 42}, {"a": "c"}], + index="i", + stats_only=True, + raise_on_error=False, + ) + assert 1 == success + assert 1 == failed - def teardown_method(self, m): - self.client.transport.perform_request("DELETE", "/_search/scroll/_all") - super(TestScan, self).teardown_method(m) - def test_order_can_be_preserved(self): - bulk = [] - for x in range(100): - bulk.append({"index": {"_index": "test_index", "_id": x}}) - bulk.append({"answer": x, "correct": x == 42}) - self.client.bulk(bulk, refresh=True) +mock_scroll_responses = [ + { + "_scroll_id": "dummy_id", + "_shards": {"successful": 4, "total": 5, "skipped": 0}, + "hits": {"hits": [{"scroll_data": 42}]}, + }, + { + "_scroll_id": "dummy_id", + "_shards": {"successful": 4, "total": 5, "skipped": 0}, + "hits": {"hits": []}, + }, +] - docs = list( - helpers.scan( - self.client, - index="test_index", - query={"sort": "answer"}, - preserve_order=True, - ) + +@pytest.fixture(scope="function") +def scan_teardown(sync_client): + yield + sync_client.clear_scroll(scroll_id="_all") + + +@pytest.mark.usefixtures("scan_teardown") +def test_order_can_be_preserved(sync_client): + bulk = [] + for x in range(100): + bulk.append({"index": {"_index": "test_index", "_id": x}}) + bulk.append({"answer": x, "correct": x == 42}) + sync_client.bulk(bulk, refresh=True) + + docs = list( + helpers.scan( + sync_client, + index="test_index", + query={"sort": "answer"}, + preserve_order=True, ) + ) - self.assertEqual(100, len(docs)) - self.assertEqual(list(map(str, range(100))), list(d["_id"] for d in docs)) - self.assertEqual(list(range(100)), list(d["_source"]["answer"] for d in docs)) + assert 100 == len(docs) + assert list(map(str, range(100))) == list(d["_id"] for d in docs) + assert list(range(100)) == list(d["_source"]["answer"] for d in docs) - def test_all_documents_are_read(self): - bulk = [] - for x in range(100): - bulk.append({"index": {"_index": "test_index", "_id": x}}) - bulk.append({"answer": x, "correct": x == 42}) - self.client.bulk(bulk, refresh=True) - docs = list(helpers.scan(self.client, index="test_index", size=2)) +@pytest.mark.usefixtures("scan_teardown") +def test_all_documents_are_read(sync_client): + bulk = [] + for x in range(100): + bulk.append({"index": {"_index": "test_index", "_id": x}}) + bulk.append({"answer": x, "correct": x == 42}) + sync_client.bulk(bulk, refresh=True) - self.assertEqual(100, len(docs)) - self.assertEqual(set(map(str, range(100))), set(d["_id"] for d in docs)) - self.assertEqual(set(range(100)), set(d["_source"]["answer"] for d in docs)) + docs = list(helpers.scan(sync_client, index="test_index", size=2)) + + assert 100 == len(docs) + assert set(map(str, range(100))) == set(d["_id"] for d in docs) + assert set(range(100)) == set(d["_source"]["answer"] for d in docs) + + +@pytest.mark.usefixtures("scan_teardown") +def test_scroll_error(sync_client): + bulk = [] + for x in range(4): + bulk.append({"index": {"_index": "test_index"}}) + bulk.append({"value": x}) + sync_client.bulk(bulk, refresh=True) - def test_scroll_error(self): - bulk = [] - for x in range(4): - bulk.append({"index": {"_index": "test_index"}}) - bulk.append({"value": x}) - self.client.bulk(bulk, refresh=True) + with patch.object(sync_client, "scroll") as scroll_mock: + scroll_mock.side_effect = mock_scroll_responses + data = list( + helpers.scan( + sync_client, + index="test_index", + size=2, + raise_on_error=False, + clear_scroll=False, + ) + ) + assert len(data) == 3 + assert data[-1] == {"scroll_data": 42} - with patch.object(self.client, "scroll") as scroll_mock: - scroll_mock.side_effect = self.mock_scroll_responses + scroll_mock.side_effect = mock_scroll_responses + with pytest.raises(ScanError): data = list( helpers.scan( - self.client, + sync_client, index="test_index", size=2, - raise_on_error=False, + raise_on_error=True, clear_scroll=False, ) ) - self.assertEqual(len(data), 3) - self.assertEqual(data[-1], {"scroll_data": 42}) + assert len(data) == 3 + assert data[-1] == {"scroll_data": 42} - scroll_mock.side_effect = self.mock_scroll_responses - with self.assertRaises(ScanError): - data = list( - helpers.scan( - self.client, - index="test_index", - size=2, - raise_on_error=True, - clear_scroll=False, - ) - ) - self.assertEqual(len(data), 3) - self.assertEqual(data[-1], {"scroll_data": 42}) - def test_initial_search_error(self): - with patch.object(self, "client") as client_mock: - client_mock.search.return_value = { - "_scroll_id": "dummy_id", - "_shards": {"successful": 4, "total": 5, "skipped": 0}, - "hits": {"hits": [{"search_data": 1}]}, - } - client_mock.scroll.side_effect = self.mock_scroll_responses +def test_initial_search_error(sync_client): + with patch.object( + sync_client, + "search", + return_value={ + "_scroll_id": "dummy_id", + "_shards": {"successful": 4, "total": 5, "skipped": 0}, + "hits": {"hits": [{"search_data": 1}]}, + }, + ): + with patch.object(sync_client, "scroll") as scroll_mock, patch.object( + sync_client, "clear_scroll" + ) as clear_scroll_mock: + scroll_mock.side_effect = mock_scroll_responses data = list( helpers.scan( - self.client, index="test_index", size=2, raise_on_error=False + sync_client, index="test_index", size=2, raise_on_error=False ) ) - self.assertEqual(data, [{"search_data": 1}, {"scroll_data": 42}]) + assert data == [{"search_data": 1}, {"scroll_data": 42}] + + # Scrolled at least once and received a scroll_id to clear. + scroll_mock.assert_called_with( + body={"scroll_id": "dummy_id", "scroll": "5m"}, + params={"__elastic_client_meta": (("h", "s"),)}, + ) + clear_scroll_mock.assert_called_once_with( + body={"scroll_id": ["dummy_id"]}, + ignore=(404,), + params={"__elastic_client_meta": (("h", "s"),)}, + ) + + with patch.object(sync_client, "scroll") as scroll_mock, patch.object( + sync_client, "clear_scroll" + ) as clear_scroll_mock: - client_mock.scroll.side_effect = self.mock_scroll_responses - with self.assertRaises(ScanError): + scroll_mock.side_effect = mock_scroll_responses + with pytest.raises(ScanError): data = list( helpers.scan( - self.client, index="test_index", size=2, raise_on_error=True + sync_client, index="test_index", size=2, raise_on_error=True ) ) - self.assertEqual(data, [{"search_data": 1}]) - client_mock.scroll.assert_not_called() - - def test_no_scroll_id_fast_route(self): - with patch.object(self, "client") as client_mock: - client_mock.search.return_value = {"no": "_scroll_id"} - data = list(helpers.scan(self.client, index="test_index")) - - self.assertEqual(data, []) - client_mock.scroll.assert_not_called() - client_mock.clear_scroll.assert_not_called() - - def test_scan_auth_kwargs_forwarded(self): - for key, val in { - "api_key": ("name", "value"), - "http_auth": ("username", "password"), - "headers": {"custom": "header"}, - }.items(): - with patch.object(self, "client") as client_mock: - client_mock.search.return_value = { - "_scroll_id": "scroll_id", - "_shards": {"successful": 5, "total": 5, "skipped": 0}, - "hits": {"hits": [{"search_data": 1}]}, - } - client_mock.scroll.return_value = { - "_scroll_id": "scroll_id", - "_shards": {"successful": 5, "total": 5, "skipped": 0}, - "hits": {"hits": []}, - } - client_mock.clear_scroll.return_value = {} - - data = list(helpers.scan(self.client, index="test_index", **{key: val})) - - self.assertEqual(data, [{"search_data": 1}]) + assert data == [{"search_data": 1}] + + # Never scrolled but did receive a scroll_id to clear. + scroll_mock.assert_not_called() + clear_scroll_mock.assert_called_once_with( + body={"scroll_id": ["dummy_id"]}, + ignore=(404,), + params={"__elastic_client_meta": (("h", "s"),)}, + ) - # Assert that 'search', 'scroll' and 'clear_scroll' all - # received the extra kwarg related to authentication. - for api_mock in ( - client_mock.search, - client_mock.scroll, - client_mock.clear_scroll, - ): - self.assertEqual(api_mock.call_args[1][key], val) - def test_scan_auth_kwargs_favor_scroll_kwargs_option(self): - with patch.object(self, "client") as client_mock: - client_mock.search.return_value = { - "_scroll_id": "scroll_id", - "_shards": {"successful": 5, "total": 5, "skipped": 0}, - "hits": {"hits": [{"search_data": 1}]}, - } - client_mock.scroll.return_value = { +def test_no_scroll_id_fast_route(sync_client): + with patch.object( + sync_client, "search", return_value={"no": "_scroll_id"} + ) as search_mock, patch.object(sync_client, "scroll") as scroll_mock, patch.object( + sync_client, "clear_scroll" + ) as clear_scroll_mock: + data = list(helpers.scan(sync_client, index="test_index")) + + assert data == [] + search_mock.assert_called_once_with( + body={"sort": "_doc"}, + scroll="5m", + size=1000, + request_timeout=None, + index="test_index", + ) + scroll_mock.assert_not_called() + clear_scroll_mock.assert_not_called() + + +@pytest.mark.parametrize( + "kwargs", + [ + {"api_key": ("name", "value")}, + {"http_auth": ("username", "password")}, + {"headers": {"custom", "header"}}, + ], +) +@pytest.mark.usefixtures("scan_teardown") +def test_scan_auth_kwargs_forwarded(sync_client, kwargs): + ((key, val),) = kwargs.items() + + with patch.object( + sync_client, + "search", + return_value={ + "_scroll_id": "scroll_id", + "_shards": {"successful": 5, "total": 5, "skipped": 0}, + "hits": {"hits": [{"search_data": 1}]}, + }, + ) as search_mock: + with patch.object( + sync_client, + "scroll", + return_value={ "_scroll_id": "scroll_id", "_shards": {"successful": 5, "total": 5, "skipped": 0}, "hits": {"hits": []}, - } - client_mock.clear_scroll.return_value = {} + }, + ) as scroll_mock: + with patch.object( + sync_client, "clear_scroll", return_value={} + ) as clear_mock: - data = list( - helpers.scan( - self.client, - index="test_index", - scroll_kwargs={"headers": {"scroll": "kwargs"}, "sort": "asc"}, - headers={"not scroll": "kwargs"}, - ) + data = list(helpers.scan(sync_client, index="test_index", **kwargs)) + + assert data == [{"search_data": 1}] + + for api_mock in (search_mock, scroll_mock, clear_mock): + assert api_mock.call_args[1][key] == val + + +def test_scan_auth_kwargs_favor_scroll_kwargs_option(sync_client): + + with patch.object( + sync_client, + "search", + return_value={ + "_scroll_id": "scroll_id", + "_shards": {"successful": 5, "total": 5, "skipped": 0}, + "hits": {"hits": [{"search_data": 1}]}, + }, + ) as search_mock, patch.object( + sync_client, + "scroll", + return_value={ + "_scroll_id": "scroll_id", + "_shards": {"successful": 5, "total": 5, "skipped": 0}, + "hits": {"hits": []}, + }, + ) as scroll_mock, patch.object( + sync_client, "clear_scroll", return_value={} + ): + + data = list( + helpers.scan( + sync_client, + index="test_index", + scroll_kwargs={"headers": {"scroll": "kwargs"}, "sort": "asc"}, + headers={"not scroll": "kwargs"}, ) + ) + + assert data == [{"search_data": 1}] + + # Assert that we see 'scroll_kwargs' options used instead of 'kwargs' + search_mock.assert_called_once_with( + body={"sort": "_doc"}, + scroll="5m", + size=1000, + request_timeout=None, + index="test_index", + headers={"not scroll": "kwargs"}, + ) + scroll_mock.assert_called_once_with( + body={"scroll_id": "scroll_id", "scroll": "5m"}, + headers={"scroll": "kwargs"}, + sort="asc", + params={"__elastic_client_meta": (("h", "s"),)}, + ) + - self.assertEqual(data, [{"search_data": 1}]) +def test_log_warning_on_shard_failures(sync_client): + bulk = [] + for x in range(4): + bulk.append({"index": {"_index": "test_index"}}) + bulk.append({"value": x}) + sync_client.bulk(bulk, refresh=True) - # Assert that we see 'scroll_kwargs' options used instead of 'kwargs' - self.assertEqual( - client_mock.scroll.call_args[1]["headers"], {"scroll": "kwargs"} + with patch("elasticsearch.helpers.actions.logger") as logger_mock, patch.object( + sync_client, "scroll" + ) as scroll_mock: + scroll_mock.side_effect = mock_scroll_responses + list( + helpers.scan( + sync_client, + index="test_index", + size=2, + raise_on_error=False, + clear_scroll=False, ) - self.assertEqual(client_mock.scroll.call_args[1]["sort"], "asc") - - @patch("elasticsearch.helpers.actions.logger") - def test_logger(self, logger_mock): - bulk = [] - for x in range(4): - bulk.append({"index": {"_index": "test_index"}}) - bulk.append({"value": x}) - self.client.bulk(bulk, refresh=True) - - with patch.object(self.client, "scroll") as scroll_mock: - scroll_mock.side_effect = self.mock_scroll_responses + ) + logger_mock.warning.assert_called() + + scroll_mock.side_effect = mock_scroll_responses + try: list( helpers.scan( - self.client, + sync_client, index="test_index", size=2, - raise_on_error=False, + raise_on_error=True, clear_scroll=False, ) ) - logger_mock.warning.assert_called() + except ScanError: + pass + logger_mock.warning.assert_called() - scroll_mock.side_effect = self.mock_scroll_responses - try: - list( - helpers.scan( - self.client, - index="test_index", - size=2, - raise_on_error=True, - clear_scroll=False, - ) - ) - except ScanError: - pass - logger_mock.warning.assert_called() - def test_clear_scroll(self): - bulk = [] - for x in range(4): - bulk.append({"index": {"_index": "test_index"}}) - bulk.append({"value": x}) - self.client.bulk(bulk, refresh=True) +def test_clear_scroll(sync_client): + bulk = [] + for x in range(4): + bulk.append({"index": {"_index": "test_index"}}) + bulk.append({"value": x}) + sync_client.bulk(bulk, refresh=True) - with patch.object( - self.client, "clear_scroll", wraps=self.client.clear_scroll - ) as spy: - list(helpers.scan(self.client, index="test_index", size=2)) - spy.assert_called_once() + with patch.object( + sync_client, "clear_scroll", wraps=sync_client.clear_scroll + ) as clear_scroll_mock: + list(helpers.scan(sync_client, index="test_index", size=2)) + clear_scroll_mock.assert_called_once() - spy.reset_mock() - list( - helpers.scan(self.client, index="test_index", size=2, clear_scroll=True) - ) - spy.assert_called_once() + clear_scroll_mock.reset_mock() + list(helpers.scan(sync_client, index="test_index", size=2, clear_scroll=True)) + clear_scroll_mock.assert_called_once() + + clear_scroll_mock.reset_mock() + list(helpers.scan(sync_client, index="test_index", size=2, clear_scroll=False)) + clear_scroll_mock.assert_not_called() - spy.reset_mock() - list( - helpers.scan( - self.client, index="test_index", size=2, clear_scroll=False - ) - ) - spy.assert_not_called() - def test_shards_no_skipped_field(self): - with patch.object(self, "client") as client_mock: - client_mock.search.return_value = { +def test_shards_no_skipped_field(sync_client): + # Test that scan doesn't fail if 'hits.skipped' isn't available. + with patch.object( + sync_client, + "search", + return_value={ + "_scroll_id": "dummy_id", + "_shards": {"successful": 5, "total": 5}, + "hits": {"hits": [{"search_data": 1}]}, + }, + ), patch.object(sync_client, "scroll") as scroll_mock, patch.object( + sync_client, "clear_scroll" + ): + scroll_mock.side_effect = [ + { "_scroll_id": "dummy_id", "_shards": {"successful": 5, "total": 5}, - "hits": {"hits": [{"search_data": 1}]}, - } - client_mock.scroll.side_effect = [ - { - "_scroll_id": "dummy_id", - "_shards": {"successful": 5, "total": 5}, - "hits": {"hits": [{"scroll_data": 42}]}, - }, - { - "_scroll_id": "dummy_id", - "_shards": {"successful": 5, "total": 5}, - "hits": {"hits": []}, - }, - ] - - data = list( - helpers.scan( - self.client, index="test_index", size=2, raise_on_error=True - ) - ) - self.assertEqual(data, [{"search_data": 1}, {"scroll_data": 42}]) - - -class TestReindex(ElasticsearchTestCase): - def setup_method(self, _): - bulk = [] - for x in range(100): - bulk.append({"index": {"_index": "test_index", "_id": x}}) - bulk.append( - { - "answer": x, - "correct": x == 42, - "type": "answers" if x % 2 == 0 else "questions", - } - ) - self.client.bulk(bulk, refresh=True) + "hits": {"hits": [{"scroll_data": 42}]}, + }, + { + "_scroll_id": "dummy_id", + "_shards": {"successful": 5, "total": 5}, + "hits": {"hits": []}, + }, + ] - def test_reindex_passes_kwargs_to_scan_and_bulk(self): - helpers.reindex( - self.client, - "test_index", - "prod_index", - scan_kwargs={"q": "type:answers"}, - bulk_kwargs={"refresh": True}, + data = list( + helpers.scan(sync_client, index="test_index", size=2, raise_on_error=True) ) + assert data == [{"search_data": 1}, {"scroll_data": 42}] - self.assertTrue(self.client.indices.exists("prod_index")) - self.assertEqual( - 50, self.client.count(index="prod_index", q="type:answers")["count"] - ) - self.assertEqual( - {"answer": 42, "correct": True, "type": "answers"}, - self.client.get(index="prod_index", id=42)["_source"], +@pytest.fixture(scope="function") +def reindex_setup(sync_client): + bulk = [] + for x in range(100): + bulk.append({"index": {"_index": "test_index", "_id": x}}) + bulk.append( + { + "answer": x, + "correct": x == 42, + "type": "answers" if x % 2 == 0 else "questions", + } ) + sync_client.bulk(bulk, refresh=True) - def test_reindex_accepts_a_query(self): - helpers.reindex( - self.client, - "test_index", - "prod_index", - query={"query": {"bool": {"filter": {"term": {"type": "answers"}}}}}, - ) - self.client.indices.refresh() - self.assertTrue(self.client.indices.exists("prod_index")) - self.assertEqual( - 50, self.client.count(index="prod_index", q="type:answers")["count"] - ) +@pytest.mark.usefixtures("reindex_setup") +def test_reindex_passes_kwargs_to_scan_and_bulk(sync_client): + helpers.reindex( + sync_client, + "test_index", + "prod_index", + scan_kwargs={"q": "type:answers"}, + bulk_kwargs={"refresh": True}, + ) - self.assertEqual( - {"answer": 42, "correct": True, "type": "answers"}, - self.client.get(index="prod_index", id=42)["_source"], - ) + assert sync_client.indices.exists("prod_index") + assert 50 == sync_client.count(index="prod_index", q="type:answers")["count"] - def test_all_documents_get_moved(self): - helpers.reindex(self.client, "test_index", "prod_index") - self.client.indices.refresh() + assert {"answer": 42, "correct": True, "type": "answers"} == sync_client.get( + index="prod_index", id=42 + )["_source"] - self.assertTrue(self.client.indices.exists("prod_index")) - self.assertEqual( - 50, self.client.count(index="prod_index", q="type:questions")["count"] - ) - self.assertEqual( - 50, self.client.count(index="prod_index", q="type:answers")["count"] - ) - self.assertEqual( - {"answer": 42, "correct": True, "type": "answers"}, - self.client.get(index="prod_index", id=42)["_source"], - ) +@pytest.mark.usefixtures("reindex_setup") +def test_reindex_accepts_a_query(sync_client): + helpers.reindex( + sync_client, + "test_index", + "prod_index", + query={"query": {"bool": {"filter": {"term": {"type": "answers"}}}}}, + ) + sync_client.indices.refresh() + assert sync_client.indices.exists("prod_index") + assert 50 == sync_client.count(index="prod_index", q="type:answers")["count"] -class TestParentChildReindex(ElasticsearchTestCase): - def setup_method(self, _): - body = { - "settings": {"number_of_shards": 1, "number_of_replicas": 0}, - "mappings": { - "properties": { - "question_answer": { - "type": "join", - "relations": {"question": "answer"}, - } - } - }, - } - self.client.indices.create(index="test-index", body=body) - self.client.indices.create(index="real-index", body=body) + assert {"answer": 42, "correct": True, "type": "answers"} == sync_client.get( + index="prod_index", id=42 + )["_source"] - self.client.index( - index="test-index", id=42, body={"question_answer": "question"} - ) - self.client.index( - index="test-index", - id=47, - routing=42, - body={"some": "data", "question_answer": {"name": "answer", "parent": 42}}, - ) - self.client.indices.refresh(index="test-index") - def test_children_are_reindexed_correctly(self): - helpers.reindex(self.client, "test-index", "real-index") +@pytest.mark.usefixtures("reindex_setup") +def test_all_documents_get_moved(sync_client): + helpers.reindex(sync_client, "test_index", "prod_index") + sync_client.indices.refresh() - q = self.client.get(index="real-index", id=42) - self.assertEqual( - { - "_id": "42", - "_index": "real-index", - "_primary_term": 1, - "_seq_no": 0, - "_source": {"question_answer": "question"}, - "_version": 1, - "found": True, - }, - q, - ) - q = self.client.get(index="test-index", id=47, routing=42) - self.assertEqual( - { - "_routing": "42", - "_id": "47", - "_index": "test-index", - "_primary_term": 1, - "_seq_no": 1, - "_source": { - "some": "data", - "question_answer": {"name": "answer", "parent": 42}, - }, - "_version": 1, - "found": True, - }, - q, - ) + assert sync_client.indices.exists("prod_index") + assert 50 == sync_client.count(index="prod_index", q="type:questions")["count"] + assert 50 == sync_client.count(index="prod_index", q="type:answers")["count"] + + assert {"answer": 42, "correct": True, "type": "answers"} == sync_client.get( + index="prod_index", id=42 + )["_source"] + + +@pytest.fixture(scope="function") +def parent_child_reindex_setup(sync_client): + body = { + "settings": {"number_of_shards": 1, "number_of_replicas": 0}, + "mappings": { + "properties": { + "question_answer": { + "type": "join", + "relations": {"question": "answer"}, + } + } + }, + } + sync_client.indices.create(index="test-index", body=body) + sync_client.indices.create(index="real-index", body=body) + + sync_client.index(index="test-index", id=42, body={"question_answer": "question"}) + sync_client.index( + index="test-index", + id=47, + routing=42, + body={"some": "data", "question_answer": {"name": "answer", "parent": 42}}, + ) + sync_client.indices.refresh(index="test-index") + + +@pytest.mark.usefixtures("parent_child_reindex_setup") +def test_children_are_reindexed_correctly(sync_client): + helpers.reindex(sync_client, "test-index", "real-index") + + q = sync_client.get(index="real-index", id=42) + assert { + "_id": "42", + "_index": "real-index", + "_primary_term": 1, + "_seq_no": 0, + "_source": {"question_answer": "question"}, + "_version": 1, + "found": True, + } == q + q = sync_client.get(index="test-index", id=47, routing=42) + assert { + "_routing": "42", + "_id": "47", + "_index": "test-index", + "_primary_term": 1, + "_seq_no": 1, + "_source": { + "some": "data", + "question_answer": {"name": "answer", "parent": 42}, + }, + "_version": 1, + "found": True, + } == q @pytest.fixture(scope="function") @@ -786,32 +840,30 @@ def reindex_data_stream_setup(sync_client): sync_client.indices.refresh() -class TestDataStreamReindex(object): - @pytest.mark.usefixtures("reindex_data_stream_setup") - @pytest.mark.parametrize("op_type", [None, "create"]) - def test_reindex_index_datastream(self, op_type, sync_client): +@pytest.mark.usefixtures("reindex_data_stream_setup") +@pytest.mark.parametrize("op_type", [None, "create"]) +def test_reindex_index_datastream(op_type, sync_client): + helpers.reindex( + sync_client, + source_index="test_index_stream", + target_index="py-test-stream", + query={"query": {"bool": {"filter": {"term": {"type": "answers"}}}}}, + op_type=op_type, + ) + sync_client.indices.refresh() + assert sync_client.indices.exists(index="py-test-stream") + assert 50 == sync_client.count(index="py-test-stream", q="type:answers")["count"] + + +@pytest.mark.usefixtures("reindex_data_stream_setup") +def test_reindex_index_datastream_op_type_index(sync_client): + with pytest.raises( + ValueError, match="Data streams must have 'op_type' set to 'create'" + ): helpers.reindex( sync_client, source_index="test_index_stream", target_index="py-test-stream", query={"query": {"bool": {"filter": {"term": {"type": "answers"}}}}}, - op_type=op_type, - ) - sync_client.indices.refresh() - assert sync_client.indices.exists(index="py-test-stream") - assert ( - 50 == sync_client.count(index="py-test-stream", q="type:answers")["count"] + op_type="_index", ) - - @pytest.mark.usefixtures("reindex_data_stream_setup") - def test_reindex_index_datastream_op_type_index(self, sync_client): - with pytest.raises( - ValueError, match="Data streams must have 'op_type' set to 'create'" - ): - helpers.reindex( - sync_client, - source_index="test_index_stream", - target_index="py-test-stream", - query={"query": {"bool": {"filter": {"term": {"type": "answers"}}}}}, - op_type="_index", - ) diff --git a/test_elasticsearch/test_server/test_rest_api_spec.py b/test_elasticsearch/test_server/test_rest_api_spec.py index 1c74880bb..8869f60f7 100644 --- a/test_elasticsearch/test_server/test_rest_api_spec.py +++ b/test_elasticsearch/test_server/test_rest_api_spec.py @@ -32,12 +32,16 @@ import urllib3 import yaml -from elasticsearch import ElasticsearchWarning, RequestError, TransportError +from elasticsearch import ( + Elasticsearch, + ElasticsearchWarning, + RequestError, + TransportError, +) from elasticsearch.client.utils import _base64_auth_header from elasticsearch.compat import string_types -from elasticsearch.helpers.test import _get_version -from . import get_client +from ..utils import CA_CERTS, es_url, parse_version # some params had to be changed in python, keep track of them so we can rename # those in the tests accordingly @@ -300,9 +304,9 @@ def run_skip(self, skip): version, reason = skip["version"], skip["reason"] if version == "all": pytest.skip(reason) - min_version, max_version = version.split("-") - min_version = _get_version(min_version) or (0,) - max_version = _get_version(max_version) or (999,) + min_version, _, max_version = version.partition("-") + min_version = parse_version(min_version.strip()) or (0,) + max_version = parse_version(max_version.strip()) or (999,) if min_version <= (self.es_version()) <= max_version: pytest.skip(reason) @@ -490,7 +494,7 @@ def remove_implicit_resolver(cls, tag_to_remove): try: # Construct the HTTP and Elasticsearch client http = urllib3.PoolManager(retries=10) - client = get_client() + client = Elasticsearch(es_url(), timeout=3, ca_certs=CA_CERTS) # Make a request to Elasticsearch for the build hash, we'll be looking for # an artifact with this same hash to download test specs for. diff --git a/test_elasticsearch/test_transport.py b/test_elasticsearch/test_transport.py index bb520d0e7..f432ea739 100644 --- a/test_elasticsearch/test_transport.py +++ b/test_elasticsearch/test_transport.py @@ -19,6 +19,7 @@ from __future__ import unicode_literals import json +import re import time import pytest @@ -34,8 +35,6 @@ ) from elasticsearch.transport import Transport, get_host_info -from .test_cases import TestCase - class DummyConnection(Connection): def __init__(self, **kwargs): @@ -103,7 +102,7 @@ def perform_request(self, *args, **kwargs): }""" -class TestHostsInfoCallback(TestCase): +class TestHostsInfoCallback: def test_master_only_nodes_are_ignored(self): nodes = [ {"roles": ["master"]}, @@ -117,26 +116,27 @@ def test_master_only_nodes_are_ignored(self): for i, node_info in enumerate(nodes) if get_host_info(node_info, i) is not None ] - self.assertEqual([1, 2, 3, 4], chosen) + assert [1, 2, 3, 4] == chosen -class TestTransport(TestCase): +class TestTransport: def test_single_connection_uses_dummy_connection_pool(self): t = Transport([{}]) - self.assertIsInstance(t.connection_pool, DummyConnectionPool) + assert isinstance(t.connection_pool, DummyConnectionPool) t = Transport([{"host": "localhost"}]) - self.assertIsInstance(t.connection_pool, DummyConnectionPool) + assert isinstance(t.connection_pool, DummyConnectionPool) def test_request_timeout_extracted_from_params_and_passed(self): t = Transport([{}], meta_header=False, connection_class=DummyConnection) t.perform_request("GET", "/", params={"request_timeout": 42}) - self.assertEqual(1, len(t.get_connection().calls)) - self.assertEqual(("GET", "/", {}, None), t.get_connection().calls[0][0]) - self.assertEqual( - {"timeout": 42, "ignore": (), "headers": None}, - t.get_connection().calls[0][1], - ) + assert 1 == len(t.get_connection().calls) + assert ("GET", "/", {}, None) == t.get_connection().calls[0][0] + assert { + "timeout": 42, + "ignore": (), + "headers": None, + } == t.get_connection().calls[0][1] def test_opaque_id(self): t = Transport( @@ -144,60 +144,57 @@ def test_opaque_id(self): ) t.perform_request("GET", "/") - self.assertEqual(1, len(t.get_connection().calls)) - self.assertEqual(("GET", "/", None, None), t.get_connection().calls[0][0]) - self.assertEqual( - {"timeout": None, "ignore": (), "headers": None}, - t.get_connection().calls[0][1], - ) + assert 1 == len(t.get_connection().calls) + assert ("GET", "/", None, None) == t.get_connection().calls[0][0] + assert { + "timeout": None, + "ignore": (), + "headers": None, + } == t.get_connection().calls[0][1] # Now try with an 'x-opaque-id' set on perform_request(). t.perform_request("GET", "/", headers={"x-opaque-id": "request-1"}) - self.assertEqual(2, len(t.get_connection().calls)) - self.assertEqual(("GET", "/", None, None), t.get_connection().calls[1][0]) - self.assertEqual( - {"timeout": None, "ignore": (), "headers": {"x-opaque-id": "request-1"}}, - t.get_connection().calls[1][1], - ) + assert 2 == len(t.get_connection().calls) + assert ("GET", "/", None, None) == t.get_connection().calls[1][0] + assert { + "timeout": None, + "ignore": (), + "headers": {"x-opaque-id": "request-1"}, + } == t.get_connection().calls[1][1] def test_request_with_custom_user_agent_header(self): t = Transport([{}], meta_header=False, connection_class=DummyConnection) t.perform_request("GET", "/", headers={"user-agent": "my-custom-value/1.2.3"}) - self.assertEqual(1, len(t.get_connection().calls)) - self.assertEqual( - { - "timeout": None, - "ignore": (), - "headers": {"user-agent": "my-custom-value/1.2.3"}, - }, - t.get_connection().calls[0][1], - ) + assert 1 == len(t.get_connection().calls) + assert { + "timeout": None, + "ignore": (), + "headers": {"user-agent": "my-custom-value/1.2.3"}, + } == t.get_connection().calls[0][1] def test_send_get_body_as_source(self): t = Transport([{}], send_get_body_as="source", connection_class=DummyConnection) t.perform_request("GET", "/", body={}) - self.assertEqual(1, len(t.get_connection().calls)) - self.assertEqual( - ("GET", "/", {"source": "{}"}, None), t.get_connection().calls[0][0] - ) + assert 1 == len(t.get_connection().calls) + assert ("GET", "/", {"source": "{}"}, None) == t.get_connection().calls[0][0] def test_send_get_body_as_post(self): t = Transport([{}], send_get_body_as="POST", connection_class=DummyConnection) t.perform_request("GET", "/", body={}) - self.assertEqual(1, len(t.get_connection().calls)) - self.assertEqual(("POST", "/", None, b"{}"), t.get_connection().calls[0][0]) + assert 1 == len(t.get_connection().calls) + assert ("POST", "/", None, b"{}") == t.get_connection().calls[0][0] def test_client_meta_header(self): t = Transport([{}], connection_class=DummyConnection) t.perform_request("GET", "/", body={}) - self.assertEqual(1, len(t.get_connection().calls)) + assert 1 == len(t.get_connection().calls) headers = t.get_connection().calls[0][1]["headers"] - self.assertRegexpMatches( - headers["x-elastic-client-meta"], r"^es=[0-9.]+p?,py=[0-9.]+p?,t=[0-9.]+p?$" + assert re.search( + r"^es=[0-9.]+p?,py=[0-9.]+p?,t=[0-9.]+p?$", headers["x-elastic-client-meta"] ) class DummyConnectionWithMeta(DummyConnection): @@ -206,21 +203,21 @@ class DummyConnectionWithMeta(DummyConnection): t = Transport([{}], connection_class=DummyConnectionWithMeta) t.perform_request("GET", "/", body={}, headers={"Custom": "header"}) - self.assertEqual(1, len(t.get_connection().calls)) + assert 1 == len(t.get_connection().calls) headers = t.get_connection().calls[0][1]["headers"] - self.assertRegexpMatches( - headers["x-elastic-client-meta"], + assert re.search( r"^es=[0-9.]+p?,py=[0-9.]+p?,t=[0-9.]+p?,dm=1.2.3$", + headers["x-elastic-client-meta"], ) - self.assertEqual(headers["Custom"], "header") + assert headers["Custom"] == "header" def test_client_meta_header_not_sent(self): t = Transport([{}], meta_header=False, connection_class=DummyConnection) t.perform_request("GET", "/", body={}) - self.assertEqual(1, len(t.get_connection().calls)) + assert 1 == len(t.get_connection().calls) headers = t.get_connection().calls[0][1]["headers"] - self.assertIs(headers, None) + assert headers is None def test_meta_header_type_error(self): with pytest.raises(TypeError) as e: @@ -231,39 +228,43 @@ def test_body_gets_encoded_into_bytes(self): t = Transport([{}], connection_class=DummyConnection) t.perform_request("GET", "/", body="你好") - self.assertEqual(1, len(t.get_connection().calls)) - self.assertEqual( - ("GET", "/", None, b"\xe4\xbd\xa0\xe5\xa5\xbd"), - t.get_connection().calls[0][0], - ) + assert 1 == len(t.get_connection().calls) + assert ( + "GET", + "/", + None, + b"\xe4\xbd\xa0\xe5\xa5\xbd", + ) == t.get_connection().calls[0][0] def test_body_bytes_get_passed_untouched(self): t = Transport([{}], connection_class=DummyConnection) body = b"\xe4\xbd\xa0\xe5\xa5\xbd" t.perform_request("GET", "/", body=body) - self.assertEqual(1, len(t.get_connection().calls)) - self.assertEqual(("GET", "/", None, body), t.get_connection().calls[0][0]) + assert 1 == len(t.get_connection().calls) + assert ("GET", "/", None, body) == t.get_connection().calls[0][0] def test_body_surrogates_replaced_encoded_into_bytes(self): t = Transport([{}], connection_class=DummyConnection) t.perform_request("GET", "/", body="你好\uda6a") - self.assertEqual(1, len(t.get_connection().calls)) - self.assertEqual( - ("GET", "/", None, b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"), - t.get_connection().calls[0][0], - ) + assert 1 == len(t.get_connection().calls) + assert ( + "GET", + "/", + None, + b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa", + ) == t.get_connection().calls[0][0] def test_kwargs_passed_on_to_connections(self): t = Transport([{"host": "google.com"}], port=123) - self.assertEqual(1, len(t.connection_pool.connections)) - self.assertEqual("http://google.com:123", t.connection_pool.connections[0].host) + assert 1 == len(t.connection_pool.connections) + assert "http://google.com:123" == t.connection_pool.connections[0].host def test_kwargs_passed_on_to_connection_pool(self): dt = object() t = Transport([{}, {}], dead_timeout=dt) - self.assertIs(dt, t.connection_pool.dead_timeout) + assert dt is t.connection_pool.dead_timeout def test_custom_connection_class(self): class MyConnection(object): @@ -271,17 +272,15 @@ def __init__(self, **kwargs): self.kwargs = kwargs t = Transport([{}], connection_class=MyConnection) - self.assertEqual(1, len(t.connection_pool.connections)) - self.assertIsInstance(t.connection_pool.connections[0], MyConnection) + assert 1 == len(t.connection_pool.connections) + assert isinstance(t.connection_pool.connections[0], MyConnection) def test_add_connection(self): t = Transport([{}], randomize_hosts=False) t.add_connection({"host": "google.com", "port": 1234}) - self.assertEqual(2, len(t.connection_pool.connections)) - self.assertEqual( - "http://google.com:1234", t.connection_pool.connections[1].host - ) + assert 2 == len(t.connection_pool.connections) + assert "http://google.com:1234" == t.connection_pool.connections[1].host def test_request_will_fail_after_X_retries(self): t = Transport( @@ -289,8 +288,9 @@ def test_request_will_fail_after_X_retries(self): connection_class=DummyConnection, ) - self.assertRaises(ConnectionError, t.perform_request, "GET", "/") - self.assertEqual(4, len(t.get_connection().calls)) + with pytest.raises(ConnectionError): + t.perform_request("GET", "/") + assert 4 == len(t.get_connection().calls) def test_failed_connection_will_be_marked_as_dead(self): t = Transport( @@ -298,8 +298,9 @@ def test_failed_connection_will_be_marked_as_dead(self): connection_class=DummyConnection, ) - self.assertRaises(ConnectionError, t.perform_request, "GET", "/") - self.assertEqual(0, len(t.connection_pool.connections)) + with pytest.raises(ConnectionError): + t.perform_request("GET", "/") + assert 0 == len(t.connection_pool.connections) def test_resurrected_connection_will_be_marked_as_live_on_success(self): for method in ("GET", "HEAD"): @@ -310,16 +311,16 @@ def test_resurrected_connection_will_be_marked_as_live_on_success(self): t.connection_pool.mark_dead(con2) t.perform_request(method, "/") - self.assertEqual(1, len(t.connection_pool.connections)) - self.assertEqual(1, len(t.connection_pool.dead_count)) + assert 1 == len(t.connection_pool.connections) + assert 1 == len(t.connection_pool.dead_count) def test_sniff_will_use_seed_connections(self): t = Transport([{"data": CLUSTER_NODES}], connection_class=DummyConnection) t.set_connections([{"data": "invalid"}]) t.sniff_hosts() - self.assertEqual(1, len(t.connection_pool.connections)) - self.assertEqual("http://1.1.1.1:123", t.get_connection().host) + assert 1 == len(t.connection_pool.connections) + assert "http://1.1.1.1:123" == t.get_connection().host def test_sniff_on_start_fetches_and_uses_nodes_list(self): t = Transport( @@ -327,8 +328,8 @@ def test_sniff_on_start_fetches_and_uses_nodes_list(self): connection_class=DummyConnection, sniff_on_start=True, ) - self.assertEqual(1, len(t.connection_pool.connections)) - self.assertEqual("http://1.1.1.1:123", t.get_connection().host) + assert 1 == len(t.connection_pool.connections) + assert "http://1.1.1.1:123" == t.get_connection().host def test_sniff_on_start_ignores_sniff_timeout(self): t = Transport( @@ -337,10 +338,9 @@ def test_sniff_on_start_ignores_sniff_timeout(self): sniff_on_start=True, sniff_timeout=12, ) - self.assertEqual( - (("GET", "/_nodes/_all/http"), {"timeout": None}), - t.seed_connections[0].calls[0], - ) + assert (("GET", "/_nodes/_all/http"), {"timeout": None}) == t.seed_connections[ + 0 + ].calls[0] def test_sniff_uses_sniff_timeout(self): t = Transport( @@ -349,10 +349,9 @@ def test_sniff_uses_sniff_timeout(self): sniff_timeout=42, ) t.sniff_hosts() - self.assertEqual( - (("GET", "/_nodes/_all/http"), {"timeout": 42}), - t.seed_connections[0].calls[0], - ) + assert (("GET", "/_nodes/_all/http"), {"timeout": 42}) == t.seed_connections[ + 0 + ].calls[0] def test_sniff_reuses_connection_instances_if_possible(self): t = Transport( @@ -363,8 +362,8 @@ def test_sniff_reuses_connection_instances_if_possible(self): connection = t.connection_pool.connections[1] t.sniff_hosts() - self.assertEqual(1, len(t.connection_pool.connections)) - self.assertIs(connection, t.get_connection()) + assert 1 == len(t.connection_pool.connections) + assert connection is t.get_connection() def test_sniff_on_fail_triggers_sniffing_on_fail(self): t = Transport( @@ -375,9 +374,10 @@ def test_sniff_on_fail_triggers_sniffing_on_fail(self): randomize_hosts=False, ) - self.assertRaises(ConnectionError, t.perform_request, "GET", "/") - self.assertEqual(1, len(t.connection_pool.connections)) - self.assertEqual("http://1.1.1.1:123", t.get_connection().host) + with pytest.raises(ConnectionError): + t.perform_request("GET", "/") + assert 1 == len(t.connection_pool.connections) + assert "http://1.1.1.1:123" == t.get_connection().host @patch("elasticsearch.transport.Transport.sniff_hosts") def test_sniff_on_fail_failing_does_not_prevent_retires(self, sniff_hosts): @@ -392,10 +392,10 @@ def test_sniff_on_fail_failing_does_not_prevent_retires(self, sniff_hosts): conn_err, conn_data = t.connection_pool.connections response = t.perform_request("GET", "/") - self.assertEqual(json.loads(CLUSTER_NODES), response) - self.assertEqual(1, sniff_hosts.call_count) - self.assertEqual(1, len(conn_err.calls)) - self.assertEqual(1, len(conn_data.calls)) + assert json.loads(CLUSTER_NODES) == response + assert 1 == sniff_hosts.call_count + assert 1 == len(conn_err.calls) + assert 1 == len(conn_data.calls) def test_sniff_after_n_seconds(self): t = Transport( @@ -406,14 +406,14 @@ def test_sniff_after_n_seconds(self): for _ in range(4): t.perform_request("GET", "/") - self.assertEqual(1, len(t.connection_pool.connections)) - self.assertIsInstance(t.get_connection(), DummyConnection) + assert 1 == len(t.connection_pool.connections) + assert isinstance(t.get_connection(), DummyConnection) t.last_sniff = time.time() - 5.1 t.perform_request("GET", "/") - self.assertEqual(1, len(t.connection_pool.connections)) - self.assertEqual("http://1.1.1.1:123", t.get_connection().host) - self.assertTrue(time.time() - 1 < t.last_sniff < time.time() + 0.01) + assert 1 == len(t.connection_pool.connections) + assert "http://1.1.1.1:123" == t.get_connection().host + assert time.time() - 1 < t.last_sniff < time.time() + 0.01 def test_sniff_7x_publish_host(self): # Test the response shaped when a 7.x node has publish_host set @@ -425,10 +425,10 @@ def test_sniff_7x_publish_host(self): ) t.sniff_hosts() # Ensure we parsed out the fqdn and port from the fqdn/ip:port string. - self.assertEqual( - t.connection_pool.connection_opts[0][1], - {"host": "somehost.tld", "port": 123}, - ) + assert t.connection_pool.connection_opts[0][1] == { + "host": "somehost.tld", + "port": 123, + } @patch("elasticsearch.transport.Transport.sniff_hosts") def test_sniffing_disabled_on_cloud_instances(self, sniff_hosts): @@ -439,8 +439,8 @@ def test_sniffing_disabled_on_cloud_instances(self, sniff_hosts): cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", ) - self.assertFalse(t.sniff_on_connection_fail) - self.assertIs(sniff_hosts.call_args, None) # Assert not called. + assert not t.sniff_on_connection_fail + assert sniff_hosts.call_args is None # Assert not called. @pytest.mark.parametrize("headers", [{}, {"X-elastic-product": "BAD HEADER"}]) diff --git a/test_elasticsearch/utils.py b/test_elasticsearch/utils.py index ef6305de0..8468b3541 100644 --- a/test_elasticsearch/utils.py +++ b/test_elasticsearch/utils.py @@ -15,10 +15,84 @@ # specific language governing permissions and limitations # under the License. +import os +import re import time +from pathlib import Path +from typing import Optional, Tuple from elasticsearch import Elasticsearch, NotFoundError, RequestError -from elasticsearch.helpers.test import es_version + +SOURCE_DIR = Path(__file__).absolute().parent.parent +CA_CERTS = str(SOURCE_DIR / ".ci/certs/ca.crt") + + +def es_url() -> str: + """Grabs an Elasticsearch URL which can be designated via + an environment variable otherwise falls back on localhost. + """ + urls_to_try = [] + + # Try user-supplied URLs before generic localhost ones. + if os.environ.get("ELASTICSEARCH_URL", ""): + urls_to_try.append(os.environ["ELASTICSEARCH_URL"]) + if os.environ.get("elasticsearch_url", ""): + urls_to_try.append(os.environ["elasticsearch_url"]) + urls_to_try.extend( + [ + "https://localhost:9200", + "http://localhost:9200", + "https://elastic:changeme@localhost:9200", + "http://elastic:changeme@localhost:9200", + ] + ) + + error = None + for url in urls_to_try: + client = Elasticsearch(url, timeout=3, ca_certs=CA_CERTS) + try: + # Check that we get any sort of connection first. + client.info() + + # After we get a connection let's wait for the cluster + # to be in 'yellow' state, otherwise we could start + # tests too early and get failures. + for _ in range(100): + try: + client.cluster.health(wait_for_status="yellow") + break + except ConnectionError: + time.sleep(0.1) + + except Exception as e: + if error is None: + error = str(e) + else: + successful_url = url + break + else: + raise RuntimeError( + "Could not connect to Elasticsearch (tried %s): %s" + % (", ".join(urls_to_try), error) + ) + return successful_url + + +def es_version(client) -> Tuple[int, ...]: + """Determines the version number and parses the number as a tuple of ints""" + resp = client.info() + return parse_version(resp["version"]["number"]) + + +def parse_version(version: Optional[str]) -> Optional[Tuple[int, ...]]: + """Parses a version number string into it's major, minor, patch as a tuple""" + if not version: + return None + version_number = tuple( + int(x) + for x in re.search(r"^([0-9]+(?:\.[0-9]+)*)", version).group(1).split(".") + ) + return version_number def wipe_cluster(client):