Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to asynchronously prepare CQL statements #1239

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 72 additions & 27 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2910,6 +2910,58 @@ def _on_analytics_master_result(self, response, master_future, query_future):

self.submit(query_future.send_request)

def prepare_async(self, query, custom_payload=None, keyspace=None):
"""
Prepare the given query and return a :class:`~.ResponseFuture`
object. You may also call :meth:`~.ResponseFuture.result()`
on the :class:`.ResponseFuture` to synchronously block for
prepared statement object at any time.

See :meth:`Session.prepare` for parameter definitions.

Example usage::

>>> future = session.prepare_async("SELECT * FROM mycf")
>>> # do other stuff...

>>> try:
... prepared_statement = future.result()
... except Exception:
... log.exception("Operation failed:")
"""
future = self._create_prepare_response_future(query, keyspace, custom_payload)
future._protocol_handler = self.client_protocol_handler
self._on_request(future)
future.send_request()
return future

def _create_prepare_response_future(self, query, keyspace, custom_payload):
message = PrepareMessage(query=query, keyspace=keyspace)
future = ResponseFuture(self, message, query=None, timeout=self.default_timeout)

def _prepare_result_processor(future, response):
prepared_keyspace = keyspace if keyspace else None
prepared_statement = PreparedStatement.from_message(
response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query,
prepared_keyspace,
self._protocol_version, response.column_metadata, response.result_metadata_id,
self.cluster.column_encryption_policy)
prepared_statement.custom_payload = custom_payload
self.cluster.add_prepared(response.query_id, prepared_statement)
if self.cluster.prepare_on_all_hosts:
# prepare statement on all hosts
host = future._current_host
try:
self.prepare_on_all_nodes(future.message.query, host, future.message.keyspace)
except Exception:
log.exception("Error preparing query on all hosts:")

return prepared_statement

future._set_result_processor(_prepare_result_processor)
return future


def _create_response_future(self, query, parameters, trace, custom_payload,
timeout, execution_profile=EXEC_PROFILE_DEFAULT,
paging_state=None, host=None):
Expand Down Expand Up @@ -3118,36 +3170,18 @@ def prepare(self, query, custom_payload=None, keyspace=None):
**Important**: PreparedStatements should be prepared only once.
Preparing the same query more than once will likely affect performance.

When :meth:`~.Cluster.prepare_on_all_hosts` is enabled, method
attempts to prepare given query on all hosts and waits for each node to respond.
Preparing CQL query on other nodes may fail, but error is not propagated
to the caller.

`custom_payload` is a key value map to be passed along with the prepare
message. See :ref:`custom_payload`.
"""
message = PrepareMessage(query=query, keyspace=keyspace)
future = ResponseFuture(self, message, query=None, timeout=self.default_timeout)
try:
future.send_request()
response = future.result().one()
except Exception:
log.exception("Error preparing query:")
raise
future = self.prepare_async(query, custom_payload, keyspace)
return future.result()

prepared_keyspace = keyspace if keyspace else None
prepared_statement = PreparedStatement.from_message(
response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace,
self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy)
prepared_statement.custom_payload = future.custom_payload

self.cluster.add_prepared(response.query_id, prepared_statement)

if self.cluster.prepare_on_all_hosts:
host = future._current_host
try:
self.prepare_on_all_hosts(prepared_statement.query_string, host, prepared_keyspace)
except Exception:
log.exception("Error preparing query on all hosts:")

return prepared_statement

def prepare_on_all_hosts(self, query, excluded_host, keyspace=None):
def prepare_on_all_nodes(self, query, excluded_host, keyspace=None):
"""
Prepare the given query on all hosts, excluding ``excluded_host``.
Intended for internal use only.
Expand Down Expand Up @@ -4320,6 +4354,7 @@ class ResponseFuture(object):
_col_types = None
_final_exception = None
_query_traces = None
_result_processor = None
_callbacks = None
_errbacks = None
_current_host = None
Expand Down Expand Up @@ -4951,10 +4986,20 @@ def result(self):
"""
self._event.wait()
if self._final_result is not _NOT_SET:
return ResultSet(self, self._final_result)
if self._result_processor is not None:
return self._result_processor(self, self._final_result)
else:
return ResultSet(self, self._final_result)
else:
raise self._final_exception

def _set_result_processor(self, result_processor):
"""
Sets internal result processor which allows to control object
returned by :meth:`ResponseFuture.result()` method.
"""
self._result_processor = result_processor

def get_query_trace_ids(self):
"""
Returns the trace session ids for this future, if tracing was enabled (does not fetch trace data).
Expand Down
78 changes: 78 additions & 0 deletions tests/integration/standard/test_prepared_statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from cassandra import InvalidRequest, DriverException

