Skip to content

Commit 2f32100

Browse files
committed
Add AIOHttpConnection
1 parent f0ebc12 commit 2f32100

File tree

10 files changed

+661
-1
lines changed

10 files changed

+661
-1
lines changed

dev-requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,7 @@ pandas
1414
pyyaml<5.3
1515

1616
black; python_version>="3.6"
17+
18+
# Requirements for testing [async] extra
19+
aiohttp; python_version>="3.6"
20+
pytest-asyncio; python_version>="3.6"

elasticsearch/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
__version__ = VERSION
1010
__versionstr__ = ".".join(map(str, VERSION))
1111

12+
import sys
1213
import logging
1314
import warnings
1415

@@ -64,3 +65,14 @@
6465
"AuthorizationException",
6566
"ElasticsearchDeprecationWarning",
6667
]
68+
69+
try:
70+
# Asyncio only supported on Python 3.6+
71+
if sys.version_info < (3, 6):
72+
raise ImportError
73+
74+
from ._async.http_aiohttp import AIOHttpConnection
75+
76+
__all__ += ["AIOHttpConnection"]
77+
except (ImportError, SyntaxError):
78+
pass

elasticsearch/_async/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information

elasticsearch/_async/compat.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
5+
import asyncio
6+
7+
# Hack supporting Python 3.6 asyncio which didn't have 'get_running_loop()'.
8+
# Essentially we want to get away from having users pass in a loop to us.
9+
# Instead we should call 'get_running_loop()' whenever we need
10+
# the currently running loop.
11+
# See: https://aiopg.readthedocs.io/en/stable/run_loop.html#implementation
12+
try:
13+
from asyncio import get_running_loop
14+
except ImportError:
15+
16+
def get_running_loop():
17+
loop = asyncio.get_event_loop()
18+
if not loop.is_running():
19+
raise RuntimeError("no running event loop")
20+
return loop
21+
22+
23+
__all__ = ["get_running_loop"]

elasticsearch/_async/http_aiohttp.py

