Skip to content

Commit 5731264

Browse files
committed
Add test suite for async API
1 parent 719d1e4 commit 5731264

File tree

5 files changed

+344
-1
lines changed

5 files changed

+344
-1
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
5+
import os
6+
import pytest
7+
import asyncio
8+
import elasticsearch
9+
10+
pytestmark = pytest.mark.asyncio
11+
12+
13+
@pytest.fixture(scope="function")
14+
async def async_client():
15+
client = None
16+
try:
17+
if not hasattr(elasticsearch, "AsyncElasticsearch"):
18+
pytest.skip("test requires 'AsyncElasticsearch'")
19+
20+
kw = {
21+
"timeout": 30,
22+
"ca_certs": ".ci/certs/ca.pem",
23+
"connection_class": elasticsearch.AIOHttpConnection,
24+
}
25+
26+
client = elasticsearch.AsyncElasticsearch(
27+
[os.environ.get("ELASTICSEARCH_HOST", {})], **kw
28+
)
29+
30+
# wait for yellow status
31+
for _ in range(100):
32+
try:
33+
await client.cluster.health(wait_for_status="yellow")
34+
break
35+
except ConnectionError:
36+
await asyncio.sleep(0.1)
37+
else:
38+
# timeout
39+
pytest.skip("Elasticsearch failed to start.")
40+
41+
yield client
42+
43+
finally:
44+
if client:
45+
version = tuple(
46+
[
47+
int(x) if x.isdigit() else 999
48+
for x in (await client.info())["version"]["number"].split(".")
49+
]
50+
)
51+
52+
expand_wildcards = ["open", "closed"]
53+
if version >= (7, 7):
54+
expand_wildcards.append("hidden")
55+
56+
await client.indices.delete(
57+
index="*", ignore=404, expand_wildcards=expand_wildcards
58+
)
59+
await client.indices.delete_template(name="*", ignore=404)
60+
await client.indices.delete_index_template(name="*", ignore=404)
61+
await client.close()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# -*- coding: utf-8 -*-
2+
# Licensed to Elasticsearch B.V under one or more agreements.
3+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
4+
# See the LICENSE file in the project root for more information
5+
6+
from __future__ import unicode_literals
7+
import pytest
8+
9+
pytestmark = pytest.mark.asyncio
10+
11+
12+
class TestUnicode:
13+
async def test_indices_analyze(self, async_client):
14+
await async_client.indices.analyze(body='{"text": "привет"}')
15+
16+
17+
class TestBulk:
18+
async def test_bulk_works_with_string_body(self, async_client):
19+
docs = '{ "index" : { "_index" : "bulk_test_index", "_id" : "1" } }\n{"answer": 42}'
20+
response = await async_client.bulk(body=docs)
21+
22+
assert response["errors"] is False
23+
assert len(response["items"]) == 1
24+
25+
async def test_bulk_works_with_bytestring_body(self, async_client):
26+
docs = b'{ "index" : { "_index" : "bulk_test_index", "_id" : "2" } }\n{"answer": 42}'
27+
response = await async_client.bulk(body=docs)
28+
29+
assert response["errors"] is False
30+
assert len(response["items"]) == 1
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
5+
"""
6+
Dynamically generated set of TestCases based on set of yaml files decribing
7+
some integration tests. These files are shared among all official Elasticsearch
8+
clients.
9+
"""
10+
import pytest
11+
import re
12+
from shutil import rmtree
13+
import warnings
14+
import inspect
15+
16+
from elasticsearch import TransportError, RequestError, ElasticsearchDeprecationWarning
17+
from elasticsearch.helpers.test import _get_version
18+
from ...test_server.test_rest_api_spec import (
19+
YamlRunner,
20+
YAML_TEST_SPECS,
21+
InvalidActionType,
22+
RUN_ASYNC_REST_API_TESTS,
23+
)
24+
25+
pytestmark = pytest.mark.asyncio
26+
27+
# some params had to be changed in python, keep track of them so we can rename
28+
# those in the tests accordingly
29+
PARAMS_RENAMES = {"type": "doc_type", "from": "from_"}
30+
31+
# mapping from catch values to http status codes
32+
CATCH_CODES = {"missing": 404, "conflict": 409, "unauthorized": 401}
33+
34+
# test features we have implemented
35+
IMPLEMENTED_FEATURES = {
36+
"gtelte",
37+
"stash_in_path",
38+
"headers",
39+
"catch_unauthorized",
40+
"default_shards",
41+
"warnings",
42+
}
43+
44+
XPACK_FEATURES = None
45+
ES_VERSION = None
46+
47+
48+
async def await_if_coro(x):
49+
if inspect.iscoroutine(x):
50+
return await x
51+
return x
52+
53+
54+
class AsyncYamlRunner(YamlRunner):
55+
async def setup(self):
56+
if self._setup_code:
57+
await self.run_code(self._setup_code)
58+
59+
async def teardown(self):
60+
if self._teardown_code:
61+
await self.run_code(self._teardown_code)
62+
63+
for repo, definition in (
64+
await self.client.snapshot.get_repository(repository="_all")
65+
).items():
66+
await self.client.snapshot.delete_repository(repository=repo)
67+
if definition["type"] == "fs":
68+
rmtree(
69+
"/tmp/%s" % definition["settings"]["location"], ignore_errors=True
70+
)
71+
72+
# stop and remove all ML stuff
73+
if await self._feature_enabled("ml"):
74+
await self.client.ml.stop_datafeed(datafeed_id="*", force=True)
75+
for feed in (await self.client.ml.get_datafeeds(datafeed_id="*"))[
76+
"datafeeds"
77+
]:
78+
await self.client.ml.delete_datafeed(datafeed_id=feed["datafeed_id"])
79+
80+
await self.client.ml.close_job(job_id="*", force=True)
81+
for job in (await self.client.ml.get_jobs(job_id="*"))["jobs"]:
82+
await self.client.ml.delete_job(
83+
job_id=job["job_id"], wait_for_completion=True, force=True
84+
)
85+
86+
# stop and remove all Rollup jobs
87+
if await self._feature_enabled("rollup"):
88+
for rollup in (await self.client.rollup.get_jobs(id="*"))["jobs"]:
89+
await self.client.rollup.stop_job(
90+
id=rollup["config"]["id"], wait_for_completion=True
91+
)
92+
await self.client.rollup.delete_job(id=rollup["config"]["id"])
93+
94+
async def es_version(self):
95+
global ES_VERSION
96+
if ES_VERSION is None:
97+
version_string = (await self.client.info())["version"]["number"]
98+
if "." not in version_string:
99+
return ()
100+
version = version_string.strip().split(".")
101+
ES_VERSION = tuple(int(v) if v.isdigit() else 999 for v in version)
102+
return ES_VERSION
103+
104+
async def run(self):
105+
try:
106+
await self.setup()
107+
await self.run_code(self._run_code)
108+
finally:
109+
await self.teardown()
110+
111+
async def run_code(self, test):
112+
""" Execute an instruction based on it's type. """
113+
print(test)
114+
for action in test:
115+
assert len(action) == 1
116+
action_type, action = list(action.items())[0]
117+
118+
if hasattr(self, "run_" + action_type):
119+
await await_if_coro(getattr(self, "run_" + action_type)(action))
120+
else:
121+
raise InvalidActionType(action_type)
122+
123+
async def run_do(self, action):
124+
api = self.client
125+
headers = action.pop("headers", None)
126+
catch = action.pop("catch", None)
127+
warn = action.pop("warnings", ())
128+
assert len(action) == 1
129+
130+
method, args = list(action.items())[0]
131+
args["headers"] = headers
132+
133+
# locate api endpoint
134+
for m in method.split("."):
135+
assert hasattr(api, m)
136+
api = getattr(api, m)
137+
138+
# some parameters had to be renamed to not clash with python builtins,
139+
# compensate
140+
for k in PARAMS_RENAMES:
141+
if k in args:
142+
args[PARAMS_RENAMES[k]] = args.pop(k)
143+
144+
# resolve vars
145+
for k in args:
146+
args[k] = self._resolve(args[k])
147+
148+
warnings.simplefilter("always", category=ElasticsearchDeprecationWarning)
149+
with warnings.catch_warnings(record=True) as caught_warnings:
150+
try:
151+
self.last_response = await api(**args)
152+
except Exception as e:
153+
if not catch:
154+
raise
155+
self.run_catch(catch, e)
156+
else:
157+
if catch:
158+
raise AssertionError(
159+
"Failed to catch %r in %r." % (catch, self.last_response)
160+
)
161+
162+
# Filter out warnings raised by other components.
163+
caught_warnings = [
164+
str(w.message)
165+
for w in caught_warnings
166+
if w.category == ElasticsearchDeprecationWarning
167+
]
168+
169+
# Sorting removes the issue with order raised. We only care about
170+
# if all warnings are raised in the single API call.
171+
if sorted(warn) != sorted(caught_warnings):
172+
raise AssertionError(
173+
"Expected warnings not equal to actual warnings: expected=%r actual=%r"
174+
% (warn, caught_warnings)
175+
)
176+
177+
def run_catch(self, catch, exception):
178+
if catch == "param":
179+
assert isinstance(exception, TypeError)
180+
return
181+
182+
assert isinstance(exception, TransportError)
183+
if catch in CATCH_CODES:
184+
assert CATCH_CODES[catch] == exception.status_code
185+
elif catch[0] == "/" and catch[-1] == "/":
186+
assert (
187+
re.search(catch[1:-1], exception.error + " " + repr(exception.info)),
188+
"%s not in %r" % (catch, exception.info),
189+
) is not None
190+
self.last_response = exception.info
191+
192+
async def run_skip(self, skip):
193+
global IMPLEMENTED_FEATURES
194+
195+
if "features" in skip:
196+
features = skip["features"]
197+
if not isinstance(features, (tuple, list)):
198+
features = [features]
199+
for feature in features:
200+
if feature in IMPLEMENTED_FEATURES:
201+
continue
202+
pytest.skip("feature '%s' is not supported" % feature)
203+
204+
if "version" in skip:
205+
version, reason = skip["version"], skip["reason"]
206+
if version == "all":
207+
pytest.skip(reason)
208+
min_version, max_version = version.split("-")
209+
min_version = _get_version(min_version) or (0,)
210+
max_version = _get_version(max_version) or (999,)
211+
if min_version <= (await self.es_version()) <= max_version:
212+
pytest.skip(reason)
213+
214+
async def _feature_enabled(self, name):
215+
global XPACK_FEATURES, IMPLEMENTED_FEATURES
216+
if XPACK_FEATURES is None:
217+
try:
218+
xinfo = await self.client.xpack.info()
219+
XPACK_FEATURES = set(
220+
f for f in xinfo["features"] if xinfo["features"][f]["enabled"]
221+
)
222+
IMPLEMENTED_FEATURES.add("xpack")
223+
except RequestError:
224+
XPACK_FEATURES = set()
225+
IMPLEMENTED_FEATURES.add("no_xpack")
226+
return name in XPACK_FEATURES
227+
228+
229+
@pytest.fixture(scope="function")
230+
def runner(async_client):
231+
return AsyncYamlRunner(async_client)
232+
233+
234+
@pytest.mark.parametrize("test_spec", YAML_TEST_SPECS)
235+
async def test_rest_api_spec(test_spec, runner):
236+
if not RUN_ASYNC_REST_API_TESTS:
237+
pytest.skip("Skipped running async REST API tests")
238+
if test_spec.get("skip", False):
239+
pytest.skip("Manually skipped in 'SKIP_TESTS'")
240+
runner.use_spec(test_spec)
241+
await runner.run()