from cassandra import ConsistencyLevel, ProtocolVersion
from cassandra.cluster import ResponseFuture
from cassandra.query import PreparedStatement, UNSET_VALUE
from tests.integration import (get_server_versions, greaterthanorequalcass40, greaterthanorequaldse50,
requirecassandra, BasicSharedKeyspaceUnitTestCase)
Expand Down Expand Up @@ -121,6 +122,83 @@ def test_basic(self):
results = self.session.execute(bound)
self.assertEqual(results, [('x', 'y', 'z')])

def test_basic_async(self):
"""
Test basic asynchronous PreparedStatement usage
"""
self.session.execute(
"""
DROP KEYSPACE IF EXISTS preparedtests
"""
)
self.session.execute(
"""
CREATE KEYSPACE preparedtests
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
""")

self.session.set_keyspace("preparedtests")
self.session.execute(
"""
CREATE TABLE cf0 (
a text,
b text,
c text,
PRIMARY KEY (a, b)
)
""")

prepared_future = self.session.prepare_async(
"""
INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?)
""")
self.assertIsInstance(prepared_future, ResponseFuture)
prepared = prepared_future.result()
self.assertIsInstance(prepared, PreparedStatement)

bound = prepared.bind(('a', 'b', 'c'))
self.session.execute(bound)

prepared_future = self.session.prepare_async(
"""
SELECT * FROM cf0 WHERE a=?
""")
self.assertIsInstance(prepared_future, ResponseFuture)
prepared = prepared_future.result()
self.assertIsInstance(prepared, PreparedStatement)

bound = prepared.bind(('a'))
results = self.session.execute(bound)
self.assertEqual(results, [('a', 'b', 'c')])

# test with new dict binding
prepared_future = self.session.prepare_async(
"""
INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?)
""")
self.assertIsInstance(prepared_future, ResponseFuture)
prepared = prepared_future.result()
self.assertIsInstance(prepared, PreparedStatement)

bound = prepared.bind({
'a': 'x',
'b': 'y',
'c': 'z'
})
self.session.execute(bound)

prepared_future = self.session.prepare_async(
"""
SELECT * FROM cf0 WHERE a=?
""")
self.assertIsInstance(prepared_future, ResponseFuture)
prepared = prepared_future.result()
self.assertIsInstance(prepared, PreparedStatement)

bound = prepared.bind({'a': 'x'})
results = self.session.execute(bound)
self.assertEqual(results, [('x', 'y', 'z')])

def test_missing_primary_key(self):
"""
Ensure an InvalidRequest is thrown
Expand Down
27 changes: 25 additions & 2 deletions tests/integration/standard/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,19 @@ def test_prepare_on_all_hosts(self):
session.execute(select_statement, (1, ), host=host)
self.assertEqual(2, self.mock_handler.get_message_count('debug', "Re-preparing"))

def test_prepare_async_on_all_hosts(self):
"""
Test to validate prepare_on_all_hosts flag is honored during prepare_async execution.
"""
clus = TestCluster(prepare_on_all_hosts=True)
self.addCleanup(clus.shutdown)

session = clus.connect(wait_for_all_pools=True)
select_statement = session.prepare_async("SELECT k FROM test3rf.test WHERE k = ? AND v = ? ALLOW FILTERING").result()
for host in clus.metadata.all_hosts():
session.execute(select_statement, (1, 1), host=host)
self.assertEqual(0, self.mock_handler.get_message_count('debug', "Re-preparing"))

def test_prepare_batch_statement(self):
"""
Test to validate a prepared statement used inside a batch statement is correctly handled
Expand Down Expand Up @@ -647,7 +660,6 @@ def test_prepared_statement(self):

prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)')
prepared.consistency_level = ConsistencyLevel.ONE

self.assertEqual(str(prepared),
'<PreparedStatement query="INSERT INTO test3rf.test (k, v) VALUES (?, ?)", consistency=ONE>')

Expand Down Expand Up @@ -717,6 +729,17 @@ def test_prepared_statements(self):
self.session.execute_async(batch).result()
self.confirm_results()

def test_prepare_async(self):
prepared = self.session.prepare_async("INSERT INTO test3rf.test (k, v) VALUES (?, ?)").result()

batch = BatchStatement(BatchType.LOGGED)
for i in range(10):
batch.add(prepared, (i, i))

self.session.execute(batch)
self.session.execute_async(batch).result()
self.confirm_results()

def test_bound_statements(self):
prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

Expand Down Expand Up @@ -942,7 +965,7 @@ def test_no_connection_refused_on_timeout(self):
exception_type = type(result).__name__
if exception_type == "NoHostAvailable":
self.fail("PYTHON-91: Disconnected from Cassandra: %s" % result.message)
if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessageSub"]:
if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessageSub", "ErrorMessage"]:
if type(result).__name__ in ["WriteTimeout", "WriteFailure"]:
received_timeout = True
continue
Expand Down