Skip to content

Optimize arrow table performance #636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 12 additions & 43 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import requests
import json
import os
import decimal
from uuid import UUID

from databricks.sql import __version__
Expand Down Expand Up @@ -1389,7 +1388,7 @@ def _fill_results_buffer(self):
self.results = results
self.has_more_rows = has_more_rows

def _convert_columnar_table(self, table):
def _convert_columnar_table(self, table: ColumnTable):
column_names = [c[0] for c in self.description]
ResultRow = Row(*column_names)
result = []
Expand All @@ -1401,14 +1400,14 @@ def _convert_columnar_table(self, table):

return result

def _convert_arrow_table(self, table):
def _convert_arrow_table(self, table: "pyarrow.Table"):

column_names = [c[0] for c in self.description]
ResultRow = Row(*column_names)

if self.connection.disable_pandas is True:
return [
ResultRow(*[v.as_py() for v in r]) for r in zip(*table.itercolumns())
]
columns_as_lists = [col.to_pylist() for col in table.itercolumns()]
return [ResultRow(*row) for row in zip(*columns_as_lists)]

# Need to use nullable types, as otherwise type can change when there are missing values.
# See https://arrow.apache.org/docs/python/pandas.html#nullable-types
Expand All @@ -1434,6 +1433,7 @@ def _convert_arrow_table(self, table):
types_mapper=dtype_mapping.get,
date_as_object=True,
timestamp_as_object=True,
self_destruct=True,
)

res = df.to_numpy(na_value=None, dtype="object")
Expand All @@ -1454,36 +1454,18 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
results = self.results.next_n_rows(size)
n_remaining_rows = size - results.num_rows
self._next_row_index += results.num_rows

while (
n_remaining_rows > 0
and not self.has_been_closed_server_side
and self.has_more_rows
):
self._fill_results_buffer()
partial_results = self.results.next_n_rows(n_remaining_rows)
results = pyarrow.concat_tables([results, partial_results])
results.append(partial_results)
n_remaining_rows -= partial_results.num_rows
self._next_row_index += partial_results.num_rows

return results

def merge_columnar(self, result1, result2):
"""
Function to merge / combining the columnar results into a single result
:param result1:
:param result2:
:return:
"""

if result1.column_names != result2.column_names:
raise ValueError("The columns in the results don't match")

merged_result = [
result1.column_table[i] + result2.column_table[i]
for i in range(result1.num_columns)
]
return ColumnTable(merged_result, result1.column_names)
return results.to_arrow_table()

def fetchmany_columnar(self, size: int):
"""
Expand All @@ -1504,7 +1486,7 @@ def fetchmany_columnar(self, size: int):
):
self._fill_results_buffer()
partial_results = self.results.next_n_rows(n_remaining_rows)
results = self.merge_columnar(results, partial_results)
results.append(partial_results)
n_remaining_rows -= partial_results.num_rows
self._next_row_index += partial_results.num_rows

Expand All @@ -1518,23 +1500,10 @@ def fetchall_arrow(self) -> "pyarrow.Table":
while not self.has_been_closed_server_side and self.has_more_rows:
self._fill_results_buffer()
partial_results = self.results.remaining_rows()
if isinstance(results, ColumnTable) and isinstance(
partial_results, ColumnTable
):
results = self.merge_columnar(results, partial_results)
else:
results = pyarrow.concat_tables([results, partial_results])
results.append(partial_results)
self._next_row_index += partial_results.num_rows

# If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table
# Valid only for metadata commands result set
if isinstance(results, ColumnTable) and pyarrow:
data = {
name: col
for name, col in zip(results.column_names, results.column_table)
}
return pyarrow.Table.from_pydict(data)
return results
return results.to_arrow_table()

def fetchall_columnar(self):
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
Expand All @@ -1544,7 +1513,7 @@ def fetchall_columnar(self):
while not self.has_been_closed_server_side and self.has_more_rows:
self._fill_results_buffer()
partial_results = self.results.remaining_rows()
results = self.merge_columnar(results, partial_results)
results.append(partial_results)
self._next_row_index += partial_results.num_rows

return results
Expand Down
1 change: 0 additions & 1 deletion src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def get_next_downloaded_file(
next_row_offset, file.start_row_offset, file.row_count
)
)

return file

def _schedule_downloads(self):
Expand Down
26 changes: 10 additions & 16 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
from databricks.sql.exc import Error
from databricks.sql.types import SSLOptions
from databricks.sql.common.http import DatabricksHttpClient, HttpMethod

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
self.settings = settings
self.link = link
self._ssl_options = ssl_options
self._http_client = DatabricksHttpClient.get_instance()

def run(self) -> DownloadedFile:
"""
Expand All @@ -90,19 +92,14 @@ def run(self) -> DownloadedFile:
self.link, self.settings.link_expiry_buffer_secs
)

session = requests.Session()
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
session.mount("https://", HTTPAdapter(max_retries=retryPolicy))

try:
# Get the file via HTTP request
response = session.get(
self.link.fileLink,
timeout=self.settings.download_timeout,
verify=self._ssl_options.tls_verify,
headers=self.link.httpHeaders
# TODO: Pass cert from `self._ssl_options`
)
with self._http_client.execute(
method=HttpMethod.GET,
url=self.link.fileLink,
timeout=self.settings.download_timeout,
verify=self._ssl_options.tls_verify,
headers=self.link.httpHeaders
# TODO: Pass cert from `self._ssl_options`
) as response:
response.raise_for_status()

# Save (and decompress if needed) the downloaded file
Expand Down Expand Up @@ -132,9 +129,6 @@ def run(self) -> DownloadedFile:
self.link.startRowOffset,
self.link.rowCount,
)
finally:
if session:
session.close()

@staticmethod
def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):
Expand Down
1 change: 1 addition & 0 deletions src/databricks/sql/common/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from contextlib import contextmanager
from typing import Generator
import logging
import time

logger = logging.getLogger(__name__)

Expand Down
20 changes: 0 additions & 20 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@
RequestErrorInfo,
NoRetryReason,
ResultSetQueueFactory,
convert_arrow_based_set_to_arrow_table,
convert_decimals_in_arrow_table,
convert_column_based_set_to_arrow_table,
)
from databricks.sql.types import SSLOptions

Expand Down Expand Up @@ -633,23 +630,6 @@ def _poll_for_status(self, op_handle):
)
return self.make_request(self._client.GetOperationStatus, req)

def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, description):
if t_row_set.columns is not None:
(
arrow_table,
num_rows,
) = convert_column_based_set_to_arrow_table(t_row_set.columns, description)
elif t_row_set.arrowBatches is not None:
(arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table(
t_row_set.arrowBatches, lz4_compressed, schema_bytes
)
else:
raise OperationalError(
"Unsupported TRowSet instance {}".format(t_row_set),
session_id_hex=self._session_id_hex,
)
return convert_decimals_in_arrow_table(arrow_table, description), num_rows

def _get_metadata_resp(self, op_handle):
req = ttypes.TGetResultSetMetadataReq(operationHandle=op_handle)
return self.make_request(self._client.GetResultSetMetadata, req)
Expand Down
Loading
Loading