diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2e4157f..3a3ff97 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,6 +31,8 @@ jobs: run: black --check setup.py es - name: flake8 run: flake8 es + - name: mypy + run: mypy es tests: runs-on: ubuntu-18.04 @@ -38,12 +40,19 @@ jobs: matrix: python-version: [3.6, 3.7, 3.8] services: - postgres: + elasticsearch: image: elasticsearch:7.3.2 env: discovery.type: single-node ports: - 9200:9200 + opendistro: + image: amazon/opendistro-for-elasticsearch:1.12.0 + env: + discovery.type: single-node + ports: + - 9400:9200 + steps: - uses: actions/checkout@v2 - name: Setup Python @@ -58,8 +67,18 @@ jobs: pip install -r requirements.txt pip install -r requirements-dev.txt pip install -e . - - name: Run tests + - name: Run tests on Elasticsearch + run: | + export ES_URI="http://localhost:9200" + nosetests -v --with-coverage --cover-package=es es.tests + - name: Run tests on Opendistro run: | + export ES_DRIVER=odelasticsearch + export ES_URI="https://admin:admin@localhost:9400" + export ES_PASSWORD=admin + export ES_PORT=9400 + export ES_SCHEME=https + export ES_USER=admin nosetests -v --with-coverage --cover-package=es es.tests - name: Upload code coverage run: | diff --git a/README.md b/README.md index e65da4a..03db68f 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,12 @@ `elasticsearch-dbapi` Implements a DBAPI (PEP-249) and SQLAlchemy dialect, that enables SQL access on elasticsearch clusters for query only access. + +On Elastic Elasticsearch: Uses Elastic X-Pack [SQL API](https://www.elastic.co/guide/en/elasticsearch/reference/current/xpack-sql.html) -We are currently building support for `opendistro/_sql` API for AWS Elasticsearch Service / [Open Distro SQL](https://opendistro.github.io/for-elasticsearch-docs/docs/sql/) +On AWS ES, opendistro Elasticsearch: +[Open Distro SQL](https://opendistro.github.io/for-elasticsearch-docs/docs/sql/) This library supports Elasticsearch 7.X versions. @@ -23,7 +26,7 @@ $ pip install elasticsearch-dbapi To install support for AWS Elasticsearch Service / [Open Distro](https://opendistro.github.io/for-elasticsearch/features/SQL%20Support.html): ```bash -$ pip install elasticsearch-dbapi[aws] +$ pip install elasticsearch-dbapi[opendistro] ``` ### Usage: @@ -131,8 +134,7 @@ $ nosetests -v ### Special case for sql opendistro endpoint (AWS ES) AWS ES exposes the opendistro SQL plugin, and it follows a different SQL dialect. -Because of dialect and API response differences, we provide limited support for opendistro SQL -on this package using the `odelasticsearch` driver: +Using the `odelasticsearch` driver: ```python from sqlalchemy.engine import create_engine @@ -159,6 +161,9 @@ curs = conn.cursor().execute( print([row for row in curs]) ``` +To connect to the provided Opendistro ES on `docker-compose` use the following URI: +`odelasticsearch+https://admin:admin@localhost:9400/?verify_certs=False` + ### Known limitations This library does not yet support the following features: @@ -168,4 +173,7 @@ SQLAlchemy `get_columns` will exclude them. - `object` and `nested` column types are not well supported and are converted to strings - Indexes that whose name start with `.` - GEO points are not currently well-supported and are converted to strings -- Very limited support for AWS ES, no AWS Auth yet for example + +- AWS ES (opendistro elascticsearch) is supported (still beta), known limitations are: + * You are only able to `GROUP BY` keyword fields (new [experimental](https://github.com/opendistro-for-elasticsearch/sql#experimental) + opendistro SQL already supports it) diff --git a/docker-compose.yml b/docker-compose.yml index acdfc65..1ab8d45 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,7 +8,7 @@ services: - 9300:9300 opendistro: - image: amazon/opendistro-for-elasticsearch:1.7.0 + image: amazon/opendistro-for-elasticsearch:1.12.0 env_file: .env ports: - 9400:9200 diff --git a/es/baseapi.py b/es/baseapi.py index 232cce1..2b9e962 100644 --- a/es/baseapi.py +++ b/es/baseapi.py @@ -1,3 +1,6 @@ +from collections import namedtuple +from typing import Dict, List, Optional, Tuple + from elasticsearch import exceptions as es_exceptions from es import exceptions from six import string_types @@ -7,6 +10,14 @@ from .const import DEFAULT_FETCH_SIZE, DEFAULT_SCHEMA, DEFAULT_SQL_PATH +CursorDescriptionRow = namedtuple( + "CursorDescriptionRow", + ["name", "type", "display_size", "internal_size", "precision", "scale", "null_ok"], +) + +CursorDescriptionType = List[CursorDescriptionRow] + + class Type(object): STRING = 1 NUMBER = 2 @@ -17,25 +28,80 @@ class Type(object): def check_closed(f): """Decorator that checks if connection/cursor is closed.""" - def g(self, *args, **kwargs): + def wrap(self, *args, **kwargs): if self.closed: raise exceptions.Error( "{klass} already closed".format(klass=self.__class__.__name__) ) return f(self, *args, **kwargs) - return g + return wrap def check_result(f): """Decorator that checks if the cursor has results from `execute`.""" - def g(self, *args, **kwargs): + def wrap(self, *args, **kwargs): if self._results is None: raise exceptions.Error("Called before `execute`") return f(self, *args, **kwargs) - return g + return wrap + + +def get_type(data_type): + type_map = { + "text": Type.STRING, + "keyword": Type.STRING, + "integer": Type.NUMBER, + "half_float": Type.NUMBER, + "scaled_float": Type.NUMBER, + "geo_point": Type.STRING, + # TODO get a solution for nested type + "nested": Type.STRING, + "object": Type.STRING, + "date": Type.DATETIME, + "datetime": Type.DATETIME, + "short": Type.NUMBER, + "long": Type.NUMBER, + "float": Type.NUMBER, + "double": Type.NUMBER, + "bytes": Type.NUMBER, + "boolean": Type.BOOLEAN, + "ip": Type.STRING, + "interval_minute_to_second": Type.STRING, + "interval_hour_to_second": Type.STRING, + "interval_hour_to_minute": Type.STRING, + "interval_day_to_second": Type.STRING, + "interval_day_to_minute": Type.STRING, + "interval_day_to_hour": Type.STRING, + "interval_year_to_month": Type.STRING, + "interval_second": Type.STRING, + "interval_minute": Type.STRING, + "interval_day": Type.STRING, + "interval_month": Type.STRING, + "interval_year": Type.STRING, + } + return type_map[data_type.lower()] + + +def get_description_from_columns( + columns: List[Dict[str, str]] +) -> CursorDescriptionType: + return [ + ( + CursorDescriptionRow( + column.get("name") if "alias" not in column else column.get("alias"), + get_type(column.get("type")), + None, # [display_size] + None, # [internal_size] + None, # [precision] + None, # [scale] + True, # [null_ok] + ) + ) + for column in columns + ] class BaseConnection(object): @@ -120,19 +186,19 @@ def __init__(self, url, es, **kwargs): # this is set to an iterator after a successfull query self._results = None - @property + @property # type: ignore @check_result @check_closed - def rowcount(self): + def rowcount(self) -> int: return len(self._results) @check_closed - def close(self): + def close(self) -> None: """Close the cursor.""" self.closed = True @check_closed - def execute(self, operation, parameters=None): + def execute(self, operation, parameters=None) -> "BaseCursor": raise NotImplementedError # pragma: no cover @check_closed @@ -143,7 +209,7 @@ def executemany(self, operation, seq_of_parameters=None): @check_result @check_closed - def fetchone(self): + def fetchone(self) -> Optional[Tuple[str]]: """ Fetch the next row of a query result set, returning a single sequence, or `None` when no more data is available. @@ -155,7 +221,7 @@ def fetchone(self): @check_result @check_closed - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> List[Tuple[str]]: """ Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a list of tuples). An empty sequence is returned when @@ -167,7 +233,7 @@ def fetchmany(self, size=None): @check_result @check_closed - def fetchall(self): + def fetchall(self) -> List[Tuple[str]]: """ Fetch all (remaining) rows of a query result, returning them as a sequence of sequences (e.g. a list of tuples). Note that the cursor's @@ -202,7 +268,7 @@ def sanitize_query(self, query): # remove dummy schema from queries return query.replace(f'FROM "{DEFAULT_SCHEMA}".', "FROM ") - def elastic_query(self, query: str, csv=False): + def elastic_query(self, query: str): """ Request an http SQL query to elasticsearch """ @@ -210,10 +276,7 @@ def elastic_query(self, query: str, csv=False): # Sanitize query query = self.sanitize_query(query) payload = {"query": query, "fetch_size": self.fetch_size} - if csv: - path = f"/{self.sql_path}/?format=csv" - else: - path = f"/{self.sql_path}/" + path = f"/{self.sql_path}/" try: resp = self.es.transport.perform_request("POST", path, body=payload) except es_exceptions.ConnectionError as e: diff --git a/es/basesqlalchemy.py b/es/basesqlalchemy.py index 47431d2..ae1702f 100644 --- a/es/basesqlalchemy.py +++ b/es/basesqlalchemy.py @@ -4,7 +4,7 @@ from __future__ import unicode_literals import logging -from typing import List +from typing import Any, List, Type import es from es import exceptions @@ -27,7 +27,7 @@ def parse_bool_argument(value: str) -> bool: class BaseESCompiler(compiler.SQLCompiler): - def visit_fromclause(self, fromclause, **kwargs): + def visit_fromclause(self, fromclause: str, **kwargs: Any): return fromclause.replace("default.", "") def visit_label(self, *args, **kwargs): @@ -38,10 +38,10 @@ def visit_label(self, *args, **kwargs): class BaseESTypeCompiler(compiler.GenericTypeCompiler): - def visit_REAL(self, type_, **kwargs): + def visit_REAL(self, type_, **kwargs: Any) -> str: return "DOUBLE" - def visit_NUMERIC(self, type_, **kwargs): + def visit_NUMERIC(self, type_, **kwargs: Any) -> str: return "LONG" visit_DECIMAL = visit_NUMERIC @@ -52,7 +52,7 @@ def visit_NUMERIC(self, type_, **kwargs): visit_TIMESTAMP = visit_NUMERIC visit_DATE = visit_NUMERIC - def visit_CHAR(self, type_, **kwargs): + def visit_CHAR(self, type_, **kwargs: Any) -> str: return "STRING" visit_NCHAR = visit_CHAR @@ -60,25 +60,25 @@ def visit_CHAR(self, type_, **kwargs): visit_NVARCHAR = visit_CHAR visit_TEXT = visit_CHAR - def visit_DATETIME(self, type_, **kwargs): + def visit_DATETIME(self, type_, **kwargs: Any) -> str: return "DATETIME" - def visit_TIME(self, type_, **kwargs): + def visit_TIME(self, type_, **kwargs: Any) -> str: raise exceptions.NotSupportedError("Type TIME is not supported") - def visit_BINARY(self, type_, **kwargs): + def visit_BINARY(self, type_, **kwargs: Any) -> str: raise exceptions.NotSupportedError("Type BINARY is not supported") - def visit_VARBINARY(self, type_, **kwargs): + def visit_VARBINARY(self, type_, **kwargs: Any) -> str: raise exceptions.NotSupportedError("Type VARBINARY is not supported") - def visit_BLOB(self, type_, **kwargs): + def visit_BLOB(self, type_, **kwargs: Any) -> str: raise exceptions.NotSupportedError("Type BLOB is not supported") - def visit_CLOB(self, type_, **kwargs): + def visit_CLOB(self, type_, **kwargs: Any) -> str: raise exceptions.NotSupportedError("Type CBLOB is not supported") - def visit_NCLOB(self, type_, **kwargs): + def visit_NCLOB(self, type_, **kwargs: Any) -> str: raise exceptions.NotSupportedError("Type NCBLOB is not supported") @@ -87,8 +87,8 @@ class BaseESDialect(default.DefaultDialect): name = "SET" scheme = "SET" driver = "SET" - statement_compiler = None - type_compiler = None + statement_compiler: Type[BaseESCompiler] = BaseESCompiler + type_compiler: Type[BaseESTypeCompiler] = BaseESTypeCompiler preparer = compiler.IdentifierPreparer supports_alter = False supports_pk_autoincrement = False @@ -147,15 +147,10 @@ def has_table(self, connection, table_name, schema=None): return table_name in self.get_table_names(connection, schema) def get_table_names(self, connection, schema=None, **kwargs) -> List[str]: - query = "SHOW TABLES" - result = connection.execute(query) - # return a list of table names exclude hidden and empty indexes - return [ - table.name - for table in result - if table.name[0] != "." - and len(self.get_columns(connection, table.name)) > 0 - ] + pass + + def get_columns(self, connection, table_name, schema=None, **kw): + pass def get_view_names(self, connection, schema=None, **kwargs): return [] @@ -163,30 +158,6 @@ def get_view_names(self, connection, schema=None, **kwargs): def get_table_options(self, connection, table_name, schema=None, **kwargs): return {} - def get_columns(self, connection, table_name, schema=None, **kwargs): - query = f'SHOW COLUMNS FROM "{table_name}"' - # A bit of an hack this cmd does not exist on ES - array_columns_ = connection.execute( - f"SHOW ARRAY_COLUMNS FROM {table_name}" - ).fetchall() - if len(array_columns_[0]) == 0: - array_columns = [] - else: - array_columns = [col_name[0] for col_name in array_columns_] - - result = connection.execute(query) - return [ - { - "name": row.column, - "type": get_type(row.mapping), - "nullable": True, - "default": None, - } - for row in result - if row.mapping not in self._not_supported_column_types - and row.column not in array_columns - ] - def get_pk_constraint(self, connection, table_name, schema=None, **kwargs): return {"constrained_columns": [], "name": None} @@ -222,7 +193,7 @@ def get_type(data_type): type_map = { "bytes": types.LargeBinary, "boolean": types.Boolean, - "date": types.Date, + "date": types.DateTime, "datetime": types.DateTime, "double": types.Numeric, "text": types.String, diff --git a/es/elastic/api.py b/es/elastic/api.py index 94dd317..28a949f 100644 --- a/es/elastic/api.py +++ b/es/elastic/api.py @@ -8,7 +8,14 @@ from elasticsearch import Elasticsearch, exceptions as es_exceptions from es import exceptions -from es.baseapi import apply_parameters, BaseConnection, BaseCursor, check_closed, Type +from es.baseapi import ( + apply_parameters, + BaseConnection, + BaseCursor, + check_closed, + get_description_from_columns, + Type, +) def connect( @@ -18,7 +25,7 @@ def connect( scheme: str = "http", user: Optional[str] = None, password: Optional[str] = None, - context: Optional[Dict] = None, + context: Optional[Dict[str, Any]] = None, **kwargs: Any, ): """ @@ -32,57 +39,6 @@ def connect( return Connection(host, port, path, scheme, user, password, context, **kwargs) -def get_type(data_type): - type_map = { - "text": Type.STRING, - "keyword": Type.STRING, - "integer": Type.NUMBER, - "half_float": Type.NUMBER, - "scaled_float": Type.NUMBER, - "geo_point": Type.STRING, - # TODO get a solution for nested type - "nested": Type.STRING, - "object": Type.STRING, - "date": Type.DATETIME, - "datetime": Type.DATETIME, - "short": Type.NUMBER, - "long": Type.NUMBER, - "float": Type.NUMBER, - "double": Type.NUMBER, - "bytes": Type.NUMBER, - "boolean": Type.BOOLEAN, - "ip": Type.STRING, - "interval_minute_to_second": Type.STRING, - "interval_hour_to_second": Type.STRING, - "interval_hour_to_minute": Type.STRING, - "interval_day_to_second": Type.STRING, - "interval_day_to_minute": Type.STRING, - "interval_day_to_hour": Type.STRING, - "interval_year_to_month": Type.STRING, - "interval_second": Type.STRING, - "interval_minute": Type.STRING, - "interval_day": Type.STRING, - "interval_month": Type.STRING, - "interval_year": Type.STRING, - } - return type_map[data_type.lower()] - - -def get_description_from_columns(columns: Dict): - return [ - ( - column.get("name"), # name - get_type(column.get("type")), # type code - None, # [display_size] - None, # [internal_size] - None, # [precision] - None, # [scale] - True, # [null_ok] - ) - for column in columns - ] - - class Connection(BaseConnection): """Connection to an ES Cluster """ @@ -129,8 +85,37 @@ def __init__(self, url, es, **kwargs): super().__init__(url, es, **kwargs) self.sql_path = kwargs.get("sql_path") or "_sql" + def get_valid_table_names(self) -> "Cursor": + """ + Custom for "SHOW VALID_TABLES" excludes empty indices from the response + Mixes `SHOW TABLES` with direct index access info to exclude indexes + that have no rows so no columns (unless templated). SQLAlchemy will + not support reflection of tables with no columns + + https://github.com/preset-io/elasticsearch-dbapi/issues/38 + """ + results = self.execute("SHOW TABLES") + response = self.es.cat.indices(format="json") + + _results = [] + for result in results: + is_empty = False + for item in response: + # First column is TABLE_NAME + if item["index"] == result[0]: + if int(item["docs.count"]) == 0: + is_empty = True + break + if not is_empty: + _results.append(result) + self._results = _results + return self + @check_closed def execute(self, operation, parameters=None): + if operation == "SHOW VALID_TABLES": + return self.get_valid_table_names() + re_table_name = re.match("SHOW ARRAY_COLUMNS FROM (.*)", operation) if re_table_name: return self.get_array_type_columns(re_table_name[1]) diff --git a/es/elastic/sqlalchemy.py b/es/elastic/sqlalchemy.py index cb06169..3bdfcf9 100644 --- a/es/elastic/sqlalchemy.py +++ b/es/elastic/sqlalchemy.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals import logging +from typing import List from es import basesqlalchemy import es.elastic @@ -31,6 +32,36 @@ class ESDialect(basesqlalchemy.BaseESDialect): def dbapi(cls): return es.elastic + def get_table_names(self, connection, schema=None, **kwargs) -> List[str]: + query = "SHOW VALID_TABLES" + result = connection.execute(query) + # return a list of table names exclude hidden and empty indexes + return [table.name for table in result if table.name[0] != "."] + + def get_columns(self, connection, table_name, schema=None, **kwargs): + query = f'SHOW COLUMNS FROM "{table_name}"' + # A bit of an hack this cmd does not exist on ES + array_columns_ = connection.execute( + f"SHOW ARRAY_COLUMNS FROM {table_name}" + ).fetchall() + if len(array_columns_[0]) == 0: + array_columns = [] + else: + array_columns = [col_name[0] for col_name in array_columns_] + + result = connection.execute(query) + return [ + { + "name": row.column, + "type": basesqlalchemy.get_type(row.mapping), + "nullable": True, + "default": None, + } + for row in result + if row.mapping not in self._not_supported_column_types + and row.column not in array_columns + ] + ESHTTPDialect = ESDialect diff --git a/es/opendistro/api.py b/es/opendistro/api.py index 7706adb..7c08a1c 100644 --- a/es/opendistro/api.py +++ b/es/opendistro/api.py @@ -3,13 +3,18 @@ from __future__ import print_function from __future__ import unicode_literals -import csv import re -from typing import Any, Dict, Optional # pragma: no cover +from typing import Any, Dict, List, Optional, Tuple # pragma: no cover from elasticsearch import Elasticsearch from es import exceptions -from es.baseapi import apply_parameters, BaseConnection, BaseCursor, check_closed, Type +from es.baseapi import ( + apply_parameters, + BaseConnection, + BaseCursor, + check_closed, + get_description_from_columns, +) from es.const import DEFAULT_SCHEMA @@ -20,7 +25,7 @@ def connect( scheme: str = "https", user: Optional[str] = None, password: Optional[str] = None, - context: Optional[Dict] = None, + context: Optional[Dict[Any, Any]] = None, **kwargs: Any, ): # pragma: no cover """ @@ -34,47 +39,20 @@ def connect( return Connection(host, port, path, scheme, user, password, context, **kwargs) -def get_type_from_value(value): # pragma: no cover - if value in ("true", "false"): - return Type.BOOLEAN - try: - float(value) - return Type.NUMBER - except ValueError: - return Type.STRING - - -def get_description_from_first_row(header: list, row: list): # pragma: no cover - description = [] - for i, col_name in enumerate(header): - description.append( - ( - col_name, - get_type_from_value(row[i]), - None, # [display_size] - None, # [internal_size] - None, # [precision] - None, # [scale] - True, # [null_ok] - ) - ) - return description - - class Connection(BaseConnection): # pragma: no cover """Connection to an ES Cluster """ def __init__( self, - host="localhost", - port=443, - path="", - scheme="https", - user=None, - password=None, - context=None, - **kwargs, + host: str = "localhost", + port: int = 443, + path: str = "", + scheme: str = "https", + user: Optional[str] = None, + password: Optional[str] = None, + context: Optional[Dict[Any, Any]] = None, + **kwargs: Dict[str, Any], ): super().__init__( host=host, @@ -91,13 +69,13 @@ def __init__( else: self.es = Elasticsearch(self.url, **self.kwargs) - def _aws_auth(self, aws_access_key, aws_secret_key, region): + def _aws_auth(self, aws_access_key: str, aws_secret_key: str, region: str) -> Any: from requests_4auth import AWS4Auth return AWS4Auth(aws_access_key, aws_secret_key, region, "es") @check_closed - def cursor(self): + def cursor(self) -> "Cursor": """Return a new Cursor Object using the connection.""" cursor = Cursor(self.url, self.es, **self.kwargs) self.cursors.append(cursor) @@ -112,63 +90,104 @@ def __init__(self, url, es, **kwargs): super().__init__(url, es, **kwargs) self.sql_path = kwargs.get("sql_path") or "_opendistro/_sql" - def _show_tables(self): + def get_valid_table_names(self) -> "Cursor": """ - Simulates SHOW TABLES more like SQL from elastic itself + Custom for "SHOW VALID_TABLES" excludes empty indices from the response + Mixes `SHOW TABLES LIKE` with direct index access info to exclude indexes + that have no rows so no columns (unless templated). SQLAlchemy will + not support reflection of tables with no columns + + https://github.com/preset-io/elasticsearch-dbapi/issues/38 """ - results = self.elastic_query("SHOW TABLES LIKE *") - self.description = [("name", Type.STRING, None, None, None, None, None)] - self._results = [[result] for result in results] + results = self.execute("SHOW TABLES LIKE %") + response = self.es.cat.indices(format="json") + + _results = [] + for result in results: + is_empty = False + for item in response: + # Third column is TABLE_NAME + if item["index"] == result[2]: + if int(item["docs.count"]) == 0: + is_empty = True + break + if not is_empty: + _results.append(result) + self._results = _results return self - def _show_columns(self, table_name): + def _traverse_mapping( + self, + mapping: Dict[str, Any], + results: List[Tuple[str, ...]], + parent_field_name=None, + ) -> List[Tuple[str, ...]]: + for field_name, metadata in mapping.items(): + if parent_field_name: + field_name = f"{parent_field_name}.{field_name}" + if "properties" in metadata: + self._traverse_mapping(metadata["properties"], results, field_name) + else: + results.append((field_name, metadata["type"])) + if "fields" in metadata: + for sub_field_name, sub_metadata in metadata["fields"].items(): + results.append( + (f"{field_name}.{sub_field_name}", sub_metadata["type"]) + ) + return results + + def get_valid_columns(self, index_name: str) -> "Cursor": """ - Simulates SHOW COLUMNS FROM more like SQL from elastic itself + Custom for "SHOW VALID_COLUMNS FROM " + Adds keywords to text if they exist and flattens nested structures + get's all fields by directly accessing `/_mapping/` endpoint + + https://github.com/preset-io/elasticsearch-dbapi/issues/38 """ - results = self.elastic_query(f"SHOW TABLES LIKE {table_name}") - if table_name not in results: - raise exceptions.ProgrammingError(f"Table {table_name} not found") - rows = [] - for col, value in results[table_name]["mappings"]["_doc"]["properties"].items(): - type = value.get("type") - if type: - rows.append([col, type]) - self.description = [ - ("column", Type.STRING, None, None, None, None, None), - ("mapping", Type.STRING, None, None, None, None, None), - ] - self._results = rows + response = self.es.indices.get_mapping(index=index_name, format="json") + self._results = self._traverse_mapping( + response[index_name]["mappings"]["properties"], [] + ) + + self.description = get_description_from_columns( + [ + {"name": "COLUMN_NAME", "type": "text"}, + {"name": "TYPE_NAME", "type": "text"}, + ] + ) + return self + + def get_valid_select_one(self) -> "Cursor": + res = self.es.ping() + if not res: + raise exceptions.DatabaseError() + self._results = [(1,)] + self.description = get_description_from_columns([{"name": "1", "type": "long"}]) return self @check_closed def execute(self, operation, parameters=None): - if operation == "SHOW TABLES": - return self._show_tables() - re_table_name = re.match("SHOW COLUMNS FROM (.*)", operation) - if re_table_name: - return self._show_columns(re_table_name[1]) + if operation == "SHOW VALID_TABLES": + return self.get_valid_table_names() + + if operation.lower() == "select 1": + return self.get_valid_select_one() - re_table_name = re.match("SHOW ARRAY_COLUMNS FROM (.*)", operation) + re_table_name = re.match("SHOW VALID_COLUMNS FROM (.*)", operation) if re_table_name: - return self.get_array_type_columns(re_table_name[1]) + return self.get_valid_columns(re_table_name[1]) query = apply_parameters(operation, parameters) - _results = self.elastic_query(query, csv=True).split("\n") - header = _results[0].split(",") - _results = _results[1:] - results = list(csv.reader(_results)) - self.description = get_description_from_first_row(header, results[0]) - self._results = results - return self + results = self.elastic_query(query) - def get_array_type_columns(self, table_name: str) -> "Cursor": - """ - Queries the index (table) for just one record - and return a list of array type columns. - This is useful since arrays are not supported by ES SQL - """ - self.description = [("name", Type.STRING, None, None, None, None, None)] - self._results = [[]] + rows = [tuple(row) for row in results.get("datarows")] + columns = results.get("schema") + if not columns: + raise exceptions.DataError( + "Missing columns field, maybe it's an elastic sql ep" + ) + self._results = rows + self.description = get_description_from_columns(columns) return self def sanitize_query(self, query): diff --git a/es/opendistro/sqlalchemy.py b/es/opendistro/sqlalchemy.py index 54003ee..db056c1 100644 --- a/es/opendistro/sqlalchemy.py +++ b/es/opendistro/sqlalchemy.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals import logging +from typing import List from es import basesqlalchemy import es.opendistro @@ -31,6 +32,27 @@ class ESDialect(basesqlalchemy.BaseESDialect): # pragma: no cover def dbapi(cls): return es.opendistro + def get_table_names(self, connection, schema=None, **kwargs) -> List[str]: + # custom builtin query + query = "SHOW VALID_TABLES" + result = connection.execute(query) + # return a list of table names exclude hidden and empty indexes + return [table.TABLE_NAME for table in result if table.TABLE_NAME[0] != "."] + + def get_columns(self, connection, table_name, schema=None, **kwargs): + # custom builtin query + query = f"SHOW VALID_COLUMNS FROM {table_name}" + result = connection.execute(query) + return [ + { + "name": row.COLUMN_NAME, + "type": basesqlalchemy.get_type(row.TYPE_NAME), + "nullable": True, + "default": None, + } + for row in result + ] + ESHTTPDialect = ESDialect diff --git a/es/tests/fixtures/fixtures.py b/es/tests/fixtures/fixtures.py index 7157469..1d288f1 100644 --- a/es/tests/fixtures/fixtures.py +++ b/es/tests/fixtures/fixtures.py @@ -76,19 +76,19 @@ def import_file_to_es(base_url, file_path, index_name): fd.close() set_index_replica_zero(base_url, index_name) - es = Elasticsearch(base_url) + es = Elasticsearch(base_url, verify_certs=False) for doc in data: es.index(index=index_name, doc_type="_doc", body=doc, refresh=True) def set_index_replica_zero(base_url, index_name): settings = {"settings": {"number_of_shards": 1, "number_of_replicas": 0}} - es = Elasticsearch(base_url) + es = Elasticsearch(base_url, verify_certs=False) es.indices.create(index=index_name, ignore=400, body=settings) def delete_index(base_url, index_name): - es = Elasticsearch(base_url) + es = Elasticsearch(base_url, verify_certs=False) try: es.delete_by_query(index=index_name, body={"query": {"match_all": {}}}) except NotFoundError: diff --git a/es/tests/test_dbapi.py b/es/tests/test_dbapi.py index 43b2605..c8608f8 100644 --- a/es/tests/test_dbapi.py +++ b/es/tests/test_dbapi.py @@ -1,13 +1,34 @@ +import os import unittest from unittest.mock import patch -from es.elastic.api import connect, Type +from es.elastic.api import connect as elastic_connect, Type from es.exceptions import Error, NotSupportedError, OperationalError, ProgrammingError +from es.opendistro.api import connect as open_connect class TestData(unittest.TestCase): def setUp(self): - self.conn = connect(host="localhost") + self.driver_name = os.environ.get("ES_DRIVER", "elasticsearch") + host = os.environ.get("ES_HOST", "localhost") + port = int(os.environ.get("ES_PORT", 9200)) + scheme = os.environ.get("ES_SCHEME", "http") + verify_certs = os.environ.get("ES_VERIFY_CERTS", False) + user = os.environ.get("ES_USER", None) + password = os.environ.get("ES_PASSWORD", None) + + if self.driver_name == "elasticsearch": + self.connect_func = elastic_connect + else: + self.connect_func = open_connect + self.conn = self.connect_func( + host=host, + port=port, + scheme=scheme, + verify_certs=verify_certs, + user=user, + password=password, + ) self.cursor = self.conn.cursor() def tearDown(self): @@ -17,7 +38,7 @@ def test_connect_failed(self): """ DBAPI: Test connection failed """ - conn = connect(host="unknown") + conn = self.connect_func(host="unknown") curs = conn.cursor() with self.assertRaises(OperationalError): curs.execute("select Carrier from flights").fetchall() @@ -27,7 +48,7 @@ def test_close(self): """ DBAPI: Test connection failed """ - conn = connect(host="localhost") + conn = self.connect_func(host="localhost") conn.close() with self.assertRaises(Error): conn.close() @@ -151,7 +172,7 @@ def test_simple_group_by(self): DBAPI: Test simple group by """ rows = self.cursor.execute( - "select COUNT(*) as c, Carrier from flights GROUP BY Carrier" + "select COUNT(*) as c, Carrier.keyword from flights GROUP BY Carrier.keyword" ).fetchall() # poor assertion because that is loaded async self.assertEqual(len(rows), 4) @@ -162,7 +183,9 @@ def test_auth(self, mock_elasticsearch): DBAPI: test Elasticsearch is called with user password """ mock_elasticsearch.return_value = None - connect(host="localhost", user="user", password="password") + self.connect_func( + host="localhost", scheme="http", port=9200, user="user", password="password" + ) mock_elasticsearch.assert_called_once_with( "http://localhost:9200/", http_auth=("user", "password") ) @@ -173,7 +196,13 @@ def test_https(self, mock_elasticsearch): DBAPI: test Elasticsearch is called with https """ mock_elasticsearch.return_value = None - connect(host="localhost", user="user", password="password", scheme="https") + self.connect_func( + host="localhost", + user="user", + password="password", + scheme="https", + port=9200, + ) mock_elasticsearch.assert_called_once_with( "https://localhost:9200/", http_auth=("user", "password") ) diff --git a/es/tests/test_sqlalchemy.py b/es/tests/test_sqlalchemy.py index 4065864..ca3ce7c 100644 --- a/es/tests/test_sqlalchemy.py +++ b/es/tests/test_sqlalchemy.py @@ -1,16 +1,35 @@ +import os import unittest from unittest.mock import patch from es.tests.fixtures.fixtures import data1_columns, flights_columns from sqlalchemy import func, inspect, select from sqlalchemy.engine import create_engine +from sqlalchemy.engine.url import URL from sqlalchemy.exc import ProgrammingError from sqlalchemy.schema import MetaData, Table class TestData(unittest.TestCase): def setUp(self): - self.engine = create_engine("elasticsearch+http://localhost:9200/") + self.driver_name = os.environ.get("ES_DRIVER", "elasticsearch") + host = os.environ.get("ES_HOST", "localhost") + port = int(os.environ.get("ES_PORT", 9200)) + scheme = os.environ.get("ES_SCHEME", "http") + verify_certs = os.environ.get("ES_VERIFY_CERTS", "False") + user = os.environ.get("ES_USER", None) + password = os.environ.get("ES_PASSWORD", None) + + uri = URL( + f"{self.driver_name}+{scheme}", + user, + password, + host, + port, + None, + {"verify_certs": str(verify_certs)}, + ) + self.engine = create_engine(uri) self.connection = self.engine.connect() self.table_flights = Table("flights", MetaData(bind=self.engine), autoload=True) @@ -63,6 +82,8 @@ def test_get_columns_exclude_arrays(self): """ SQLAlchemy: Test get_columns exclude arrays """ + if self.driver_name == "odelasticsearch": + return metadata = MetaData() metadata.reflect(bind=self.engine) source_cols = [c.name for c in metadata.tables["data1"].c] diff --git a/requirements-dev.txt b/requirements-dev.txt index 59ff221..9292c8c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,3 +8,4 @@ pip-tools==3.6.0 pre-commit==1.17.0 twine==2.0.0 readme_renderer==24.0 +mypy==0.790 diff --git a/setup.cfg b/setup.cfg index e304326..c35fb62 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,3 +5,9 @@ description-file = README.md author = Preset Inc. author-email = daniel@preset.io license = Apache License, Version 2.0 + +[mypy] +disallow_any_generics = true +ignore_missing_imports = true +no_implicit_optional = true +warn_unused_ignores = true