Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PECO-1803] Splitting the PySql connector into the core and the non core part #417

Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
335fc0c
Implemented ColumnQueue to test the fetchall without pyarrow
jprakash-db Jul 24, 2024
ad2b014
order of fields in row corrected
shivam2680 Jul 24, 2024
882e080
Changed the folder structure and tested the basic setup to work
jprakash-db Aug 6, 2024
98c4cd6
Refractored the code to make connector to work
jprakash-db Aug 6, 2024
25a006d
Basic Setup of connector, core and sqlalchemy is working
jprakash-db Aug 7, 2024
1cfaae2
Basic integration of core, connect and sqlalchemy is working
jprakash-db Aug 7, 2024
c576110
Setup working dynamic change from ColumnQueue to ArrowQueue
jprakash-db Aug 7, 2024
0ddca9d
Refractored the test code and moved to respective folders
jprakash-db Aug 8, 2024
786dc1e
Added the unit test for column_queue
jprakash-db Aug 14, 2024
c91d43d
venv_main added to git ignore
jprakash-db Aug 20, 2024
b67b739
Added code for merging columnar table
jprakash-db Aug 21, 2024
32e3dcd
Merging code for columnar
jprakash-db Aug 21, 2024
d062887
Fixed the retry_close sesssion test issue with logging
jprakash-db Aug 22, 2024
9908976
Fixed the databricks_sqlalchemy tests and introduced pytest.ini for t…
jprakash-db Aug 22, 2024
af0cd4f
Added pyarrow_test mark on pytest
jprakash-db Aug 25, 2024
a07c232
Fixed databricks.sqlalchemy to databricks_sqlalchemy imports
jprakash-db Aug 25, 2024
f0c8e7a
Added poetry.lock
jprakash-db Aug 26, 2024
46ae0f7
Added dist folder
jprakash-db Aug 26, 2024
8c5e7dd
Changed the pyproject.toml
jprakash-db Aug 26, 2024
24730dd
Minor Fix
jprakash-db Aug 27, 2024
61de281
Added the pyarrow skip tag on unit tests and tested their working
jprakash-db Aug 30, 2024
bad7c0b
Fixed the Decimal and timestamp conversion issue in non arrow pipeline
jprakash-db Sep 1, 2024
520d2c8
Removed not required files and reformatted
jprakash-db Sep 2, 2024
051bdce
Fixed test_retry error
jprakash-db Sep 2, 2024
4d64034
Changed the folder structure to src / databricks
jprakash-db Sep 13, 2024
93848d7
Removed the columnar non arrow flow to another PR
jprakash-db Sep 13, 2024
f37f42c
Moved the README to the root
jprakash-db Sep 15, 2024
0e4b599
removed columnQueue instance
jprakash-db Sep 15, 2024
8199832
Revmoved databricks_sqlalchemy dependency in core
jprakash-db Sep 15, 2024
613e7dc
Changed the pysql_supports_arrow predicate, introduced changes in the…
jprakash-db Sep 17, 2024
c4a2e08
Ran the black formatter with the original version
jprakash-db Sep 17, 2024
6f10ec6
Extra .py removed from all the __init__.py files names
jprakash-db Sep 17, 2024
2115db1
Undo formatting check
jprakash-db Sep 18, 2024
36c0f95
Check
jprakash-db Sep 18, 2024
15d5047
Check
jprakash-db Sep 18, 2024
8415230
Check
jprakash-db Sep 18, 2024
15df683
Check
jprakash-db Sep 18, 2024
b4c3029
Check
jprakash-db Sep 18, 2024
bc758d8
Check
jprakash-db Sep 18, 2024
31da868
Check
jprakash-db Sep 18, 2024
f79fc69
Check
jprakash-db Sep 18, 2024
d121fed
Check
jprakash-db Sep 18, 2024
6e393b0
Check
jprakash-db Sep 18, 2024
2383caf
Check
jprakash-db Sep 18, 2024
dd4c487
Check
jprakash-db Sep 18, 2024
6be085c
Check
jprakash-db Sep 18, 2024
e7cf5c3
Check
jprakash-db Sep 18, 2024
b6a5668
BIG UPDATE
jprakash-db Sep 18, 2024
4496a04
Refeactor code
jprakash-db Sep 19, 2024
726b8ed
Refractor
jprakash-db Sep 19, 2024
66bfa6d
Fixed versioning
jprakash-db Sep 19, 2024
2c1cfbd
Minor refractoring
jprakash-db Sep 19, 2024
ae20a65
Minor refractoring
jprakash-db Sep 19, 2024
f056c80
Merged databricks:PECO-1803/connector-split into jprakash-db/PECO-1803
jprakash-db Sep 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions databricks_sql_connector/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[tool.poetry]
name = "databricks-sql-connector"
version = "3.5.0"
description = "Databricks SQL Connector for Python"
authors = ["Databricks <databricks-sql-connector-maintainers@databricks.com>"]
license = "Apache-2.0"


