Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Store error on transaction stream #248

Merged
merged 3 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions typedb/connection/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@

import typedb_protocol.common.transaction_pb2 as transaction_proto
from grpc import RpcError

from typedb.api.connection.options import TypeDBOptions
from typedb.api.query.future import QueryFuture
from typedb.api.connection.transaction import _TypeDBTransactionExtended, TransactionType
from typedb.api.query.future import QueryFuture
from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED, TRANSACTION_CLOSED_WITH_ERRORS
from typedb.common.rpc.request_builder import transaction_commit_req, transaction_rollback_req, transaction_open_req
from typedb.concept.concept_manager import _ConceptManager
Expand Down Expand Up @@ -113,8 +112,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return False

def _raise_transaction_closed(self):
errors = self._bidirectional_stream.get_errors()
if len(errors) == 0:
error = self._bidirectional_stream.get_error()
if error is None:
raise TypeDBClientException.of(TRANSACTION_CLOSED)
else:
raise TypeDBClientException.of(TRANSACTION_CLOSED_WITH_ERRORS, errors)
raise TypeDBClientException.of(TRANSACTION_CLOSED_WITH_ERRORS, error)
11 changes: 8 additions & 3 deletions typedb/stream/bidirectional_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, stub: TypeDBStub, transmitter: RequestTransmitter):
self._response_iterator = stub.transaction(self._request_iterator)
self._dispatcher = transmitter.dispatcher(self._request_iterator)
self._is_open = AtomicBoolean(True)
self._error: TypeDBClientException = None

def single(self, req: transaction_proto.Transaction.Req, batch: bool) -> "BidirectionalStream.Single[transaction_proto.Transaction.Res]":
request_id = uuid4()
Expand All @@ -60,7 +61,7 @@ def stream(self, req: transaction_proto.Transaction.Req) -> Iterator[transaction
req.req_id = request_id.bytes
self._response_collector.new_queue(request_id)
self._dispatcher.dispatch(req)
return ResponsePartIterator(request_id, self, self._dispatcher)
return ResponsePartIterator(request_id, self)

def done(self, request_id: UUID):
self._response_collector.remove(request_id)
Expand Down Expand Up @@ -104,11 +105,15 @@ def _collect(self, response: Union[transaction_proto.Transaction.Res, transactio
else:
raise TypeDBClientException.of(UNKNOWN_REQUEST_ID, request_id)

def get_errors(self) -> List[TypeDBClientException]:
return self._response_collector.get_errors()
def dispatcher(self):
return self._dispatcher

def get_error(self) -> TypeDBClientException:
return self._error

def close(self, error: TypeDBClientException = None):
if self._is_open.compare_and_set(True, False):
self._error = error
self._response_collector.close(error)
try:
self._dispatcher.close()
Expand Down
11 changes: 0 additions & 11 deletions typedb/stream/response_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,6 @@ def close(self, error: Optional[TypeDBClientException]):
for collector in self._response_queues.values():
collector.close(error)

def get_errors(self) -> [TypeDBClientException]:
errors = []
with self._collectors_lock:
for collector in self._response_queues.values():
error = collector.get_error()
if error is not None:
errors.append(error)
return errors

class Queue(Generic[R]):

def __init__(self):
Expand All @@ -87,8 +78,6 @@ def close(self, error: Optional[TypeDBClientException]):
self._error = error
self._response_queue.put(DoneResponse())

def get_error(self) -> TypeDBClientException:
return self._error


class Response:
Expand Down
15 changes: 8 additions & 7 deletions typedb/stream/response_part_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,15 @@
from enum import Enum
from typedb.common.exception import TypeDBClientException, ILLEGAL_ARGUMENT, MISSING_RESPONSE, ILLEGAL_STATE
from typedb.common.rpc.request_builder import transaction_stream_req
from typedb.stream.request_transmitter import RequestTransmitter

if TYPE_CHECKING:
from typedb.stream.bidirectional_stream import BidirectionalStream


class ResponsePartIterator(Iterator[transaction_proto.Transaction.ResPart]):

def __init__(self, request_id: UUID, bidirectional_stream: "BidirectionalStream", request_dispatcher: RequestTransmitter.Dispatcher):
def __init__(self, request_id: UUID, bidirectional_stream: "BidirectionalStream"):
self._request_id = request_id
self._dispatcher = request_dispatcher
self._bidirectional_stream = bidirectional_stream
self._state = ResponsePartIterator.State.EMPTY
self._next: transaction_proto.Transaction.ResPart = None
Expand All @@ -54,7 +52,7 @@ def _fetch_and_check(self) -> bool:
self._state = ResponsePartIterator.State.DONE
return False
elif state == transaction_proto.Transaction.Stream.State.Value("CONTINUE"):
self._dispatcher.dispatch(transaction_stream_req(self._request_id))
self._bidirectional_stream.dispatcher().dispatch(transaction_stream_req(self._request_id))
return self._fetch_and_check()
else:
raise TypeDBClientException.of(ILLEGAL_ARGUMENT)
Expand All @@ -76,8 +74,11 @@ def _has_next(self) -> bool:
raise TypeDBClientException.of(ILLEGAL_STATE)

def __next__(self) -> transaction_proto.Transaction.ResPart:
if not self._has_next():
if self._bidirectional_stream.get_error() is not None:
raise self._bidirectional_stream.get_error()
elif not self._has_next():
self._bidirectional_stream.done(self._request_id)
raise StopIteration
self._state = ResponsePartIterator.State.EMPTY
return self._next
else:
self._state = ResponsePartIterator.State.EMPTY
return self._next