|
| 1 | +# Licensed to Elasticsearch B.V under one or more agreements. |
| 2 | +# Elasticsearch B.V licenses this file to you under the Apache 2.0 License. |
| 3 | +# See the LICENSE file in the project root for more information |
| 4 | + |
| 5 | +import asyncio |
| 6 | +import ssl |
| 7 | +import os |
| 8 | +import urllib3 |
| 9 | +import warnings |
| 10 | + |
| 11 | +import aiohttp |
| 12 | +import yarl |
| 13 | +from aiohttp.client_exceptions import ServerFingerprintMismatch, ServerTimeoutError |
| 14 | + |
| 15 | +from .compat import get_running_loop |
| 16 | +from ..connection import Connection |
| 17 | +from ..compat import urlencode |
| 18 | +from ..exceptions import ( |
| 19 | + ConnectionError, |
| 20 | + ConnectionTimeout, |
| 21 | + ImproperlyConfigured, |
| 22 | + SSLError, |
| 23 | +) |
| 24 | + |
| 25 | + |
| 26 | +# sentinel value for `verify_certs`. |
| 27 | +# This is used to detect if a user is passing in a value |
| 28 | +# for SSL kwargs if also using an SSLContext. |
| 29 | +VERIFY_CERTS_DEFAULT = object() |
| 30 | +SSL_SHOW_WARN_DEFAULT = object() |
| 31 | + |
| 32 | +CA_CERTS = None |
| 33 | + |
| 34 | +try: |
| 35 | + import certifi |
| 36 | + |
| 37 | + CA_CERTS = certifi.where() |
| 38 | +except ImportError: |
| 39 | + pass |
| 40 | + |
| 41 | + |
| 42 | +class AIOHttpConnection(Connection): |
| 43 | + def __init__( |
| 44 | + self, |
| 45 | + host="localhost", |
| 46 | + port=None, |
| 47 | + http_auth=None, |
| 48 | + use_ssl=False, |
| 49 | + verify_certs=VERIFY_CERTS_DEFAULT, |
| 50 | + ssl_show_warn=SSL_SHOW_WARN_DEFAULT, |
| 51 | + ca_certs=None, |
| 52 | + client_cert=None, |
| 53 | + client_key=None, |
| 54 | + ssl_version=None, |
| 55 | + ssl_assert_fingerprint=None, |
| 56 | + maxsize=10, |
| 57 | + headers=None, |
| 58 | + ssl_context=None, |
| 59 | + http_compress=None, |
| 60 | + cloud_id=None, |
| 61 | + api_key=None, |
| 62 | + opaque_id=None, |
| 63 | + loop=None, |
| 64 | + **kwargs, |
| 65 | + ): |
| 66 | + """ |
| 67 | + Default connection class for ``AsyncElasticsearch`` using the `aiohttp` library and the http protocol. |
| 68 | +
|
| 69 | + :arg host: hostname of the node (default: localhost) |
| 70 | + :arg port: port to use (integer, default: 9200) |
| 71 | + :arg timeout: default timeout in seconds (float, default: 10) |
| 72 | + :arg http_auth: optional http auth information as either ':' separated |
| 73 | + string or a tuple |
| 74 | + :arg use_ssl: use ssl for the connection if `True` |
| 75 | + :arg verify_certs: whether to verify SSL certificates |
| 76 | + :arg ssl_show_warn: show warning when verify certs is disabled |
| 77 | + :arg ca_certs: optional path to CA bundle. |
| 78 | + See https://urllib3.readthedocs.io/en/latest/security.html#using-certifi-with-urllib3 |
| 79 | + for instructions how to get default set |
| 80 | + :arg client_cert: path to the file containing the private key and the |
| 81 | + certificate, or cert only if using client_key |
| 82 | + :arg client_key: path to the file containing the private key if using |
| 83 | + separate cert and key files (client_cert will contain only the cert) |
| 84 | + :arg ssl_version: version of the SSL protocol to use. Choices are: |
| 85 | + SSLv23 (default) SSLv2 SSLv3 TLSv1 (see ``PROTOCOL_*`` constants in the |
| 86 | + ``ssl`` module for exact options for your environment). |
| 87 | + :arg ssl_assert_hostname: use hostname verification if not `False` |
| 88 | + :arg ssl_assert_fingerprint: verify the supplied certificate fingerprint if not `None` |
| 89 | + :arg maxsize: the number of connections which will be kept open to this |
| 90 | + host. See https://urllib3.readthedocs.io/en/1.4/pools.html#api for more |
| 91 | + information. |
| 92 | + :arg headers: any custom http headers to be add to requests |
| 93 | + :arg http_compress: Use gzip compression |
| 94 | + :arg cloud_id: The Cloud ID from ElasticCloud. Convenient way to connect to cloud instances. |
| 95 | + Other host connection params will be ignored. |
| 96 | + :arg api_key: optional API Key authentication as either base64 encoded string or a tuple. |
| 97 | + :arg opaque_id: Send this value in the 'X-Opaque-Id' HTTP header |
| 98 | + For tracing all requests made by this transport. |
| 99 | + :arg loop: asyncio Event Loop to use with aiohttp. This is set by default to the currently running loop. |
| 100 | + """ |
| 101 | + |
| 102 | + self.headers = {} |
| 103 | + |
| 104 | + super().__init__( |
| 105 | + host=host, |
| 106 | + port=port, |
| 107 | + use_ssl=use_ssl, |
| 108 | + headers=headers, |
| 109 | + http_compress=http_compress, |
| 110 | + cloud_id=cloud_id, |
| 111 | + api_key=api_key, |
| 112 | + opaque_id=opaque_id, |
| 113 | + **kwargs, |
| 114 | + ) |
| 115 | + |
| 116 | + if http_auth is not None: |
| 117 | + if isinstance(http_auth, (tuple, list)): |
| 118 | + http_auth = ":".join(http_auth) |
| 119 | + self.headers.update(urllib3.make_headers(basic_auth=http_auth)) |
| 120 | + |
| 121 | + # if providing an SSL context, raise error if any other SSL related flag is used |
| 122 | + if ssl_context and ( |
| 123 | + (verify_certs is not VERIFY_CERTS_DEFAULT) |
| 124 | + or (ssl_show_warn is not SSL_SHOW_WARN_DEFAULT) |
| 125 | + or ca_certs |
| 126 | + or client_cert |
| 127 | + or client_key |
| 128 | + or ssl_version |
| 129 | + ): |
| 130 | + warnings.warn( |
| 131 | + "When using `ssl_context`, all other SSL related kwargs are ignored" |
| 132 | + ) |
| 133 | + |
| 134 | + self.ssl_assert_fingerprint = ssl_assert_fingerprint |
| 135 | + if self.use_ssl and ssl_context is None: |
| 136 | + ssl_context = ssl.SSLContext(ssl_version or ssl.PROTOCOL_TLS) |
| 137 | + |
| 138 | + # Convert all sentinel values to their actual default |
| 139 | + # values if not using an SSLContext. |
| 140 | + if verify_certs is VERIFY_CERTS_DEFAULT: |
| 141 | + verify_certs = True |
| 142 | + if ssl_show_warn is SSL_SHOW_WARN_DEFAULT: |
| 143 | + ssl_show_warn = True |
| 144 | + |
| 145 | + if verify_certs: |
| 146 | + ssl_context.verify_mode = ssl.CERT_REQUIRED |
| 147 | + ssl_context.check_hostname = True |
| 148 | + else: |
| 149 | + ssl_context.verify_mode = ssl.CERT_NONE |
| 150 | + ssl_context.check_hostname = False |
| 151 | + |
| 152 | + ca_certs = CA_CERTS if ca_certs is None else ca_certs |
| 153 | + if verify_certs: |
| 154 | + if not ca_certs: |
| 155 | + raise ImproperlyConfigured( |
| 156 | + "Root certificates are missing for certificate " |
| 157 | + "validation. Either pass them in using the ca_certs parameter or " |
| 158 | + "install certifi to use it automatically." |
| 159 | + ) |
| 160 | + else: |
| 161 | + if ssl_show_warn: |
| 162 | + warnings.warn( |
| 163 | + "Connecting to %s using SSL with verify_certs=False is insecure." |
| 164 | + % self.host |
| 165 | + ) |
| 166 | + |
| 167 | + if os.path.isfile(ca_certs): |
| 168 | + ssl_context.load_verify_locations(cafile=ca_certs) |
| 169 | + elif os.path.isdir(ca_certs): |
| 170 | + ssl_context.load_verify_locations(capath=ca_certs) |
| 171 | + else: |
| 172 | + raise ImproperlyConfigured("ca_certs parameter is not a path") |
| 173 | + |
| 174 | + self.headers.setdefault("connection", "keep-alive") |
| 175 | + self.loop = loop |
| 176 | + self.session = None |
| 177 | + |
| 178 | + # Parameters for creating an aiohttp.ClientSession later. |
| 179 | + self._limit = maxsize |
| 180 | + self._http_auth = http_auth |
| 181 | + self._ssl_context = ssl_context |
| 182 | + |
| 183 | + async def perform_request( |
| 184 | + self, method, url, params=None, body=None, timeout=None, ignore=(), headers=None |
| 185 | + ): |
| 186 | + if self.session is None: |
| 187 | + await self._create_aiohttp_session() |
| 188 | + |
| 189 | + orig_body = body |
| 190 | + url_path = url |
| 191 | + if params: |
| 192 | + query_string = urlencode(params) |
| 193 | + else: |
| 194 | + query_string = "" |
| 195 | + |
| 196 | + # There is a bug in aiohttp that disables the re-use |
| 197 | + # of the connection in the pool when method=HEAD. |
| 198 | + # See: aio-libs/aiohttp#1769 |
| 199 | + is_head = False |
| 200 | + if method == "HEAD": |
| 201 | + method = "GET" |
| 202 | + is_head = True |
| 203 | + |
| 204 | + # Provide correct URL object to avoid string parsing in low-level code |
| 205 | + url = yarl.URL.build( |
| 206 | + scheme=self.scheme, |
| 207 | + host=self.hostname, |
| 208 | + port=self.port, |
| 209 | + path=url, |
| 210 | + query_string=query_string, |
| 211 | + encoded=True, |
| 212 | + ) |
| 213 | + |
| 214 | + timeout = aiohttp.ClientTimeout( |
| 215 | + total=timeout if timeout is not None else self.timeout |
| 216 | + ) |
| 217 | + |
| 218 | + req_headers = self.headers.copy() |
| 219 | + if headers: |
| 220 | + req_headers.update(headers) |
| 221 | + |
| 222 | + if self.http_compress and body: |
| 223 | + body = self._gzip_compress(body) |
| 224 | + req_headers["content-encoding"] = "gzip" |
| 225 | + |
| 226 | + start = self.loop.time() |
| 227 | + try: |
| 228 | + async with self.session.request( |
| 229 | + method, |
| 230 | + url, |
| 231 | + data=body, |
| 232 | + headers=req_headers, |
| 233 | + timeout=timeout, |
| 234 | + fingerprint=self.ssl_assert_fingerprint, |
| 235 | + ) as response: |
| 236 | + if is_head: # We actually called 'GET' so throw away the data. |
| 237 | + await response.release() |
| 238 | + raw_data = "" |
| 239 | + else: |
| 240 | + raw_data = (await response.read()).decode("utf-8", "surrogatepass") |
| 241 | + duration = self.loop.time() - start |
| 242 | + |
| 243 | + # We want to reraise a cancellation. |
| 244 | + except asyncio.CancelledError: |
| 245 | + raise |
| 246 | + |
| 247 | + except Exception as e: |
| 248 | + self.log_request_fail( |
| 249 | + method, url, url_path, orig_body, self.loop.time() - start, exception=e |
| 250 | + ) |
| 251 | + if isinstance(e, ServerFingerprintMismatch): |
| 252 | + raise SSLError("N/A", str(e), e) |
| 253 | + if isinstance(e, (asyncio.TimeoutError, ServerTimeoutError)): |
| 254 | + raise ConnectionTimeout("TIMEOUT", str(e), e) |
| 255 | + raise ConnectionError("N/A", str(e), e) |
| 256 | + |
| 257 | + # raise errors based on http status codes, let the client handle those if needed |
| 258 | + if not (200 <= response.status < 300) and response.status not in ignore: |
| 259 | + self.log_request_fail( |
| 260 | + method, |
| 261 | + url, |
| 262 | + url_path, |
| 263 | + orig_body, |
| 264 | + duration, |
| 265 | + status_code=response.status, |
| 266 | + response=raw_data, |
| 267 | + ) |
| 268 | + self._raise_error(response.status, raw_data) |
| 269 | + |
| 270 | + self.log_request_success( |
| 271 | + method, url, url_path, orig_body, response.status, raw_data, duration |
| 272 | + ) |
| 273 | + |
| 274 | + return response.status, response.headers, raw_data |
| 275 | + |
| 276 | + async def close(self): |
| 277 | + """ |
| 278 | + Explicitly closes connection |
| 279 | + """ |
| 280 | + if self.session: |
| 281 | + await self.session.close() |
| 282 | + |
| 283 | + async def _create_aiohttp_session(self): |
| 284 | + """Creates an aiohttp.ClientSession(). This is delayed until |
| 285 | + the first call to perform_request() so that AsyncTransport has |
| 286 | + a chance to set AIOHttpConnection.loop |
| 287 | + """ |
| 288 | + if self.loop is None: |
| 289 | + self.loop = get_running_loop() |
| 290 | + self.session = aiohttp.ClientSession( |
| 291 | + headers=self.headers, |
| 292 | + auto_decompress=True, |
| 293 | + loop=self.loop, |
| 294 | + cookie_jar=aiohttp.DummyCookieJar(), |
| 295 | + connector=aiohttp.TCPConnector( |
| 296 | + limit=self._limit, use_dns_cache=True, ssl=self._ssl_context, |
| 297 | + ), |
| 298 | + ) |
0 commit comments