[tool.poetry.dependencies]
databricks_sql_connector_core = { version = ">=1.0.0", extras=["all"]}
databricks_sqlalchemy = { version = ">=1.0.0", optional = true }

[tool.poetry.extras]
databricks_sqlalchemy = ["databricks_sqlalchemy"]

[tool.poetry.urls]
"Homepage" = "https://github.com/databricks/databricks-sql-python"
"Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 6 additions & 18 deletions pyproject.toml → databricks_sql_connector_core/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,36 +1,26 @@
[tool.poetry]
name = "databricks-sql-connector"
version = "3.3.0"
description = "Databricks SQL Connector for Python"
name = "databricks-sql-connector-core"
version = "1.0.0"
description = "Databricks SQL Connector core for Python"
authors = ["Databricks <databricks-sql-connector-maintainers@databricks.com>"]
license = "Apache-2.0"
readme = "README.md"
packages = [{ include = "databricks", from = "src" }]
include = ["CHANGELOG.md"]

[tool.poetry.dependencies]
python = "^3.8.0"
thrift = ">=0.16.0,<0.21.0"
pandas = [
{ version = ">=1.2.5,<2.2.0", python = ">=3.8" }
]
pyarrow = ">=14.0.1,<17"

lz4 = "^4.0.2"
requests = "^2.18.1"
oauthlib = "^3.1.0"
numpy = [
{ version = "^1.16.6", python = ">=3.8,<3.11" },
{ version = "^1.23.4", python = ">=3.11" },
]
sqlalchemy = { version = ">=2.0.21", optional = true }
openpyxl = "^3.0.10"
alembic = { version = "^1.0.11", optional = true }
urllib3 = ">=1.26"
pyarrow = {version = ">=14.0.1,<17", optional = true}

[tool.poetry.extras]
sqlalchemy = ["sqlalchemy"]
alembic = ["sqlalchemy", "alembic"]
pyarrow = ["pyarrow"]

[tool.poetry.dev-dependencies]
pytest = "^7.1.2"
Expand All @@ -43,8 +33,6 @@ pytest-dotenv = "^0.5.2"
"Homepage" = "https://github.com/databricks/databricks-sql-python"
"Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues"

[tool.poetry.plugins."sqlalchemy.dialects"]
"databricks" = "databricks.sqlalchemy:DatabricksDialect"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand All @@ -62,5 +50,5 @@ markers = {"reviewed" = "Test case has been reviewed by Databricks"}
minversion = "6.0"
log_cli = "false"
log_cli_level = "INFO"
testpaths = ["tests", "src/databricks/sqlalchemy/test_local"]
testpaths = ["tests", "databricks_sql_connector_core/tests"]
env_files = ["test.env"]
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence

import pandas
import pyarrow
import requests
import json
import os
Expand Down Expand Up @@ -43,6 +42,10 @@
TSparkParameter,
)

