Skip to content

Commit 8d8f618

Browse files
committed
Add test suite for async API
1 parent 719d1e4 commit 8d8f618

File tree

5 files changed

+333
-1
lines changed

5 files changed

+333
-1
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
5+
import os
6+
import pytest
7+
import asyncio
8+
import elasticsearch
9+
10+
pytestmark = pytest.mark.asyncio
11+
12+
13+
@pytest.fixture(scope="function")
14+
async def async_client():
15+
client = None
16+
try:
17+
if not hasattr(elasticsearch, "AsyncElasticsearch"):
18+
pytest.skip("test requires 'AsyncElasticsearch'")
19+
20+
kw = {
21+
"timeout": 30,
22+
"ca_certs": ".ci/certs/ca.pem",
23+
"connection_class": elasticsearch.AIOHttpConnection,
24+
}
25+
26+
client = elasticsearch.AsyncElasticsearch(
27+
[os.environ.get("ELASTICSEARCH_HOST", {})], **kw
28+
)
29+
30+
# wait for yellow status
31+
for _ in range(100):
32+
try:
33+
await client.cluster.health(wait_for_status="yellow")
34+
break
35+
except ConnectionError:
36+
await asyncio.sleep(0.1)
37+
else:
38+
# timeout
39+
pytest.skip("Elasticsearch failed to start.")
40+
41+
yield client
42+
43+
finally:
44+
if client:
45+
version = tuple(
46+
[
47+
int(x) if x.isdigit() else 999
48+
for x in (await client.info())["version"]["number"].split(".")
49+
]
50+
)
51+
52+
expand_wildcards = ["open", "closed"]
53+
if version >= (7, 7):
54+
expand_wildcards.append("hidden")
55+
56+
await client.indices.delete(
57+
index="*", ignore=404, expand_wildcards=expand_wildcards
58+
)
59+
await client.indices.delete_template(name="*", ignore=404)
60+
await client.indices.delete_index_template(name="*", ignore=404)
61+
await client.close()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# -*- coding: utf-8 -*-
2+
# Licensed to Elasticsearch B.V under one or more agreements.
3+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
4+
# See the LICENSE file in the project root for more information
5+
6+
from __future__ import unicode_literals
7+
import pytest
8+
9+
pytestmark = pytest.mark.asyncio
10+
11+
12+
class TestUnicode:
13+
async def test_indices_analyze(self, async_client):
14+
await async_client.indices.analyze(body='{"text": "привет"}')
15+
16+
17+
class TestBulk:
18+
async def test_bulk_works_with_string_body(self, async_client):
19+
docs = '{ "index" : { "_index" : "bulk_test_index", "_id" : "1" } }\n{"answer": 42}'
20+
response = await async_client.bulk(body=docs)
21+
22+
assert response["errors"] is False
23+
assert len(response["items"]) == 1
24+
25+
async def test_bulk_works_with_bytestring_body(self, async_client):
26+
docs = b'{ "index" : { "_index" : "bulk_test_index", "_id" : "2" } }\n{"answer": 42}'
27+
response = await async_client.bulk(body=docs)
28+
29+
assert response["errors"] is False
30+
assert len(response["items"]) == 1
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
5+
"""
6+
Dynamically generated set of TestCases based on set of yaml files decribing
7+
some integration tests. These files are shared among all official Elasticsearch
8+
clients.
9+
"""
10+
import pytest
11+
import re
12+
from shutil import rmtree
13+
import warnings
14+
import inspect
15+
16+
from elasticsearch import TransportError, RequestError, ElasticsearchDeprecationWarning
17+
from elasticsearch.helpers.test import _get_version
18+
from ...test_server.test_rest_api_spec import (
19+
YamlRunner,
20+
YAML_TEST_SPECS,
21+
InvalidActionType,
22+
)
23+
24+
pytestmark = pytest.mark.asyncio
25+
26+
# some params had to be changed in python, keep track of them so we can rename
27+
# those in the tests accordingly
28+
PARAMS_RENAMES = {"type": "doc_type", "from": "from_"}
29+
30+
# mapping from catch values to http status codes
31+
CATCH_CODES = {"missing": 404, "conflict": 409, "unauthorized": 401}
32+
33+
# test features we have implemented
34+
IMPLEMENTED_FEATURES = {
35+
"gtelte",
36+
"stash_in_path",
37+
"headers",
38+
"catch_unauthorized",
39+
"default_shards",
40+
"warnings",
41+
}
42+
43+
XPACK_FEATURES = None
44+
ES_VERSION = None
45+
46+
47+
async def await_if_coro(x):
48+
if inspect.iscoroutine(x):
49+
return await x
50+
return x
51+
52+
53+
class AsyncYamlRunner(YamlRunner):
54+
async def setup(self):
55+
if self._setup_code:
56+
await self.run_code(self._setup_code)
57+
58+
async def teardown(self):
59+
if self._teardown_code:
60+
await self.run_code(self._teardown_code)
61+
62+
for repo, definition in (
63+
await self.client.snapshot.get_repository(repository="_all")
64+
).items():
65+
await self.client.snapshot.delete_repository(repository=repo)
66+
if definition["type"] == "fs":
67+
rmtree(
68+
"/tmp/%s" % definition["settings"]["location"], ignore_errors=True
69+
)
70+
71+
# stop and remove all ML stuff
72+
if await self._feature_enabled("ml"):
73+
await self.client.ml.stop_datafeed(datafeed_id="*", force=True)
74+
for feed in (await self.client.ml.get_datafeeds(datafeed_id="*"))[
75+
"datafeeds"
76+
]:
77+
await self.client.ml.delete_datafeed(datafeed_id=feed["datafeed_id"])
78+
79+
await self.client.ml.close_job(job_id="*", force=True)
80+
for job in (await self.client.ml.get_jobs(job_id="*"))["jobs"]:
81+
await self.client.ml.delete_job(
82+
job_id=job["job_id"], wait_for_completion=True, force=True
83+
)
84+
85+
# stop and remove all Rollup jobs
86+
if await self._feature_enabled("rollup"):
87+
for rollup in (await self.client.rollup.get_jobs(id="*"))["jobs"]:
88+
await self.client.rollup.stop_job(
89+
id=rollup["config"]["id"], wait_for_completion=True
90+
)
91+
await self.client.rollup.delete_job(id=rollup["config"]["id"])
92+
93+
async def es_version(self):
94+
global ES_VERSION
95+
if ES_VERSION is None:
96+
version_string = (await self.client.info())["version"]["number"]
97+
if "." not in version_string:
98+
return ()
99+
version = version_string.strip().split(".")
100+
ES_VERSION = tuple(int(v) if v.isdigit() else 999 for v in version)
101+
return ES_VERSION
102+
103+
async def run(self):
104+
try:
105+
await self.setup()
106+
await self.run_code(self._run_code)
107+
finally:
108+
await self.teardown()
109+
110+
async def run_code(self, test):
111+
""" Execute an instruction based on it's type. """
112+
print(test)
113+
for action in test:
114+
assert len(action) == 1
115+
action_type, action = list(action.items())[0]
116+
117+
if hasattr(self, "run_" + action_type):
118+
await await_if_coro(getattr(self, "run_" + action_type)(action))
119+
else:
120+
raise InvalidActionType(action_type)
121+
122+
async def run_do(self, action):
123+
api = self.client
124+
headers = action.pop("headers", None)
125+
catch = action.pop("catch", None)
126+
warn = action.pop("warnings", ())
127+
assert len(action) == 1
128+
129+
method, args = list(action.items())[0]
130+
args["headers"] = headers
131+
132+
# locate api endpoint
133+
for m in method.split("."):
134+
assert hasattr(api, m)
135+
api = getattr(api, m)
136+
137+
# some parameters had to be renamed to not clash with python builtins,
138+
# compensate
139+
for k in PARAMS_RENAMES:
140+
if k in args:
141+
args[PARAMS_RENAMES[k]] = args.pop(k)
142+
143+
# resolve vars
144+
for k in args:
145+
args[k] = self._resolve(args[k])
146+
147+
warnings.simplefilter("always", category=ElasticsearchDeprecationWarning)
148+
with warnings.catch_warnings(record=True) as caught_warnings:
149+
try:
150+
self.last_response = await api(**args)
151+
except Exception as e:
152+
if not catch:
153+
raise
154+
self.run_catch(catch, e)
155+
else:
156+
if catch:
157+
raise AssertionError(
158+
"Failed to catch %r in %r." % (catch, self.last_response)
159+
)
160+
161+
# Filter out warnings raised by other components.
162+
caught_warnings = [
163+
str(w.message)
164+
for w in caught_warnings
165+
if w.category == ElasticsearchDeprecationWarning
166+
]
167+
168+
# Sorting removes the issue with order raised. We only care about
169+
# if all warnings are raised in the single API call.
170+
if sorted(warn) != sorted(caught_warnings):
171+
raise AssertionError(
172+
"Expected warnings not equal to actual warnings: expected=%r actual=%r"
173+
% (warn, caught_warnings)
174+
)
175+
176+
def run_catch(self, catch, exception):
177+
if catch == "param":
178+
assert isinstance(exception, TypeError)
179+
return
180+
181+
assert isinstance(exception, TransportError)
182+
if catch in CATCH_CODES:
183+
assert CATCH_CODES[catch] == exception.status_code
184+
elif catch[0] == "/" and catch[-1] == "/":
185+
assert (
186+
re.search(catch[1:-1], exception.error + " " + repr(exception.info)),
187+
"%s not in %r" % (catch, exception.info),
188+
) is not None
189+
self.last_response = exception.info
190+
191+
async def run_skip(self, skip):
192+
global IMPLEMENTED_FEATURES
193+
194+
if "features" in skip:
195+
features = skip["features"]
196+
if not isinstance(features, (tuple, list)):
197+
features = [features]
198+
for feature in features:
199+
if feature in IMPLEMENTED_FEATURES:
200+
continue
201+
pytest.skip("feature '%s' is not supported" % feature)
202+
203+
if "version" in skip:
204+
version, reason = skip["version"], skip["reason"]
205+
if version == "all":
206+
pytest.skip(reason)
207+
min_version, max_version = version.split("-")
208+
min_version = _get_version(min_version) or (0,)
209+
max_version = _get_version(max_version) or (999,)
210+
if min_version <= (await self.es_version()) <= max_version:
211+
pytest.skip(reason)
212+
213+
async def _feature_enabled(self, name):
214+
global XPACK_FEATURES, IMPLEMENTED_FEATURES
215+
if XPACK_FEATURES is None:
216+
try:
217+
xinfo = await self.client.xpack.info()
218+
XPACK_FEATURES = set(
219+
f for f in xinfo["features"] if xinfo["features"][f]["enabled"]
220+
)
221+
IMPLEMENTED_FEATURES.add("xpack")
222+
except RequestError:
223+
XPACK_FEATURES = set()
224+
IMPLEMENTED_FEATURES.add("no_xpack")
225+
return name in XPACK_FEATURES
226+
227+
228+
@pytest.fixture(scope="function")
229+
def runner(async_client):
230+
return AsyncYamlRunner(async_client)
231+
232+
233+
@pytest.mark.parametrize("test_spec", YAML_TEST_SPECS)
234+
async def test_rest_api_spec(test_spec, runner):
235+
if test_spec.get("skip", False):
236+
pytest.skip("Manually skipped in 'SKIP_TESTS'")
237+
runner.use_spec(test_spec)
238+
await runner.run()

test_elasticsearch/test_server/test_rest_api_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(self, client):
7777
def use_spec(self, test_spec):
7878
self._setup_code = test_spec.pop("setup", None)
7979
self._run_code = test_spec.pop("run", None)
80-
self._teardown_code = test_spec.pop("teardown")
80+
self._teardown_code = test_spec.pop("teardown", None)
8181

8282
def setup(self):
8383
if self._setup_code:

0 commit comments

Comments
 (0)