Skip to content

[SC-110400] Enabling compression in Python SQL Connector #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 207 additions & 23 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ python = "^3.7.1"
thrift = "^0.13.0"
pandas = "^1.3.0"
pyarrow = "^9.0.0"
lz4 = "^4.0.2"
requests=">2.18.1"
oauthlib=">=3.1.0"

Expand Down
4 changes: 4 additions & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def read(self) -> Optional[OAuthToken]:
self.host = server_hostname
self.port = kwargs.get("_port", 443)
self.disable_pandas = kwargs.get("_disable_pandas", False)
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)

auth_provider = get_python_sql_connector_auth_provider(
server_hostname, **kwargs
Expand Down Expand Up @@ -318,6 +319,7 @@ def execute(
session_handle=self.connection._session_handle,
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
cursor=self,
)
self.active_result_set = ResultSet(
Expand Down Expand Up @@ -614,6 +616,7 @@ def __init__(
self.has_been_closed_server_side = execute_response.has_been_closed_server_side
self.has_more_rows = execute_response.has_more_rows
self.buffer_size_bytes = result_buffer_size_bytes
self.lz4_compressed = execute_response.lz4_compressed
self.arraysize = arraysize
self.thrift_backend = thrift_backend
self.description = execute_response.description
Expand Down Expand Up @@ -642,6 +645,7 @@ def _fill_results_buffer(self):
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
expected_row_start_offset=self._next_row_index,
lz4_compressed=self.lz4_compressed,
arrow_schema_bytes=self._arrow_schema_bytes,
description=self.description,
)
Expand Down
41 changes: 28 additions & 13 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import time
import threading
import lz4.frame
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context

import pyarrow
Expand Down Expand Up @@ -451,7 +452,7 @@ def open_session(self, session_configuration, catalog, schema):
initial_namespace = None

open_session_req = ttypes.TOpenSessionReq(
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5,
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6,
client_protocol=None,
initialNamespace=initial_namespace,
canUseMultipleCatalogs=True,
Expand Down Expand Up @@ -507,7 +508,7 @@ def _poll_for_status(self, op_handle):
)
return self.make_request(self._client.GetOperationStatus, req)

def _create_arrow_table(self, t_row_set, schema_bytes, description):
def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, description):
if t_row_set.columns is not None:
(
arrow_table,
Expand All @@ -520,7 +521,7 @@ def _create_arrow_table(self, t_row_set, schema_bytes, description):
arrow_table,
num_rows,
) = ThriftBackend._convert_arrow_based_set_to_arrow_table(
t_row_set.arrowBatches, schema_bytes
t_row_set.arrowBatches, lz4_compressed, schema_bytes
)
else:
raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set))
Expand All @@ -545,13 +546,20 @@ def _convert_decimals_in_arrow_table(table, description):
return table

@staticmethod
def _convert_arrow_based_set_to_arrow_table(arrow_batches, schema_bytes):
def _convert_arrow_based_set_to_arrow_table(
arrow_batches, lz4_compressed, schema_bytes
):
ba = bytearray()
ba += schema_bytes
n_rows = 0
for arrow_batch in arrow_batches:
n_rows += arrow_batch.rowCount
ba += arrow_batch.batch
if lz4_compressed:
for arrow_batch in arrow_batches:
n_rows += arrow_batch.rowCount
ba += lz4.frame.decompress(arrow_batch.batch)
else:
for arrow_batch in arrow_batches:
n_rows += arrow_batch.rowCount
ba += arrow_batch.batch
arrow_table = pyarrow.ipc.open_stream(ba).read_all()
return arrow_table, n_rows

Expand Down Expand Up @@ -708,7 +716,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
]
)
)

direct_results = resp.directResults
has_been_closed_server_side = direct_results and direct_results.closeOperation
has_more_rows = (
Expand All @@ -725,12 +732,16 @@ def _results_message_to_execute_response(self, resp, operation_state):
.serialize()
.to_pybytes()
)

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
if direct_results and direct_results.resultSet:
assert direct_results.resultSet.results.startRowOffset == 0
assert direct_results.resultSetMetadata

arrow_results, n_rows = self._create_arrow_table(
direct_results.resultSet.results, schema_bytes, description
direct_results.resultSet.results,
lz4_compressed,
schema_bytes,
description,
)
arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0)
else:
Expand All @@ -740,6 +751,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
status=operation_state,
has_been_closed_server_side=has_been_closed_server_side,
has_more_rows=has_more_rows,
lz4_compressed=lz4_compressed,
command_handle=resp.operationHandle,
description=description,
arrow_schema_bytes=schema_bytes,
Expand Down Expand Up @@ -783,7 +795,9 @@ def _check_direct_results_for_error(t_spark_direct_results):
t_spark_direct_results.closeOperation
)