try:
import pyarrow
except ImportError:
pyarrow = None

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -977,14 +980,14 @@ def fetchmany(self, size: int) -> List[Row]:
else:
raise Error("There is no active result set")

def fetchall_arrow(self) -> pyarrow.Table:
def fetchall_arrow(self) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchall_arrow()
else:
raise Error("There is no active result set")

def fetchmany_arrow(self, size) -> pyarrow.Table:
def fetchmany_arrow(self, size) -> "pyarrow.Table":
self._check_not_closed()
if self.active_result_set:
return self.active_result_set.fetchmany_arrow(size)
Expand Down Expand Up @@ -1171,7 +1174,7 @@ def _convert_arrow_table(self, table):
def rownumber(self):
return self._next_row_index

def fetchmany_arrow(self, size: int) -> pyarrow.Table:
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
"""
Fetch the next set of rows of a query result, returning a PyArrow table.

Expand All @@ -1196,7 +1199,7 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table:

return results

def fetchall_arrow(self) -> pyarrow.Table:
def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
results = self.results.remaining_rows()
self._next_row_index += results.num_rows
Expand Down
kravets-levko marked this conversation as resolved.
Show resolved Hide resolved
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
from typing import List, Union

import pyarrow
import thrift.transport.THttpClient
import thrift.protocol.TBinaryProtocol
import thrift.transport.TSocket
Expand Down Expand Up @@ -37,6 +36,11 @@
convert_column_based_set_to_arrow_table,
)

try:
import pyarrow
except ImportError:
pyarrow = None

logger = logging.getLogger(__name__)

unsafe_logger = logging.getLogger("databricks.sql.unsafe")
Expand Down Expand Up @@ -652,6 +656,12 @@ def _get_metadata_resp(self, op_handle):

@staticmethod
def _hive_schema_to_arrow_schema(t_table_schema):

if pyarrow is None:
raise ImportError(
"pyarrow is required to convert Hive schema to Arrow schema"
)

def map_type(t_type_entry):
if t_type_entry.primitiveEntry:
return {
Expand Down Expand Up @@ -858,7 +868,7 @@ def execute_command(
getDirectResults=ttypes.TSparkGetDirectResults(
maxRows=max_rows, maxBytes=max_bytes
),
canReadArrowResult=True,
canReadArrowResult=True if pyarrow else False,
canDecompressLZ4Result=lz4_compression,
canDownloadResult=use_cloud_fetch,
confOverlay={
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ssl import SSLContext

import lz4.frame
import pyarrow

from databricks.sql import OperationalError, exc
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
Expand All @@ -28,16 +27,21 @@

import logging

try:
import pyarrow
except ImportError:
pyarrow = None

logger = logging.getLogger(__name__)


class ResultSetQueue(ABC):
@abstractmethod
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
def next_n_rows(self, num_rows: int):
pass

@abstractmethod
def remaining_rows(self) -> pyarrow.Table:
def remaining_rows(self):
pass


Expand Down Expand Up @@ -100,7 +104,7 @@ def build_queue(
class ArrowQueue(ResultSetQueue):
def __init__(
self,
arrow_table: pyarrow.Table,
arrow_table: "pyarrow.Table",
n_valid_rows: int,
start_row_index: int = 0,
):
Expand All @@ -115,7 +119,7 @@ def __init__(
self.arrow_table = arrow_table
self.n_valid_rows = n_valid_rows

def next_n_rows(self, num_rows: int) -> pyarrow.Table:
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
"""Get upto the next n rows of the Arrow dataframe"""
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
# Note that the table.slice API is not the same as Python's slice
Expand All @@ -124,7 +128,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
self.cur_row_index += slice.num_rows
return slice

def remaining_rows(self) -> pyarrow.Table:
def remaining_rows(self) -> "pyarrow.Table":
slice = self.arrow_table.slice(
self.cur_row_index, self.n_valid_rows - self.cur_row_index
)
Expand Down Expand Up @@ -184,7 +188,7 @@ def __init__(
self.table = self._create_next_table()
self.table_row_index = 0

def next_n_rows(self, num_rows: int) -> pyarrow.Table:
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
"""
Get up to the next n rows of the cloud fetch Arrow dataframes.

