Skip to content

Commit e1484c2

Browse files
authored
Merge branch 'main' into jprakash-db/arrow-optim
2 parents 8cdfd88 + ba1eab3 commit e1484c2

16 files changed

+2124
-680
lines changed
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Dict, List, Optional, Any, Union, TYPE_CHECKING
5+
6+
if TYPE_CHECKING:
7+
from databricks.sql.client import Cursor
8+
from databricks.sql.result_set import ResultSet
9+
10+
from databricks.sql.thrift_api.TCLIService import ttypes
11+
from databricks.sql.backend.types import SessionId, CommandId, CommandState
12+
13+
14+
class DatabricksClient(ABC):
15+
"""
16+
Abstract client interface for interacting with Databricks SQL services.
17+
18+
Implementations of this class are responsible for:
19+
- Managing connections to Databricks SQL services
20+
- Executing SQL queries and commands
21+
- Retrieving query results
22+
- Fetching metadata about catalogs, schemas, tables, and columns
23+
"""
24+
25+
# == Connection and Session Management ==
26+
@abstractmethod
27+
def open_session(
28+
self,
29+
session_configuration: Optional[Dict[str, Any]],
30+
catalog: Optional[str],
31+
schema: Optional[str],
32+
) -> SessionId:
33+
"""
34+
Opens a new session with the Databricks SQL service.
35+
36+
This method establishes a new session with the server and returns a session
37+
identifier that can be used for subsequent operations.
38+
39+
Args:
40+
session_configuration: Optional dictionary of configuration parameters for the session
41+
catalog: Optional catalog name to use as the initial catalog for the session
42+
schema: Optional schema name to use as the initial schema for the session
43+
44+
Returns:
45+
SessionId: A session identifier object that can be used for subsequent operations
46+
47+
Raises:
48+
Error: If the session configuration is invalid
49+
OperationalError: If there's an error establishing the session
50+
InvalidServerResponseError: If the server response is invalid or unexpected
51+
"""
52+
pass
53+
54+
@abstractmethod
55+
def close_session(self, session_id: SessionId) -> None:
56+
"""
57+
Closes an existing session with the Databricks SQL service.
58+
59+
This method terminates the session identified by the given session ID and
60+
releases any resources associated with it.
61+
62+
Args:
63+
session_id: The session identifier returned by open_session()
64+
65+
Raises:
66+
ValueError: If the session ID is invalid
67+
OperationalError: If there's an error closing the session
68+
"""
69+
pass
70+
71+
# == Query Execution, Command Management ==
72+
@abstractmethod
73+
def execute_command(
74+
self,
75+
operation: str,
76+
session_id: SessionId,
77+
max_rows: int,
78+
max_bytes: int,
79+
lz4_compression: bool,
80+
cursor: Cursor,
81+
use_cloud_fetch: bool,
82+
parameters: List[ttypes.TSparkParameter],
83+
async_op: bool,
84+
enforce_embedded_schema_correctness: bool,
85+
) -> Union[ResultSet, None]:
86+
"""
87+
Executes a SQL command or query within the specified session.
88+
89+
This method sends a SQL command to the server for execution and handles
90+
the response. It can operate in both synchronous and asynchronous modes.
91+
92+
Args:
93+
operation: The SQL command or query to execute
94+
session_id: The session identifier in which to execute the command
95+
max_rows: Maximum number of rows to fetch in a single fetch batch
96+
max_bytes: Maximum number of bytes to fetch in a single fetch batch
97+
lz4_compression: Whether to use LZ4 compression for result data
98+
cursor: The cursor object that will handle the results. The command id is set in this cursor.
99+
use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets
100+
parameters: List of parameters to bind to the query
101+
async_op: Whether to execute the command asynchronously
102+
enforce_embedded_schema_correctness: Whether to enforce schema correctness
103+
104+
Returns:
105+
If async_op is False, returns a ResultSet object containing the
106+
query results and metadata. If async_op is True, returns None and the
107+
results must be fetched later using get_execution_result().
108+
109+
Raises:
110+
ValueError: If the session ID is invalid
111+
OperationalError: If there's an error executing the command
112+
ServerOperationError: If the server encounters an error during execution
113+
"""
114+
pass
115+
116+
@abstractmethod
117+
def cancel_command(self, command_id: CommandId) -> None:
118+
"""
119+
Cancels a running command or query.
120+
121+
This method attempts to cancel a command that is currently being executed.
122+
It can be called from a different thread than the one executing the command.
123+
124+
Args:
125+
command_id: The command identifier to cancel
126+
127+
Raises:
128+
ValueError: If the command ID is invalid
129+
OperationalError: If there's an error canceling the command
130+
"""
131+
pass
132+
133+
@abstractmethod
134+
def close_command(self, command_id: CommandId) -> None:
135+
"""
136+
Closes a command and releases associated resources.
137+
138+
This method informs the server that the client is done with the command
139+
and any resources associated with it can be released.
140+
141+
Args:
142+
command_id: The command identifier to close
143+
144+
Raises:
145+
ValueError: If the command ID is invalid
146+
OperationalError: If there's an error closing the command
147+
"""
148+
pass
149+
150+
@abstractmethod
151+
def get_query_state(self, command_id: CommandId) -> CommandState:
152+
"""
153+
Gets the current state of a query or command.
154+
155+
This method retrieves the current execution state of a command from the server.
156+
157+
Args:
158+
command_id: The command identifier to check
159+
160+
Returns:
161+
CommandState: The current state of the command
162+
163+
Raises:
164+
ValueError: If the command ID is invalid
165+
OperationalError: If there's an error retrieving the state
166+
ServerOperationError: If the command is in an error state
167+
DatabaseError: If the command has been closed unexpectedly
168+
"""
169+
pass
170+
171+
@abstractmethod
172+
def get_execution_result(
173+
self,
174+
command_id: CommandId,
175+
cursor: Cursor,
176+
) -> ResultSet:
177+
"""
178+
Retrieves the results of a previously executed command.
179+
180+
This method fetches the results of a command that was executed asynchronously
181+
or retrieves additional results from a command that has more rows available.
182+
183+
Args:
184+
command_id: The command identifier for which to retrieve results
185+
cursor: The cursor object that will handle the results
186+
187+
Returns:
188+
ResultSet: An object containing the query results and metadata
189+
190+
Raises:
191+
ValueError: If the command ID is invalid
192+
OperationalError: If there's an error retrieving the results
193+
"""
194+
pass
195+
196+
# == Metadata Operations ==
197+
@abstractmethod
198+
def get_catalogs(
199+
self,
200+
session_id: SessionId,
201+
max_rows: int,
202+
max_bytes: int,
203+
cursor: Cursor,
204+
) -> ResultSet:
205+
"""
206+
Retrieves a list of available catalogs.
207+
208+
This method fetches metadata about all catalogs available in the current
209+
session's context.
210+
211+
Args:
212+
session_id: The session identifier
213+
max_rows: Maximum number of rows to fetch in a single batch
214+
max_bytes: Maximum number of bytes to fetch in a single batch
215+
cursor: The cursor object that will handle the results
216+
217+
Returns:
218+
ResultSet: An object containing the catalog metadata
219+
220+
Raises:
221+
ValueError: If the session ID is invalid
222+
OperationalError: If there's an error retrieving the catalogs
223+
"""
224+
pass
225+
226+
@abstractmethod
227+
def get_schemas(
228+
self,
229+
session_id: SessionId,
230+
max_rows: int,
231+
max_bytes: int,
232+
cursor: Cursor,
233+
catalog_name: Optional[str] = None,
234+
schema_name: Optional[str] = None,
235+
) -> ResultSet:
236+
"""
237+
Retrieves a list of schemas, optionally filtered by catalog and schema name patterns.
238+
239+
This method fetches metadata about schemas available in the specified catalog
240+
or all catalogs if no catalog is specified.
241+
242+
Args:
243+
session_id: The session identifier
244+
max_rows: Maximum number of rows to fetch in a single batch
245+
max_bytes: Maximum number of bytes to fetch in a single batch
246+
cursor: The cursor object that will handle the results
247+
catalog_name: Optional catalog name pattern to filter by
248+
schema_name: Optional schema name pattern to filter by
249+
250+
Returns:
251+
ResultSet: An object containing the schema metadata
252+
253+
Raises:
254+
ValueError: If the session ID is invalid
255+
OperationalError: If there's an error retrieving the schemas
256+
"""
257+
pass
258+
259+
@abstractmethod
260+
def get_tables(
261+
self,
262+
session_id: SessionId,
263+
max_rows: int,
264+
max_bytes: int,
265+
cursor: Cursor,
266+
catalog_name: Optional[str] = None,
267+
schema_name: Optional[str] = None,
268+
table_name: Optional[str] = None,
269+
table_types: Optional[List[str]] = None,
270+
) -> ResultSet:
271+
"""
272+
Retrieves a list of tables, optionally filtered by catalog, schema, table name, and table types.
273+
274+
This method fetches metadata about tables available in the specified catalog
275+
and schema, or all catalogs and schemas if not specified.
276+
277+
Args:
278+
session_id: The session identifier
279+
max_rows: Maximum number of rows to fetch in a single batch
280+
max_bytes: Maximum number of bytes to fetch in a single batch
281+
cursor: The cursor object that will handle the results
282+
catalog_name: Optional catalog name pattern to filter by
283+
if catalog_name is None, we fetch across all catalogs
284+
schema_name: Optional schema name pattern to filter by
285+
if schema_name is None, we fetch across all schemas
286+
table_name: Optional table name pattern to filter by
287+
table_types: Optional list of table types to filter by (e.g., ['TABLE', 'VIEW'])
288+
289+
Returns:
290+
ResultSet: An object containing the table metadata
291+
292+
Raises:
293+
ValueError: If the session ID is invalid
294+
OperationalError: If there's an error retrieving the tables
295+
"""
296+
pass
297+
298+
@abstractmethod
299+
def get_columns(
300+
self,
301+
session_id: SessionId,
302+
max_rows: int,
303+
max_bytes: int,
304+
cursor: Cursor,
305+
catalog_name: Optional[str] = None,
306+
schema_name: Optional[str] = None,
307+
table_name: Optional[str] = None,
308+
column_name: Optional[str] = None,
309+
) -> ResultSet:
310+
"""
311+
Retrieves a list of columns, optionally filtered by catalog, schema, table, and column name patterns.
312+
313+
This method fetches metadata about columns available in the specified table,
314+
or all tables if not specified.
315+
316+
Args:
317+
session_id: The session identifier
318+
max_rows: Maximum number of rows to fetch in a single batch
319+
max_bytes: Maximum number of bytes to fetch in a single batch
320+
cursor: The cursor object that will handle the results
321+
catalog_name: Optional catalog name pattern to filter by
322+
schema_name: Optional schema name pattern to filter by
323+
table_name: Optional table name pattern to filter by
324+
if table_name is None, we fetch across all tables
325+
column_name: Optional column name pattern to filter by
326+
327+
Returns:
328+
ResultSet: An object containing the column metadata
329+
330+
Raises:
331+
ValueError: If the session ID is invalid
332+
OperationalError: If there's an error retrieving the columns
333+
"""
334+
pass
335+
336+
@property
337+
@abstractmethod
338+
def max_download_threads(self) -> int:
339+
"""
340+
Gets the maximum number of download threads for cloud fetch operations.
341+
342+
Returns:
343+
int: The maximum number of download threads
344+
"""
345+
pass

0 commit comments

Comments
 (0)