test_elasticsearch/test_server/test_rest_api_spec.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
some integration tests. These files are shared among all official Elasticsearch
88
clients.
99
"""
10+
import sys
1011
import re
12+
import os
1113
from os import walk, environ
1214
from os.path import exists, join, dirname, pardir, relpath
1315
import yaml
@@ -62,6 +64,10 @@
6264

6365
XPACK_FEATURES = None
6466
ES_VERSION = None
67+
RUN_ASYNC_REST_API_TESTS = (
68+
sys.version_info >= (3, 6)
69+
and os.environ.get("PYTHON_CONNECTION_CLASS") == "RequestsHttpConnection"
70+
)
6571

6672

6773
class YamlRunner:
@@ -77,7 +83,7 @@ def __init__(self, client):
7783
def use_spec(self, test_spec):
7884
self._setup_code = test_spec.pop("setup", None)
7985
self._run_code = test_spec.pop("run", None)
80-
self._teardown_code = test_spec.pop("teardown")
86+
self._teardown_code = test_spec.pop("teardown", None)
8187

8288
def setup(self):
8389
if self._setup_code:
@@ -414,6 +420,8 @@ def sync_runner(sync_client):
414420

415421
@pytest.mark.parametrize("test_spec", YAML_TEST_SPECS)
416422
def test_rest_api_spec(test_spec, sync_runner):
423+
if RUN_ASYNC_REST_API_TESTS:
424+
pytest.skip("Skipped running sync REST API tests")
417425
if test_spec.get("skip", False):
418426
pytest.skip("Manually skipped in 'SKIP_TESTS'")
419427
sync_runner.use_spec(test_spec)

0 commit comments

Comments
 (0)