From cc053fe500d8914566c89c6db7e182518f0ced27 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Thu, 13 Sep 2018 17:31:49 -0400 Subject: [PATCH] Make ReadBuffer interface more robust The current ReadBuffer interface is somewhat error prone. There is no way to "peek" at the next message, so various subprotocols that have nested message processing loops have to resort to the `_skip_discard` kludge on the protocol to inform the main loop that it shouldn't skip over to the next message because the subprotocol already read too much. Fix this by adding a way to iterate over the messages without over-reading, and a way to "put" a message back into the buffer when necessary. This also renames `has_message()` to `take_message()` to make it clear that it changes the buffer state. --- asyncpg/protocol/buffer.pxd | 9 ++++--- asyncpg/protocol/buffer.pyx | 46 ++++++++++++++++++++++------------ asyncpg/protocol/coreproto.pyx | 29 ++++++++------------- 3 files changed, 45 insertions(+), 39 deletions(-) diff --git a/asyncpg/protocol/buffer.pxd b/asyncpg/protocol/buffer.pxd index caca282e..2da85f64 100644 --- a/asyncpg/protocol/buffer.pxd +++ b/asyncpg/protocol/buffer.pxd @@ -105,13 +105,14 @@ cdef class ReadBuffer: cdef inline int32_t read_int32(self) except? -1 cdef inline int16_t read_int16(self) except? -1 cdef inline read_cstr(self) - cdef int32_t has_message(self) except -1 - cdef inline int32_t has_message_type(self, char mtype) except -1 + cdef int32_t take_message(self) except -1 + cdef inline int32_t take_message_type(self, char mtype) except -1 + cdef int32_t put_message(self) except -1 cdef inline const char* try_consume_message(self, ssize_t* len) cdef Memory consume_message(self) cdef bytearray consume_messages(self, char mtype) - cdef discard_message(self) - cdef inline _discard_message(self) + cdef finish_message(self) + cdef inline _finish_message(self) cdef inline char get_message_type(self) cdef inline int32_t get_message_length(self) diff --git a/asyncpg/protocol/buffer.pyx b/asyncpg/protocol/buffer.pyx index 6270b6ca..efed819f 100644 --- a/asyncpg/protocol/buffer.pyx +++ b/asyncpg/protocol/buffer.pyx @@ -489,7 +489,7 @@ cdef class ReadBuffer: self._ensure_first_buf() - cdef int32_t has_message(self) except -1: + cdef int32_t take_message(self) except -1: cdef: const char *cbuf @@ -525,8 +525,24 @@ cdef class ReadBuffer: self._current_message_ready = 1 return 1 - cdef inline int32_t has_message_type(self, char mtype) except -1: - return self.has_message() and self.get_message_type() == mtype + cdef inline int32_t take_message_type(self, char mtype) except -1: + cdef const char *buf0 + + if self._current_message_ready: + return self._current_message_type == mtype + elif self._length >= 1: + self._ensure_first_buf() + buf0 = cpython.PyBytes_AS_STRING(self._buf0) + + return buf0[self._pos0] == mtype and self.take_message() + else: + return 0 + + cdef int32_t put_message(self) except -1: + if not self._current_message_ready: + raise BufferError('cannot put message: no message taken') + self._current_message_ready = False + return 0 cdef inline const char* try_consume_message(self, ssize_t* len): cdef: @@ -541,7 +557,7 @@ cdef class ReadBuffer: buf = self._try_read_bytes(buf_len) if buf != NULL: len[0] = buf_len - self._discard_message() + self._finish_message() return buf cdef Memory consume_message(self): @@ -551,7 +567,7 @@ cdef class ReadBuffer: mem = self.read(self._current_message_len_unread) else: mem = None - self._discard_message() + self._finish_message() return mem cdef bytearray consume_messages(self, char mtype): @@ -562,7 +578,7 @@ cdef class ReadBuffer: ssize_t total_bytes = 0 bytearray result - if not self.has_message_type(mtype): + if not self.take_message_type(mtype): return None # consume_messages is a volume-oriented method, so @@ -571,26 +587,24 @@ cdef class ReadBuffer: result = PyByteArray_FromStringAndSize(NULL, self._length) buf = PyByteArray_AsString(result) - while self.has_message_type(mtype): + while self.take_message_type(mtype): nbytes = self._current_message_len_unread self._read(buf, nbytes) buf += nbytes total_bytes += nbytes - self._discard_message() + self._finish_message() # Clamp the result to an actual size read. PyByteArray_Resize(result, total_bytes) return result - cdef discard_message(self): - if self._current_message_type == 0: - # Already discarded + cdef finish_message(self): + if self._current_message_type == 0 or not self._current_message_ready: + # The message has already been finished (e.g by consume_message()), + # or has been put back by put_message(). return - if not self._current_message_ready: - raise BufferError('no message to discard') - if self._current_message_len_unread: if ASYNCPG_DEBUG: mtype = chr(self._current_message_type) @@ -602,9 +616,9 @@ cdef class ReadBuffer: mtype, (discarded).as_bytes())) - self._discard_message() + self._finish_message() - cdef inline _discard_message(self): + cdef inline _finish_message(self): self._current_message_type = 0 self._current_message_len = 0 self._current_message_ready = 0 diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index 21498e7d..2927d4a9 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -22,8 +22,6 @@ cdef class CoreProtocol: self.xact_status = PQTRANS_IDLE self.encoding = 'utf-8' - self._skip_discard = False - # executemany support data self._execute_iter = None self._execute_portal_name = None @@ -36,7 +34,7 @@ cdef class CoreProtocol: char mtype ProtocolState state - while self.buffer.has_message() == 1: + while self.buffer.take_message() == 1: mtype = self.buffer.get_message_type() state = self.state @@ -150,10 +148,7 @@ cdef class CoreProtocol: self._push_result() finally: - if self._skip_discard: - self._skip_discard = False - else: - self.buffer.discard_message() + self.buffer.finish_message() cdef _process__auth(self, char mtype): if mtype == b'R': @@ -319,8 +314,6 @@ cdef class CoreProtocol: self.result = buf.consume_messages(b'd') - self._skip_discard = True - # By this point we have consumed all CopyData messages # in the inbound buffer. If there are no messages left # in the buffer, we need to push the accumulated data @@ -328,9 +321,13 @@ cdef class CoreProtocol: # batch. If there _are_ non-CopyData messages left, # we must not push the result here and let the # _process__copy_out_data subprotocol do the job. - if not buf.has_message(): + if not buf.take_message(): self._on_result() self.result = None + else: + # If there is a message in the buffer, put it back to + # be processed by the next protocol iteration. + buf.put_message() cdef _write_copy_data_msg(self, object data): cdef: @@ -385,11 +382,9 @@ cdef class CoreProtocol: '_parse_data_msgs: first message is not "D"') if self._discard_data: - while True: + while buf.take_message_type(b'D'): buf.consume_message() - if not buf.has_message() or buf.get_message_type() != b'D': - self._skip_discard = True - return + return if ASYNCPG_DEBUG: if type(self.result) is not list: @@ -398,7 +393,7 @@ cdef class CoreProtocol: format(self.result)) rows = self.result - while True: + while buf.take_message_type(b'D'): cbuf = buf.try_consume_message(&cbuf_len) if cbuf != NULL: row = decoder(self, cbuf, cbuf_len) @@ -408,10 +403,6 @@ cdef class CoreProtocol: cpython.PyList_Append(rows, row) - if not buf.has_message() or buf.get_message_type() != b'D': - self._skip_discard = True - return - cdef _parse_msg_backend_key_data(self): self.backend_pid = self.buffer.read_int32() self.backend_secret = self.buffer.read_int32()