+298
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
5+
import asyncio
6+
import ssl
7+
import os
8+
import urllib3
9+
import warnings
10+
11+
import aiohttp
12+
import yarl
13+
from aiohttp.client_exceptions import ServerFingerprintMismatch, ServerTimeoutError
14+
15+
from .compat import get_running_loop
16+
from ..connection import Connection
17+
from ..compat import urlencode
18+
from ..exceptions import (
19+
ConnectionError,
20+
ConnectionTimeout,
21+
ImproperlyConfigured,
22+
SSLError,
23+
)
24+
25+
26+
# sentinel value for `verify_certs`.
27+
# This is used to detect if a user is passing in a value
28+
# for SSL kwargs if also using an SSLContext.
29+
VERIFY_CERTS_DEFAULT = object()
30+
SSL_SHOW_WARN_DEFAULT = object()
31+
32+
CA_CERTS = None
33+
34+
try:
35+
import certifi
36+
37+
CA_CERTS = certifi.where()
38+
except ImportError:
39+
pass
40+
41+
42+
class AIOHttpConnection(Connection):
43+
def __init__(
44+
self,
45+
host="localhost",
46+
port=None,
47+
http_auth=None,
48+
use_ssl=False,
49+
verify_certs=VERIFY_CERTS_DEFAULT,
50+
ssl_show_warn=SSL_SHOW_WARN_DEFAULT,
51+
ca_certs=None,
52+
client_cert=None,
53+
client_key=None,
54+
ssl_version=None,
55+
ssl_assert_fingerprint=None,
56+
maxsize=10,
57+
headers=None,
58+
ssl_context=None,
59+
http_compress=None,
60+
cloud_id=None,
61+
api_key=None,
62+
opaque_id=None,
63+
loop=None,
64+
**kwargs,
65+
):
66+
"""
67+
Default connection class for ``AsyncElasticsearch`` using the `aiohttp` library and the http protocol.
68+
69+
:arg host: hostname of the node (default: localhost)
70+
:arg port: port to use (integer, default: 9200)
71+
:arg timeout: default timeout in seconds (float, default: 10)
72+
:arg http_auth: optional http auth information as either ':' separated
73+
string or a tuple
74+
:arg use_ssl: use ssl for the connection if `True`
75+
:arg verify_certs: whether to verify SSL certificates
76+
:arg ssl_show_warn: show warning when verify certs is disabled
77+
:arg ca_certs: optional path to CA bundle.
78+
See https://urllib3.readthedocs.io/en/latest/security.html#using-certifi-with-urllib3
79+
for instructions how to get default set
80+
:arg client_cert: path to the file containing the private key and the
81+
certificate, or cert only if using client_key
82+
:arg client_key: path to the file containing the private key if using
83+
separate cert and key files (client_cert will contain only the cert)
84+
:arg ssl_version: version of the SSL protocol to use. Choices are:
85+
SSLv23 (default) SSLv2 SSLv3 TLSv1 (see ``PROTOCOL_*`` constants in the
86+
``ssl`` module for exact options for your environment).
87+
:arg ssl_assert_hostname: use hostname verification if not `False`
88+
:arg ssl_assert_fingerprint: verify the supplied certificate fingerprint if not `None`
89+
:arg maxsize: the number of connections which will be kept open to this
90+
host. See https://urllib3.readthedocs.io/en/1.4/pools.html#api for more
91+
information.
92+
:arg headers: any custom http headers to be add to requests
93+
:arg http_compress: Use gzip compression
94+
:arg cloud_id: The Cloud ID from ElasticCloud. Convenient way to connect to cloud instances.
95+
Other host connection params will be ignored.
96+
:arg api_key: optional API Key authentication as either base64 encoded string or a tuple.
97+
:arg opaque_id: Send this value in the 'X-Opaque-Id' HTTP header
98+
For tracing all requests made by this transport.
99+
:arg loop: asyncio Event Loop to use with aiohttp. This is set by default to the currently running loop.
100+
"""
101+
102+
self.headers = {}
103+
104+
super().__init__(
105+
host=host,
106+
port=port,
107+
use_ssl=use_ssl,
108+
headers=headers,
109+
http_compress=http_compress,
110+
cloud_id=cloud_id,
111+
api_key=api_key,
112+
opaque_id=opaque_id,
113+
**kwargs,
114+
)
115+
116+
if http_auth is not None:
117+
if isinstance(http_auth, (tuple, list)):
118+
http_auth = ":".join(http_auth)
119+
self.headers.update(urllib3.make_headers(basic_auth=http_auth))
120+
121+
# if providing an SSL context, raise error if any other SSL related flag is used
122+
if ssl_context and (
123+
(verify_certs is not VERIFY_CERTS_DEFAULT)
124+
or (ssl_show_warn is not SSL_SHOW_WARN_DEFAULT)
125+
or ca_certs
126+
or client_cert
127+
or client_key
128+
or ssl_version
129+
):
130+
warnings.warn(
131+
"When using `ssl_context`, all other SSL related kwargs are ignored"
132+
)
133+
134+
self.ssl_assert_fingerprint = ssl_assert_fingerprint
135+
if self.use_ssl and ssl_context is None:
136+
ssl_context = ssl.SSLContext(ssl_version or ssl.PROTOCOL_TLS)
137+
138+
# Convert all sentinel values to their actual default
139+
# values if not using an SSLContext.
140+
if verify_certs is VERIFY_CERTS_DEFAULT:
141+
verify_certs = True
142+
if ssl_show_warn is SSL_SHOW_WARN_DEFAULT:
143+
ssl_show_warn = True
144+
145+
if verify_certs:
146+
ssl_context.verify_mode = ssl.CERT_REQUIRED
147+
ssl_context.check_hostname = True
148+
else:
149+
ssl_context.verify_mode = ssl.CERT_NONE
150+
ssl_context.check_hostname = False
151+
152+
ca_certs = CA_CERTS if ca_certs is None else ca_certs
153+
if verify_certs:
154+
if not ca_certs:
155+
raise ImproperlyConfigured(
156+
"Root certificates are missing for certificate "
157+
"validation. Either pass them in using the ca_certs parameter or "
158+
"install certifi to use it automatically."
159+
)
160+
else:
161+
if ssl_show_warn:
162+
warnings.warn(
163+
"Connecting to %s using SSL with verify_certs=False is insecure."
164+
% self.host
165+
)
166+
167+
if os.path.isfile(ca_certs):
168+
ssl_context.load_verify_locations(cafile=ca_certs)
169+
elif os.path.isdir(ca_certs):
170+
ssl_context.load_verify_locations(capath=ca_certs)
171+
else:
172+
raise ImproperlyConfigured("ca_certs parameter is not a path")
173+
174+
self.headers.setdefault("connection", "keep-alive")
175+
self.loop = loop
176+
self.session = None
177+
178+
# Parameters for creating an aiohttp.ClientSession later.
179+
self._limit = maxsize
180+
self._http_auth = http_auth
181+
self._ssl_context = ssl_context
182+
183+
async def perform_request(
184+
self, method, url, params=None, body=None, timeout=None, ignore=(), headers=None
185+
):
186+
if self.session is None:
187+
await self._create_aiohttp_session()
188+
189+
orig_body = body
190+
url_path = url
191+
if params:
192+
query_string = urlencode(params)
193+
else:
194+
query_string = ""
195+
196+
# There is a bug in aiohttp that disables the re-use
197+
# of the connection in the pool when method=HEAD.
198+
# See: aio-libs/aiohttp#1769
199+
is_head = False
200+
if method == "HEAD":
201+
method = "GET"
202+
is_head = True
203+
204+
# Provide correct URL object to avoid string parsing in low-level code
205+
url = yarl.URL.build(
206+
scheme=self.scheme,
207+
host=self.hostname,
208+
port=self.port,
209+
path=url,
210+
query_string=query_string,
211+
encoded=True,
212+
)
213+
214+
timeout = aiohttp.ClientTimeout(
215+
total=timeout if timeout is not None else self.timeout
216+
)
217+
218+
req_headers = self.headers.copy()
219+
if headers:
220+
req_headers.update(headers)
221+
222+
if self.http_compress and body:
223+
body = self._gzip_compress(body)
224+
req_headers["content-encoding"] = "gzip"
225+
226+
start = self.loop.time()
227+
try:
228+
async with self.session.request(
229+
method,
230+
url,
231+
data=body,
232+
headers=req_headers,
233+
timeout=timeout,
234+
fingerprint=self.ssl_assert_fingerprint,
235+
) as response:
236+
if is_head: # We actually called 'GET' so throw away the data.
237+
await response.release()
238+
raw_data = ""
239+
else:
240+
raw_data = (await response.read()).decode("utf-8", "surrogatepass")
241+
duration = self.loop.time() - start
242+
243+
# We want to reraise a cancellation.
244+
except asyncio.CancelledError:
245+
raise
246+
247+
except Exception as e:
248+
self.log_request_fail(
249+
method, url, url_path, orig_body, self.loop.time() - start, exception=e
250+
)
251+
if isinstance(e, ServerFingerprintMismatch):
252+
raise SSLError("N/A", str(e), e)
253+
if isinstance(e, (asyncio.TimeoutError, ServerTimeoutError)):
254+
raise ConnectionTimeout("TIMEOUT", str(e), e)
255+
raise ConnectionError("N/A", str(e), e)
256+
257+
# raise errors based on http status codes, let the client handle those if needed
258+
if not (200 <= response.status < 300) and response.status not in ignore:
259+
self.log_request_fail(
260+
method,
261+
url,
262+
url_path,
263+
orig_body,
264+
duration,
265+
status_code=response.status,
266+
response=raw_data,
267+
)
268+
self._raise_error(response.status, raw_data)
269+
270+
self.log_request_success(
271+
method, url, url_path, orig_body, response.status, raw_data, duration
272+
)
273+
274+
return response.status, response.headers, raw_data
275+
276+
async def close(self):
277+
"""
278+
Explicitly closes connection
279+
"""
280+
if self.session:
281+
await self.session.close()
282+
283+
async def _create_aiohttp_session(self):
284+
"""Creates an aiohttp.ClientSession(). This is delayed until
285+
the first call to perform_request() so that AsyncTransport has
286+
a chance to set AIOHttpConnection.loop
287+
"""
288+
if self.loop is None:
289+
self.loop = get_running_loop()
290+
self.session = aiohttp.ClientSession(
291+
headers=self.headers,
292+
auto_decompress=True,
293+
loop=self.loop,
294+
cookie_jar=aiohttp.DummyCookieJar(),
295+
connector=aiohttp.TCPConnector(
296+
limit=self._limit, use_dns_cache=True, ssl=self._ssl_context,
297+
),
298+
)

