Skip to content

SEA: Cleanup #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: sea-migration
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,8 @@ def _results_message_to_execute_response(
command_id=CommandId.from_sea_statement_id(response.statement_id),
status=response.status.state,
description=description,
has_been_closed_server_side=False,
lz4_compressed=lz4_compressed,
is_staging_operation=response.manifest.is_volume_operation,
arrow_schema_bytes=None,
result_format=response.manifest.format,
)

Expand Down Expand Up @@ -624,7 +622,6 @@ def get_execution_result(
return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/backend/sea/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC
from typing import List, Optional, Tuple

from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.sea.client import SeaDatabricksClient
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
from databricks.sql.backend.sea.utils.constants import ResultFormat
from databricks.sql.exc import ProgrammingError
Expand Down
64 changes: 45 additions & 19 deletions src/databricks/sql/backend/sea/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging

from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.sea.client import SeaDatabricksClient
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter

Expand All @@ -15,10 +15,10 @@

if TYPE_CHECKING:
from databricks.sql.client import Connection
from databricks.sql.exc import ProgrammingError
from databricks.sql.exc import CursorAlreadyClosedError, ProgrammingError, RequestError
from databricks.sql.types import Row
from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory
from databricks.sql.backend.types import ExecuteResponse
from databricks.sql.backend.types import CommandState, ExecuteResponse
from databricks.sql.result_set import ResultSet

logger = logging.getLogger(__name__)
Expand All @@ -31,7 +31,6 @@ def __init__(
self,
connection: Connection,
execute_response: ExecuteResponse,
sea_client: SeaDatabricksClient,
result_data: ResultData,
manifest: ResultManifest,
buffer_size_bytes: int = 104857600,
Expand All @@ -43,7 +42,6 @@ def __init__(
Args:
connection: The parent connection
execute_response: Response from the execute command
sea_client: The SeaDatabricksClient instance for direct access
buffer_size_bytes: Buffer size for fetching results
arraysize: Default number of rows to fetch
result_data: Result data from SEA response
Expand All @@ -56,32 +54,36 @@ def __init__(
if statement_id is None:
raise ValueError("Command ID is not a SEA statement ID")

results_queue = SeaResultSetQueueFactory.build_queue(
result_data,
self.manifest,
statement_id,
description=execute_response.description,
max_download_threads=sea_client.max_download_threads,
sea_client=sea_client,
lz4_compressed=execute_response.lz4_compressed,
)

# Call parent constructor with common attributes
super().__init__(
connection=connection,
backend=sea_client,
arraysize=arraysize,
buffer_size_bytes=buffer_size_bytes,
command_id=execute_response.command_id,
status=execute_response.status,
has_been_closed_server_side=execute_response.has_been_closed_server_side,
results_queue=results_queue,
description=execute_response.description,
is_staging_operation=execute_response.is_staging_operation,
lz4_compressed=execute_response.lz4_compressed,
arrow_schema_bytes=execute_response.arrow_schema_bytes,
)

# Assert that the backend is of the correct type
assert isinstance(
self.backend, SeaDatabricksClient
), "Backend must be a SeaDatabricksClient"

results_queue = SeaResultSetQueueFactory.build_queue(
result_data,
self.manifest,
statement_id,
description=execute_response.description,
max_download_threads=self.backend.max_download_threads,
sea_client=self.backend,
lz4_compressed=execute_response.lz4_compressed,
)

# Set the results queue
self.results = results_queue

def _convert_json_types(self, row: List[str]) -> List[Any]:
"""
Convert string values in the row to appropriate Python types based on column metadata.
Expand Down Expand Up @@ -160,6 +162,9 @@ def fetchmany_json(self, size: int) -> List[List[str]]:
if size < 0:
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")

if self.results is None:
raise RuntimeError("Results queue is not initialized")

results = self.results.next_n_rows(size)
self._next_row_index += len(results)

Expand All @@ -173,6 +178,9 @@ def fetchall_json(self) -> List[List[str]]:
Columnar table containing all remaining rows
"""

if self.results is None:
raise RuntimeError("Results queue is not initialized")

results = self.results.remaining_rows()
self._next_row_index += len(results)

Expand Down Expand Up @@ -264,3 +272,21 @@ def fetchall(self) -> List[Row]:
return self._create_json_table(self.fetchall_json())
else:
raise NotImplementedError("fetchall only supported for JSON data")

def close(self) -> None:
"""
Close the result set.

If the connection has not been closed, and the result set has not already
been closed on the server for some other reason, issue a request to the server to close it.
"""
try:
if self.results is not None:
self.results.close()
if self.status != CommandState.CLOSED and self.connection.open:
self.backend.close_command(self.command_id)
except RequestError as e:
if isinstance(e.args[1], CursorAlreadyClosedError):
logger.info("Operation was canceled by a prior request")
finally:
self.status = CommandState.CLOSED
11 changes: 5 additions & 6 deletions src/databricks/sql/backend/sea/utils/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
Optional,
Any,
Callable,
cast,
TYPE_CHECKING,
)

if TYPE_CHECKING:
from databricks.sql.backend.sea.result_set import SeaResultSet

from databricks.sql.backend.types import ExecuteResponse
from databricks.sql.backend.types import ExecuteResponse, CommandId, CommandState

logger = logging.getLogger(__name__)

Expand All @@ -45,6 +44,9 @@ def _filter_sea_result_set(
"""

# Get all remaining rows
if result_set.results is None:
raise RuntimeError("Results queue is not initialized")

all_rows = result_set.results.remaining_rows()

# Filter rows
Expand All @@ -58,9 +60,7 @@ def _filter_sea_result_set(
command_id=command_id,
status=result_set.status,
description=result_set.description,
has_been_closed_server_side=result_set.has_been_closed_server_side,
lz4_compressed=result_set.lz4_compressed,
arrow_schema_bytes=result_set._arrow_schema_bytes,
is_staging_operation=False,
)

Expand All @@ -69,7 +69,7 @@ def _filter_sea_result_set(

result_data = ResultData(data=filtered_rows, external_links=None)

from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.sea.client import SeaDatabricksClient
from databricks.sql.backend.sea.result_set import SeaResultSet

# Create a new SeaResultSet with the filtered data
Expand All @@ -79,7 +79,6 @@ def _filter_sea_result_set(
filtered_result_set = SeaResultSet(
connection=result_set.connection,
execute_response=execute_response,
sea_client=cast(SeaDatabricksClient, result_set.backend),
result_data=result_data,
manifest=manifest,
buffer_size_bytes=result_set.buffer_size_bytes,
Expand Down
Loading
Loading