diff --git a/src/databricks/sql/auth/__init__.py b/databricks_sql_connector/databricks_sql_connector/__init__.py similarity index 100% rename from src/databricks/sql/auth/__init__.py rename to databricks_sql_connector/databricks_sql_connector/__init__.py diff --git a/databricks_sql_connector/pyproject.toml b/databricks_sql_connector/pyproject.toml new file mode 100644 index 00000000..6e7297d1 --- /dev/null +++ b/databricks_sql_connector/pyproject.toml @@ -0,0 +1,23 @@ +[tool.poetry] +name = "databricks-sql-connector" +version = "3.5.0" +description = "Databricks SQL Connector for Python" +authors = ["Databricks "] +license = "Apache-2.0" + + +[tool.poetry.dependencies] +databricks_sql_connector_core = { version = ">=1.0.0", extras=["all"]} +databricks_sqlalchemy = { version = ">=1.0.0", optional = true } + +[tool.poetry.extras] +databricks_sqlalchemy = ["databricks_sqlalchemy"] + +[tool.poetry.urls] +"Homepage" = "https://github.com/databricks/databricks-sql-python" +"Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + diff --git a/poetry.lock b/databricks_sql_connector_core/poetry.lock similarity index 100% rename from poetry.lock rename to databricks_sql_connector_core/poetry.lock diff --git a/pyproject.toml b/databricks_sql_connector_core/pyproject.toml similarity index 67% rename from pyproject.toml rename to databricks_sql_connector_core/pyproject.toml index 44d25ef9..a6e36091 100644 --- a/pyproject.toml +++ b/databricks_sql_connector_core/pyproject.toml @@ -1,12 +1,9 @@ [tool.poetry] -name = "databricks-sql-connector" -version = "3.3.0" -description = "Databricks SQL Connector for Python" +name = "databricks-sql-connector-core" +version = "1.0.0" +description = "Databricks SQL Connector core for Python" authors = ["Databricks "] -license = "Apache-2.0" -readme = "README.md" packages = [{ include = "databricks", from = "src" }] -include = ["CHANGELOG.md"] [tool.poetry.dependencies] python = "^3.8.0" @@ -14,23 +11,16 @@ thrift = ">=0.16.0,<0.21.0" pandas = [ { version = ">=1.2.5,<2.3.0", python = ">=3.8" } ] -pyarrow = ">=14.0.1,<17" - lz4 = "^4.0.2" requests = "^2.18.1" oauthlib = "^3.1.0" -numpy = [ - { version = "^1.16.6", python = ">=3.8,<3.11" }, - { version = "^1.23.4", python = ">=3.11" }, -] -sqlalchemy = { version = ">=2.0.21", optional = true } openpyxl = "^3.0.10" alembic = { version = "^1.0.11", optional = true } urllib3 = ">=1.26" +pyarrow = {version = ">=14.0.1,<17", optional = true} [tool.poetry.extras] -sqlalchemy = ["sqlalchemy"] -alembic = ["sqlalchemy", "alembic"] +pyarrow = ["pyarrow"] [tool.poetry.dev-dependencies] pytest = "^7.1.2" @@ -43,8 +33,6 @@ pytest-dotenv = "^0.5.2" "Homepage" = "https://github.com/databricks/databricks-sql-python" "Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues" -[tool.poetry.plugins."sqlalchemy.dialects"] -"databricks" = "databricks.sqlalchemy:DatabricksDialect" [build-system] requires = ["poetry-core>=1.0.0"] @@ -62,5 +50,5 @@ markers = {"reviewed" = "Test case has been reviewed by Databricks"} minversion = "6.0" log_cli = "false" log_cli_level = "INFO" -testpaths = ["tests", "src/databricks/sqlalchemy/test_local"] +testpaths = ["tests", "databricks_sql_connector_core/tests"] env_files = ["test.env"] diff --git a/src/databricks/__init__.py b/databricks_sql_connector_core/src/databricks/__init__.py similarity index 100% rename from src/databricks/__init__.py rename to databricks_sql_connector_core/src/databricks/__init__.py diff --git a/src/databricks/sql/__init__.py b/databricks_sql_connector_core/src/databricks/sql/__init__.py similarity index 100% rename from src/databricks/sql/__init__.py rename to databricks_sql_connector_core/src/databricks/sql/__init__.py diff --git a/src/databricks/sql/experimental/__init__.py b/databricks_sql_connector_core/src/databricks/sql/auth/__init__.py similarity index 100% rename from src/databricks/sql/experimental/__init__.py rename to databricks_sql_connector_core/src/databricks/sql/auth/__init__.py diff --git a/src/databricks/sql/auth/auth.py b/databricks_sql_connector_core/src/databricks/sql/auth/auth.py similarity index 100% rename from src/databricks/sql/auth/auth.py rename to databricks_sql_connector_core/src/databricks/sql/auth/auth.py diff --git a/src/databricks/sql/auth/authenticators.py b/databricks_sql_connector_core/src/databricks/sql/auth/authenticators.py similarity index 100% rename from src/databricks/sql/auth/authenticators.py rename to databricks_sql_connector_core/src/databricks/sql/auth/authenticators.py diff --git a/src/databricks/sql/auth/endpoint.py b/databricks_sql_connector_core/src/databricks/sql/auth/endpoint.py similarity index 100% rename from src/databricks/sql/auth/endpoint.py rename to databricks_sql_connector_core/src/databricks/sql/auth/endpoint.py diff --git a/src/databricks/sql/auth/oauth.py b/databricks_sql_connector_core/src/databricks/sql/auth/oauth.py similarity index 100% rename from src/databricks/sql/auth/oauth.py rename to databricks_sql_connector_core/src/databricks/sql/auth/oauth.py diff --git a/src/databricks/sql/auth/oauth_http_handler.py b/databricks_sql_connector_core/src/databricks/sql/auth/oauth_http_handler.py similarity index 100% rename from src/databricks/sql/auth/oauth_http_handler.py rename to databricks_sql_connector_core/src/databricks/sql/auth/oauth_http_handler.py diff --git a/src/databricks/sql/auth/retry.py b/databricks_sql_connector_core/src/databricks/sql/auth/retry.py similarity index 100% rename from src/databricks/sql/auth/retry.py rename to databricks_sql_connector_core/src/databricks/sql/auth/retry.py diff --git a/src/databricks/sql/auth/thrift_http_client.py b/databricks_sql_connector_core/src/databricks/sql/auth/thrift_http_client.py similarity index 100% rename from src/databricks/sql/auth/thrift_http_client.py rename to databricks_sql_connector_core/src/databricks/sql/auth/thrift_http_client.py diff --git a/src/databricks/sql/client.py b/databricks_sql_connector_core/src/databricks/sql/client.py similarity index 99% rename from src/databricks/sql/client.py rename to databricks_sql_connector_core/src/databricks/sql/client.py index c0bf534d..72811628 100755 --- a/src/databricks/sql/client.py +++ b/databricks_sql_connector_core/src/databricks/sql/client.py @@ -1,7 +1,6 @@ from typing import Dict, Tuple, List, Optional, Any, Union, Sequence import pandas -import pyarrow import requests import json import os @@ -43,6 +42,10 @@ TSparkParameter, ) +try: + import pyarrow +except ImportError: + pyarrow = None logger = logging.getLogger(__name__) @@ -977,14 +980,14 @@ def fetchmany(self, size: int) -> List[Row]: else: raise Error("There is no active result set") - def fetchall_arrow(self) -> pyarrow.Table: + def fetchall_arrow(self) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchall_arrow() else: raise Error("There is no active result set") - def fetchmany_arrow(self, size) -> pyarrow.Table: + def fetchmany_arrow(self, size) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchmany_arrow(size) @@ -1171,7 +1174,7 @@ def _convert_arrow_table(self, table): def rownumber(self): return self._next_row_index - def fetchmany_arrow(self, size: int) -> pyarrow.Table: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ Fetch the next set of rows of a query result, returning a PyArrow table. @@ -1196,7 +1199,7 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table: return results - def fetchall_arrow(self) -> pyarrow.Table: + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/databricks_sql_connector_core/src/databricks/sql/cloudfetch/download_manager.py similarity index 100% rename from src/databricks/sql/cloudfetch/download_manager.py rename to databricks_sql_connector_core/src/databricks/sql/cloudfetch/download_manager.py diff --git a/src/databricks/sql/cloudfetch/downloader.py b/databricks_sql_connector_core/src/databricks/sql/cloudfetch/downloader.py similarity index 100% rename from src/databricks/sql/cloudfetch/downloader.py rename to databricks_sql_connector_core/src/databricks/sql/cloudfetch/downloader.py diff --git a/src/databricks/sql/exc.py b/databricks_sql_connector_core/src/databricks/sql/exc.py similarity index 100% rename from src/databricks/sql/exc.py rename to databricks_sql_connector_core/src/databricks/sql/exc.py diff --git a/src/databricks/sql/thrift_api/__init__.py b/databricks_sql_connector_core/src/databricks/sql/experimental/__init__.py similarity index 100% rename from src/databricks/sql/thrift_api/__init__.py rename to databricks_sql_connector_core/src/databricks/sql/experimental/__init__.py diff --git a/src/databricks/sql/experimental/oauth_persistence.py b/databricks_sql_connector_core/src/databricks/sql/experimental/oauth_persistence.py similarity index 100% rename from src/databricks/sql/experimental/oauth_persistence.py rename to databricks_sql_connector_core/src/databricks/sql/experimental/oauth_persistence.py diff --git a/src/databricks/sql/parameters/__init__.py b/databricks_sql_connector_core/src/databricks/sql/parameters/__init__.py similarity index 100% rename from src/databricks/sql/parameters/__init__.py rename to databricks_sql_connector_core/src/databricks/sql/parameters/__init__.py diff --git a/src/databricks/sql/parameters/native.py b/databricks_sql_connector_core/src/databricks/sql/parameters/native.py similarity index 100% rename from src/databricks/sql/parameters/native.py rename to databricks_sql_connector_core/src/databricks/sql/parameters/native.py diff --git a/src/databricks/sql/parameters/py.typed b/databricks_sql_connector_core/src/databricks/sql/parameters/py.typed similarity index 100% rename from src/databricks/sql/parameters/py.typed rename to databricks_sql_connector_core/src/databricks/sql/parameters/py.typed diff --git a/src/databricks/sql/py.typed b/databricks_sql_connector_core/src/databricks/sql/py.typed similarity index 100% rename from src/databricks/sql/py.typed rename to databricks_sql_connector_core/src/databricks/sql/py.typed diff --git a/src/databricks/sql/thrift_api/TCLIService/TCLIService-remote b/databricks_sql_connector_core/src/databricks/sql/thrift_api/TCLIService/TCLIService-remote similarity index 100% rename from src/databricks/sql/thrift_api/TCLIService/TCLIService-remote rename to databricks_sql_connector_core/src/databricks/sql/thrift_api/TCLIService/TCLIService-remote diff --git a/src/databricks/sql/thrift_api/TCLIService/TCLIService.py b/databricks_sql_connector_core/src/databricks/sql/thrift_api/TCLIService/TCLIService.py similarity index 100% rename from src/databricks/sql/thrift_api/TCLIService/TCLIService.py rename to databricks_sql_connector_core/src/databricks/sql/thrift_api/TCLIService/TCLIService.py diff --git a/src/databricks/sql/thrift_api/TCLIService/__init__.py b/databricks_sql_connector_core/src/databricks/sql/thrift_api/TCLIService/__init__.py similarity index 100% rename from src/databricks/sql/thrift_api/TCLIService/__init__.py rename to databricks_sql_connector_core/src/databricks/sql/thrift_api/TCLIService/__init__.py diff --git a/src/databricks/sql/thrift_api/TCLIService/constants.py b/databricks_sql_connector_core/src/databricks/sql/thrift_api/TCLIService/constants.py similarity index 100% rename from src/databricks/sql/thrift_api/TCLIService/constants.py rename to databricks_sql_connector_core/src/databricks/sql/thrift_api/TCLIService/constants.py diff --git a/src/databricks/sql/thrift_api/TCLIService/ttypes.py b/databricks_sql_connector_core/src/databricks/sql/thrift_api/TCLIService/ttypes.py similarity index 100% rename from src/databricks/sql/thrift_api/TCLIService/ttypes.py rename to databricks_sql_connector_core/src/databricks/sql/thrift_api/TCLIService/ttypes.py diff --git a/tests/__init__.py b/databricks_sql_connector_core/src/databricks/sql/thrift_api/__init__.py similarity index 100% rename from tests/__init__.py rename to databricks_sql_connector_core/src/databricks/sql/thrift_api/__init__.py diff --git a/src/databricks/sql/thrift_backend.py b/databricks_sql_connector_core/src/databricks/sql/thrift_backend.py similarity index 99% rename from src/databricks/sql/thrift_backend.py rename to databricks_sql_connector_core/src/databricks/sql/thrift_backend.py index 56412fce..42daf85e 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/databricks_sql_connector_core/src/databricks/sql/thrift_backend.py @@ -8,7 +8,6 @@ from ssl import CERT_NONE, CERT_REQUIRED, create_default_context from typing import List, Union -import pyarrow import thrift.transport.THttpClient import thrift.protocol.TBinaryProtocol import thrift.transport.TSocket @@ -37,6 +36,11 @@ convert_column_based_set_to_arrow_table, ) +try: + import pyarrow +except ImportError: + pyarrow = None + logger = logging.getLogger(__name__) unsafe_logger = logging.getLogger("databricks.sql.unsafe") @@ -652,6 +656,12 @@ def _get_metadata_resp(self, op_handle): @staticmethod def _hive_schema_to_arrow_schema(t_table_schema): + + if pyarrow is None: + raise ImportError( + "pyarrow is required to convert Hive schema to Arrow schema" + ) + def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -858,7 +868,7 @@ def execute_command( getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), - canReadArrowResult=True, + canReadArrowResult=True if pyarrow else False, canDecompressLZ4Result=lz4_compression, canDownloadResult=use_cloud_fetch, confOverlay={ diff --git a/src/databricks/sql/types.py b/databricks_sql_connector_core/src/databricks/sql/types.py similarity index 100% rename from src/databricks/sql/types.py rename to databricks_sql_connector_core/src/databricks/sql/types.py diff --git a/src/databricks/sql/utils.py b/databricks_sql_connector_core/src/databricks/sql/utils.py similarity index 97% rename from src/databricks/sql/utils.py rename to databricks_sql_connector_core/src/databricks/sql/utils.py index c22688bb..1bcc8a88 100644 --- a/src/databricks/sql/utils.py +++ b/databricks_sql_connector_core/src/databricks/sql/utils.py @@ -12,7 +12,6 @@ from ssl import SSLContext import lz4.frame -import pyarrow from databricks.sql import OperationalError, exc from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager @@ -28,16 +27,21 @@ import logging +try: + import pyarrow +except ImportError: + pyarrow = None + logger = logging.getLogger(__name__) class ResultSetQueue(ABC): @abstractmethod - def next_n_rows(self, num_rows: int) -> pyarrow.Table: + def next_n_rows(self, num_rows: int): pass @abstractmethod - def remaining_rows(self) -> pyarrow.Table: + def remaining_rows(self): pass @@ -100,7 +104,7 @@ def build_queue( class ArrowQueue(ResultSetQueue): def __init__( self, - arrow_table: pyarrow.Table, + arrow_table: "pyarrow.Table", n_valid_rows: int, start_row_index: int = 0, ): @@ -115,7 +119,7 @@ def __init__( self.arrow_table = arrow_table self.n_valid_rows = n_valid_rows - def next_n_rows(self, num_rows: int) -> pyarrow.Table: + def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """Get upto the next n rows of the Arrow dataframe""" length = min(num_rows, self.n_valid_rows - self.cur_row_index) # Note that the table.slice API is not the same as Python's slice @@ -124,7 +128,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table: self.cur_row_index += slice.num_rows return slice - def remaining_rows(self) -> pyarrow.Table: + def remaining_rows(self) -> "pyarrow.Table": slice = self.arrow_table.slice( self.cur_row_index, self.n_valid_rows - self.cur_row_index ) @@ -184,7 +188,7 @@ def __init__( self.table = self._create_next_table() self.table_row_index = 0 - def next_n_rows(self, num_rows: int) -> pyarrow.Table: + def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """ Get up to the next n rows of the cloud fetch Arrow dataframes. @@ -216,7 +220,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table: logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) return results - def remaining_rows(self) -> pyarrow.Table: + def remaining_rows(self) -> "pyarrow.Table": """ Get all remaining rows of the cloud fetch Arrow dataframes. @@ -237,7 +241,7 @@ def remaining_rows(self) -> pyarrow.Table: self.table_row_index = 0 return results - def _create_next_table(self) -> Union[pyarrow.Table, None]: + def _create_next_table(self) -> Union["pyarrow.Table", None]: logger.debug( "CloudFetchQueue: Trying to get downloaded file for row {}".format( self.start_row_index @@ -276,7 +280,7 @@ def _create_next_table(self) -> Union[pyarrow.Table, None]: return arrow_table - def _create_empty_table(self) -> pyarrow.Table: + def _create_empty_table(self) -> "pyarrow.Table": # Create a 0-row table with just the schema bytes return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) @@ -515,7 +519,7 @@ def transform_paramstyle( return output -def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> pyarrow.Table: +def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> "pyarrow.Table": arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) return convert_decimals_in_arrow_table(arrow_table, description) @@ -542,7 +546,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema return arrow_table, n_rows -def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table: +def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table": for i, col in enumerate(table.itercolumns()): if description[i][1] == "decimal": decimal_col = col.to_pandas().apply( diff --git a/databricks_sql_connector_core/src/databricks/sqlalchemy/__init__.py b/databricks_sql_connector_core/src/databricks/sqlalchemy/__init__.py new file mode 100644 index 00000000..f79d4c20 --- /dev/null +++ b/databricks_sql_connector_core/src/databricks/sqlalchemy/__init__.py @@ -0,0 +1,6 @@ +try: + from databricks_sqlalchemy import * +except: + import warnings + + warnings.warn("Install databricks-sqlalchemy plugin before using this") \ No newline at end of file diff --git a/tests/e2e/__init__.py b/databricks_sql_connector_core/tests/__init__.py similarity index 100% rename from tests/e2e/__init__.py rename to databricks_sql_connector_core/tests/__init__.py diff --git a/conftest.py b/databricks_sql_connector_core/tests/conftest.py similarity index 100% rename from conftest.py rename to databricks_sql_connector_core/tests/conftest.py diff --git a/tests/e2e/common/__init__.py b/databricks_sql_connector_core/tests/e2e/__init__.py similarity index 100% rename from tests/e2e/common/__init__.py rename to databricks_sql_connector_core/tests/e2e/__init__.py diff --git a/tests/unit/__init__.py b/databricks_sql_connector_core/tests/e2e/common/__init__.py similarity index 100% rename from tests/unit/__init__.py rename to databricks_sql_connector_core/tests/e2e/common/__init__.py diff --git a/tests/e2e/common/core_tests.py b/databricks_sql_connector_core/tests/e2e/common/core_tests.py similarity index 100% rename from tests/e2e/common/core_tests.py rename to databricks_sql_connector_core/tests/e2e/common/core_tests.py diff --git a/tests/e2e/common/decimal_tests.py b/databricks_sql_connector_core/tests/e2e/common/decimal_tests.py similarity index 79% rename from tests/e2e/common/decimal_tests.py rename to databricks_sql_connector_core/tests/e2e/common/decimal_tests.py index 5005cdf1..47fc2070 100644 --- a/tests/e2e/common/decimal_tests.py +++ b/databricks_sql_connector_core/tests/e2e/common/decimal_tests.py @@ -1,11 +1,20 @@ from decimal import Decimal -import pyarrow import pytest +try: + import pyarrow +except ImportError: + pyarrow = None -class DecimalTestsMixin: - decimal_and_expected_results = [ +from tests.e2e.common.predicates import pysql_supports_arrow + +def decimal_and_expected_results(): + + if pyarrow is None: + return [] + + return [ ("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)), ("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)), ("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)), @@ -17,7 +26,12 @@ class DecimalTestsMixin: ("1e-3 AS DECIMAL(38, 3)", Decimal("0.001"), pyarrow.decimal128(38, 3)), ] - multi_decimals_and_expected_results = [ +def multi_decimals_and_expected_results(): + + if pyarrow is None: + return [] + + return [ ( ["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"], [Decimal("1.00"), Decimal("100.001"), None], @@ -30,7 +44,9 @@ class DecimalTestsMixin: ), ] - @pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results) +@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") +class DecimalTestsMixin: + @pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results()) def test_decimals(self, decimal, expected_value, expected_type): with self.cursor({}) as cursor: query = "SELECT CAST ({})".format(decimal) @@ -39,9 +55,7 @@ def test_decimals(self, decimal, expected_value, expected_type): assert table.field(0).type == expected_type assert table.to_pydict().popitem()[1][0] == expected_value - @pytest.mark.parametrize( - "decimals, expected_values, expected_type", multi_decimals_and_expected_results - ) + @pytest.mark.parametrize("decimals, expected_values, expected_type", multi_decimals_and_expected_results()) def test_multi_decimals(self, decimals, expected_values, expected_type): with self.cursor({}) as cursor: union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals]) diff --git a/tests/e2e/common/large_queries_mixin.py b/databricks_sql_connector_core/tests/e2e/common/large_queries_mixin.py similarity index 95% rename from tests/e2e/common/large_queries_mixin.py rename to databricks_sql_connector_core/tests/e2e/common/large_queries_mixin.py index 9ebc3f01..07d02447 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/databricks_sql_connector_core/tests/e2e/common/large_queries_mixin.py @@ -1,6 +1,10 @@ import logging import math import time +from unittest import skipUnless + +import pytest +from tests.e2e.common.predicates import pysql_supports_arrow log = logging.getLogger(__name__) @@ -40,6 +44,7 @@ def fetch_rows(self, cursor, row_count, fetchmany_size): + "assuming 10K fetch size." ) + @pytest.mark.skipif(not pysql_supports_arrow(), "Without pyarrow lz4 compression is not supported") def test_query_with_large_wide_result_set(self): resultSize = 300 * 1000 * 1000 # 300 MB width = 8192 # B diff --git a/tests/e2e/common/predicates.py b/databricks_sql_connector_core/tests/e2e/common/predicates.py similarity index 95% rename from tests/e2e/common/predicates.py rename to databricks_sql_connector_core/tests/e2e/common/predicates.py index 88b14961..99e6f701 100644 --- a/tests/e2e/common/predicates.py +++ b/databricks_sql_connector_core/tests/e2e/common/predicates.py @@ -8,9 +8,13 @@ def pysql_supports_arrow(): - """Import databricks.sql and test whether Cursor has fetchall_arrow.""" - from databricks.sql.client import Cursor - return hasattr(Cursor, 'fetchall_arrow') + """Checks if the pyarrow library is installed or not""" + try: + import pyarrow + + return True + except ImportError: + return False def pysql_has_version(compare, version): diff --git a/tests/e2e/common/retry_test_mixins.py b/databricks_sql_connector_core/tests/e2e/common/retry_test_mixins.py similarity index 100% rename from tests/e2e/common/retry_test_mixins.py rename to databricks_sql_connector_core/tests/e2e/common/retry_test_mixins.py diff --git a/tests/e2e/common/staging_ingestion_tests.py b/databricks_sql_connector_core/tests/e2e/common/staging_ingestion_tests.py similarity index 100% rename from tests/e2e/common/staging_ingestion_tests.py rename to databricks_sql_connector_core/tests/e2e/common/staging_ingestion_tests.py diff --git a/tests/e2e/common/timestamp_tests.py b/databricks_sql_connector_core/tests/e2e/common/timestamp_tests.py similarity index 100% rename from tests/e2e/common/timestamp_tests.py rename to databricks_sql_connector_core/tests/e2e/common/timestamp_tests.py diff --git a/tests/e2e/common/uc_volume_tests.py b/databricks_sql_connector_core/tests/e2e/common/uc_volume_tests.py similarity index 100% rename from tests/e2e/common/uc_volume_tests.py rename to databricks_sql_connector_core/tests/e2e/common/uc_volume_tests.py diff --git a/tests/e2e/test_complex_types.py b/databricks_sql_connector_core/tests/e2e/test_complex_types.py similarity index 93% rename from tests/e2e/test_complex_types.py rename to databricks_sql_connector_core/tests/e2e/test_complex_types.py index 0a7f514a..acac4e44 100644 --- a/tests/e2e/test_complex_types.py +++ b/databricks_sql_connector_core/tests/e2e/test_complex_types.py @@ -2,8 +2,9 @@ from numpy import ndarray from tests.e2e.test_driver import PySQLPytestTestCase +from tests.e2e.common.predicates import pysql_supports_arrow - +@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") class TestComplexTypes(PySQLPytestTestCase): @pytest.fixture(scope="class") def table_fixture(self, connection_details): diff --git a/tests/e2e/test_driver.py b/databricks_sql_connector_core/tests/e2e/test_driver.py similarity index 97% rename from tests/e2e/test_driver.py rename to databricks_sql_connector_core/tests/e2e/test_driver.py index c23e4f79..6fa686e9 100644 --- a/tests/e2e/test_driver.py +++ b/databricks_sql_connector_core/tests/e2e/test_driver.py @@ -12,7 +12,6 @@ from uuid import uuid4 import numpy as np -import pyarrow import pytz import thrift import pytest @@ -35,6 +34,7 @@ pysql_supports_arrow, compare_dbr_versions, is_thrift_v5_plus, + pysql_supports_arrow ) from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin from tests.e2e.common.large_queries_mixin import LargeQueriesMixin @@ -48,6 +48,11 @@ from databricks.sql.exc import SessionAlreadyClosedError +try: + import pyarrow +except: + pyarrow = None + log = logging.getLogger(__name__) unsafe_logger = logging.getLogger("databricks.sql.unsafe") @@ -591,7 +596,7 @@ def test_ssp_passthrough(self): cursor.execute("SET ansi_mode") assert list(cursor.fetchone()) == ["ansi_mode", str(enable_ansi)] - @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") + @pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") def test_timestamps_arrow(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: for timestamp, expected in self.timestamp_and_expected_results: @@ -611,7 +616,7 @@ def test_timestamps_arrow(self): aware_timestamp and aware_timestamp.timestamp() * 1000000 ), "timestamp {} did not match {}".format(timestamp, expected) - @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") + @pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") def test_multi_timestamps_arrow(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: query, expected = self.multi_query() @@ -627,7 +632,7 @@ def test_multi_timestamps_arrow(self): ] assert result == expected - @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") + @pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") def test_timezone_with_timestamp(self): if self.should_add_timezone(): with self.cursor() as cursor: @@ -646,7 +651,7 @@ def test_timezone_with_timestamp(self): assert arrow_result_table.field(0).type == ts_type assert arrow_result_value == expected.timestamp() * 1000000 - @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") + @pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") def test_can_flip_compression(self): with self.cursor() as cursor: cursor.execute("SELECT array(1,2,3,4)") @@ -663,7 +668,7 @@ def test_can_flip_compression(self): def _should_have_native_complex_types(self): return pysql_has_version(">=", 2) and is_thrift_v5_plus(self.arguments) - @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") + @pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") def test_arrays_are_not_returned_as_strings_arrow(self): if self._should_have_native_complex_types(): with self.cursor() as cursor: @@ -674,7 +679,7 @@ def test_arrays_are_not_returned_as_strings_arrow(self): assert pyarrow.types.is_list(list_type) assert pyarrow.types.is_integer(list_type.value_type) - @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") + @pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") def test_structs_are_not_returned_as_strings_arrow(self): if self._should_have_native_complex_types(): with self.cursor() as cursor: @@ -684,7 +689,7 @@ def test_structs_are_not_returned_as_strings_arrow(self): struct_type = arrow_df.field(0).type assert pyarrow.types.is_struct(struct_type) - @skipUnless(pysql_supports_arrow(), "arrow test needs arrow support") + @pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") def test_decimal_not_returned_as_strings_arrow(self): if self._should_have_native_complex_types(): with self.cursor() as cursor: diff --git a/tests/e2e/test_parameterized_queries.py b/databricks_sql_connector_core/tests/e2e/test_parameterized_queries.py similarity index 98% rename from tests/e2e/test_parameterized_queries.py rename to databricks_sql_connector_core/tests/e2e/test_parameterized_queries.py index 47dfc38c..e2eac174 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/databricks_sql_connector_core/tests/e2e/test_parameterized_queries.py @@ -28,6 +28,7 @@ VoidParameter, ) from tests.e2e.test_driver import PySQLPytestTestCase +from tests.e2e.common.predicates import pysql_supports_arrow class ParamStyle(Enum): @@ -284,6 +285,8 @@ def test_primitive_single( (PrimitiveExtra.TINYINT, TinyIntParameter), ], ) + + @pytest.mark.skipif(not pysql_supports_arrow(),reason="Without pyarrow TIMESTAMP_NTZ datatype cannot be inferred",) def test_dbsqlparameter_single( self, primitive: Primitive, diff --git a/src/databricks/sqlalchemy/py.typed b/databricks_sql_connector_core/tests/unit/__init__.py old mode 100755 new mode 100644 similarity index 100% rename from src/databricks/sqlalchemy/py.typed rename to databricks_sql_connector_core/tests/unit/__init__.py diff --git a/tests/unit/test_arrow_queue.py b/databricks_sql_connector_core/tests/unit/test_arrow_queue.py similarity index 82% rename from tests/unit/test_arrow_queue.py rename to databricks_sql_connector_core/tests/unit/test_arrow_queue.py index 6834cc9c..ac98e137 100644 --- a/tests/unit/test_arrow_queue.py +++ b/databricks_sql_connector_core/tests/unit/test_arrow_queue.py @@ -1,10 +1,17 @@ import unittest -import pyarrow as pa +import pytest from databricks.sql.utils import ArrowQueue +try: + import pyarrow as pa +except ImportError: + pa = None +from tests.e2e.common.predicates import pysql_supports_arrow + +@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") class ArrowQueueSuite(unittest.TestCase): @staticmethod def make_arrow_table(batch): diff --git a/tests/unit/test_auth.py b/databricks_sql_connector_core/tests/unit/test_auth.py similarity index 100% rename from tests/unit/test_auth.py rename to databricks_sql_connector_core/tests/unit/test_auth.py diff --git a/tests/unit/test_client.py b/databricks_sql_connector_core/tests/unit/test_client.py similarity index 100% rename from tests/unit/test_client.py rename to databricks_sql_connector_core/tests/unit/test_client.py diff --git a/tests/unit/test_cloud_fetch_queue.py b/databricks_sql_connector_core/tests/unit/test_cloud_fetch_queue.py similarity index 98% rename from tests/unit/test_cloud_fetch_queue.py rename to databricks_sql_connector_core/tests/unit/test_cloud_fetch_queue.py index cd14c676..def6b8aa 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/databricks_sql_connector_core/tests/unit/test_cloud_fetch_queue.py @@ -1,11 +1,18 @@ -import pyarrow +import pytest import unittest from unittest.mock import MagicMock, patch from ssl import create_default_context from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink import databricks.sql.utils as utils +from tests.e2e.common.predicates import pysql_supports_arrow +try: + import pyarrow +except ImportError: + pyarrow = None + +@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") class CloudFetchQueueSuite(unittest.TestCase): def create_result_link( diff --git a/tests/unit/test_download_manager.py b/databricks_sql_connector_core/tests/unit/test_download_manager.py similarity index 93% rename from tests/unit/test_download_manager.py rename to databricks_sql_connector_core/tests/unit/test_download_manager.py index c084d8e4..f17049e8 100644 --- a/tests/unit/test_download_manager.py +++ b/databricks_sql_connector_core/tests/unit/test_download_manager.py @@ -1,12 +1,15 @@ import unittest from unittest.mock import patch, MagicMock +import pytest from ssl import create_default_context import databricks.sql.cloudfetch.download_manager as download_manager from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink +from tests.e2e.common.predicates import pysql_supports_arrow +@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") class DownloadManagerTests(unittest.TestCase): """ Unit tests for checking download manager logic. diff --git a/tests/unit/test_downloader.py b/databricks_sql_connector_core/tests/unit/test_downloader.py similarity index 100% rename from tests/unit/test_downloader.py rename to databricks_sql_connector_core/tests/unit/test_downloader.py diff --git a/tests/unit/test_endpoint.py b/databricks_sql_connector_core/tests/unit/test_endpoint.py similarity index 100% rename from tests/unit/test_endpoint.py rename to databricks_sql_connector_core/tests/unit/test_endpoint.py diff --git a/tests/unit/test_fetches.py b/databricks_sql_connector_core/tests/unit/test_fetches.py similarity index 97% rename from tests/unit/test_fetches.py rename to databricks_sql_connector_core/tests/unit/test_fetches.py index 7d5686f8..c1aeadca 100644 --- a/tests/unit/test_fetches.py +++ b/databricks_sql_connector_core/tests/unit/test_fetches.py @@ -1,12 +1,17 @@ import unittest from unittest.mock import Mock - -import pyarrow as pa +import pytest import databricks.sql.client as client from databricks.sql.utils import ExecuteResponse, ArrowQueue +from tests.e2e.common.predicates import pysql_supports_arrow +try: + import pyarrow as pa +except ImportError: + pa = None +@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") class FetchTests(unittest.TestCase): """ Unit tests for checking the fetch logic. diff --git a/tests/unit/test_fetches_bench.py b/databricks_sql_connector_core/tests/unit/test_fetches_bench.py similarity index 90% rename from tests/unit/test_fetches_bench.py rename to databricks_sql_connector_core/tests/unit/test_fetches_bench.py index e322b44a..bba18247 100644 --- a/tests/unit/test_fetches_bench.py +++ b/databricks_sql_connector_core/tests/unit/test_fetches_bench.py @@ -1,15 +1,20 @@ import unittest from unittest.mock import Mock -import pyarrow as pa import uuid import time import pytest import databricks.sql.client as client from databricks.sql.utils import ExecuteResponse, ArrowQueue +from tests.e2e.common.predicates import pysql_supports_arrow +try: + import pyarrow as pa +except ImportError: + pa = None +@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") class FetchBenchmarkTests(unittest.TestCase): """ Micro benchmark test for Arrow result handling. diff --git a/tests/unit/test_init_file.py b/databricks_sql_connector_core/tests/unit/test_init_file.py similarity index 100% rename from tests/unit/test_init_file.py rename to databricks_sql_connector_core/tests/unit/test_init_file.py diff --git a/tests/unit/test_oauth_persistence.py b/databricks_sql_connector_core/tests/unit/test_oauth_persistence.py similarity index 100% rename from tests/unit/test_oauth_persistence.py rename to databricks_sql_connector_core/tests/unit/test_oauth_persistence.py diff --git a/tests/unit/test_param_escaper.py b/databricks_sql_connector_core/tests/unit/test_param_escaper.py similarity index 100% rename from tests/unit/test_param_escaper.py rename to databricks_sql_connector_core/tests/unit/test_param_escaper.py diff --git a/tests/unit/test_parameters.py b/databricks_sql_connector_core/tests/unit/test_parameters.py similarity index 100% rename from tests/unit/test_parameters.py rename to databricks_sql_connector_core/tests/unit/test_parameters.py diff --git a/tests/unit/test_retry.py b/databricks_sql_connector_core/tests/unit/test_retry.py similarity index 100% rename from tests/unit/test_retry.py rename to databricks_sql_connector_core/tests/unit/test_retry.py diff --git a/tests/unit/test_thrift_backend.py b/databricks_sql_connector_core/tests/unit/test_thrift_backend.py similarity index 99% rename from tests/unit/test_thrift_backend.py rename to databricks_sql_connector_core/tests/unit/test_thrift_backend.py index 4bcf84d2..9b53a17e 100644 --- a/tests/unit/test_thrift_backend.py +++ b/databricks_sql_connector_core/tests/unit/test_thrift_backend.py @@ -2,18 +2,22 @@ from decimal import Decimal import itertools import unittest +import pytest from unittest.mock import patch, MagicMock, Mock from ssl import CERT_NONE, CERT_REQUIRED -import pyarrow - import databricks.sql from databricks.sql import utils from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql import * from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.thrift_backend import ThriftBackend +from tests.e2e.common.predicates import pysql_supports_arrow +try: + import pyarrow +except ImportError: + pyarrow = None def retry_policy_factory(): return { # (type, default, min, max) @@ -24,7 +28,7 @@ def retry_policy_factory(): "_retry_delay_default": (float, 5, 1, 60), } - +@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed") class ThriftBackendTestSuite(unittest.TestCase): okay_status = ttypes.TStatus(statusCode=ttypes.TStatusCode.SUCCESS_STATUS) diff --git a/src/databricks/sqlalchemy/README.sqlalchemy.md b/src/databricks/sqlalchemy/README.sqlalchemy.md deleted file mode 100644 index 8aa51973..00000000 --- a/src/databricks/sqlalchemy/README.sqlalchemy.md +++ /dev/null @@ -1,203 +0,0 @@ -## Databricks dialect for SQLALchemy 2.0 - -The Databricks dialect for SQLAlchemy serves as bridge between [SQLAlchemy](https://www.sqlalchemy.org/) and the Databricks SQL Python driver. The dialect is included with `databricks-sql-connector==3.0.0` and above. A working example demonstrating usage can be found in `examples/sqlalchemy.py`. - -## Usage with SQLAlchemy <= 2.0 -A SQLAlchemy 1.4 compatible dialect was first released in connector [version 2.4](https://github.com/databricks/databricks-sql-python/releases/tag/v2.4.0). Support for SQLAlchemy 1.4 was dropped from the dialect as part of `databricks-sql-connector==3.0.0`. To continue using the dialect with SQLAlchemy 1.x, you can use `databricks-sql-connector^2.4.0`. - - -## Installation - -To install the dialect and its dependencies: - -```shell -pip install databricks-sql-connector[sqlalchemy] -``` - -If you also plan to use `alembic` you can alternatively run: - -```shell -pip install databricks-sql-connector[alembic] -``` - -## Connection String - -Every SQLAlchemy application that connects to a database needs to use an [Engine](https://docs.sqlalchemy.org/en/20/tutorial/engine.html#tutorial-engine), which you can create by passing a connection string to `create_engine`. The connection string must include these components: - -1. Host -2. HTTP Path for a compute resource -3. API access token -4. Initial catalog for the connection -5. Initial schema for the connection - -**Note: Our dialect is built and tested on workspaces with Unity Catalog enabled. Support for the `hive_metastore` catalog is untested.** - -For example: - -```python -import os -from sqlalchemy import create_engine - -host = os.getenv("DATABRICKS_SERVER_HOSTNAME") -http_path = os.getenv("DATABRICKS_HTTP_PATH") -access_token = os.getenv("DATABRICKS_TOKEN") -catalog = os.getenv("DATABRICKS_CATALOG") -schema = os.getenv("DATABRICKS_SCHEMA") - -engine = create_engine( - f"databricks://token:{access_token}@{host}?http_path={http_path}&catalog={catalog}&schema={schema}" - ) -``` - -## Types - -The [SQLAlchemy type hierarchy](https://docs.sqlalchemy.org/en/20/core/type_basics.html) contains backend-agnostic type implementations (represented in CamelCase) and backend-specific types (represented in UPPERCASE). The majority of SQLAlchemy's [CamelCase](https://docs.sqlalchemy.org/en/20/core/type_basics.html#the-camelcase-datatypes) types are supported. This means that a SQLAlchemy application using these types should "just work" with Databricks. - -|SQLAlchemy Type|Databricks SQL Type| -|-|-| -[`BigInteger`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.BigInteger)| [`BIGINT`](https://docs.databricks.com/en/sql/language-manual/data-types/bigint-type.html) -[`LargeBinary`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.LargeBinary)| (not supported)| -[`Boolean`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Boolean)| [`BOOLEAN`](https://docs.databricks.com/en/sql/language-manual/data-types/boolean-type.html) -[`Date`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Date)| [`DATE`](https://docs.databricks.com/en/sql/language-manual/data-types/date-type.html) -[`DateTime`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.DateTime)| [`TIMESTAMP_NTZ`](https://docs.databricks.com/en/sql/language-manual/data-types/timestamp-ntz-type.html)| -[`Double`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Double)| [`DOUBLE`](https://docs.databricks.com/en/sql/language-manual/data-types/double-type.html) -[`Enum`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Enum)| (not supported)| -[`Float`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Float)| [`FLOAT`](https://docs.databricks.com/en/sql/language-manual/data-types/float-type.html) -[`Integer`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Integer)| [`INT`](https://docs.databricks.com/en/sql/language-manual/data-types/int-type.html) -[`Numeric`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Numeric)| [`DECIMAL`](https://docs.databricks.com/en/sql/language-manual/data-types/decimal-type.html)| -[`PickleType`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.PickleType)| (not supported)| -[`SmallInteger`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.SmallInteger)| [`SMALLINT`](https://docs.databricks.com/en/sql/language-manual/data-types/smallint-type.html) -[`String`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.String)| [`STRING`](https://docs.databricks.com/en/sql/language-manual/data-types/string-type.html)| -[`Text`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Text)| [`STRING`](https://docs.databricks.com/en/sql/language-manual/data-types/string-type.html)| -[`Time`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Time)| [`STRING`](https://docs.databricks.com/en/sql/language-manual/data-types/string-type.html)| -[`Unicode`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Unicode)| [`STRING`](https://docs.databricks.com/en/sql/language-manual/data-types/string-type.html)| -[`UnicodeText`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.UnicodeText)| [`STRING`](https://docs.databricks.com/en/sql/language-manual/data-types/string-type.html)| -[`Uuid`](https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Uuid)| [`STRING`](https://docs.databricks.com/en/sql/language-manual/data-types/string-type.html) - -In addition, the dialect exposes three UPPERCASE SQLAlchemy types which are specific to Databricks: - -- [`databricks.sqlalchemy.TINYINT`](https://docs.databricks.com/en/sql/language-manual/data-types/tinyint-type.html) -- [`databricks.sqlalchemy.TIMESTAMP`](https://docs.databricks.com/en/sql/language-manual/data-types/timestamp-type.html) -- [`databricks.sqlalchemy.TIMESTAMP_NTZ`](https://docs.databricks.com/en/sql/language-manual/data-types/timestamp-ntz-type.html) - - -### `LargeBinary()` and `PickleType()` - -Databricks Runtime doesn't currently support binding of binary values in SQL queries, which is a pre-requisite for this functionality in SQLAlchemy. - -## `Enum()` and `CHECK` constraints - -Support for `CHECK` constraints is not implemented in this dialect. Support is planned for a future release. - -SQLAlchemy's `Enum()` type depends on `CHECK` constraints and is therefore not yet supported. - -### `DateTime()`, `TIMESTAMP_NTZ()`, and `TIMESTAMP()` - -Databricks Runtime provides two datetime-like types: `TIMESTAMP` which is always timezone-aware and `TIMESTAMP_NTZ` which is timezone agnostic. Both types can be imported from `databricks.sqlalchemy` and used in your models. - -The SQLAlchemy documentation indicates that `DateTime()` is not timezone-aware by default. So our dialect maps this type to `TIMESTAMP_NTZ()`. In practice, you should never need to use `TIMESTAMP_NTZ()` directly. Just use `DateTime()`. - -If you need your field to be timezone-aware, you can import `TIMESTAMP()` and use it instead. - -_Note that SQLAlchemy documentation suggests that you can declare a `DateTime()` with `timezone=True` on supported backends. However, if you do this with the Databricks dialect, the `timezone` argument will be ignored._ - -```python -from sqlalchemy import DateTime -from databricks.sqlalchemy import TIMESTAMP - -class SomeModel(Base): - some_date_without_timezone = DateTime() - some_date_with_timezone = TIMESTAMP() -``` - -### `String()`, `Text()`, `Unicode()`, and `UnicodeText()` - -Databricks Runtime doesn't support length limitations for `STRING` fields. Therefore `String()` or `String(1)` or `String(255)` will all produce identical DDL. Since `Text()`, `Unicode()`, `UnicodeText()` all use the same underlying type in Databricks SQL, they will generate equivalent DDL. - -### `Time()` - -Databricks Runtime doesn't have a native time-like data type. To implement this type in SQLAlchemy, our dialect stores SQLAlchemy `Time()` values in a `STRING` field. Unlike `DateTime` above, this type can optionally support timezone awareness (since the dialect is in complete control of the strings that we write to the Delta table). - -```python -from sqlalchemy import Time - -class SomeModel(Base): - time_tz = Time(timezone=True) - time_ntz = Time() -``` - - -# Usage Notes - -## `Identity()` and `autoincrement` - -Identity and generated value support is currently limited in this dialect. - -When defining models, SQLAlchemy types can accept an [`autoincrement`](https://docs.sqlalchemy.org/en/20/core/metadata.html#sqlalchemy.schema.Column.params.autoincrement) argument. In our dialect, this argument is currently ignored. To create an auto-incrementing field in your model you can pass in an explicit [`Identity()`](https://docs.sqlalchemy.org/en/20/core/defaults.html#identity-ddl) instead. - -Furthermore, in Databricks Runtime, only `BIGINT` fields can be configured to auto-increment. So in SQLAlchemy, you must use the `BigInteger()` type. - -```python -from sqlalchemy import Identity, String - -class SomeModel(Base): - id = BigInteger(Identity()) - value = String() -``` - -When calling `Base.metadata.create_all()`, the executed DDL will include `GENERATED ALWAYS AS IDENTITY` for the `id` column. This is useful when using SQLAlchemy to generate tables. However, as of this writing, `Identity()` constructs are not captured when SQLAlchemy reflects a table's metadata (support for this is planned). - -## Parameters - -`databricks-sql-connector` supports two approaches to parameterizing SQL queries: native and inline. Our SQLAlchemy 2.0 dialect always uses the native approach and is therefore limited to DBR 14.2 and above. If you are writing parameterized queries to be executed by SQLAlchemy, you must use the "named" paramstyle (`:param`). Read more about parameterization in `docs/parameters.md`. - -## Usage with pandas - -Use [`pandas.DataFrame.to_sql`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html) and [`pandas.read_sql`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_sql.html#pandas.read_sql) to write and read from Databricks SQL. These methods both accept a SQLAlchemy connection to interact with Databricks. - -### Read from Databricks SQL into pandas -```python -from sqlalchemy import create_engine -import pandas as pd - -engine = create_engine("databricks://token:dapi***@***.cloud.databricks.com?http_path=***&catalog=main&schema=test") -with engine.connect() as conn: - # This will read the contents of `main.test.some_table` - df = pd.read_sql("some_table", conn) -``` - -### Write to Databricks SQL from pandas - -```python -from sqlalchemy import create_engine -import pandas as pd - -engine = create_engine("databricks://token:dapi***@***.cloud.databricks.com?http_path=***&catalog=main&schema=test") -squares = [(i, i * i) for i in range(100)] -df = pd.DataFrame(data=squares,columns=['x','x_squared']) - -with engine.connect() as conn: - # This will write the contents of `df` to `main.test.squares` - df.to_sql('squares',conn) -``` - -## [`PrimaryKey()`](https://docs.sqlalchemy.org/en/20/core/constraints.html#sqlalchemy.schema.PrimaryKeyConstraint) and [`ForeignKey()`](https://docs.sqlalchemy.org/en/20/core/constraints.html#defining-foreign-keys) - -Unity Catalog workspaces in Databricks support PRIMARY KEY and FOREIGN KEY constraints. _Note that Databricks Runtime does not enforce the integrity of FOREIGN KEY constraints_. You can establish a primary key by setting `primary_key=True` when defining a column. - -When building `ForeignKey` or `ForeignKeyConstraint` objects, you must specify a `name` for the constraint. - -If your model definition requires a self-referential FOREIGN KEY constraint, you must include `use_alter=True` when defining the relationship. - -```python -from sqlalchemy import Table, Column, ForeignKey, BigInteger, String - -users = Table( - "users", - metadata_obj, - Column("id", BigInteger, primary_key=True), - Column("name", String(), nullable=False), - Column("email", String()), - Column("manager_id", ForeignKey("users.id", name="fk_users_manager_id_x_users_id", use_alter=True)) -) -``` diff --git a/src/databricks/sqlalchemy/README.tests.md b/src/databricks/sqlalchemy/README.tests.md deleted file mode 100644 index 3ed92aba..00000000 --- a/src/databricks/sqlalchemy/README.tests.md +++ /dev/null @@ -1,44 +0,0 @@ -## SQLAlchemy Dialect Compliance Test Suite with Databricks - -The contents of the `test/` directory follow the SQLAlchemy developers' [guidance] for running the reusable dialect compliance test suite. Since not every test in the suite is applicable to every dialect, two options are provided to skip tests: - -- Any test can be skipped by subclassing its parent class, re-declaring the test-case and adding a `pytest.mark.skip` directive. -- Any test that is decorated with a `@requires` decorator can be skipped by marking the indicated requirement as `.closed()` in `requirements.py` - -We prefer to skip test cases directly with the first method wherever possible. We only mark requirements as `closed()` if there is no easier option to avoid a test failure. This principally occurs in test cases where the same test in the suite is parametrized, and some parameter combinations are conditionally skipped depending on `requirements.py`. If we skip the entire test method, then we skip _all_ permutations, not just the combinations we don't support. - -## Regression, Unsupported, and Future test cases - -We maintain three files of test cases that we import from the SQLAlchemy source code: - -* **`_regression.py`** contains all the tests cases with tests that we expect to pass for our dialect. Each one is marked with `pytest.mark.reiewed` to indicate that we've evaluated it for relevance. This file only contains base class declarations. -* **`_unsupported.py`** contains test cases that fail because of missing features in Databricks. We mark them as skipped with a `SkipReason` enumeration. If Databricks comes to support these features, those test or entire classes can be moved to `_regression.py`. -* **`_future.py`** contains test cases that fail because of missing features in the dialect itself, but which _are_ supported by Databricks generally. We mark them as skipped with a `FutureFeature` enumeration. These are features that have not been prioritised or that do not violate our acceptance criteria. All of these test cases will eventually move to either `_regression.py`. - -In some cases, only certain tests in class should be skipped with a `SkipReason` or `FutureFeature` justification. In those cases, we import the class into `_regression.py`, then import it from there into one or both of `_future.py` and `_unsupported.py`. If a class needs to be "touched" by regression, unsupported, and future, the class will be imported in that order. If an entire class should be skipped, then we do not import it into `_regression.py` at all. - -We maintain `_extra.py` with test cases that depend on SQLAlchemy's reusable dialect test fixtures but which are specific to Databricks (e.g TinyIntegerTest). - -## Running the reusable dialect tests - -``` -poetry shell -cd src/databricks/sqlalchemy/test -python -m pytest test_suite.py --dburi \ - "databricks://token:$access_token@$host?http_path=$http_path&catalog=$catalog&schema=$schema" -``` - -Whatever schema you pass in the `dburi` argument should be empty. Some tests also require the presence of an empty schema named `test_schema`. Note that we plan to implement our own `provision.py` which SQLAlchemy can automatically use to create an empty schema for testing. But for now this is a manual process. - -You can run only reviewed tests by appending `-m "reviewed"` to the test runner invocation. - -You can run only the unreviewed tests by appending `-m "not reviewed"` instead. - -Note that because these tests depend on SQLAlchemy's custom pytest plugin, they are not discoverable by IDE-based test runners like VSCode or PyCharm and must be invoked from a CLI. - -## Running local unit and e2e tests - -Apart from the SQLAlchemy reusable suite, we maintain our own unit and e2e tests under the `test_local/` directory. These can be invoked from a VSCode or Pycharm since they don't depend on a custom pytest plugin. Due to pytest's lookup order, the `pytest.ini` which is required for running the reusable dialect tests, also conflicts with VSCode and Pycharm's default pytest implementation and overrides the settings in `pyproject.toml`. So to run these tests, you can delete or rename `pytest.ini`. - - -[guidance]: "https://github.com/sqlalchemy/sqlalchemy/blob/rel_2_0_22/README.dialects.rst" diff --git a/src/databricks/sqlalchemy/__init__.py b/src/databricks/sqlalchemy/__init__.py deleted file mode 100644 index 2a17ac3e..00000000 --- a/src/databricks/sqlalchemy/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from databricks.sqlalchemy.base import DatabricksDialect -from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ - -__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ"] diff --git a/src/databricks/sqlalchemy/_ddl.py b/src/databricks/sqlalchemy/_ddl.py deleted file mode 100644 index d5d0bf87..00000000 --- a/src/databricks/sqlalchemy/_ddl.py +++ /dev/null @@ -1,100 +0,0 @@ -import re -from sqlalchemy.sql import compiler, sqltypes -import logging - -logger = logging.getLogger(__name__) - - -class DatabricksIdentifierPreparer(compiler.IdentifierPreparer): - """https://docs.databricks.com/en/sql/language-manual/sql-ref-identifiers.html""" - - legal_characters = re.compile(r"^[A-Z0-9_]+$", re.I) - - def __init__(self, dialect): - super().__init__(dialect, initial_quote="`") - - -class DatabricksDDLCompiler(compiler.DDLCompiler): - def post_create_table(self, table): - post = [" USING DELTA"] - if table.comment: - comment = self.sql_compiler.render_literal_value( - table.comment, sqltypes.String() - ) - post.append("COMMENT " + comment) - - post.append("TBLPROPERTIES('delta.feature.allowColumnDefaults' = 'enabled')") - return "\n".join(post) - - def visit_unique_constraint(self, constraint, **kw): - logger.warning("Databricks does not support unique constraints") - pass - - def visit_check_constraint(self, constraint, **kw): - logger.warning("This dialect does not support check constraints") - pass - - def visit_identity_column(self, identity, **kw): - """When configuring an Identity() with Databricks, only the always option is supported. - All other options are ignored. - - Note: IDENTITY columns must always be defined as BIGINT. An exception will be raised if INT is used. - - https://www.databricks.com/blog/2022/08/08/identity-columns-to-generate-surrogate-keys-are-now-available-in-a-lakehouse-near-you.html - """ - text = "GENERATED %s AS IDENTITY" % ( - "ALWAYS" if identity.always else "BY DEFAULT", - ) - return text - - def visit_set_column_comment(self, create, **kw): - return "ALTER TABLE %s ALTER COLUMN %s COMMENT %s" % ( - self.preparer.format_table(create.element.table), - self.preparer.format_column(create.element), - self.sql_compiler.render_literal_value( - create.element.comment, sqltypes.String() - ), - ) - - def visit_drop_column_comment(self, create, **kw): - return "ALTER TABLE %s ALTER COLUMN %s COMMENT ''" % ( - self.preparer.format_table(create.element.table), - self.preparer.format_column(create.element), - ) - - def get_column_specification(self, column, **kwargs): - """ - Emit a log message if a user attempts to set autoincrement=True on a column. - See comments in test_suite.py. We may implement implicit IDENTITY using this - feature in the future, similar to the Microsoft SQL Server dialect. - """ - if column is column.table._autoincrement_column or column.autoincrement is True: - logger.warning( - "Databricks dialect ignores SQLAlchemy's autoincrement semantics. Use explicit Identity() instead." - ) - - colspec = super().get_column_specification(column, **kwargs) - if column.comment is not None: - literal = self.sql_compiler.render_literal_value( - column.comment, sqltypes.STRINGTYPE - ) - colspec += " COMMENT " + literal - - return colspec - - -class DatabricksStatementCompiler(compiler.SQLCompiler): - def limit_clause(self, select, **kw): - """Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1, - since Databricks SQL doesn't support the latter. - - https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-limit.html - """ - text = "" - if select._limit_clause is not None: - text += "\n LIMIT " + self.process(select._limit_clause, **kw) - if select._offset_clause is not None: - if select._limit_clause is None: - text += "\n LIMIT ALL" - text += " OFFSET " + self.process(select._offset_clause, **kw) - return text diff --git a/src/databricks/sqlalchemy/_parse.py b/src/databricks/sqlalchemy/_parse.py deleted file mode 100644 index 6d38e1e6..00000000 --- a/src/databricks/sqlalchemy/_parse.py +++ /dev/null @@ -1,385 +0,0 @@ -from typing import List, Optional, Dict -import re - -import sqlalchemy -from sqlalchemy.engine import CursorResult -from sqlalchemy.engine.interfaces import ReflectedColumn - -from databricks.sqlalchemy import _types as type_overrides - -""" -This module contains helper functions that can parse the contents -of metadata and exceptions received from DBR. These are mostly just -wrappers around regexes. -""" - - -class DatabricksSqlAlchemyParseException(Exception): - pass - - -def _match_table_not_found_string(message: str) -> bool: - """Return True if the message contains a substring indicating that a table was not found""" - - DBR_LTE_12_NOT_FOUND_STRING = "Table or view not found" - DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND" - return any( - [ - DBR_LTE_12_NOT_FOUND_STRING in message, - DBR_GT_12_NOT_FOUND_STRING in message, - ] - ) - - -def _describe_table_extended_result_to_dict_list( - result: CursorResult, -) -> List[Dict[str, str]]: - """Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries""" - - rows_to_return = [] - for row in result.all(): - this_row = {"col_name": row.col_name, "data_type": row.data_type} - rows_to_return.append(this_row) - - return rows_to_return - - -def extract_identifiers_from_string(input_str: str) -> List[str]: - """For a string input resembling (`a`, `b`, `c`) return a list of identifiers ['a', 'b', 'c']""" - - # This matches the valid character list contained in DatabricksIdentifierPreparer - pattern = re.compile(r"`([A-Za-z0-9_]+)`") - matches = pattern.findall(input_str) - return [i for i in matches] - - -def extract_identifier_groups_from_string(input_str: str) -> List[str]: - """For a string input resembling : - - FOREIGN KEY (`pname`, `pid`, `pattr`) REFERENCES `main`.`pysql_sqlalchemy`.`tb1` (`name`, `id`, `attr`) - - Return ['(`pname`, `pid`, `pattr`)', '(`name`, `id`, `attr`)'] - """ - pattern = re.compile(r"\([`A-Za-z0-9_,\s]*\)") - matches = pattern.findall(input_str) - return [i for i in matches] - - -def extract_three_level_identifier_from_constraint_string(input_str: str) -> dict: - """For a string input resembling : - FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`pysql_dialect_compliance`.`users` (`user_id`) - - Return a dict like - { - "catalog": "main", - "schema": "pysql_dialect_compliance", - "table": "users" - } - - Raise a DatabricksSqlAlchemyParseException if a 3L namespace isn't found - """ - pat = re.compile(r"REFERENCES\s+(.*?)\s*\(") - matches = pat.findall(input_str) - - if not matches: - raise DatabricksSqlAlchemyParseException( - "3L namespace not found in constraint string" - ) - - first_match = matches[0] - parts = first_match.split(".") - - def strip_backticks(input: str): - return input.replace("`", "") - - try: - return { - "catalog": strip_backticks(parts[0]), - "schema": strip_backticks(parts[1]), - "table": strip_backticks(parts[2]), - } - except IndexError: - raise DatabricksSqlAlchemyParseException( - "Incomplete 3L namespace found in constraint string: " + ".".join(parts) - ) - - -def _parse_fk_from_constraint_string(constraint_str: str) -> dict: - """Build a dictionary of foreign key constraint information from a constraint string. - - For example: - - ``` - FOREIGN KEY (`pname`, `pid`, `pattr`) REFERENCES `main`.`pysql_dialect_compliance`.`tb1` (`name`, `id`, `attr`) - ``` - - Return a dictionary like: - - ``` - { - "constrained_columns": ["pname", "pid", "pattr"], - "referred_table": "tb1", - "referred_schema": "pysql_dialect_compliance", - "referred_columns": ["name", "id", "attr"] - } - ``` - - Note that the constraint name doesn't appear in the constraint string so it will not - be present in the output of this function. - """ - - referred_table_dict = extract_three_level_identifier_from_constraint_string( - constraint_str - ) - referred_table = referred_table_dict["table"] - referred_schema = referred_table_dict["schema"] - - # _extracted is a tuple of two lists of identifiers - # we assume the first immediately follows "FOREIGN KEY" and the second - # immediately follows REFERENCES $tableName - _extracted = extract_identifier_groups_from_string(constraint_str) - constrained_columns_str, referred_columns_str = ( - _extracted[0], - _extracted[1], - ) - - constrained_columns = extract_identifiers_from_string(constrained_columns_str) - referred_columns = extract_identifiers_from_string(referred_columns_str) - - return { - "constrained_columns": constrained_columns, - "referred_table": referred_table, - "referred_columns": referred_columns, - "referred_schema": referred_schema, - } - - -def build_fk_dict( - fk_name: str, fk_constraint_string: str, schema_name: Optional[str] -) -> dict: - """ - Given a foriegn key name and a foreign key constraint string, return a dictionary - with the following keys: - - name - the name of the foreign key constraint - constrained_columns - a list of column names that make up the foreign key - referred_table - the name of the table that the foreign key references - referred_columns - a list of column names that are referenced by the foreign key - referred_schema - the name of the schema that the foreign key references. - - referred schema will be None if the schema_name argument is None. - This is required by SQLAlchey's ComponentReflectionTest::test_get_foreign_keys - """ - - # The foreign key name is not contained in the constraint string so we - # need to add it manually - base_fk_dict = _parse_fk_from_constraint_string(fk_constraint_string) - - if not schema_name: - schema_override_dict = dict(referred_schema=None) - else: - schema_override_dict = {} - - # mypy doesn't like this method of conditionally adding a key to a dictionary - # while keeping everything immutable - complete_foreign_key_dict = { - "name": fk_name, - **base_fk_dict, - **schema_override_dict, # type: ignore - } - - return complete_foreign_key_dict - - -def _parse_pk_columns_from_constraint_string(constraint_str: str) -> List[str]: - """Build a list of constrained columns from a constraint string returned by DESCRIBE TABLE EXTENDED - - For example: - - PRIMARY KEY (`id`, `name`, `email_address`) - - Returns a list like - - ["id", "name", "email_address"] - """ - - _extracted = extract_identifiers_from_string(constraint_str) - - return _extracted - - -def build_pk_dict(pk_name: str, pk_constraint_string: str) -> dict: - """Given a primary key name and a primary key constraint string, return a dictionary - with the following keys: - - constrained_columns - A list of string column names that make up the primary key - - name - The name of the primary key constraint - """ - - constrained_columns = _parse_pk_columns_from_constraint_string(pk_constraint_string) - - return {"constrained_columns": constrained_columns, "name": pk_name} - - -def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> List[dict]: - """Return a list of dictionaries containing only the col_name:data_type pairs where the `data_type` - value contains the match argument. - - Today, DESCRIBE TABLE EXTENDED doesn't give a deterministic name to the fields - a constraint will be found in its output. So we cycle through its output looking - for a match. This is brittle. We could optionally make two roundtrips: the first - would query information_schema for the name of the constraint on this table, and - a second to DESCRIBE TABLE EXTENDED, at which point we would know the name of the - constraint. But for now we instead assume that Python list comprehension is faster - than a network roundtrip - """ - - output_rows = [] - - for row_dict in dte_output: - if match in row_dict["data_type"]: - output_rows.append(row_dict) - - return output_rows - - -def match_dte_rows_by_key(dte_output: List[Dict[str, str]], match: str) -> List[dict]: - """Return a list of dictionaries containing only the col_name:data_type pairs where the `col_name` - value contains the match argument. - """ - - output_rows = [] - - for row_dict in dte_output: - if match in row_dict["col_name"]: - output_rows.append(row_dict) - - return output_rows - - -def get_fk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> List[dict]: - """If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries, - one dictionary per defined constraint - """ - - output = match_dte_rows_by_value(dte_output, "FOREIGN KEY") - - return output - - -def get_pk_strings_from_dte_output( - dte_output: List[Dict[str, str]] -) -> Optional[List[dict]]: - """If the DESCRIBE TABLE EXTENDED output contains primary key constraints, return a list of dictionaries, - one dictionary per defined constraint. - - Returns None if no primary key constraints are found. - """ - - output = match_dte_rows_by_value(dte_output, "PRIMARY KEY") - - return output - - -def get_comment_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[str]: - """Returns the value of the first "Comment" col_name data in dte_output""" - output = match_dte_rows_by_key(dte_output, "Comment") - if not output: - return None - else: - return output[0]["data_type"] - - -# The keys of this dictionary are the values we expect to see in a -# TGetColumnsRequest's .TYPE_NAME attribute. -# These are enumerated in ttypes.py as class TTypeId. -# TODO: confirm that all types in TTypeId are included here. -GET_COLUMNS_TYPE_MAP = { - "boolean": sqlalchemy.types.Boolean, - "smallint": sqlalchemy.types.SmallInteger, - "tinyint": type_overrides.TINYINT, - "int": sqlalchemy.types.Integer, - "bigint": sqlalchemy.types.BigInteger, - "float": sqlalchemy.types.Float, - "double": sqlalchemy.types.Float, - "string": sqlalchemy.types.String, - "varchar": sqlalchemy.types.String, - "char": sqlalchemy.types.String, - "binary": sqlalchemy.types.String, - "array": sqlalchemy.types.String, - "map": sqlalchemy.types.String, - "struct": sqlalchemy.types.String, - "uniontype": sqlalchemy.types.String, - "decimal": sqlalchemy.types.Numeric, - "timestamp": type_overrides.TIMESTAMP, - "timestamp_ntz": type_overrides.TIMESTAMP_NTZ, - "date": sqlalchemy.types.Date, -} - - -def parse_numeric_type_precision_and_scale(type_name_str): - """Return an intantiated sqlalchemy Numeric() type that preserves the precision and scale indicated - in the output from TGetColumnsRequest. - - type_name_str - The value of TGetColumnsReq.TYPE_NAME. - - If type_name_str is "DECIMAL(18,5) returns sqlalchemy.types.Numeric(18,5) - """ - - pattern = re.compile(r"DECIMAL\((\d+,\d+)\)") - match = re.search(pattern, type_name_str) - precision_and_scale = match.group(1) - precision, scale = tuple(precision_and_scale.split(",")) - - return sqlalchemy.types.Numeric(int(precision), int(scale)) - - -def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColumn: - """Returns a dictionary of the ReflectedColumn schema parsed from - a single of the result of a TGetColumnsRequest thrift RPC - """ - - pat = re.compile(r"^\w+") - - # This method assumes a valid TYPE_NAME field in the response. - # TODO: add error handling in case TGetColumnsResponse format changes - - _raw_col_type = re.search(pat, thrift_resp_row.TYPE_NAME).group(0).lower() # type: ignore - _col_type = GET_COLUMNS_TYPE_MAP[_raw_col_type] - - if _raw_col_type == "decimal": - final_col_type = parse_numeric_type_precision_and_scale( - thrift_resp_row.TYPE_NAME - ) - else: - final_col_type = _col_type - - # See comments about autoincrement in test_suite.py - # Since Databricks SQL doesn't currently support inline AUTOINCREMENT declarations - # the autoincrement must be manually declared with an Identity() construct in SQLAlchemy - # Other dialects can perform this extra Identity() step automatically. But that is not - # implemented in the Databricks dialect right now. So autoincrement is currently always False. - # It's not clear what IS_AUTO_INCREMENT in the thrift response actually reflects or whether - # it ever returns a `YES`. - - # Per the guidance in SQLAlchemy's docstrings, we prefer to not even include an autoincrement - # key in this dictionary. - this_column = { - "name": thrift_resp_row.COLUMN_NAME, - "type": final_col_type, - "nullable": bool(thrift_resp_row.NULLABLE), - "default": thrift_resp_row.COLUMN_DEF, - "comment": thrift_resp_row.REMARKS or None, - } - - # TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects - return this_column # type: ignore diff --git a/src/databricks/sqlalchemy/_types.py b/src/databricks/sqlalchemy/_types.py deleted file mode 100644 index 5fc14a70..00000000 --- a/src/databricks/sqlalchemy/_types.py +++ /dev/null @@ -1,323 +0,0 @@ -from datetime import datetime, time, timezone -from itertools import product -from typing import Any, Union, Optional - -import sqlalchemy -from sqlalchemy.engine.interfaces import Dialect -from sqlalchemy.ext.compiler import compiles - -from databricks.sql.utils import ParamEscaper - - -def process_literal_param_hack(value: Any): - """This method is supposed to accept a Python type and return a string representation of that type. - But due to some weirdness in the way SQLAlchemy's literal rendering works, we have to return - the value itself because, by the time it reaches our custom type code, it's already been converted - into a string. - - TimeTest - DateTimeTest - DateTimeTZTest - - This dynamic only seems to affect the literal rendering of datetime and time objects. - - All fail without this hack in-place. I'm not sure why. But it works. - """ - return value - - -@compiles(sqlalchemy.types.Enum, "databricks") -@compiles(sqlalchemy.types.String, "databricks") -@compiles(sqlalchemy.types.Text, "databricks") -@compiles(sqlalchemy.types.Time, "databricks") -@compiles(sqlalchemy.types.Unicode, "databricks") -@compiles(sqlalchemy.types.UnicodeText, "databricks") -@compiles(sqlalchemy.types.Uuid, "databricks") -def compile_string_databricks(type_, compiler, **kw): - """ - We override the default compilation for Enum(), String(), Text(), and Time() because SQLAlchemy - defaults to incompatible / abnormal compiled names - - Enum -> VARCHAR - String -> VARCHAR[LENGTH] - Text -> VARCHAR[LENGTH] - Time -> TIME - Unicode -> VARCHAR[LENGTH] - UnicodeText -> TEXT - Uuid -> CHAR[32] - - But all of these types will be compiled to STRING in Databricks SQL - """ - return "STRING" - - -@compiles(sqlalchemy.types.Integer, "databricks") -def compile_integer_databricks(type_, compiler, **kw): - """ - We need to override the default Integer compilation rendering because Databricks uses "INT" instead of "INTEGER" - """ - return "INT" - - -@compiles(sqlalchemy.types.LargeBinary, "databricks") -def compile_binary_databricks(type_, compiler, **kw): - """ - We need to override the default LargeBinary compilation rendering because Databricks uses "BINARY" instead of "BLOB" - """ - return "BINARY" - - -@compiles(sqlalchemy.types.Numeric, "databricks") -def compile_numeric_databricks(type_, compiler, **kw): - """ - We need to override the default Numeric compilation rendering because Databricks uses "DECIMAL" instead of "NUMERIC" - - The built-in visit_DECIMAL behaviour captures the precision and scale. Here we're just mapping calls to compile Numeric - to the SQLAlchemy Decimal() implementation - """ - return compiler.visit_DECIMAL(type_, **kw) - - -@compiles(sqlalchemy.types.DateTime, "databricks") -def compile_datetime_databricks(type_, compiler, **kw): - """ - We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP_NTZ" instead of "DATETIME" - """ - return "TIMESTAMP_NTZ" - - -@compiles(sqlalchemy.types.ARRAY, "databricks") -def compile_array_databricks(type_, compiler, **kw): - """ - SQLAlchemy's default ARRAY can't compile as it's only implemented for Postgresql. - The Postgres implementation works for Databricks SQL, so we duplicate that here. - - :type_: - This is an instance of sqlalchemy.types.ARRAY which always includes an item_type attribute - which is itself an instance of TypeEngine - - https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.ARRAY - """ - - inner = compiler.process(type_.item_type, **kw) - - return f"ARRAY<{inner}>" - - -class TIMESTAMP_NTZ(sqlalchemy.types.TypeDecorator): - """Represents values comprising values of fields year, month, day, hour, minute, and second. - All operations are performed without taking any time zone into account. - - Our dialect maps sqlalchemy.types.DateTime() to this type, which means that all DateTime() - objects are stored without tzinfo. To read and write timezone-aware datetimes use - databricks.sql.TIMESTAMP instead. - - https://docs.databricks.com/en/sql/language-manual/data-types/timestamp-ntz-type.html - """ - - impl = sqlalchemy.types.DateTime - - cache_ok = True - - def process_result_value(self, value: Union[None, datetime], dialect): - if value is None: - return None - return value.replace(tzinfo=None) - - -class TIMESTAMP(sqlalchemy.types.TypeDecorator): - """Represents values comprising values of fields year, month, day, hour, minute, and second, - with the session local time-zone. - - Our dialect maps sqlalchemy.types.DateTime() to TIMESTAMP_NTZ, which means that all DateTime() - objects are stored without tzinfo. To read and write timezone-aware datetimes use - this type instead. - - ```python - # This won't work - `Column(sqlalchemy.DateTime(timezone=True))` - - # But this does - `Column(TIMESTAMP)` - ```` - - https://docs.databricks.com/en/sql/language-manual/data-types/timestamp-type.html - """ - - impl = sqlalchemy.types.DateTime - - cache_ok = True - - def process_result_value(self, value: Union[None, datetime], dialect): - if value is None: - return None - - if not value.tzinfo: - return value.replace(tzinfo=timezone.utc) - return value - - def process_bind_param( - self, value: Union[datetime, None], dialect - ) -> Optional[datetime]: - """pysql can pass datetime.datetime() objects directly to DBR""" - return value - - def process_literal_param( - self, value: Union[datetime, None], dialect: Dialect - ) -> str: - """ """ - return process_literal_param_hack(value) - - -@compiles(TIMESTAMP, "databricks") -def compile_timestamp_databricks(type_, compiler, **kw): - """ - We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP_NTZ" instead of "DATETIME" - """ - return "TIMESTAMP" - - -class DatabricksTimeType(sqlalchemy.types.TypeDecorator): - """Databricks has no native TIME type. So we store it as a string.""" - - impl = sqlalchemy.types.Time - cache_ok = True - - BASE_FMT = "%H:%M:%S" - MICROSEC_PART = ".%f" - TIMEZONE_PART = "%z" - - def _generate_fmt_string(self, ms: bool, tz: bool) -> str: - """Return a format string for datetime.strptime() that includes or excludes microseconds and timezone.""" - _ = lambda x, y: x if y else "" - return f"{self.BASE_FMT}{_(self.MICROSEC_PART,ms)}{_(self.TIMEZONE_PART,tz)}" - - @property - def allowed_fmt_strings(self): - """Time strings can be read with or without microseconds and with or without a timezone.""" - - if not hasattr(self, "_allowed_fmt_strings"): - ms_switch = tz_switch = [True, False] - self._allowed_fmt_strings = [ - self._generate_fmt_string(x, y) - for x, y in product(ms_switch, tz_switch) - ] - - return self._allowed_fmt_strings - - def _parse_result_string(self, value: str) -> time: - """Parse a string into a time object. Try all allowed formats until one works.""" - for fmt in self.allowed_fmt_strings: - try: - # We use timetz() here because we want to preserve the timezone information - # Calling .time() will strip the timezone information - return datetime.strptime(value, fmt).timetz() - except ValueError: - pass - - raise ValueError(f"Could not parse time string {value}") - - def _determine_fmt_string(self, value: time) -> str: - """Determine which format string to use to render a time object as a string.""" - ms_bool = value.microsecond > 0 - tz_bool = value.tzinfo is not None - return self._generate_fmt_string(ms_bool, tz_bool) - - def process_bind_param(self, value: Union[time, None], dialect) -> Union[None, str]: - """Values sent to the database are converted to %:H:%M:%S strings.""" - if value is None: - return None - fmt_string = self._determine_fmt_string(value) - return value.strftime(fmt_string) - - # mypy doesn't like this workaround because TypeEngine wants process_literal_param to return a string - def process_literal_param(self, value, dialect) -> time: # type: ignore - """ """ - return process_literal_param_hack(value) - - def process_result_value( - self, value: Union[None, str], dialect - ) -> Union[time, None]: - """Values received from the database are parsed into datetime.time() objects""" - if value is None: - return None - - return self._parse_result_string(value) - - -class DatabricksStringType(sqlalchemy.types.TypeDecorator): - """We have to implement our own String() type because SQLAlchemy's default implementation - wants to escape single-quotes with a doubled single-quote. Databricks uses a backslash for - escaping of literal strings. And SQLAlchemy's default escaping breaks Databricks SQL. - """ - - impl = sqlalchemy.types.String - cache_ok = True - pe = ParamEscaper() - - def process_literal_param(self, value, dialect) -> str: - """SQLAlchemy's default string escaping for backslashes doesn't work for databricks. The logic here - implements the same logic as our legacy inline escaping logic. - """ - - return self.pe.escape_string(value) - - def literal_processor(self, dialect): - """We manually override this method to prevent further processing of the string literal beyond - what happens in the process_literal_param() method. - - The SQLAlchemy docs _specifically_ say to not override this method. - - It appears that any processing that happens from TypeEngine.process_literal_param happens _before_ - and _in addition to_ whatever the class's impl.literal_processor() method does. The String.literal_processor() - method performs a string replacement that doubles any single-quote in the contained string. This raises a syntax - error in Databricks. And it's not necessary because ParamEscaper() already implements all the escaping we need. - - We should consider opening an issue on the SQLAlchemy project to see if I'm using it wrong. - - See type_api.py::TypeEngine.literal_processor: - - ```python - def process(value: Any) -> str: - return fixed_impl_processor( - fixed_process_literal_param(value, dialect) - ) - ``` - - That call to fixed_impl_processor wraps the result of fixed_process_literal_param (which is the - process_literal_param defined in our Databricks dialect) - - https://docs.sqlalchemy.org/en/20/core/custom_types.html#sqlalchemy.types.TypeDecorator.literal_processor - """ - - def process(value): - """This is a copy of the default String.literal_processor() method but stripping away - its double-escaping behaviour for single-quotes. - """ - - _step1 = self.process_literal_param(value, dialect="databricks") - if dialect.identifier_preparer._double_percents: - _step2 = _step1.replace("%", "%%") - else: - _step2 = _step1 - - return "%s" % _step2 - - return process - - -class TINYINT(sqlalchemy.types.TypeDecorator): - """Represents 1-byte signed integers - - Acts like a sqlalchemy SmallInteger() in Python but writes to a TINYINT field in Databricks - - https://docs.databricks.com/en/sql/language-manual/data-types/tinyint-type.html - """ - - impl = sqlalchemy.types.SmallInteger - cache_ok = True - - -@compiles(TINYINT, "databricks") -def compile_tinyint(type_, compiler, **kw): - return "TINYINT" diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py deleted file mode 100644 index 9148de7f..00000000 --- a/src/databricks/sqlalchemy/base.py +++ /dev/null @@ -1,436 +0,0 @@ -from typing import Any, List, Optional, Dict, Union - -import databricks.sqlalchemy._ddl as dialect_ddl_impl -import databricks.sqlalchemy._types as dialect_type_impl -from databricks import sql -from databricks.sqlalchemy._parse import ( - _describe_table_extended_result_to_dict_list, - _match_table_not_found_string, - build_fk_dict, - build_pk_dict, - get_fk_strings_from_dte_output, - get_pk_strings_from_dte_output, - get_comment_from_dte_output, - parse_column_info_from_tgetcolumnsresponse, -) - -import sqlalchemy -from sqlalchemy import DDL, event -from sqlalchemy.engine import Connection, Engine, default, reflection -from sqlalchemy.engine.interfaces import ( - ReflectedForeignKeyConstraint, - ReflectedPrimaryKeyConstraint, - ReflectedColumn, - ReflectedTableComment, -) -from sqlalchemy.engine.reflection import ReflectionDefaults -from sqlalchemy.exc import DatabaseError, SQLAlchemyError - -try: - import alembic -except ImportError: - pass -else: - from alembic.ddl import DefaultImpl - - class DatabricksImpl(DefaultImpl): - __dialect__ = "databricks" - - -import logging - -logger = logging.getLogger(__name__) - - -class DatabricksDialect(default.DefaultDialect): - """This dialect implements only those methods required to pass our e2e tests""" - - # See sqlalchemy.engine.interfaces for descriptions of each of these properties - name: str = "databricks" - driver: str = "databricks" - default_schema_name: str = "default" - preparer = dialect_ddl_impl.DatabricksIdentifierPreparer # type: ignore - ddl_compiler = dialect_ddl_impl.DatabricksDDLCompiler - statement_compiler = dialect_ddl_impl.DatabricksStatementCompiler - supports_statement_cache: bool = True - supports_multivalues_insert: bool = True - supports_native_decimal: bool = True - supports_sane_rowcount: bool = False - non_native_boolean_check_constraint: bool = False - supports_identity_columns: bool = True - supports_schemas: bool = True - default_paramstyle: str = "named" - div_is_floordiv: bool = False - supports_default_values: bool = False - supports_server_side_cursors: bool = False - supports_sequences: bool = False - supports_native_boolean: bool = True - - colspecs = { - sqlalchemy.types.DateTime: dialect_type_impl.TIMESTAMP_NTZ, - sqlalchemy.types.Time: dialect_type_impl.DatabricksTimeType, - sqlalchemy.types.String: dialect_type_impl.DatabricksStringType, - } - - # SQLAlchemy requires that a table with no primary key - # constraint return a dictionary that looks like this. - EMPTY_PK: Dict[str, Any] = {"constrained_columns": [], "name": None} - - # SQLAlchemy requires that a table with no foreign keys - # defined return an empty list. Same for indexes. - EMPTY_FK: List - EMPTY_INDEX: List - EMPTY_FK = EMPTY_INDEX = [] - - @classmethod - def import_dbapi(cls): - return sql - - def _force_paramstyle_to_native_mode(self): - """This method can be removed after databricks-sql-connector wholly switches to NATIVE ParamApproach. - - This is a hack to trick SQLAlchemy into using a different paramstyle - than the one declared by this module in src/databricks/sql/__init__.py - - This method is called _after_ the dialect has been initialised, which is important because otherwise - our users would need to include a `paramstyle` argument in their SQLAlchemy connection string. - - This dialect is written to support NATIVE queries. Although the INLINE approach can technically work, - the same behaviour can be achieved within SQLAlchemy itself using its literal_processor methods. - """ - - self.paramstyle = self.default_paramstyle - - def create_connect_args(self, url): - # TODO: can schema be provided after HOST? - # Expected URI format is: databricks+thrift://token:dapi***@***.cloud.databricks.com?http_path=/sql/*** - - kwargs = { - "server_hostname": url.host, - "access_token": url.password, - "http_path": url.query.get("http_path"), - "catalog": url.query.get("catalog"), - "schema": url.query.get("schema"), - "use_inline_params": False, - } - - self.schema = kwargs["schema"] - self.catalog = kwargs["catalog"] - - self._force_paramstyle_to_native_mode() - - return [], kwargs - - def get_columns( - self, connection, table_name, schema=None, **kwargs - ) -> List[ReflectedColumn]: - """Return information about columns in `table_name`.""" - - with self.get_connection_cursor(connection) as cur: - resp = cur.columns( - catalog_name=self.catalog, - schema_name=schema or self.schema, - table_name=table_name, - ).fetchall() - - if not resp: - # TGetColumnsRequest will not raise an exception if passed a table that doesn't exist - # But Databricks supports tables with no columns. So if the result is an empty list, - # we need to check if the table exists (and raise an exception if not) or simply return - # an empty list. - self._describe_table_extended( - connection, - table_name, - self.catalog, - schema or self.schema, - expect_result=False, - ) - return resp - columns = [] - for col in resp: - row_dict = parse_column_info_from_tgetcolumnsresponse(col) - columns.append(row_dict) - - return columns - - def _describe_table_extended( - self, - connection: Connection, - table_name: str, - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, - expect_result=True, - ) -> Union[List[Dict[str, str]], None]: - """Run DESCRIBE TABLE EXTENDED on a table and return a list of dictionaries of the result. - - This method is the fastest way to check for the presence of a table in a schema. - - If expect_result is False, this method returns None as the output dict isn't required. - - Raises NoSuchTableError if the table is not present in the schema. - """ - - _target_catalog = catalog_name or self.catalog - _target_schema = schema_name or self.schema - _target = f"`{_target_catalog}`.`{_target_schema}`.`{table_name}`" - - # sql injection risk? - # DESCRIBE TABLE EXTENDED in DBR doesn't support parameterised inputs :( - stmt = DDL(f"DESCRIBE TABLE EXTENDED {_target}") - - try: - result = connection.execute(stmt) - except DatabaseError as e: - if _match_table_not_found_string(str(e)): - raise sqlalchemy.exc.NoSuchTableError( - f"No such table {table_name}" - ) from e - raise e - - if not expect_result: - return None - - fmt_result = _describe_table_extended_result_to_dict_list(result) - return fmt_result - - @reflection.cache - def get_pk_constraint( - self, - connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> ReflectedPrimaryKeyConstraint: - """Fetch information about the primary key constraint on table_name. - - Returns a dictionary with these keys: - constrained_columns - a list of column names that make up the primary key. Results is an empty list - if no PRIMARY KEY is defined. - - name - the name of the primary key constraint - """ - - result = self._describe_table_extended( - connection=connection, - table_name=table_name, - schema_name=schema, - ) - - # Type ignore is because mypy knows that self._describe_table_extended *can* - # return None (even though it never will since expect_result defaults to True) - raw_pk_constraints: List = get_pk_strings_from_dte_output(result) # type: ignore - if not any(raw_pk_constraints): - return self.EMPTY_PK # type: ignore - - if len(raw_pk_constraints) > 1: - logger.warning( - "Found more than one primary key constraint in DESCRIBE TABLE EXTENDED output. " - "This is unexpected. Please report this as a bug. " - "Only the first primary key constraint will be returned." - ) - - first_pk_constraint = raw_pk_constraints[0] - pk_name = first_pk_constraint.get("col_name") - pk_constraint_string = first_pk_constraint.get("data_type") - - # TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects - return build_pk_dict(pk_name, pk_constraint_string) # type: ignore - - def get_foreign_keys( - self, connection, table_name, schema=None, **kw - ) -> List[ReflectedForeignKeyConstraint]: - """Return information about foreign_keys in `table_name`.""" - - result = self._describe_table_extended( - connection=connection, - table_name=table_name, - schema_name=schema, - ) - - # Type ignore is because mypy knows that self._describe_table_extended *can* - # return None (even though it never will since expect_result defaults to True) - raw_fk_constraints: List = get_fk_strings_from_dte_output(result) # type: ignore - - if not any(raw_fk_constraints): - return self.EMPTY_FK - - fk_constraints = [] - for constraint_dict in raw_fk_constraints: - fk_name = constraint_dict.get("col_name") - fk_constraint_string = constraint_dict.get("data_type") - this_constraint_dict = build_fk_dict( - fk_name, fk_constraint_string, schema_name=schema - ) - fk_constraints.append(this_constraint_dict) - - # TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects - return fk_constraints # type: ignore - - def get_indexes(self, connection, table_name, schema=None, **kw): - """SQLAlchemy requires this method. Databricks doesn't support indexes.""" - return self.EMPTY_INDEX - - @reflection.cache - def get_table_names(self, connection: Connection, schema=None, **kwargs): - """Return a list of tables in the current schema.""" - - _target_catalog = self.catalog - _target_schema = schema or self.schema - _target = f"`{_target_catalog}`.`{_target_schema}`" - - stmt = DDL(f"SHOW TABLES FROM {_target}") - - tables_result = connection.execute(stmt).all() - views_result = self.get_view_names(connection=connection, schema=schema) - - # In Databricks, SHOW TABLES FROM returns both tables and views. - # Potential optimisation: rewrite this to instead query information_schema - tables_minus_views = [ - row.tableName for row in tables_result if row.tableName not in views_result - ] - - return tables_minus_views - - @reflection.cache - def get_view_names( - self, - connection, - schema=None, - only_materialized=False, - only_temp=False, - **kwargs, - ) -> List[str]: - """Returns a list of string view names contained in the schema, if any.""" - - _target_catalog = self.catalog - _target_schema = schema or self.schema - _target = f"`{_target_catalog}`.`{_target_schema}`" - - stmt = DDL(f"SHOW VIEWS FROM {_target}") - result = connection.execute(stmt).all() - - return [ - row.viewName - for row in result - if (not only_materialized or row.isMaterialized) - and (not only_temp or row.isTemporary) - ] - - @reflection.cache - def get_materialized_view_names( - self, connection: Connection, schema: Optional[str] = None, **kw: Any - ) -> List[str]: - """A wrapper around get_view_names that fetches only the names of materialized views""" - return self.get_view_names(connection, schema, only_materialized=True) - - @reflection.cache - def get_temp_view_names( - self, connection: Connection, schema: Optional[str] = None, **kw: Any - ) -> List[str]: - """A wrapper around get_view_names that fetches only the names of temporary views""" - return self.get_view_names(connection, schema, only_temp=True) - - def do_rollback(self, dbapi_connection): - # Databricks SQL Does not support transactions - pass - - @reflection.cache - def has_table( - self, connection, table_name, schema=None, catalog=None, **kwargs - ) -> bool: - """For internal dialect use, check the existence of a particular table - or view in the database. - """ - - try: - self._describe_table_extended( - connection=connection, - table_name=table_name, - catalog_name=catalog, - schema_name=schema, - ) - return True - except sqlalchemy.exc.NoSuchTableError as e: - return False - - def get_connection_cursor(self, connection): - """Added for backwards compatibility with 1.3.x""" - if hasattr(connection, "_dbapi_connection"): - return connection._dbapi_connection.dbapi_connection.cursor() - elif hasattr(connection, "raw_connection"): - return connection.raw_connection().cursor() - elif hasattr(connection, "connection"): - return connection.connection.cursor() - - raise SQLAlchemyError( - "Databricks dialect can't obtain a cursor context manager from the dbapi" - ) - - @reflection.cache - def get_schema_names(self, connection, **kw): - """Return a list of all schema names available in the database.""" - stmt = DDL("SHOW SCHEMAS") - result = connection.execute(stmt) - schema_list = [row[0] for row in result] - return schema_list - - @reflection.cache - def get_table_comment( - self, - connection: Connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> ReflectedTableComment: - result = self._describe_table_extended( - connection=connection, - table_name=table_name, - schema_name=schema, - ) - - if result is None: - return ReflectionDefaults.table_comment() - - comment = get_comment_from_dte_output(result) - - if comment: - return dict(text=comment) - else: - return ReflectionDefaults.table_comment() - - -@event.listens_for(Engine, "do_connect") -def receive_do_connect(dialect, conn_rec, cargs, cparams): - """Helpful for DS on traffic from clients using SQLAlchemy in particular""" - - # Ignore connect invocations that don't use our dialect - if not dialect.name == "databricks": - return - - ua = cparams.get("_user_agent_entry", "") - - def add_sqla_tag_if_not_present(val: str): - if not val: - output = "sqlalchemy" - - if val and "sqlalchemy" in val: - output = val - - else: - output = f"sqlalchemy + {val}" - - return output - - cparams["_user_agent_entry"] = add_sqla_tag_if_not_present(ua) - - if sqlalchemy.__version__.startswith("1.3"): - # SQLAlchemy 1.3.x fails to parse the http_path, catalog, and schema from our connection string - # These should be passed in as connect_args when building the Engine - - if "schema" in cparams: - dialect.schema = cparams["schema"] - - if "catalog" in cparams: - dialect.catalog = cparams["catalog"] diff --git a/src/databricks/sqlalchemy/requirements.py b/src/databricks/sqlalchemy/requirements.py deleted file mode 100644 index 5c70c029..00000000 --- a/src/databricks/sqlalchemy/requirements.py +++ /dev/null @@ -1,249 +0,0 @@ -""" -The complete list of requirements is provided by SQLAlchemy here: - -https://github.com/sqlalchemy/sqlalchemy/blob/main/lib/sqlalchemy/testing/requirements.py - -When SQLAlchemy skips a test because a requirement is closed() it gives a generic skip message. -To make these failures more actionable, we only define requirements in this file that we wish to -force to be open(). If a test should be skipped on Databricks, it will be specifically marked skip -in test_suite.py with a Databricks-specific reason. - -See the special note about the array_type exclusion below. -See special note about has_temp_table exclusion below. -""" - -import sqlalchemy.testing.requirements -import sqlalchemy.testing.exclusions - - -class Requirements(sqlalchemy.testing.requirements.SuiteRequirements): - @property - def date_historic(self): - """target dialect supports representation of Python - datetime.datetime() objects with historic (pre 1970) values.""" - - return sqlalchemy.testing.exclusions.open() - - @property - def datetime_historic(self): - """target dialect supports representation of Python - datetime.datetime() objects with historic (pre 1970) values.""" - - return sqlalchemy.testing.exclusions.open() - - @property - def datetime_literals(self): - """target dialect supports rendering of a date, time, or datetime as a - literal string, e.g. via the TypeEngine.literal_processor() method. - - """ - - return sqlalchemy.testing.exclusions.open() - - @property - def timestamp_microseconds(self): - """target dialect supports representation of Python - datetime.datetime() with microsecond objects but only - if TIMESTAMP is used.""" - - return sqlalchemy.testing.exclusions.open() - - @property - def time_microseconds(self): - """target dialect supports representation of Python - datetime.time() with microsecond objects. - - This requirement declaration isn't needed but I've included it here for completeness. - Since Databricks doesn't have a TIME type, SQLAlchemy will compile Time() columns - as STRING Databricks data types. And we use a custom time type to render those strings - between str() and time.time() representations. Therefore we can store _any_ precision - that SQLAlchemy needs. The time_microseconds requirement defaults to ON for all dialects - except mssql, mysql, mariadb, and oracle. - """ - - return sqlalchemy.testing.exclusions.open() - - @property - def infinity_floats(self): - """The Float type can persist and load float('inf'), float('-inf').""" - - return sqlalchemy.testing.exclusions.open() - - @property - def precision_numerics_retains_significant_digits(self): - """A precision numeric type will return empty significant digits, - i.e. a value such as 10.000 will come back in Decimal form with - the .000 maintained.""" - - return sqlalchemy.testing.exclusions.open() - - @property - def precision_numerics_many_significant_digits(self): - """target backend supports values with many digits on both sides, - such as 319438950232418390.273596, 87673.594069654243 - - """ - return sqlalchemy.testing.exclusions.open() - - @property - def array_type(self): - """While Databricks does support ARRAY types, pysql cannot bind them. So - we cannot use them with SQLAlchemy - - Due to a bug in SQLAlchemy, we _must_ define this exclusion as closed() here or else the - test runner will crash the pytest process due to an AttributeError - """ - - # TODO: Implement array type using inline? - return sqlalchemy.testing.exclusions.closed() - - @property - def table_ddl_if_exists(self): - """target platform supports IF NOT EXISTS / IF EXISTS for tables.""" - - return sqlalchemy.testing.exclusions.open() - - @property - def identity_columns(self): - """If a backend supports GENERATED { ALWAYS | BY DEFAULT } - AS IDENTITY""" - return sqlalchemy.testing.exclusions.open() - - @property - def identity_columns_standard(self): - """If a backend supports GENERATED { ALWAYS | BY DEFAULT } - AS IDENTITY with a standard syntax. - This is mainly to exclude MSSql. - """ - return sqlalchemy.testing.exclusions.open() - - @property - def has_temp_table(self): - """target dialect supports checking a single temp table name - - unfortunately this is not the same as temp_table_names - - SQLAlchemy's HasTableTest is not normalised in such a way that temp table tests - are separate from temp view and normal table tests. If those tests were split out, - we would just add detailed skip markers in test_suite.py. But since we'd like to - run the HasTableTest group for the features we support, we must set this exclusinon - to closed(). - - It would be ideal if there were a separate requirement for has_temp_view. Without it, - we're in a bind. - """ - return sqlalchemy.testing.exclusions.closed() - - @property - def temporary_views(self): - """target database supports temporary views""" - return sqlalchemy.testing.exclusions.open() - - @property - def views(self): - """Target database must support VIEWs.""" - - return sqlalchemy.testing.exclusions.open() - - @property - def temporary_tables(self): - """target database supports temporary tables - - ComponentReflection test is intricate and simply cannot function without this exclusion being defined here. - This happens because we cannot skip individual combinations used in ComponentReflection test. - """ - return sqlalchemy.testing.exclusions.closed() - - @property - def table_reflection(self): - """target database has general support for table reflection""" - return sqlalchemy.testing.exclusions.open() - - @property - def comment_reflection(self): - """Indicates if the database support table comment reflection""" - return sqlalchemy.testing.exclusions.open() - - @property - def comment_reflection_full_unicode(self): - """Indicates if the database support table comment reflection in the - full unicode range, including emoji etc. - """ - return sqlalchemy.testing.exclusions.open() - - @property - def temp_table_reflection(self): - """ComponentReflection test is intricate and simply cannot function without this exclusion being defined here. - This happens because we cannot skip individual combinations used in ComponentReflection test. - """ - return sqlalchemy.testing.exclusions.closed() - - @property - def index_reflection(self): - """ComponentReflection test is intricate and simply cannot function without this exclusion being defined here. - This happens because we cannot skip individual combinations used in ComponentReflection test. - """ - return sqlalchemy.testing.exclusions.closed() - - @property - def unique_constraint_reflection(self): - """ComponentReflection test is intricate and simply cannot function without this exclusion being defined here. - This happens because we cannot skip individual combinations used in ComponentReflection test. - - Databricks doesn't support UNIQUE constraints. - """ - return sqlalchemy.testing.exclusions.closed() - - @property - def reflects_pk_names(self): - """Target driver reflects the name of primary key constraints.""" - - return sqlalchemy.testing.exclusions.open() - - @property - def datetime_implicit_bound(self): - """target dialect when given a datetime object will bind it such - that the database server knows the object is a date, and not - a plain string. - """ - - return sqlalchemy.testing.exclusions.open() - - @property - def tuple_in(self): - return sqlalchemy.testing.exclusions.open() - - @property - def ctes(self): - return sqlalchemy.testing.exclusions.open() - - @property - def ctes_with_update_delete(self): - return sqlalchemy.testing.exclusions.open() - - @property - def delete_from(self): - """Target must support DELETE FROM..FROM or DELETE..USING syntax""" - return sqlalchemy.testing.exclusions.open() - - @property - def table_value_constructor(self): - return sqlalchemy.testing.exclusions.open() - - @property - def reflect_tables_no_columns(self): - return sqlalchemy.testing.exclusions.open() - - @property - def denormalized_names(self): - """Target database must have 'denormalized', i.e. - UPPERCASE as case insensitive names.""" - - return sqlalchemy.testing.exclusions.open() - - @property - def time_timezone(self): - """target dialect supports representation of Python - datetime.time() with tzinfo with Time(timezone=True).""" - - return sqlalchemy.testing.exclusions.open() diff --git a/src/databricks/sqlalchemy/setup.cfg b/src/databricks/sqlalchemy/setup.cfg deleted file mode 100644 index ab89d17d..00000000 --- a/src/databricks/sqlalchemy/setup.cfg +++ /dev/null @@ -1,4 +0,0 @@ - -[sqla_testing] -requirement_cls=databricks.sqlalchemy.requirements:Requirements -profile_file=profiles.txt diff --git a/src/databricks/sqlalchemy/test/_extra.py b/src/databricks/sqlalchemy/test/_extra.py deleted file mode 100644 index 2f3e7a7d..00000000 --- a/src/databricks/sqlalchemy/test/_extra.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Additional tests authored by Databricks that use SQLAlchemy's test fixtures -""" - -import datetime - -from sqlalchemy.testing.suite.test_types import ( - _LiteralRoundTripFixture, - fixtures, - testing, - eq_, - select, - Table, - Column, - config, - _DateFixture, - literal, -) -from databricks.sqlalchemy import TINYINT, TIMESTAMP - - -class TinyIntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): - __backend__ = True - - def test_literal(self, literal_round_trip): - literal_round_trip(TINYINT, [5], [5]) - - @testing.fixture - def integer_round_trip(self, metadata, connection): - def run(datatype, data): - int_table = Table( - "tiny_integer_table", - metadata, - Column( - "id", - TINYINT, - primary_key=True, - test_needs_autoincrement=False, - ), - Column("tiny_integer_data", datatype), - ) - - metadata.create_all(config.db) - - connection.execute(int_table.insert(), {"id": 1, "integer_data": data}) - - row = connection.execute(select(int_table.c.integer_data)).first() - - eq_(row, (data,)) - - assert isinstance(row[0], int) - - return run - - -class DateTimeTZTestCustom(_DateFixture, fixtures.TablesTest): - """This test confirms that when a user uses the TIMESTAMP - type to store a datetime object, it retains its timezone - """ - - __backend__ = True - datatype = TIMESTAMP - data = datetime.datetime(2012, 10, 15, 12, 57, 18, tzinfo=datetime.timezone.utc) - - @testing.requires.datetime_implicit_bound - def test_select_direct(self, connection): - - # We need to pass the TIMESTAMP type to the literal function - # so that the value is processed correctly. - result = connection.scalar(select(literal(self.data, TIMESTAMP))) - eq_(result, self.data) diff --git a/src/databricks/sqlalchemy/test/_future.py b/src/databricks/sqlalchemy/test/_future.py deleted file mode 100644 index 6e470f60..00000000 --- a/src/databricks/sqlalchemy/test/_future.py +++ /dev/null @@ -1,331 +0,0 @@ -# type: ignore - -from enum import Enum - -import pytest -from databricks.sqlalchemy.test._regression import ( - ExpandingBoundInTest, - IdentityAutoincrementTest, - LikeFunctionsTest, - NormalizedNameTest, -) -from databricks.sqlalchemy.test._unsupported import ( - ComponentReflectionTest, - ComponentReflectionTestExtra, - CTETest, - InsertBehaviorTest, -) -from sqlalchemy.testing.suite import ( - ArrayTest, - BinaryTest, - BizarroCharacterFKResolutionTest, - CollateTest, - ComputedColumnTest, - ComputedReflectionTest, - DifficultParametersTest, - FutureWeCanSetDefaultSchemaWEventsTest, - IdentityColumnTest, - IdentityReflectionTest, - JSONLegacyStringCastIndexTest, - JSONTest, - NativeUUIDTest, - QuotedNameArgumentTest, - RowCountTest, - SimpleUpdateDeleteTest, - WeCanSetDefaultSchemaWEventsTest, -) - - -class FutureFeature(Enum): - ARRAY = "ARRAY column type handling" - BINARY = "BINARY column type handling" - CHECK = "CHECK constraint handling" - COLLATE = "COLLATE DDL generation" - CTE_FEAT = "required CTE features" - EMPTY_INSERT = "empty INSERT support" - FK_OPTS = "foreign key option checking" - GENERATED_COLUMNS = "Delta computed / generated columns support" - IDENTITY = "identity reflection" - JSON = "JSON column type handling" - MULTI_PK = "get_multi_pk_constraint method" - PROVISION = "event-driven engine configuration" - REGEXP = "_visit_regexp" - SANE_ROWCOUNT = "sane_rowcount support" - TBL_OPTS = "get_table_options method" - TEST_DESIGN = "required test-fixture overrides" - TUPLE_LITERAL = "tuple-like IN markers completely" - UUID = "native Uuid() type" - VIEW_DEF = "get_view_definition method" - - -def render_future_feature(rsn: FutureFeature, extra=False) -> str: - postfix = " More detail in _future.py" if extra else "" - return f"[FUTURE][{rsn.name}]: This dialect doesn't implement {rsn.value}.{postfix}" - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.BINARY)) -class BinaryTest(BinaryTest): - """Databricks doesn't support binding of BINARY type values. When DBR supports this, we can implement - in this dialect. - """ - - pass - - -class ExpandingBoundInTest(ExpandingBoundInTest): - @pytest.mark.skip(render_future_feature(FutureFeature.TUPLE_LITERAL)) - def test_empty_heterogeneous_tuples_bindparam(self): - pass - - @pytest.mark.skip(render_future_feature(FutureFeature.TUPLE_LITERAL)) - def test_empty_heterogeneous_tuples_direct(self): - pass - - @pytest.mark.skip(render_future_feature(FutureFeature.TUPLE_LITERAL)) - def test_empty_homogeneous_tuples_bindparam(self): - pass - - @pytest.mark.skip(render_future_feature(FutureFeature.TUPLE_LITERAL)) - def test_empty_homogeneous_tuples_direct(self): - pass - - -class NormalizedNameTest(NormalizedNameTest): - @pytest.mark.skip(render_future_feature(FutureFeature.TEST_DESIGN, True)) - def test_get_table_names(self): - """I'm not clear how this test can ever pass given that it's assertion looks like this: - - ```python - eq_(tablenames[0].upper(), tablenames[0].lower()) - eq_(tablenames[1].upper(), tablenames[1].lower()) - ``` - - It's forcibly calling .upper() and .lower() on the same string and expecting them to be equal. - """ - pass - - -class CTETest(CTETest): - @pytest.mark.skip(render_future_feature(FutureFeature.CTE_FEAT, True)) - def test_delete_from_round_trip(self): - """Databricks dialect doesn't implement multiple-table criteria within DELETE""" - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.TEST_DESIGN, True)) -class IdentityColumnTest(IdentityColumnTest): - """Identity works. Test needs rewrite for Databricks. See comments in test_suite.py - - The setup for these tests tries to create a table with a DELTA IDENTITY column but has two problems: - 1. It uses an Integer() type for the column. Whereas DELTA IDENTITY columns must be BIGINT. - 2. It tries to set the start == 42, which Databricks doesn't support - - I can get the tests to _run_ by patching the table fixture to use BigInteger(). But it asserts that the - identity of two rows are 42 and 43, which is not possible since they will be rows 1 and 2 instead. - - I'm satisified through manual testing that our implementation of visit_identity_column works but a better test is needed. - """ - - pass - - -class IdentityAutoincrementTest(IdentityAutoincrementTest): - @pytest.mark.skip(render_future_feature(FutureFeature.TEST_DESIGN, True)) - def test_autoincrement_with_identity(self): - """This test has the same issue as IdentityColumnTest.test_select_all in that it creates a table with identity - using an Integer() rather than a BigInteger(). If I override this behaviour to use a BigInteger() instead, the - test passes. - """ - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.TEST_DESIGN)) -class BizarroCharacterFKResolutionTest(BizarroCharacterFKResolutionTest): - """Some of the combinations in this test pass. Others fail. Given the esoteric nature of these failures, - we have opted to defer implementing fixes to a later time, guided by customer feedback. Passage of - these tests is not an acceptance criteria for our dialect. - """ - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.TEST_DESIGN)) -class DifficultParametersTest(DifficultParametersTest): - """Some of the combinations in this test pass. Others fail. Given the esoteric nature of these failures, - we have opted to defer implementing fixes to a later time, guided by customer feedback. Passage of - these tests is not an acceptance criteria for our dialect. - """ - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.IDENTITY, True)) -class IdentityReflectionTest(IdentityReflectionTest): - """It's not clear _how_ to implement this for SQLAlchemy. Columns created with GENERATED ALWAYS AS IDENTITY - are not specially demarked in the output of TGetColumnsResponse or DESCRIBE TABLE EXTENDED. - - We could theoretically parse this from the contents of `SHOW CREATE TABLE` but that feels like a hack. - """ - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.JSON)) -class JSONTest(JSONTest): - """Databricks supports JSON path expressions in queries it's just not implemented in this dialect.""" - - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.JSON)) -class JSONLegacyStringCastIndexTest(JSONLegacyStringCastIndexTest): - """Same comment applies as JSONTest""" - - pass - - -class LikeFunctionsTest(LikeFunctionsTest): - @pytest.mark.skip(render_future_feature(FutureFeature.REGEXP)) - def test_not_regexp_match(self): - """The defaul dialect doesn't implement _visit_regexp methods so we don't get them automatically.""" - pass - - @pytest.mark.skip(render_future_feature(FutureFeature.REGEXP)) - def test_regexp_match(self): - """The defaul dialect doesn't implement _visit_regexp methods so we don't get them automatically.""" - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.COLLATE)) -class CollateTest(CollateTest): - """This is supported in Databricks. Not implemented here.""" - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.UUID, True)) -class NativeUUIDTest(NativeUUIDTest): - """Type implementation will be straightforward. Since Databricks doesn't have a native UUID type we can use - a STRING field, create a custom TypeDecorator for sqlalchemy.types.Uuid and add it to the dialect's colspecs. - - Then mark requirements.uuid_data_type as open() so this test can run. - """ - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.SANE_ROWCOUNT)) -class RowCountTest(RowCountTest): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.SANE_ROWCOUNT)) -class SimpleUpdateDeleteTest(SimpleUpdateDeleteTest): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.PROVISION, True)) -class WeCanSetDefaultSchemaWEventsTest(WeCanSetDefaultSchemaWEventsTest): - """provision.py allows us to define event listeners that emit DDL for things like setting up a test schema - or, in this case, changing the default schema for the connection after it's been built. This would override - the schema defined in the sqlalchemy connection string. This support is possible but is not implemented - in the dialect. Deferred for now. - """ - - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.PROVISION, True)) -class FutureWeCanSetDefaultSchemaWEventsTest(FutureWeCanSetDefaultSchemaWEventsTest): - """provision.py allows us to define event listeners that emit DDL for things like setting up a test schema - or, in this case, changing the default schema for the connection after it's been built. This would override - the schema defined in the sqlalchemy connection string. This support is possible but is not implemented - in the dialect. Deferred for now. - """ - - pass - - -class ComponentReflectionTest(ComponentReflectionTest): - @pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_OPTS, True)) - def test_multi_get_table_options_tables(self): - """It's not clear what the expected ouput from this method would even _be_. Requires research.""" - pass - - @pytest.mark.skip(render_future_feature(FutureFeature.VIEW_DEF)) - def test_get_view_definition(self): - pass - - @pytest.mark.skip(render_future_feature(FutureFeature.VIEW_DEF)) - def test_get_view_definition_does_not_exist(self): - pass - - @pytest.mark.skip(render_future_feature(FutureFeature.MULTI_PK)) - def test_get_multi_pk_constraint(self): - pass - - @pytest.mark.skip(render_future_feature(FutureFeature.CHECK)) - def test_get_multi_check_constraints(self): - pass - - -class ComponentReflectionTestExtra(ComponentReflectionTestExtra): - @pytest.mark.skip(render_future_feature(FutureFeature.CHECK)) - def test_get_check_constraints(self): - pass - - @pytest.mark.skip(render_future_feature(FutureFeature.FK_OPTS)) - def test_get_foreign_key_options(self): - """It's not clear from the test code what the expected output is here. Further research required.""" - pass - - -class InsertBehaviorTest(InsertBehaviorTest): - @pytest.mark.skip(render_future_feature(FutureFeature.EMPTY_INSERT, True)) - def test_empty_insert(self): - """Empty inserts are possible using DEFAULT VALUES on Databricks. To implement it, we need - to hook into the SQLCompiler to render a no-op column list. With SQLAlchemy's default implementation - the request fails with a syntax error - """ - pass - - @pytest.mark.skip(render_future_feature(FutureFeature.EMPTY_INSERT, True)) - def test_empty_insert_multiple(self): - """Empty inserts are possible using DEFAULT VALUES on Databricks. To implement it, we need - to hook into the SQLCompiler to render a no-op column list. With SQLAlchemy's default implementation - the request fails with a syntax error - """ - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.ARRAY)) -class ArrayTest(ArrayTest): - """While Databricks supports ARRAY types, DBR cannot handle bound parameters of this type. - This makes them unusable to SQLAlchemy without some workaround. Potentially we could inline - the values of these parameters (which risks sql injection). - """ - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.TEST_DESIGN, True)) -class QuotedNameArgumentTest(QuotedNameArgumentTest): - """These tests are challenging. The whole test setup depends on a table with a name like `quote ' one` - which will never work on Databricks because table names can't contains spaces. But QuotedNamedArgumentTest - also checks the behaviour of DDL identifier preparation process. We need to override some of IdentifierPreparer - methods because these are the ultimate control for whether or not CHECK and UNIQUE constraints are emitted. - """ - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_future_feature(FutureFeature.GENERATED_COLUMNS)) -class ComputedColumnTest(ComputedColumnTest): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_future_feature(FutureFeature.GENERATED_COLUMNS)) -class ComputedReflectionTest(ComputedReflectionTest): - pass diff --git a/src/databricks/sqlalchemy/test/_regression.py b/src/databricks/sqlalchemy/test/_regression.py deleted file mode 100644 index 4dbc5ec2..00000000 --- a/src/databricks/sqlalchemy/test/_regression.py +++ /dev/null @@ -1,311 +0,0 @@ -# type: ignore - -import pytest -from sqlalchemy.testing.suite import ( - ArgSignatureTest, - BooleanTest, - CastTypeDecoratorTest, - ComponentReflectionTestExtra, - CompositeKeyReflectionTest, - CompoundSelectTest, - DateHistoricTest, - DateTest, - DateTimeCoercedToDateTimeTest, - DateTimeHistoricTest, - DateTimeMicrosecondsTest, - DateTimeTest, - DeprecatedCompoundSelectTest, - DistinctOnTest, - EscapingTest, - ExistsTest, - ExpandingBoundInTest, - FetchLimitOffsetTest, - FutureTableDDLTest, - HasTableTest, - IdentityAutoincrementTest, - InsertBehaviorTest, - IntegerTest, - IsOrIsNotDistinctFromTest, - JoinTest, - LikeFunctionsTest, - NormalizedNameTest, - NumericTest, - OrderByLabelTest, - PingTest, - PostCompileParamsTest, - ReturningGuardsTest, - RowFetchTest, - SameNamedSchemaTableTest, - StringTest, - TableDDLTest, - TableNoColumnsTest, - TextTest, - TimeMicrosecondsTest, - TimestampMicrosecondsTest, - TimeTest, - TimeTZTest, - TrueDivTest, - UnicodeTextTest, - UnicodeVarcharTest, - UuidTest, - ValuesExpressionTest, -) - -from databricks.sqlalchemy.test.overrides._ctetest import CTETest -from databricks.sqlalchemy.test.overrides._componentreflectiontest import ( - ComponentReflectionTest, -) - - -@pytest.mark.reviewed -class NumericTest(NumericTest): - pass - - -@pytest.mark.reviewed -class HasTableTest(HasTableTest): - pass - - -@pytest.mark.reviewed -class ComponentReflectionTestExtra(ComponentReflectionTestExtra): - pass - - -@pytest.mark.reviewed -class InsertBehaviorTest(InsertBehaviorTest): - pass - - -@pytest.mark.reviewed -class ComponentReflectionTest(ComponentReflectionTest): - """This test requires two schemas be present in the target Databricks workspace: - - The schema set in --dburi - - A second schema named "test_schema" - - Note that test_get_multi_foreign keys is flaky because DBR does not guarantee the order of data returned in DESCRIBE TABLE EXTENDED - - _Most_ of these tests pass if we manually override the bad test setup. - """ - - pass - - -@pytest.mark.reviewed -class TableDDLTest(TableDDLTest): - pass - - -@pytest.mark.reviewed -class FutureTableDDLTest(FutureTableDDLTest): - pass - - -@pytest.mark.reviewed -class FetchLimitOffsetTest(FetchLimitOffsetTest): - pass - - -@pytest.mark.reviewed -class UuidTest(UuidTest): - pass - - -@pytest.mark.reviewed -class ValuesExpressionTest(ValuesExpressionTest): - pass - - -@pytest.mark.reviewed -class BooleanTest(BooleanTest): - pass - - -@pytest.mark.reviewed -class PostCompileParamsTest(PostCompileParamsTest): - pass - - -@pytest.mark.reviewed -class TimeMicrosecondsTest(TimeMicrosecondsTest): - pass - - -@pytest.mark.reviewed -class TextTest(TextTest): - pass - - -@pytest.mark.reviewed -class StringTest(StringTest): - pass - - -@pytest.mark.reviewed -class DateTimeMicrosecondsTest(DateTimeMicrosecondsTest): - pass - - -@pytest.mark.reviewed -class TimestampMicrosecondsTest(TimestampMicrosecondsTest): - pass - - -@pytest.mark.reviewed -class DateTimeCoercedToDateTimeTest(DateTimeCoercedToDateTimeTest): - pass - - -@pytest.mark.reviewed -class TimeTest(TimeTest): - pass - - -@pytest.mark.reviewed -class DateTimeTest(DateTimeTest): - pass - - -@pytest.mark.reviewed -class DateTimeHistoricTest(DateTimeHistoricTest): - pass - - -@pytest.mark.reviewed -class DateTest(DateTest): - pass - - -@pytest.mark.reviewed -class DateHistoricTest(DateHistoricTest): - pass - - -@pytest.mark.reviewed -class RowFetchTest(RowFetchTest): - pass - - -@pytest.mark.reviewed -class CompositeKeyReflectionTest(CompositeKeyReflectionTest): - pass - - -@pytest.mark.reviewed -class TrueDivTest(TrueDivTest): - pass - - -@pytest.mark.reviewed -class ArgSignatureTest(ArgSignatureTest): - pass - - -@pytest.mark.reviewed -class CompoundSelectTest(CompoundSelectTest): - pass - - -@pytest.mark.reviewed -class DeprecatedCompoundSelectTest(DeprecatedCompoundSelectTest): - pass - - -@pytest.mark.reviewed -class CastTypeDecoratorTest(CastTypeDecoratorTest): - pass - - -@pytest.mark.reviewed -class DistinctOnTest(DistinctOnTest): - pass - - -@pytest.mark.reviewed -class EscapingTest(EscapingTest): - pass - - -@pytest.mark.reviewed -class ExistsTest(ExistsTest): - pass - - -@pytest.mark.reviewed -class IntegerTest(IntegerTest): - pass - - -@pytest.mark.reviewed -class IsOrIsNotDistinctFromTest(IsOrIsNotDistinctFromTest): - pass - - -@pytest.mark.reviewed -class JoinTest(JoinTest): - pass - - -@pytest.mark.reviewed -class OrderByLabelTest(OrderByLabelTest): - pass - - -@pytest.mark.reviewed -class PingTest(PingTest): - pass - - -@pytest.mark.reviewed -class ReturningGuardsTest(ReturningGuardsTest): - pass - - -@pytest.mark.reviewed -class SameNamedSchemaTableTest(SameNamedSchemaTableTest): - pass - - -@pytest.mark.reviewed -class UnicodeTextTest(UnicodeTextTest): - pass - - -@pytest.mark.reviewed -class UnicodeVarcharTest(UnicodeVarcharTest): - pass - - -@pytest.mark.reviewed -class TableNoColumnsTest(TableNoColumnsTest): - pass - - -@pytest.mark.reviewed -class ExpandingBoundInTest(ExpandingBoundInTest): - pass - - -@pytest.mark.reviewed -class CTETest(CTETest): - pass - - -@pytest.mark.reviewed -class NormalizedNameTest(NormalizedNameTest): - pass - - -@pytest.mark.reviewed -class IdentityAutoincrementTest(IdentityAutoincrementTest): - pass - - -@pytest.mark.reviewed -class LikeFunctionsTest(LikeFunctionsTest): - pass - - -@pytest.mark.reviewed -class TimeTZTest(TimeTZTest): - pass diff --git a/src/databricks/sqlalchemy/test/_unsupported.py b/src/databricks/sqlalchemy/test/_unsupported.py deleted file mode 100644 index c1f81205..00000000 --- a/src/databricks/sqlalchemy/test/_unsupported.py +++ /dev/null @@ -1,450 +0,0 @@ -# type: ignore - -from enum import Enum - -import pytest -from databricks.sqlalchemy.test._regression import ( - ComponentReflectionTest, - ComponentReflectionTestExtra, - CTETest, - FetchLimitOffsetTest, - FutureTableDDLTest, - HasTableTest, - InsertBehaviorTest, - NumericTest, - TableDDLTest, - UuidTest, -) - -# These are test suites that are fully skipped with a SkipReason -from sqlalchemy.testing.suite import ( - AutocommitIsolationTest, - DateTimeTZTest, - ExceptionTest, - HasIndexTest, - HasSequenceTest, - HasSequenceTestEmpty, - IsolationLevelTest, - LastrowidTest, - LongNameBlowoutTest, - PercentSchemaNamesTest, - ReturningTest, - SequenceCompilerTest, - SequenceTest, - ServerSideCursorsTest, - UnicodeSchemaTest, -) - - -class SkipReason(Enum): - AUTO_INC = "implicit AUTO_INCREMENT" - CTE_FEAT = "required CTE features" - CURSORS = "server-side cursors" - DECIMAL_FEAT = "required decimal features" - ENFORCE_KEYS = "enforcing primary or foreign key restraints" - FETCH = "fetch clauses" - IDENTIFIER_LENGTH = "identifiers > 255 characters" - IMPL_FLOAT_PREC = "required implicit float precision" - IMPLICIT_ORDER = "deterministic return order if ORDER BY is not present" - INDEXES = "SQL INDEXes" - RETURNING = "INSERT ... RETURNING syntax" - SEQUENCES = "SQL SEQUENCES" - STRING_FEAT = "required STRING type features" - SYMBOL_CHARSET = "symbols expected by test" - TEMP_TBL = "temporary tables" - TIMEZONE_OPT = "timezone-optional TIMESTAMP fields" - TRANSACTIONS = "transactions" - UNIQUE = "UNIQUE constraints" - - -def render_skip_reason(rsn: SkipReason, setup_error=False, extra=False) -> str: - prefix = "[BADSETUP]" if setup_error else "" - postfix = " More detail in _unsupported.py" if extra else "" - return f"[UNSUPPORTED]{prefix}[{rsn.name}]: Databricks does not support {rsn.value}.{postfix}" - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_skip_reason(SkipReason.ENFORCE_KEYS)) -class ExceptionTest(ExceptionTest): - """Per Databricks documentation, primary and foreign key constraints are informational only - and are not enforced. - - https://docs.databricks.com/api/workspace/tableconstraints - """ - - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_skip_reason(SkipReason.IDENTIFIER_LENGTH)) -class LongNameBlowoutTest(LongNameBlowoutTest): - """These tests all include assertions that the tested name > 255 characters""" - - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_skip_reason(SkipReason.SEQUENCES)) -class HasSequenceTest(HasSequenceTest): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_skip_reason(SkipReason.SEQUENCES)) -class HasSequenceTestEmpty(HasSequenceTestEmpty): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_skip_reason(SkipReason.INDEXES)) -class HasIndexTest(HasIndexTest): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_skip_reason(SkipReason.SYMBOL_CHARSET)) -class UnicodeSchemaTest(UnicodeSchemaTest): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_skip_reason(SkipReason.CURSORS)) -class ServerSideCursorsTest(ServerSideCursorsTest): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_skip_reason(SkipReason.SYMBOL_CHARSET)) -class PercentSchemaNamesTest(PercentSchemaNamesTest): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_skip_reason(SkipReason.TRANSACTIONS)) -class IsolationLevelTest(IsolationLevelTest): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_skip_reason(SkipReason.TRANSACTIONS)) -class AutocommitIsolationTest(AutocommitIsolationTest): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_skip_reason(SkipReason.RETURNING)) -class ReturningTest(ReturningTest): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_skip_reason(SkipReason.SEQUENCES)) -class SequenceTest(SequenceTest): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(reason=render_skip_reason(SkipReason.SEQUENCES)) -class SequenceCompilerTest(SequenceCompilerTest): - pass - - -class FetchLimitOffsetTest(FetchLimitOffsetTest): - @pytest.mark.flaky - @pytest.mark.skip(reason=render_skip_reason(SkipReason.IMPLICIT_ORDER, extra=True)) - def test_limit_render_multiple_times(self): - """This test depends on the order that records are inserted into the table. It's passing criteria requires that - a record inserted with id=1 is the first record returned when no ORDER BY clause is specified. But Databricks occasionally - INSERTS in a different order, which makes this test seem to fail. The test is flaky, but the underlying functionality - (can multiple LIMIT clauses be rendered) is not broken. - - Unclear if this is a bug in Databricks, Delta, or some race-condition in the test itself. - """ - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.FETCH)) - def test_bound_fetch_offset(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.FETCH)) - def test_fetch_offset_no_order(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.FETCH)) - def test_fetch_offset_nobinds(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.FETCH)) - def test_simple_fetch(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.FETCH)) - def test_simple_fetch_offset(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.FETCH)) - def test_simple_fetch_percent(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.FETCH)) - def test_simple_fetch_percent_ties(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.FETCH)) - def test_simple_fetch_ties(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.FETCH)) - def test_expr_fetch_offset(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.FETCH)) - def test_fetch_offset_percent(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.FETCH)) - def test_fetch_offset_percent_ties(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.FETCH)) - def test_fetch_offset_ties(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.FETCH)) - def test_fetch_offset_ties_exact_number(self): - pass - - -class UuidTest(UuidTest): - @pytest.mark.skip(reason=render_skip_reason(SkipReason.RETURNING)) - def test_uuid_returning(self): - pass - - -class FutureTableDDLTest(FutureTableDDLTest): - @pytest.mark.skip(render_skip_reason(SkipReason.INDEXES)) - def test_create_index_if_not_exists(self): - """We could use requirements.index_reflection and requirements.index_ddl_if_exists - here to disable this but prefer a more meaningful skip message - """ - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.INDEXES)) - def test_drop_index_if_exists(self): - """We could use requirements.index_reflection and requirements.index_ddl_if_exists - here to disable this but prefer a more meaningful skip message - """ - pass - - -class TableDDLTest(TableDDLTest): - @pytest.mark.skip(reason=render_skip_reason(SkipReason.INDEXES)) - def test_create_index_if_not_exists(self, connection): - """We could use requirements.index_reflection and requirements.index_ddl_if_exists - here to disable this but prefer a more meaningful skip message - """ - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.INDEXES)) - def test_drop_index_if_exists(self, connection): - """We could use requirements.index_reflection and requirements.index_ddl_if_exists - here to disable this but prefer a more meaningful skip message - """ - pass - - -class ComponentReflectionTest(ComponentReflectionTest): - """This test requires two schemas be present in the target Databricks workspace: - - The schema set in --dburi - - A second schema named "test_schema" - - Note that test_get_multi_foreign keys is flaky because DBR does not guarantee the order of data returned in DESCRIBE TABLE EXTENDED - """ - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.UNIQUE)) - def test_get_multi_unique_constraints(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.TEMP_TBL, True, True)) - def test_get_temp_view_names(self): - """While Databricks supports temporary views, this test creates a temp view aimed at a temp table. - Databricks doesn't support temp tables. So the test can never pass. - """ - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.TEMP_TBL)) - def test_get_temp_table_columns(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.TEMP_TBL)) - def test_get_temp_table_indexes(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.TEMP_TBL)) - def test_get_temp_table_names(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.TEMP_TBL)) - def test_get_temp_table_unique_constraints(self): - pass - - @pytest.mark.skip(reason=render_skip_reason(SkipReason.TEMP_TBL)) - def test_reflect_table_temp_table(self): - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.INDEXES)) - def test_get_indexes(self): - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.INDEXES)) - def test_multi_indexes(self): - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.INDEXES)) - def get_noncol_index(self): - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.UNIQUE)) - def test_get_unique_constraints(self): - pass - - -class NumericTest(NumericTest): - @pytest.mark.skip(render_skip_reason(SkipReason.DECIMAL_FEAT)) - def test_enotation_decimal(self): - """This test automatically runs if requirements.precision_numerics_enotation_large is open()""" - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.DECIMAL_FEAT)) - def test_enotation_decimal_large(self): - """This test automatically runs if requirements.precision_numerics_enotation_large is open()""" - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.IMPL_FLOAT_PREC, extra=True)) - def test_float_coerce_round_trip(self): - """ - This automatically runs if requirements.literal_float_coercion is open() - - Without additional work, Databricks returns 15.75629997253418 when you SELECT 15.7563. - This is a potential area where we could override the Float literal processor to add a CAST. - Will leave to a PM to decide if we should do so. - """ - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.IMPL_FLOAT_PREC, extra=True)) - def test_float_custom_scale(self): - """This test automatically runs if requirements.precision_generic_float_type is open()""" - pass - - -class HasTableTest(HasTableTest): - """Databricks does not support temporary tables.""" - - @pytest.mark.skip(render_skip_reason(SkipReason.TEMP_TBL)) - def test_has_table_temp_table(self): - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.TEMP_TBL, True, True)) - def test_has_table_temp_view(self): - """Databricks supports temporary views but this test depends on requirements.has_temp_table, which we - explicitly close so that we can run other tests in this group. See the comment under has_temp_table in - requirements.py for details. - - From what I can see, there is no way to run this test since it will fail during setup if we mark has_temp_table - open(). It _might_ be possible to hijack this behaviour by implementing temp_table_keyword_args in our own - provision.py. Doing so would mean creating a real table during this class setup instead of a temp table. Then - we could just skip the temp table tests but run the temp view tests. But this test fixture doesn't cleanup its - temp tables and has no hook to do so. - - It would be ideal for SQLAlchemy to define a separate requirements.has_temp_views. - """ - pass - - -class ComponentReflectionTestExtra(ComponentReflectionTestExtra): - @pytest.mark.skip(render_skip_reason(SkipReason.INDEXES)) - def test_reflect_covering_index(self): - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.INDEXES)) - def test_reflect_expression_based_indexes(self): - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.STRING_FEAT, extra=True)) - def test_varchar_reflection(self): - """Databricks doesn't enforce string length limitations like STRING(255).""" - pass - - -class InsertBehaviorTest(InsertBehaviorTest): - @pytest.mark.skip(render_skip_reason(SkipReason.AUTO_INC, True, True)) - def test_autoclose_on_insert(self): - """The setup for this test creates a column with implicit autoincrement enabled. - This dialect does not implement implicit autoincrement - users must declare Identity() explicitly. - """ - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.AUTO_INC, True, True)) - def test_insert_from_select_autoinc(self): - """Implicit autoincrement is not implemented in this dialect.""" - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.AUTO_INC, True, True)) - def test_insert_from_select_autoinc_no_rows(self): - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.RETURNING)) - def test_autoclose_on_insert_implicit_returning(self): - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(render_skip_reason(SkipReason.AUTO_INC, extra=True)) -class LastrowidTest(LastrowidTest): - """SQLAlchemy docs describe that a column without an explicit Identity() may implicitly create one if autoincrement=True. - That is what this method tests. Databricks supports auto-incrementing IDENTITY columns but they must be explicitly - declared. This limitation is present in our dialect as well. Which means that SQLAlchemy's autoincrement setting of a column - is ignored. We emit a logging.WARN message if you try it. - - In the future we could handle this autoincrement by implicitly calling the visit_identity_column() method of our DDLCompiler - when autoincrement=True. There is an example of this in the Microsoft SQL Server dialect: MSSDDLCompiler.get_column_specification - - For now, if you need to create a SQLAlchemy column with an auto-incrementing identity, you must set this explicitly in your column - definition by passing an Identity() to the column constructor. - """ - - pass - - -class CTETest(CTETest): - """During the teardown for this test block, it tries to drop a constraint that it never named which raises - a compilation error. This could point to poor constraint reflection but our other constraint reflection - tests pass. Requires investigation. - """ - - @pytest.mark.skip(render_skip_reason(SkipReason.CTE_FEAT, extra=True)) - def test_select_recursive_round_trip(self): - pass - - @pytest.mark.skip(render_skip_reason(SkipReason.CTE_FEAT, extra=True)) - def test_delete_scalar_subq_round_trip(self): - """Error received is [UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY] - - This suggests a limitation of the platform. But a workaround may be possible if customers require it. - """ - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(render_skip_reason(SkipReason.TIMEZONE_OPT, True)) -class DateTimeTZTest(DateTimeTZTest): - """Test whether the sqlalchemy.DateTime() type can _optionally_ include timezone info. - This dialect maps DateTime() → TIMESTAMP, which _always_ includes tzinfo. - - Users can use databricks.sqlalchemy.TIMESTAMP_NTZ for a tzinfo-less timestamp. The SQLA docs - acknowledge this is expected for some dialects. - - https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.DateTime - """ - - pass diff --git a/src/databricks/sqlalchemy/test/conftest.py b/src/databricks/sqlalchemy/test/conftest.py deleted file mode 100644 index ea43e8d3..00000000 --- a/src/databricks/sqlalchemy/test/conftest.py +++ /dev/null @@ -1,13 +0,0 @@ -from sqlalchemy.dialects import registry -import pytest - -registry.register("databricks", "databricks.sqlalchemy", "DatabricksDialect") -# sqlalchemy's dialect-testing machinery wants an entry like this. -# This seems to be based around dialects maybe having multiple drivers -# and wanting to test driver-specific URLs, but doesn't seem to make -# much sense for dialects with only one driver. -registry.register("databricks.databricks", "databricks.sqlalchemy", "DatabricksDialect") - -pytest.register_assert_rewrite("sqlalchemy.testing.assertions") - -from sqlalchemy.testing.plugin.pytestplugin import * diff --git a/src/databricks/sqlalchemy/test/overrides/_componentreflectiontest.py b/src/databricks/sqlalchemy/test/overrides/_componentreflectiontest.py deleted file mode 100644 index a1f58fa6..00000000 --- a/src/databricks/sqlalchemy/test/overrides/_componentreflectiontest.py +++ /dev/null @@ -1,189 +0,0 @@ -"""The default test setup uses self-referential foreign keys and indexes for a test table. -We override to remove these assumptions. - -Note that test_multi_foreign_keys currently does not pass for all combinations due to -an ordering issue. The dialect returns the expected information. But this test makes assertions -on the order of the returned results. We can't guarantee that order at the moment. - -The test fixture actually tries to sort the outputs, but this sort isn't working. Will need -to follow-up on this later. -""" -import sqlalchemy as sa -from sqlalchemy.testing import config -from sqlalchemy.testing.schema import Column -from sqlalchemy.testing.schema import Table -from sqlalchemy import ForeignKey -from sqlalchemy import testing - -from sqlalchemy.testing.suite.test_reflection import ComponentReflectionTest - - -class ComponentReflectionTest(ComponentReflectionTest): # type: ignore - @classmethod - def define_reflected_tables(cls, metadata, schema): - if schema: - schema_prefix = schema + "." - else: - schema_prefix = "" - - if testing.requires.self_referential_foreign_keys.enabled: - parent_id_args = ( - ForeignKey( - "%susers.user_id" % schema_prefix, name="user_id_fk", use_alter=True - ), - ) - else: - parent_id_args = () - users = Table( - "users", - metadata, - Column("user_id", sa.INT, primary_key=True), - Column("test1", sa.CHAR(5), nullable=False), - Column("test2", sa.Float(), nullable=False), - Column("parent_user_id", sa.Integer, *parent_id_args), - sa.CheckConstraint( - "test2 > 0", - name="zz_test2_gt_zero", - comment="users check constraint", - ), - sa.CheckConstraint("test2 <= 1000"), - schema=schema, - test_needs_fk=True, - ) - - Table( - "dingalings", - metadata, - Column("dingaling_id", sa.Integer, primary_key=True), - Column( - "address_id", - sa.Integer, - ForeignKey( - "%semail_addresses.address_id" % schema_prefix, - name="zz_email_add_id_fg", - comment="di fk comment", - ), - ), - Column( - "id_user", - sa.Integer, - ForeignKey("%susers.user_id" % schema_prefix), - ), - Column("data", sa.String(30), unique=True), - sa.CheckConstraint( - "address_id > 0 AND address_id < 1000", - name="address_id_gt_zero", - ), - sa.UniqueConstraint( - "address_id", - "dingaling_id", - name="zz_dingalings_multiple", - comment="di unique comment", - ), - schema=schema, - test_needs_fk=True, - ) - Table( - "email_addresses", - metadata, - Column("address_id", sa.Integer), - Column("remote_user_id", sa.Integer, ForeignKey(users.c.user_id)), - Column("email_address", sa.String(20)), - sa.PrimaryKeyConstraint( - "address_id", name="email_ad_pk", comment="ea pk comment" - ), - schema=schema, - test_needs_fk=True, - ) - Table( - "comment_test", - metadata, - Column("id", sa.Integer, primary_key=True, comment="id comment"), - Column("data", sa.String(20), comment="data % comment"), - Column( - "d2", - sa.String(20), - comment=r"""Comment types type speedily ' " \ '' Fun!""", - ), - Column("d3", sa.String(42), comment="Comment\nwith\rescapes"), - schema=schema, - comment=r"""the test % ' " \ table comment""", - ) - Table( - "no_constraints", - metadata, - Column("data", sa.String(20)), - schema=schema, - comment="no\nconstraints\rhas\fescaped\vcomment", - ) - - if testing.requires.cross_schema_fk_reflection.enabled: - if schema is None: - Table( - "local_table", - metadata, - Column("id", sa.Integer, primary_key=True), - Column("data", sa.String(20)), - Column( - "remote_id", - ForeignKey("%s.remote_table_2.id" % testing.config.test_schema), - ), - test_needs_fk=True, - schema=config.db.dialect.default_schema_name, - ) - else: - Table( - "remote_table", - metadata, - Column("id", sa.Integer, primary_key=True), - Column( - "local_id", - ForeignKey( - "%s.local_table.id" % config.db.dialect.default_schema_name - ), - ), - Column("data", sa.String(20)), - schema=schema, - test_needs_fk=True, - ) - Table( - "remote_table_2", - metadata, - Column("id", sa.Integer, primary_key=True), - Column("data", sa.String(20)), - schema=schema, - test_needs_fk=True, - ) - - if testing.requires.index_reflection.enabled: - Index("users_t_idx", users.c.test1, users.c.test2, unique=True) - Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1) - - if not schema: - # test_needs_fk is at the moment to force MySQL InnoDB - noncol_idx_test_nopk = Table( - "noncol_idx_test_nopk", - metadata, - Column("q", sa.String(5)), - test_needs_fk=True, - ) - - noncol_idx_test_pk = Table( - "noncol_idx_test_pk", - metadata, - Column("id", sa.Integer, primary_key=True), - Column("q", sa.String(5)), - test_needs_fk=True, - ) - - if ( - testing.requires.indexes_with_ascdesc.enabled - and testing.requires.reflect_indexes_with_ascdesc.enabled - ): - Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) - Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) - - if testing.requires.view_column_reflection.enabled: - cls.define_views(metadata, schema) - if not schema and testing.requires.temp_table_reflection.enabled: - cls.define_temp_tables(metadata) diff --git a/src/databricks/sqlalchemy/test/overrides/_ctetest.py b/src/databricks/sqlalchemy/test/overrides/_ctetest.py deleted file mode 100644 index 3cdae036..00000000 --- a/src/databricks/sqlalchemy/test/overrides/_ctetest.py +++ /dev/null @@ -1,33 +0,0 @@ -"""The default test setup uses a self-referential foreign key. With our dialect this requires -`use_alter=True` and the fk constraint to be named. So we override this to make the test pass. -""" - -from sqlalchemy.testing.suite import CTETest - -from sqlalchemy.testing.schema import Column -from sqlalchemy.testing.schema import Table -from sqlalchemy import ForeignKey -from sqlalchemy import Integer -from sqlalchemy import String - - -class CTETest(CTETest): # type: ignore - @classmethod - def define_tables(cls, metadata): - Table( - "some_table", - metadata, - Column("id", Integer, primary_key=True), - Column("data", String(50)), - Column( - "parent_id", ForeignKey("some_table.id", name="fk_test", use_alter=True) - ), - ) - - Table( - "some_other_table", - metadata, - Column("id", Integer, primary_key=True), - Column("data", String(50)), - Column("parent_id", Integer), - ) diff --git a/src/databricks/sqlalchemy/test/test_suite.py b/src/databricks/sqlalchemy/test/test_suite.py deleted file mode 100644 index 2b40a432..00000000 --- a/src/databricks/sqlalchemy/test/test_suite.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -The order of these imports is important. Test cases are imported first from SQLAlchemy, -then are overridden by our local skip markers in _regression, _unsupported, and _future. -""" - - -# type: ignore -# fmt: off -from sqlalchemy.testing.suite import * -from databricks.sqlalchemy.test._regression import * -from databricks.sqlalchemy.test._unsupported import * -from databricks.sqlalchemy.test._future import * -from databricks.sqlalchemy.test._extra import TinyIntegerTest, DateTimeTZTestCustom diff --git a/src/databricks/sqlalchemy/test_local/__init__.py b/src/databricks/sqlalchemy/test_local/__init__.py deleted file mode 100644 index eca1cf55..00000000 --- a/src/databricks/sqlalchemy/test_local/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -This module contains tests entirely maintained by Databricks. - -These tests do not rely on SQLAlchemy's custom test runner. -""" diff --git a/src/databricks/sqlalchemy/test_local/conftest.py b/src/databricks/sqlalchemy/test_local/conftest.py deleted file mode 100644 index c8b350be..00000000 --- a/src/databricks/sqlalchemy/test_local/conftest.py +++ /dev/null @@ -1,44 +0,0 @@ -import os -import pytest - - -@pytest.fixture(scope="session") -def host(): - return os.getenv("DATABRICKS_SERVER_HOSTNAME") - - -@pytest.fixture(scope="session") -def http_path(): - return os.getenv("DATABRICKS_HTTP_PATH") - - -@pytest.fixture(scope="session") -def access_token(): - return os.getenv("DATABRICKS_TOKEN") - - -@pytest.fixture(scope="session") -def ingestion_user(): - return os.getenv("DATABRICKS_USER") - - -@pytest.fixture(scope="session") -def catalog(): - return os.getenv("DATABRICKS_CATALOG") - - -@pytest.fixture(scope="session") -def schema(): - return os.getenv("DATABRICKS_SCHEMA", "default") - - -@pytest.fixture(scope="session", autouse=True) -def connection_details(host, http_path, access_token, ingestion_user, catalog, schema): - return { - "host": host, - "http_path": http_path, - "access_token": access_token, - "ingestion_user": ingestion_user, - "catalog": catalog, - "schema": schema, - } diff --git a/src/databricks/sqlalchemy/test_local/e2e/MOCK_DATA.xlsx b/src/databricks/sqlalchemy/test_local/e2e/MOCK_DATA.xlsx deleted file mode 100644 index e080689a..00000000 Binary files a/src/databricks/sqlalchemy/test_local/e2e/MOCK_DATA.xlsx and /dev/null differ diff --git a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py deleted file mode 100644 index ce0b5d89..00000000 --- a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py +++ /dev/null @@ -1,543 +0,0 @@ -import datetime -import decimal -from typing import Tuple, Union, List -from unittest import skipIf - -import pytest -from sqlalchemy import ( - Column, - MetaData, - Table, - Text, - create_engine, - insert, - select, - text, -) -from sqlalchemy.engine import Engine -from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column -from sqlalchemy.schema import DropColumnComment, SetColumnComment -from sqlalchemy.types import BOOLEAN, DECIMAL, Date, Integer, String - -try: - from sqlalchemy.orm import declarative_base -except ImportError: - from sqlalchemy.ext.declarative import declarative_base - - -USER_AGENT_TOKEN = "PySQL e2e Tests" - - -def sqlalchemy_1_3(): - import sqlalchemy - - return sqlalchemy.__version__.startswith("1.3") - - -def version_agnostic_select(object_to_select, *args, **kwargs): - """ - SQLAlchemy==1.3.x requires arguments to select() to be a Python list - - https://docs.sqlalchemy.org/en/20/changelog/migration_14.html#orm-query-is-internally-unified-with-select-update-delete-2-0-style-execution-available - """ - - if sqlalchemy_1_3(): - return select([object_to_select], *args, **kwargs) - else: - return select(object_to_select, *args, **kwargs) - - -def version_agnostic_connect_arguments(connection_details) -> Tuple[str, dict]: - HOST = connection_details["host"] - HTTP_PATH = connection_details["http_path"] - ACCESS_TOKEN = connection_details["access_token"] - CATALOG = connection_details["catalog"] - SCHEMA = connection_details["schema"] - - ua_connect_args = {"_user_agent_entry": USER_AGENT_TOKEN} - - if sqlalchemy_1_3(): - conn_string = f"databricks://token:{ACCESS_TOKEN}@{HOST}" - connect_args = { - **ua_connect_args, - "http_path": HTTP_PATH, - "server_hostname": HOST, - "catalog": CATALOG, - "schema": SCHEMA, - } - - return conn_string, connect_args - else: - return ( - f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}", - ua_connect_args, - ) - - -@pytest.fixture -def db_engine(connection_details) -> Engine: - conn_string, connect_args = version_agnostic_connect_arguments(connection_details) - return create_engine(conn_string, connect_args=connect_args) - - -def run_query(db_engine: Engine, query: Union[str, Text]): - if not isinstance(query, Text): - _query = text(query) # type: ignore - else: - _query = query # type: ignore - with db_engine.begin() as conn: - return conn.execute(_query).fetchall() - - -@pytest.fixture -def samples_engine(connection_details) -> Engine: - details = connection_details.copy() - details["catalog"] = "samples" - details["schema"] = "nyctaxi" - conn_string, connect_args = version_agnostic_connect_arguments(details) - return create_engine(conn_string, connect_args=connect_args) - - -@pytest.fixture() -def base(db_engine): - return declarative_base() - - -@pytest.fixture() -def session(db_engine): - return Session(db_engine) - - -@pytest.fixture() -def metadata_obj(db_engine): - return MetaData() - - -def test_can_connect(db_engine): - simple_query = "SELECT 1" - result = run_query(db_engine, simple_query) - assert len(result) == 1 - - -def test_connect_args(db_engine): - """Verify that extra connect args passed to sqlalchemy.create_engine are passed to DBAPI - - This will most commonly happen when partners supply a user agent entry - """ - - conn = db_engine.connect() - connection_headers = conn.connection.thrift_backend._transport._headers - user_agent = connection_headers["User-Agent"] - - expected = f"(sqlalchemy + {USER_AGENT_TOKEN})" - assert expected in user_agent - - -@pytest.mark.skipif(sqlalchemy_1_3(), reason="Pandas requires SQLAlchemy >= 1.4") -@pytest.mark.skip( - reason="DBR is currently limited to 256 parameters per call to .execute(). Test cannot pass." -) -def test_pandas_upload(db_engine, metadata_obj): - import pandas as pd - - SCHEMA = "default" - try: - df = pd.read_excel( - "src/databricks/sqlalchemy/test_local/e2e/demo_data/MOCK_DATA.xlsx" - ) - df.to_sql( - "mock_data", - db_engine, - schema=SCHEMA, - index=False, - method="multi", - if_exists="replace", - ) - - df_after = pd.read_sql_table("mock_data", db_engine, schema=SCHEMA) - assert len(df) == len(df_after) - except Exception as e: - raise e - finally: - db_engine.execute("DROP TABLE mock_data") - - -def test_create_table_not_null(db_engine, metadata_obj: MetaData): - table_name = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")) - - SampleTable = Table( - table_name, - metadata_obj, - Column("name", String(255)), - Column("episodes", Integer), - Column("some_bool", BOOLEAN, nullable=False), - ) - - metadata_obj.create_all(db_engine) - - columns = db_engine.dialect.get_columns( - connection=db_engine.connect(), table_name=table_name - ) - - name_column_description = columns[0] - some_bool_column_description = columns[2] - - assert name_column_description.get("nullable") is True - assert some_bool_column_description.get("nullable") is False - - metadata_obj.drop_all(db_engine) - - -def test_column_comment(db_engine, metadata_obj: MetaData): - table_name = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")) - - column = Column("name", String(255), comment="some comment") - SampleTable = Table(table_name, metadata_obj, column) - - metadata_obj.create_all(db_engine) - connection = db_engine.connect() - - columns = db_engine.dialect.get_columns( - connection=connection, table_name=table_name - ) - - assert columns[0].get("comment") == "some comment" - - column.comment = "other comment" - connection.execute(SetColumnComment(column)) - - columns = db_engine.dialect.get_columns( - connection=connection, table_name=table_name - ) - - assert columns[0].get("comment") == "other comment" - - connection.execute(DropColumnComment(column)) - - columns = db_engine.dialect.get_columns( - connection=connection, table_name=table_name - ) - - assert columns[0].get("comment") == None - - metadata_obj.drop_all(db_engine) - - -def test_bulk_insert_with_core(db_engine, metadata_obj, session): - import random - - # Maximum number of parameter is 256. 256/4 == 64 - num_to_insert = 64 - - table_name = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")) - - names = ["Bim", "Miki", "Sarah", "Ira"] - - SampleTable = Table( - table_name, metadata_obj, Column("name", String(255)), Column("number", Integer) - ) - - rows = [ - {"name": names[i % 3], "number": random.choice(range(64))} - for i in range(num_to_insert) - ] - - metadata_obj.create_all(db_engine) - with db_engine.begin() as conn: - conn.execute(insert(SampleTable).values(rows)) - - with db_engine.begin() as conn: - rows = conn.execute(version_agnostic_select(SampleTable)).fetchall() - - assert len(rows) == num_to_insert - - -def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData): - """ """ - - SampleTable = Table( - "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")), - metadata_obj, - Column("name", String(255)), - Column("episodes", Integer), - Column("some_bool", BOOLEAN), - Column("dollars", DECIMAL(10, 2)), - ) - - metadata_obj.create_all(db_engine) - - insert_stmt = insert(SampleTable).values( - name="Bim Adewunmi", episodes=6, some_bool=True, dollars=decimal.Decimal(125) - ) - - with db_engine.connect() as conn: - conn.execute(insert_stmt) - - select_stmt = version_agnostic_select(SampleTable) - with db_engine.begin() as conn: - resp = conn.execute(select_stmt) - - result = resp.fetchall() - - assert len(result) == 1 - - metadata_obj.drop_all(db_engine) - - -# ORM tests are made following this tutorial -# https://docs.sqlalchemy.org/en/14/orm/quickstart.html - - -@skipIf(False, "Unity catalog must be supported") -def test_create_insert_drop_table_orm(db_engine): - """ORM classes built on the declarative base class must have a primary key. - This is restricted to Unity Catalog. - """ - - class Base(DeclarativeBase): - pass - - class SampleObject(Base): - __tablename__ = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")) - - name: Mapped[str] = mapped_column(String(255), primary_key=True) - episodes: Mapped[int] = mapped_column(Integer) - some_bool: Mapped[bool] = mapped_column(BOOLEAN) - - Base.metadata.create_all(db_engine) - - sample_object_1 = SampleObject(name="Bim Adewunmi", episodes=6, some_bool=True) - sample_object_2 = SampleObject(name="Miki Meek", episodes=12, some_bool=False) - - session = Session(db_engine) - session.add(sample_object_1) - session.add(sample_object_2) - session.flush() - - stmt = version_agnostic_select(SampleObject).where( - SampleObject.name.in_(["Bim Adewunmi", "Miki Meek"]) - ) - - if sqlalchemy_1_3(): - output = [i for i in session.execute(stmt)] - else: - output = [i for i in session.scalars(stmt)] - - assert len(output) == 2 - - Base.metadata.drop_all(db_engine) - - -def test_dialect_type_mappings(db_engine, metadata_obj: MetaData): - """Confirms that we get back the same time we declared in a model and inserted using Core""" - - class Base(DeclarativeBase): - pass - - SampleTable = Table( - "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")), - metadata_obj, - Column("string_example", String(255)), - Column("integer_example", Integer), - Column("boolean_example", BOOLEAN), - Column("decimal_example", DECIMAL(10, 2)), - Column("date_example", Date), - ) - - string_example = "" - integer_example = 100 - boolean_example = True - decimal_example = decimal.Decimal(125) - date_example = datetime.date(2013, 1, 1) - - metadata_obj.create_all(db_engine) - - insert_stmt = insert(SampleTable).values( - string_example=string_example, - integer_example=integer_example, - boolean_example=boolean_example, - decimal_example=decimal_example, - date_example=date_example, - ) - - with db_engine.connect() as conn: - conn.execute(insert_stmt) - - select_stmt = version_agnostic_select(SampleTable) - with db_engine.begin() as conn: - resp = conn.execute(select_stmt) - - result = resp.fetchall() - this_row = result[0] - - assert this_row.string_example == string_example - assert this_row.integer_example == integer_example - assert this_row.boolean_example == boolean_example - assert this_row.decimal_example == decimal_example - assert this_row.date_example == date_example - - metadata_obj.drop_all(db_engine) - - -def test_inspector_smoke_test(samples_engine: Engine): - """It does not appear that 3L namespace is supported here""" - - schema, table = "nyctaxi", "trips" - - try: - inspector = Inspector.from_engine(samples_engine) - except Exception as e: - assert False, f"Could not build inspector: {e}" - - # Expect six columns - columns = inspector.get_columns(table, schema=schema) - - # Expect zero views, but the method should return - views = inspector.get_view_names(schema=schema) - - assert ( - len(columns) == 6 - ), "Dialect did not find the expected number of columns in samples.nyctaxi.trips" - assert len(views) == 0, "Views could not be fetched" - - -@pytest.mark.skip(reason="engine.table_names has been removed in sqlalchemy verison 2") -def test_get_table_names_smoke_test(samples_engine: Engine): - with samples_engine.connect() as conn: - _names = samples_engine.table_names(schema="nyctaxi", connection=conn) # type: ignore - _names is not None, "get_table_names did not succeed" - - -def test_has_table_across_schemas( - db_engine: Engine, samples_engine: Engine, catalog: str, schema: str -): - """For this test to pass these conditions must be met: - - Table samples.nyctaxi.trips must exist - - Table samples.tpch.customer must exist - - The `catalog` and `schema` environment variables must be set and valid - """ - - with samples_engine.connect() as conn: - # 1) Check for table within schema declared at engine creation time - assert samples_engine.dialect.has_table(connection=conn, table_name="trips") - - # 2) Check for table within another schema in the same catalog - assert samples_engine.dialect.has_table( - connection=conn, table_name="customer", schema="tpch" - ) - - # 3) Check for a table within a different catalog - # Create a table in a different catalog - with db_engine.connect() as conn: - conn.execute(text("CREATE TABLE test_has_table (numbers_are_cool INT);")) - - try: - # Verify that this table is not found in the samples catalog - assert not samples_engine.dialect.has_table( - connection=conn, table_name="test_has_table" - ) - # Verify that this table is found in a separate catalog - assert samples_engine.dialect.has_table( - connection=conn, - table_name="test_has_table", - schema=schema, - catalog=catalog, - ) - finally: - conn.execute(text("DROP TABLE test_has_table;")) - - -def test_user_agent_adjustment(db_engine): - # If .connect() is called multiple times on an engine, don't keep pre-pending the user agent - # https://github.com/databricks/databricks-sql-python/issues/192 - c1 = db_engine.connect() - c2 = db_engine.connect() - - def get_conn_user_agent(conn): - return conn.connection.dbapi_connection.thrift_backend._transport._headers.get( - "User-Agent" - ) - - ua1 = get_conn_user_agent(c1) - ua2 = get_conn_user_agent(c2) - same_ua = ua1 == ua2 - - c1.close() - c2.close() - - assert same_ua, f"User agents didn't match \n {ua1} \n {ua2}" - - -@pytest.fixture -def sample_table(metadata_obj: MetaData, db_engine: Engine): - """This fixture creates a sample table and cleans it up after the test is complete.""" - from databricks.sqlalchemy._parse import GET_COLUMNS_TYPE_MAP - - table_name = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")) - - args: List[Column] = [ - Column(colname, coltype) for colname, coltype in GET_COLUMNS_TYPE_MAP.items() - ] - - SampleTable = Table(table_name, metadata_obj, *args) - - metadata_obj.create_all(db_engine) - - yield table_name - - metadata_obj.drop_all(db_engine) - - -def test_get_columns(db_engine, sample_table: str): - """Created after PECO-1297 and Github Issue #295 to verify that get_columsn behaves like it should for all known SQLAlchemy types""" - - inspector = Inspector.from_engine(db_engine) - - # this raises an exception if `parse_column_info_from_tgetcolumnsresponse` fails a lookup - columns = inspector.get_columns(sample_table) - - assert True - - -class TestCommentReflection: - @pytest.fixture(scope="class") - def engine(self, connection_details: dict): - HOST = connection_details["host"] - HTTP_PATH = connection_details["http_path"] - ACCESS_TOKEN = connection_details["access_token"] - CATALOG = connection_details["catalog"] - SCHEMA = connection_details["schema"] - - connection_string = f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}" - connect_args = {"_user_agent_entry": USER_AGENT_TOKEN} - - engine = create_engine(connection_string, connect_args=connect_args) - return engine - - @pytest.fixture - def inspector(self, engine: Engine) -> Inspector: - return Inspector.from_engine(engine) - - @pytest.fixture(scope="class") - def table(self, engine): - md = MetaData() - tbl = Table( - "foo", - md, - Column("bar", String, comment="column comment"), - comment="table comment", - ) - md.create_all(bind=engine) - - yield tbl - - md.drop_all(bind=engine) - - def test_table_comment_reflection(self, inspector: Inspector, table: Table): - comment = inspector.get_table_comment(table.name) - assert comment == {"text": "table comment"} - - def test_column_comment(self, inspector: Inspector, table: Table): - result = inspector.get_columns(table.name)[0].get("comment") - assert result == "column comment" diff --git a/src/databricks/sqlalchemy/test_local/test_ddl.py b/src/databricks/sqlalchemy/test_local/test_ddl.py deleted file mode 100644 index f596dffa..00000000 --- a/src/databricks/sqlalchemy/test_local/test_ddl.py +++ /dev/null @@ -1,96 +0,0 @@ -import pytest -from sqlalchemy import Column, MetaData, String, Table, create_engine -from sqlalchemy.schema import ( - CreateTable, - DropColumnComment, - DropTableComment, - SetColumnComment, - SetTableComment, -) - - -class DDLTestBase: - engine = create_engine( - "databricks://token:****@****?http_path=****&catalog=****&schema=****" - ) - - def compile(self, stmt): - return str(stmt.compile(bind=self.engine)) - - -class TestColumnCommentDDL(DDLTestBase): - @pytest.fixture - def metadata(self) -> MetaData: - """Assemble a metadata object with one table containing one column.""" - metadata = MetaData() - - column = Column("foo", String, comment="bar") - table = Table("foobar", metadata, column) - - return metadata - - @pytest.fixture - def table(self, metadata) -> Table: - return metadata.tables.get("foobar") - - @pytest.fixture - def column(self, table) -> Column: - return table.columns[0] - - def test_create_table_with_column_comment(self, table): - stmt = CreateTable(table) - output = self.compile(stmt) - - # output is a CREATE TABLE statement - assert "foo STRING COMMENT 'bar'" in output - - def test_alter_table_add_column_comment(self, column): - stmt = SetColumnComment(column) - output = self.compile(stmt) - assert output == "ALTER TABLE foobar ALTER COLUMN foo COMMENT 'bar'" - - def test_alter_table_drop_column_comment(self, column): - stmt = DropColumnComment(column) - output = self.compile(stmt) - assert output == "ALTER TABLE foobar ALTER COLUMN foo COMMENT ''" - - -class TestTableCommentDDL(DDLTestBase): - @pytest.fixture - def metadata(self) -> MetaData: - """Assemble a metadata object with one table containing one column.""" - metadata = MetaData() - - col1 = Column("foo", String) - col2 = Column("foo", String) - tbl_w_comment = Table("martin", metadata, col1, comment="foobar") - tbl_wo_comment = Table("prs", metadata, col2) - - return metadata - - @pytest.fixture - def table_with_comment(self, metadata) -> Table: - return metadata.tables.get("martin") - - @pytest.fixture - def table_without_comment(self, metadata) -> Table: - return metadata.tables.get("prs") - - def test_create_table_with_comment(self, table_with_comment): - stmt = CreateTable(table_with_comment) - output = self.compile(stmt) - assert "USING DELTA" in output - assert "COMMENT 'foobar'" in output - - def test_alter_table_add_comment(self, table_without_comment: Table): - table_without_comment.comment = "wireless mechanical keyboard" - stmt = SetTableComment(table_without_comment) - output = self.compile(stmt) - - assert output == "COMMENT ON TABLE prs IS 'wireless mechanical keyboard'" - - def test_alter_table_drop_comment(self, table_with_comment): - """The syntax for COMMENT ON is here: https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-comment.html""" - stmt = DropTableComment(table_with_comment) - output = self.compile(stmt) - assert output == "COMMENT ON TABLE martin IS NULL" diff --git a/src/databricks/sqlalchemy/test_local/test_parsing.py b/src/databricks/sqlalchemy/test_local/test_parsing.py deleted file mode 100644 index c8ab443d..00000000 --- a/src/databricks/sqlalchemy/test_local/test_parsing.py +++ /dev/null @@ -1,160 +0,0 @@ -import pytest -from databricks.sqlalchemy._parse import ( - extract_identifiers_from_string, - extract_identifier_groups_from_string, - extract_three_level_identifier_from_constraint_string, - build_fk_dict, - build_pk_dict, - match_dte_rows_by_value, - get_comment_from_dte_output, - DatabricksSqlAlchemyParseException, -) - - -# These are outputs from DESCRIBE TABLE EXTENDED -@pytest.mark.parametrize( - "input, expected", - [ - ("PRIMARY KEY (`pk1`, `pk2`)", ["pk1", "pk2"]), - ("PRIMARY KEY (`a`, `b`, `c`)", ["a", "b", "c"]), - ("PRIMARY KEY (`name`, `id`, `attr`)", ["name", "id", "attr"]), - ], -) -def test_extract_identifiers(input, expected): - assert ( - extract_identifiers_from_string(input) == expected - ), "Failed to extract identifiers from string" - - -@pytest.mark.parametrize( - "input, expected", - [ - ( - "FOREIGN KEY (`pname`, `pid`, `pattr`) REFERENCES `main`.`pysql_sqlalchemy`.`tb1` (`name`, `id`, `attr`)", - [ - "(`pname`, `pid`, `pattr`)", - "(`name`, `id`, `attr`)", - ], - ) - ], -) -def test_extract_identifer_batches(input, expected): - assert ( - extract_identifier_groups_from_string(input) == expected - ), "Failed to extract identifier groups from string" - - -def test_extract_3l_namespace_from_constraint_string(): - input = "FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`pysql_dialect_compliance`.`users` (`user_id`)" - expected = { - "catalog": "main", - "schema": "pysql_dialect_compliance", - "table": "users", - } - - assert ( - extract_three_level_identifier_from_constraint_string(input) == expected - ), "Failed to extract 3L namespace from constraint string" - - -def test_extract_3l_namespace_from_bad_constraint_string(): - input = "FOREIGN KEY (`parent_user_id`) REFERENCES `pysql_dialect_compliance`.`users` (`user_id`)" - - with pytest.raises(DatabricksSqlAlchemyParseException): - extract_three_level_identifier_from_constraint_string(input) - - -@pytest.mark.parametrize("tschema", [None, "some_schema"]) -def test_build_fk_dict(tschema): - fk_constraint_string = "FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`some_schema`.`users` (`user_id`)" - - result = build_fk_dict("some_fk_name", fk_constraint_string, schema_name=tschema) - - assert result == { - "name": "some_fk_name", - "constrained_columns": ["parent_user_id"], - "referred_schema": tschema, - "referred_table": "users", - "referred_columns": ["user_id"], - } - - -def test_build_pk_dict(): - pk_constraint_string = "PRIMARY KEY (`id`, `name`, `email_address`)" - pk_name = "pk1" - - result = build_pk_dict(pk_name, pk_constraint_string) - - assert result == { - "constrained_columns": ["id", "name", "email_address"], - "name": "pk1", - } - - -# This is a real example of the output from DESCRIBE TABLE EXTENDED as of 15 October 2023 -RAW_SAMPLE_DTE_OUTPUT = [ - ["id", "int"], - ["name", "string"], - ["", ""], - ["# Detailed Table Information", ""], - ["Catalog", "main"], - ["Database", "pysql_sqlalchemy"], - ["Table", "exampleexampleexample"], - ["Created Time", "Sun Oct 15 21:12:54 UTC 2023"], - ["Last Access", "UNKNOWN"], - ["Created By", "Spark "], - ["Type", "MANAGED"], - ["Location", "s3://us-west-2-****-/19a85dee-****/tables/ccb7***"], - ["Provider", "delta"], - ["Comment", "some comment"], - ["Owner", "some.user@example.com"], - ["Is_managed_location", "true"], - ["Predictive Optimization", "ENABLE (inherited from CATALOG main)"], - [ - "Table Properties", - "[delta.checkpoint.writeStatsAsJson=false,delta.checkpoint.writeStatsAsStruct=true,delta.minReaderVersion=1,delta.minWriterVersion=2]", - ], - ["", ""], - ["# Constraints", ""], - ["exampleexampleexample_pk", "PRIMARY KEY (`id`)"], - [ - "exampleexampleexample_fk", - "FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`pysql_dialect_compliance`.`users` (`user_id`)", - ], -] - -FMT_SAMPLE_DT_OUTPUT = [ - {"col_name": i[0], "data_type": i[1]} for i in RAW_SAMPLE_DTE_OUTPUT -] - - -@pytest.mark.parametrize( - "match, output", - [ - ( - "PRIMARY KEY", - [ - { - "col_name": "exampleexampleexample_pk", - "data_type": "PRIMARY KEY (`id`)", - } - ], - ), - ( - "FOREIGN KEY", - [ - { - "col_name": "exampleexampleexample_fk", - "data_type": "FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`pysql_dialect_compliance`.`users` (`user_id`)", - } - ], - ), - ], -) -def test_filter_dict_by_value(match, output): - result = match_dte_rows_by_value(FMT_SAMPLE_DT_OUTPUT, match) - assert result == output - - -def test_get_comment_from_dte_output(): - assert get_comment_from_dte_output(FMT_SAMPLE_DT_OUTPUT) == "some comment" diff --git a/src/databricks/sqlalchemy/test_local/test_types.py b/src/databricks/sqlalchemy/test_local/test_types.py deleted file mode 100644 index b91217ed..00000000 --- a/src/databricks/sqlalchemy/test_local/test_types.py +++ /dev/null @@ -1,161 +0,0 @@ -import enum - -import pytest -import sqlalchemy - -from databricks.sqlalchemy.base import DatabricksDialect -from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ - - -class DatabricksDataType(enum.Enum): - """https://docs.databricks.com/en/sql/language-manual/sql-ref-datatypes.html""" - - BIGINT = enum.auto() - BINARY = enum.auto() - BOOLEAN = enum.auto() - DATE = enum.auto() - DECIMAL = enum.auto() - DOUBLE = enum.auto() - FLOAT = enum.auto() - INT = enum.auto() - INTERVAL = enum.auto() - VOID = enum.auto() - SMALLINT = enum.auto() - STRING = enum.auto() - TIMESTAMP = enum.auto() - TIMESTAMP_NTZ = enum.auto() - TINYINT = enum.auto() - ARRAY = enum.auto() - MAP = enum.auto() - STRUCT = enum.auto() - - -# Defines the way that SQLAlchemy CamelCase types are compiled into Databricks SQL types. -# Note: I wish I could define this within the TestCamelCaseTypesCompilation class, but pytest doesn't like that. -camel_case_type_map = { - sqlalchemy.types.BigInteger: DatabricksDataType.BIGINT, - sqlalchemy.types.LargeBinary: DatabricksDataType.BINARY, - sqlalchemy.types.Boolean: DatabricksDataType.BOOLEAN, - sqlalchemy.types.Date: DatabricksDataType.DATE, - sqlalchemy.types.DateTime: DatabricksDataType.TIMESTAMP_NTZ, - sqlalchemy.types.Double: DatabricksDataType.DOUBLE, - sqlalchemy.types.Enum: DatabricksDataType.STRING, - sqlalchemy.types.Float: DatabricksDataType.FLOAT, - sqlalchemy.types.Integer: DatabricksDataType.INT, - sqlalchemy.types.Interval: DatabricksDataType.TIMESTAMP_NTZ, - sqlalchemy.types.Numeric: DatabricksDataType.DECIMAL, - sqlalchemy.types.PickleType: DatabricksDataType.BINARY, - sqlalchemy.types.SmallInteger: DatabricksDataType.SMALLINT, - sqlalchemy.types.String: DatabricksDataType.STRING, - sqlalchemy.types.Text: DatabricksDataType.STRING, - sqlalchemy.types.Time: DatabricksDataType.STRING, - sqlalchemy.types.Unicode: DatabricksDataType.STRING, - sqlalchemy.types.UnicodeText: DatabricksDataType.STRING, - sqlalchemy.types.Uuid: DatabricksDataType.STRING, -} - - -def dict_as_tuple_list(d: dict): - """Return a list of [(key, value), ...] from a dictionary.""" - return [(key, value) for key, value in d.items()] - - -class CompilationTestBase: - dialect = DatabricksDialect() - - def _assert_compiled_value( - self, type_: sqlalchemy.types.TypeEngine, expected: DatabricksDataType - ): - """Assert that when type_ is compiled for the databricks dialect, it renders the DatabricksDataType name. - - This method initialises the type_ with no arguments. - """ - compiled_result = type_().compile(dialect=self.dialect) # type: ignore - assert compiled_result == expected.name - - def _assert_compiled_value_explicit( - self, type_: sqlalchemy.types.TypeEngine, expected: str - ): - """Assert that when type_ is compiled for the databricks dialect, it renders the expected string. - - This method expects an initialised type_ so that we can test how a TypeEngine created with arguments - is compiled. - """ - compiled_result = type_.compile(dialect=self.dialect) - assert compiled_result == expected - - -class TestCamelCaseTypesCompilation(CompilationTestBase): - """Per the sqlalchemy documentation[^1] here, the camel case members of sqlalchemy.types are - are expected to work across all dialects. These tests verify that the types compile into valid - Databricks SQL type strings. For example, the sqlalchemy.types.Integer() should compile as "INT". - - Truly custom types like STRUCT (notice the uppercase) are not expected to work across all dialects. - We test these separately. - - Note that these tests have to do with type **name** compiliation. Which is separate from actually - mapping values between Python and Databricks. - - Note: SchemaType and MatchType are not tested because it's not used in table definitions - - [1]: https://docs.sqlalchemy.org/en/20/core/type_basics.html#generic-camelcase-types - """ - - @pytest.mark.parametrize("type_, expected", dict_as_tuple_list(camel_case_type_map)) - def test_bare_camel_case_types_compile(self, type_, expected): - self._assert_compiled_value(type_, expected) - - def test_numeric_renders_as_decimal_with_precision(self): - self._assert_compiled_value_explicit( - sqlalchemy.types.Numeric(10), "DECIMAL(10)" - ) - - def test_numeric_renders_as_decimal_with_precision_and_scale(self): - self._assert_compiled_value_explicit( - sqlalchemy.types.Numeric(10, 2), "DECIMAL(10, 2)" - ) - - -uppercase_type_map = { - sqlalchemy.types.ARRAY: DatabricksDataType.ARRAY, - sqlalchemy.types.BIGINT: DatabricksDataType.BIGINT, - sqlalchemy.types.BINARY: DatabricksDataType.BINARY, - sqlalchemy.types.BOOLEAN: DatabricksDataType.BOOLEAN, - sqlalchemy.types.DATE: DatabricksDataType.DATE, - sqlalchemy.types.DECIMAL: DatabricksDataType.DECIMAL, - sqlalchemy.types.DOUBLE: DatabricksDataType.DOUBLE, - sqlalchemy.types.FLOAT: DatabricksDataType.FLOAT, - sqlalchemy.types.INT: DatabricksDataType.INT, - sqlalchemy.types.SMALLINT: DatabricksDataType.SMALLINT, - sqlalchemy.types.TIMESTAMP: DatabricksDataType.TIMESTAMP, - TINYINT: DatabricksDataType.TINYINT, - TIMESTAMP: DatabricksDataType.TIMESTAMP, - TIMESTAMP_NTZ: DatabricksDataType.TIMESTAMP_NTZ, -} - - -class TestUppercaseTypesCompilation(CompilationTestBase): - """Per the sqlalchemy documentation[^1], uppercase types are considered to be specific to some - database backends. These tests verify that the types compile into valid Databricks SQL type strings. - - [1]: https://docs.sqlalchemy.org/en/20/core/type_basics.html#backend-specific-uppercase-datatypes - """ - - @pytest.mark.parametrize("type_, expected", dict_as_tuple_list(uppercase_type_map)) - def test_bare_uppercase_types_compile(self, type_, expected): - if isinstance(type_, type(sqlalchemy.types.ARRAY)): - # ARRAY cannot be initialised without passing an item definition so we test separately - # I preserve it in the uppercase_type_map for clarity - assert True - else: - self._assert_compiled_value(type_, expected) - - def test_array_string_renders_as_array_of_string(self): - """SQLAlchemy's ARRAY type requires an item definition. And their docs indicate that they've only tested - it with Postgres since that's the only first-class dialect with support for ARRAY. - - https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.ARRAY - """ - self._assert_compiled_value_explicit( - sqlalchemy.types.ARRAY(sqlalchemy.types.String), "ARRAY" - )