Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions google/cloud/spanner_dbapi/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,21 @@
}


def _execute_insert_heterogenous(transaction, sql_params_list):
def _execute_insert_heterogenous(
transaction,
sql_params_list,
request_options=None,
):
for sql, params in sql_params_list:
sql, params = sql_pyformat_args_to_spanner(sql, params)
transaction.execute_update(sql, params, get_param_types(params))
transaction.execute_update(
sql, params, get_param_types(params), request_options=request_options
)


def handle_insert(connection, sql, params):
return connection.database.run_in_transaction(
_execute_insert_heterogenous, ((sql, params),)
_execute_insert_heterogenous, ((sql, params),), connection.request_options
)


Expand Down
25 changes: 17 additions & 8 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,21 @@ def read_only(self, value):
)
self._read_only = value

@property
def request_options(self):
"""Options for the next SQL operations.

Returns:
google.cloud.spanner_v1.RequestOptions:
Request options.
"""
if self.request_priority is None:
return

req_opts = RequestOptions(priority=self.request_priority)
self.request_priority = None
return req_opts
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Request options are used only for next one request

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spanner is designed this way, otherwise there is no sense in request options.
Let's say there is a request priority. If it'll be set for a connection/session, than all the requests will have the same priority. It's technically the same that not setting priority at all.


@property
def staleness(self):
"""Current read staleness option value of this `Connection`.
Expand Down Expand Up @@ -437,25 +452,19 @@ def run_statement(self, statement, retried=False):

if statement.is_insert:
_execute_insert_heterogenous(
transaction, ((statement.sql, statement.params),)
transaction, ((statement.sql, statement.params),), self.request_options
)
return (
iter(()),
ResultsChecksum() if retried else statement.checksum,
)

if self.request_priority is not None:
req_opts = RequestOptions(priority=self.request_priority)
self.request_priority = None
else:
req_opts = None

return (
transaction.execute_sql(
statement.sql,
statement.params,
param_types=statement.param_types,
request_options=req_opts,
request_options=self.request_options,
),
ResultsChecksum() if retried else statement.checksum,
)
Expand Down
16 changes: 13 additions & 3 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def close(self):

def _do_execute_update(self, transaction, sql, params):
result = transaction.execute_update(
sql, params=params, param_types=get_param_types(params)
sql,
params=params,
param_types=get_param_types(params),
request_options=self.connection.request_options,
)
self._itr = None
if type(result) == int:
Expand Down Expand Up @@ -278,7 +281,9 @@ def execute(self, sql, args=None):
_helpers.handle_insert(self.connection, sql, args or None)
else:
self.connection.database.run_in_transaction(
self._do_execute_update, sql, args or None
self._do_execute_update,
sql,
args or None,
)
except (AlreadyExists, FailedPrecondition, OutOfRange) as e:
raise IntegrityError(getattr(e, "details", e)) from e
Expand Down Expand Up @@ -421,7 +426,12 @@ def fetchmany(self, size=None):
return items

def _handle_DQL_with_snapshot(self, snapshot, sql, params):
self._result_set = snapshot.execute_sql(sql, params, get_param_types(params))
self._result_set = snapshot.execute_sql(
sql,
params,
get_param_types(params),
request_options=self.connection.request_options,
)
# Read the first element so that the StreamedResultSet can
# return the metadata after a DQL statement.
self._itr = PeekIterator(self._result_set)
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/spanner_dbapi/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def test__execute_insert_heterogenous(self):

mock_pyformat.assert_called_once_with(params[0], params[1])
mock_param_types.assert_called_once_with(None)
mock_update.assert_called_once_with(sql, None, None)
mock_update.assert_called_once_with(
sql, None, None, request_options=None
)

def test__execute_insert_heterogenous_error(self):
from google.cloud.spanner_dbapi import _helpers
Expand All @@ -62,7 +64,9 @@ def test__execute_insert_heterogenous_error(self):

mock_pyformat.assert_called_once_with(params[0], params[1])
mock_param_types.assert_called_once_with(None)
mock_update.assert_called_once_with(sql, None, None)
mock_update.assert_called_once_with(
sql, None, None, request_options=None
)

def test_handle_insert(self):
from google.cloud.spanner_dbapi import _helpers
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/spanner_dbapi/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,29 @@ def test_handle_dql(self):
self.assertIsInstance(cursor._itr, utils.PeekIterator)
self.assertEqual(cursor._row_count, _UNSET_COUNT)

def test_handle_dql_priority(self):
from google.cloud.spanner_dbapi import utils
from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT
from google.cloud.spanner_v1 import RequestOptions

connection = self._make_connection(self.INSTANCE, mock.MagicMock())
connection.database.snapshot.return_value.__enter__.return_value = (
mock_snapshot
) = mock.MagicMock()
connection.request_priority = 1

cursor = self._make_one(connection)

sql = "sql"
mock_snapshot.execute_sql.return_value = ["0"]
cursor._handle_DQL(sql, params=None)
self.assertEqual(cursor._result_set, ["0"])
self.assertIsInstance(cursor._itr, utils.PeekIterator)
self.assertEqual(cursor._row_count, _UNSET_COUNT)
mock_snapshot.execute_sql.assert_called_with(
sql, None, None, request_options=RequestOptions(priority=1)
)

def test_context(self):
connection = self._make_connection(self.INSTANCE, self.DATABASE)
cursor = self._make_one(connection)
Expand Down