From 8bfe4848d82990610d8f68160dab9b9f3b547a09 Mon Sep 17 00:00:00 2001 From: James Sun Date: Fri, 11 Aug 2017 11:48:50 -0700 Subject: [PATCH] Handle early cancel --- integration_tests/test_dbapi.py | 5 +++++ prestodb/client.py | 17 +++++++++++++++-- prestodb/dbapi.py | 3 +++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/integration_tests/test_dbapi.py b/integration_tests/test_dbapi.py index a255caa..6943d75 100644 --- a/integration_tests/test_dbapi.py +++ b/integration_tests/test_dbapi.py @@ -115,6 +115,11 @@ def test_cancel_query(presto_connection): cur.fetchall() assert 'Query was canceled' in str(cancel_error.value) + cur = presto_connection.cursor() + with pytest.raises(Exception) as cancel_error: + cur.cancel() + assert 'Cancel query failed; no running query' in str(cancel_error.value) + def test_session_properties(run_presto): _, host, port = run_presto diff --git a/prestodb/client.py b/prestodb/client.py index 71448be..e86f5f6 100644 --- a/prestodb/client.py +++ b/prestodb/client.py @@ -434,8 +434,10 @@ def __init__( self._columns = None self._finished = False + self._cancelled = False self._request = request self._sql = sql + self._result = PrestoResult(self) @property def columns(self): @@ -445,6 +447,10 @@ def columns(self): def stats(self): return self._stats + @property + def result(self): + return self._result + def execute(self): # type: () -> PrestoResult """Initiate a Presto query by sending the SQL statement @@ -454,6 +460,8 @@ def execute(self): track the rows returned by the query. To fetch all rows, call fetch() until is_finished is true. """ + if self._cancelled: + raise exceptions.PrestoUserError("Query has been cancelled") response = self._request.post(self._sql) status = self._request.process(response) @@ -462,8 +470,8 @@ def execute(self): self._stats.update(status.stats) if status.next_uri is None: self._finished = True - self.result = PrestoResult(self, status.rows) - return self.result + self._result = PrestoResult(self, status.rows) + return self._result def fetch(self): # type: () -> List[List[Any]] @@ -483,6 +491,11 @@ def cancel(self): # type: None -> None if self.is_finished(): return + + self._cancelled = True + if self._request.next_uri is None: + return + response = self._request.delete(self._request.next_uri) if response.status_code == requests.codes.no_content: return diff --git a/prestodb/dbapi.py b/prestodb/dbapi.py index 9002070..f9ebf8f 100644 --- a/prestodb/dbapi.py +++ b/prestodb/dbapi.py @@ -129,6 +129,7 @@ def __init__( self.arraysize = 1 self._iterator = None + self._query = None @property def description(self): @@ -245,4 +246,6 @@ def fetchall(self): return list(self.genall()) def cancel(self): + if self._query is None: + raise OperationalError("Cancel query failed; no running query") self._query.cancel()