elasticsearch/connection/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __init__(
121121
self.use_ssl = use_ssl
122122
self.http_compress = http_compress or False
123123

124+
self.scheme = scheme
124125
self.hostname = host
125126
self.port = port
126127
self.host = "%s://%s" % (scheme, host)

setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"pytest",
2626
"pytest-cov",
2727
]
28+
async_require = ["aiohttp>=3,<4", "yarl"]
2829

2930
docs_require = ["sphinx<1.7", "sphinx_rtd_theme"]
3031
generate_require = ["black", "jinja2"]
@@ -67,5 +68,6 @@
6768
"develop": tests_require + docs_require + generate_require,
6869
"docs": docs_require,
6970
"requests": ["requests>=2.4.0, <3.0.0"],
71+
"async": async_require,
7072
},
7173
)

test_elasticsearch/run_tests.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,13 @@ def run_all(argv=None):
7878
"--log-level=DEBUG",
7979
"--cache-clear",
8080
"-vv",
81-
abspath(dirname(__file__)),
8281
]
8382

83+
if sys.version_info < (3, 6):
84+
argv.append("--ignore=test_elasticsearch/test_async/")
85+
86+
argv.append(abspath(dirname(__file__)),)
87+
8488
exit_code = 0
8589
try:
8690
subprocess.check_call(argv, stdout=sys.stdout, stderr=sys.stderr)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information

0 commit comments

Comments
 (0)