def execute_command(self, operation, session_handle, max_rows, max_bytes, cursor):
def execute_command(
self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor
):
assert session_handle is not None

spark_arrow_types = ttypes.TSparkArrowTypes(
Expand All @@ -802,7 +816,7 @@ def execute_command(self, operation, session_handle, max_rows, max_bytes, cursor
maxRows=max_rows, maxBytes=max_bytes
),
canReadArrowResult=True,
canDecompressLZ4Result=False,
canDecompressLZ4Result=lz4_compression,
canDownloadResult=False,
confOverlay={
# We want to receive proper Timestamp arrow types.
Expand Down Expand Up @@ -916,6 +930,7 @@ def fetch_results(
max_rows,
max_bytes,
expected_row_start_offset,
lz4_compressed,
arrow_schema_bytes,
description,
):
Expand All @@ -941,7 +956,7 @@ def fetch_results(
)
)
arrow_results, n_rows = self._create_arrow_table(
resp.results, arrow_schema_bytes, description
resp.results, lz4_compressed, arrow_schema_bytes, description
)
arrow_queue = ArrowQueue(arrow_results, n_rows)

Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def remaining_rows(self) -> pyarrow.Table:

ExecuteResponse = namedtuple(
"ExecuteResponse",
"status has_been_closed_server_side has_more_rows description "
"status has_been_closed_server_side has_more_rows description lz4_compressed "
"command_handle arrow_queue arrow_schema_bytes",
)

Expand Down
22 changes: 13 additions & 9 deletions tests/e2e/common/large_queries_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@ def test_query_with_large_wide_result_set(self):
# This is used by PyHive tests to determine the buffer size
self.arraysize = 1000
with self.cursor() as cursor:
uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)])
cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows))
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
self.assertEqual(row[0], row_id) # Verify no rows are dropped in the middle.
self.assertEqual(len(row[1]), 36)
for lz4_compression in [False, True]:
cursor.connection.lz4_compression=lz4_compression
uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)])
cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows))
self.assertEqual(lz4_compression, cursor.active_result_set.lz4_compressed)
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
self.assertEqual(row[0], row_id) # Verify no rows are dropped in the middle.
self.assertEqual(len(row[1]), 36)


def test_query_with_large_narrow_result_set(self):
resultSize = 300 * 1000 * 1000 # 300 MB
Expand Down Expand Up @@ -85,10 +89,10 @@ def test_long_running_query(self):
start = time.time()

cursor.execute("""SELECT count(*)
FROM RANGE({scale}) x
JOIN RANGE({scale0}) y
ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%"
""".format(scale=scale_factor * scale0, scale0=scale0))
FROM RANGE({scale}) x
JOIN RANGE({scale0}) y
ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%"
""".format(scale=scale_factor * scale0, scale0=scale0))

n, = cursor.fetchone()
self.assertEqual(n, 0)
Expand Down
14 changes: 14 additions & 0 deletions tests/e2e/driver_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,20 @@ def test_timezone_with_timestamp(self):
self.assertEqual(arrow_result_table.field(0).type, ts_type)
self.assertEqual(arrow_result_value, expected.timestamp() * 1000000)

@skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support')
def test_can_flip_compression(self):
with self.cursor() as cursor:
cursor.execute("SELECT array(1,2,3,4)")
cursor.fetchall()
lz4_compressed = cursor.active_result_set.lz4_compressed
#The endpoint should support compression
self.assertEqual(lz4_compressed, True)
cursor.connection.lz4_compression=False
cursor.execute("SELECT array(1,2,3,4)")
cursor.fetchall()
lz4_compressed = cursor.active_result_set.lz4_compressed
self.assertEqual(lz4_compressed, False)

def _should_have_native_complex_types(self):
return pysql_has_version(">=", 2) and is_thrift_v5_plus(self.arguments)

Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_fetches.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def make_dummy_result_set_from_initial_results(initial_results):
has_been_closed_server_side=True,
has_more_rows=False,
description=Mock(),
lz4_compressed=Mock(),
command_handle=None,
arrow_queue=arrow_queue,
arrow_schema_bytes=schema.serialize().to_pybytes()))
Expand All @@ -50,7 +51,7 @@ def make_dummy_result_set_from_initial_results(initial_results):
def make_dummy_result_set_from_batch_list(batch_list):
batch_index = 0

def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset,
def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset, lz4_compressed,
arrow_schema_bytes, description):
nonlocal batch_index
results = FetchTests.make_arrow_queue(batch_list[batch_index])
Expand All @@ -71,6 +72,7 @@ def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset,
has_more_rows=True,
description=[(f'col{col_id}', 'integer', None, None, None, None, None)
for col_id in range(num_cols)],
lz4_compressed=Mock(),
command_handle=None,
arrow_queue=None,
arrow_schema_bytes=None))
Expand Down
Loading