|
| 1 | +# Licensed to Elasticsearch B.V under one or more agreements. |
| 2 | +# Elasticsearch B.V licenses this file to you under the Apache 2.0 License. |
| 3 | +# See the LICENSE file in the project root for more information |
| 4 | + |
| 5 | +""" |
| 6 | +Dynamically generated set of TestCases based on set of yaml files decribing |
| 7 | +some integration tests. These files are shared among all official Elasticsearch |
| 8 | +clients. |
| 9 | +""" |
| 10 | +import pytest |
| 11 | +import re |
| 12 | +from shutil import rmtree |
| 13 | +import warnings |
| 14 | +import inspect |
| 15 | + |
| 16 | +from elasticsearch import TransportError, RequestError, ElasticsearchDeprecationWarning |
| 17 | +from elasticsearch.helpers.test import _get_version |
| 18 | +from ...test_server.test_rest_api_spec import ( |
| 19 | + YamlRunner, |
| 20 | + YAML_TEST_SPECS, |
| 21 | + InvalidActionType, |
| 22 | +) |
| 23 | + |
| 24 | +pytestmark = pytest.mark.asyncio |
| 25 | + |
| 26 | +# some params had to be changed in python, keep track of them so we can rename |
| 27 | +# those in the tests accordingly |
| 28 | +PARAMS_RENAMES = {"type": "doc_type", "from": "from_"} |
| 29 | + |
| 30 | +# mapping from catch values to http status codes |
| 31 | +CATCH_CODES = {"missing": 404, "conflict": 409, "unauthorized": 401} |
| 32 | + |
| 33 | +# test features we have implemented |
| 34 | +IMPLEMENTED_FEATURES = { |
| 35 | + "gtelte", |
| 36 | + "stash_in_path", |
| 37 | + "headers", |
| 38 | + "catch_unauthorized", |
| 39 | + "default_shards", |
| 40 | + "warnings", |
| 41 | +} |
| 42 | + |
| 43 | +XPACK_FEATURES = None |
| 44 | +ES_VERSION = None |
| 45 | + |
| 46 | + |
| 47 | +async def await_if_coro(x): |
| 48 | + if inspect.iscoroutine(x): |
| 49 | + return await x |
| 50 | + return x |
| 51 | + |
| 52 | + |
| 53 | +class AsyncYamlRunner(YamlRunner): |
| 54 | + async def setup(self): |
| 55 | + if self._setup_code: |
| 56 | + await self.run_code(self._setup_code) |
| 57 | + |
| 58 | + async def teardown(self): |
| 59 | + if self._teardown_code: |
| 60 | + await self.run_code(self._teardown_code) |
| 61 | + |
| 62 | + for repo, definition in ( |
| 63 | + await self.client.snapshot.get_repository(repository="_all") |
| 64 | + ).items(): |
| 65 | + await self.client.snapshot.delete_repository(repository=repo) |
| 66 | + if definition["type"] == "fs": |
| 67 | + rmtree( |
| 68 | + "/tmp/%s" % definition["settings"]["location"], ignore_errors=True |
| 69 | + ) |
| 70 | + |
| 71 | + # stop and remove all ML stuff |
| 72 | + if await self._feature_enabled("ml"): |
| 73 | + await self.client.ml.stop_datafeed(datafeed_id="*", force=True) |
| 74 | + for feed in (await self.client.ml.get_datafeeds(datafeed_id="*"))[ |
| 75 | + "datafeeds" |
| 76 | + ]: |
| 77 | + await self.client.ml.delete_datafeed(datafeed_id=feed["datafeed_id"]) |
| 78 | + |
| 79 | + await self.client.ml.close_job(job_id="*", force=True) |
| 80 | + for job in (await self.client.ml.get_jobs(job_id="*"))["jobs"]: |
| 81 | + await self.client.ml.delete_job( |
| 82 | + job_id=job["job_id"], wait_for_completion=True, force=True |
| 83 | + ) |
| 84 | + |
| 85 | + # stop and remove all Rollup jobs |
| 86 | + if await self._feature_enabled("rollup"): |
| 87 | + for rollup in (await self.client.rollup.get_jobs(id="*"))["jobs"]: |
| 88 | + await self.client.rollup.stop_job( |
| 89 | + id=rollup["config"]["id"], wait_for_completion=True |
| 90 | + ) |
| 91 | + await self.client.rollup.delete_job(id=rollup["config"]["id"]) |
| 92 | + |
| 93 | + async def es_version(self): |
| 94 | + global ES_VERSION |
| 95 | + if ES_VERSION is None: |
| 96 | + version_string = (await self.client.info())["version"]["number"] |
| 97 | + if "." not in version_string: |
| 98 | + return () |
| 99 | + version = version_string.strip().split(".") |
| 100 | + ES_VERSION = tuple(int(v) if v.isdigit() else 999 for v in version) |
| 101 | + return ES_VERSION |
| 102 | + |
| 103 | + async def run(self): |
| 104 | + try: |
| 105 | + await self.setup() |
| 106 | + await self.run_code(self._run_code) |
| 107 | + finally: |
| 108 | + await self.teardown() |
| 109 | + |
| 110 | + async def run_code(self, test): |
| 111 | + """ Execute an instruction based on it's type. """ |
| 112 | + print(test) |
| 113 | + for action in test: |
| 114 | + assert len(action) == 1 |
| 115 | + action_type, action = list(action.items())[0] |
| 116 | + |
| 117 | + if hasattr(self, "run_" + action_type): |
| 118 | + await await_if_coro(getattr(self, "run_" + action_type)(action)) |
| 119 | + else: |
| 120 | + raise InvalidActionType(action_type) |
| 121 | + |
| 122 | + async def run_do(self, action): |
| 123 | + api = self.client |
| 124 | + headers = action.pop("headers", None) |
| 125 | + catch = action.pop("catch", None) |
| 126 | + warn = action.pop("warnings", ()) |
| 127 | + assert len(action) == 1 |
| 128 | + |
| 129 | + method, args = list(action.items())[0] |
| 130 | + args["headers"] = headers |
| 131 | + |
| 132 | + # locate api endpoint |
| 133 | + for m in method.split("."): |
| 134 | + assert hasattr(api, m) |
| 135 | + api = getattr(api, m) |
| 136 | + |
| 137 | + # some parameters had to be renamed to not clash with python builtins, |
| 138 | + # compensate |
| 139 | + for k in PARAMS_RENAMES: |
| 140 | + if k in args: |
| 141 | + args[PARAMS_RENAMES[k]] = args.pop(k) |
| 142 | + |
| 143 | + # resolve vars |
| 144 | + for k in args: |
| 145 | + args[k] = self._resolve(args[k]) |
| 146 | + |
| 147 | + warnings.simplefilter("always", category=ElasticsearchDeprecationWarning) |
| 148 | + with warnings.catch_warnings(record=True) as caught_warnings: |
| 149 | + try: |
| 150 | + self.last_response = await api(**args) |
| 151 | + except Exception as e: |
| 152 | + if not catch: |
| 153 | + raise |
| 154 | + self.run_catch(catch, e) |
| 155 | + else: |
| 156 | + if catch: |
| 157 | + raise AssertionError( |
| 158 | + "Failed to catch %r in %r." % (catch, self.last_response) |
| 159 | + ) |
| 160 | + |
| 161 | + # Filter out warnings raised by other components. |
| 162 | + caught_warnings = [ |
| 163 | + str(w.message) |
| 164 | + for w in caught_warnings |
| 165 | + if w.category == ElasticsearchDeprecationWarning |
| 166 | + ] |
| 167 | + |
| 168 | + # Sorting removes the issue with order raised. We only care about |
| 169 | + # if all warnings are raised in the single API call. |
| 170 | + if sorted(warn) != sorted(caught_warnings): |
| 171 | + raise AssertionError( |
| 172 | + "Expected warnings not equal to actual warnings: expected=%r actual=%r" |
| 173 | + % (warn, caught_warnings) |
| 174 | + ) |
| 175 | + |
| 176 | + def run_catch(self, catch, exception): |
| 177 | + if catch == "param": |
| 178 | + assert isinstance(exception, TypeError) |
| 179 | + return |
| 180 | + |
| 181 | + assert isinstance(exception, TransportError) |
| 182 | + if catch in CATCH_CODES: |
| 183 | + assert CATCH_CODES[catch] == exception.status_code |
| 184 | + elif catch[0] == "/" and catch[-1] == "/": |
| 185 | + assert ( |
| 186 | + re.search(catch[1:-1], exception.error + " " + repr(exception.info)), |
| 187 | + "%s not in %r" % (catch, exception.info), |
| 188 | + ) is not None |
| 189 | + self.last_response = exception.info |
| 190 | + |
| 191 | + async def run_skip(self, skip): |
| 192 | + global IMPLEMENTED_FEATURES |
| 193 | + |
| 194 | + if "features" in skip: |
| 195 | + features = skip["features"] |
| 196 | + if not isinstance(features, (tuple, list)): |
| 197 | + features = [features] |
| 198 | + for feature in features: |
| 199 | + if feature in IMPLEMENTED_FEATURES: |
| 200 | + continue |
| 201 | + pytest.skip("feature '%s' is not supported" % feature) |
| 202 | + |
| 203 | + if "version" in skip: |
| 204 | + version, reason = skip["version"], skip["reason"] |
| 205 | + if version == "all": |
| 206 | + pytest.skip(reason) |
| 207 | + min_version, max_version = version.split("-") |
| 208 | + min_version = _get_version(min_version) or (0,) |
| 209 | + max_version = _get_version(max_version) or (999,) |
| 210 | + if min_version <= (await self.es_version()) <= max_version: |
| 211 | + pytest.skip(reason) |
| 212 | + |
| 213 | + async def _feature_enabled(self, name): |
| 214 | + global XPACK_FEATURES, IMPLEMENTED_FEATURES |
| 215 | + if XPACK_FEATURES is None: |
| 216 | + try: |
| 217 | + xinfo = await self.client.xpack.info() |
| 218 | + XPACK_FEATURES = set( |
| 219 | + f for f in xinfo["features"] if xinfo["features"][f]["enabled"] |
| 220 | + ) |
| 221 | + IMPLEMENTED_FEATURES.add("xpack") |
| 222 | + except RequestError: |
| 223 | + XPACK_FEATURES = set() |
| 224 | + IMPLEMENTED_FEATURES.add("no_xpack") |
| 225 | + return name in XPACK_FEATURES |
| 226 | + |
| 227 | + |
| 228 | +@pytest.fixture(scope="function") |
| 229 | +def runner(async_client): |
| 230 | + return AsyncYamlRunner(async_client) |
| 231 | + |
| 232 | + |
| 233 | +@pytest.mark.parametrize("test_spec", YAML_TEST_SPECS) |
| 234 | +async def test_rest_api_spec(test_spec, runner): |
| 235 | + if test_spec.get("skip", False): |
| 236 | + pytest.skip("Manually skipped in 'SKIP_TESTS'") |
| 237 | + runner.use_spec(test_spec) |
| 238 | + await runner.run() |
0 commit comments