Expand Down Expand Up @@ -216,7 +220,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
return results

def remaining_rows(self) -> pyarrow.Table:
def remaining_rows(self) -> "pyarrow.Table":
"""
Get all remaining rows of the cloud fetch Arrow dataframes.

Expand All @@ -237,7 +241,7 @@ def remaining_rows(self) -> pyarrow.Table:
self.table_row_index = 0
return results

def _create_next_table(self) -> Union[pyarrow.Table, None]:
def _create_next_table(self) -> Union["pyarrow.Table", None]:
logger.debug(
"CloudFetchQueue: Trying to get downloaded file for row {}".format(
self.start_row_index
Expand Down Expand Up @@ -276,7 +280,7 @@ def _create_next_table(self) -> Union[pyarrow.Table, None]:

return arrow_table

def _create_empty_table(self) -> pyarrow.Table:
def _create_empty_table(self) -> "pyarrow.Table":
# Create a 0-row table with just the schema bytes
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)

Expand Down Expand Up @@ -515,7 +519,7 @@ def transform_paramstyle(
return output


def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> pyarrow.Table:
def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> "pyarrow.Table":
arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes)
return convert_decimals_in_arrow_table(arrow_table, description)

Expand All @@ -542,7 +546,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
return arrow_table, n_rows


def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table:
def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
for i, col in enumerate(table.itercolumns()):
if description[i][1] == "decimal":
decimal_col = col.to_pandas().apply(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
try:
from databricks_sqlalchemy import *
except:
import warnings

warnings.warn("Install databricks-sqlalchemy plugin before using this")
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from decimal import Decimal

import pyarrow
import pytest

try:
import pyarrow
except ImportError:
pyarrow = None

class DecimalTestsMixin:
decimal_and_expected_results = [
from tests.e2e.common.predicates import pysql_supports_arrow

def decimal_and_expected_results():

if pyarrow is None:
return []

jprakash-db marked this conversation as resolved.
Show resolved Hide resolved
return [
("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)),
("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)),
("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)),
Expand All @@ -17,7 +26,12 @@ class DecimalTestsMixin:
("1e-3 AS DECIMAL(38, 3)", Decimal("0.001"), pyarrow.decimal128(38, 3)),
]

multi_decimals_and_expected_results = [
def multi_decimals_and_expected_results():

if pyarrow is None:
return []

return [
(
["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"],
[Decimal("1.00"), Decimal("100.001"), None],
Expand All @@ -30,7 +44,9 @@ class DecimalTestsMixin:
),
]

@pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results)
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
class DecimalTestsMixin:
@pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results())
def test_decimals(self, decimal, expected_value, expected_type):
with self.cursor({}) as cursor:
query = "SELECT CAST ({})".format(decimal)
Expand All @@ -39,9 +55,7 @@ def test_decimals(self, decimal, expected_value, expected_type):
assert table.field(0).type == expected_type
assert table.to_pydict().popitem()[1][0] == expected_value

@pytest.mark.parametrize(
"decimals, expected_values, expected_type", multi_decimals_and_expected_results
)
@pytest.mark.parametrize("decimals, expected_values, expected_type", multi_decimals_and_expected_results())
def test_multi_decimals(self, decimals, expected_values, expected_type):
with self.cursor({}) as cursor:
union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import logging
import math
import time
from unittest import skipUnless

import pytest
from tests.e2e.common.predicates import pysql_supports_arrow

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,6 +44,7 @@ def fetch_rows(self, cursor, row_count, fetchmany_size):
+ "assuming 10K fetch size."
)

@pytest.mark.skipif(not pysql_supports_arrow(), "Without pyarrow lz4 compression is not supported")
def test_query_with_large_wide_result_set(self):
resultSize = 300 * 1000 * 1000 # 300 MB
width = 8192 # B
Expand Down
Loading