diff --git a/asyncpg/protocol/buffer.pxd b/asyncpg/protocol/buffer.pxd index caca282e..2da85f64 100644 --- a/asyncpg/protocol/buffer.pxd +++ b/asyncpg/protocol/buffer.pxd @@ -105,13 +105,14 @@ cdef class ReadBuffer: cdef inline int32_t read_int32(self) except? -1 cdef inline int16_t read_int16(self) except? -1 cdef inline read_cstr(self) - cdef int32_t has_message(self) except -1 - cdef inline int32_t has_message_type(self, char mtype) except -1 + cdef int32_t take_message(self) except -1 + cdef inline int32_t take_message_type(self, char mtype) except -1 + cdef int32_t put_message(self) except -1 cdef inline const char* try_consume_message(self, ssize_t* len) cdef Memory consume_message(self) cdef bytearray consume_messages(self, char mtype) - cdef discard_message(self) - cdef inline _discard_message(self) + cdef finish_message(self) + cdef inline _finish_message(self) cdef inline char get_message_type(self) cdef inline int32_t get_message_length(self) diff --git a/asyncpg/protocol/buffer.pyx b/asyncpg/protocol/buffer.pyx index 6270b6ca..efed819f 100644 --- a/asyncpg/protocol/buffer.pyx +++ b/asyncpg/protocol/buffer.pyx @@ -489,7 +489,7 @@ cdef class ReadBuffer: self._ensure_first_buf() - cdef int32_t has_message(self) except -1: + cdef int32_t take_message(self) except -1: cdef: const char *cbuf @@ -525,8 +525,24 @@ cdef class ReadBuffer: self._current_message_ready = 1 return 1 - cdef inline int32_t has_message_type(self, char mtype) except -1: - return self.has_message() and self.get_message_type() == mtype + cdef inline int32_t take_message_type(self, char mtype) except -1: + cdef const char *buf0 + + if self._current_message_ready: + return self._current_message_type == mtype + elif self._length >= 1: + self._ensure_first_buf() + buf0 = cpython.PyBytes_AS_STRING(self._buf0) + + return buf0[self._pos0] == mtype and self.take_message() + else: + return 0 + + cdef int32_t put_message(self) except -1: + if not self._current_message_ready: + raise BufferError('cannot put message: no message taken') + self._current_message_ready = False + return 0 cdef inline const char* try_consume_message(self, ssize_t* len): cdef: @@ -541,7 +557,7 @@ cdef class ReadBuffer: buf = self._try_read_bytes(buf_len) if buf != NULL: len[0] = buf_len - self._discard_message() + self._finish_message() return buf cdef Memory consume_message(self): @@ -551,7 +567,7 @@ cdef class ReadBuffer: mem = self.read(self._current_message_len_unread) else: mem = None - self._discard_message() + self._finish_message() return mem cdef bytearray consume_messages(self, char mtype): @@ -562,7 +578,7 @@ cdef class ReadBuffer: ssize_t total_bytes = 0 bytearray result - if not self.has_message_type(mtype): + if not self.take_message_type(mtype): return None # consume_messages is a volume-oriented method, so @@ -571,26 +587,24 @@ cdef class ReadBuffer: result = PyByteArray_FromStringAndSize(NULL, self._length) buf = PyByteArray_AsString(result) - while self.has_message_type(mtype): + while self.take_message_type(mtype): nbytes = self._current_message_len_unread self._read(buf, nbytes) buf += nbytes total_bytes += nbytes - self._discard_message() + self._finish_message() # Clamp the result to an actual size read. PyByteArray_Resize(result, total_bytes) return result - cdef discard_message(self): - if self._current_message_type == 0: - # Already discarded + cdef finish_message(self): + if self._current_message_type == 0 or not self._current_message_ready: + # The message has already been finished (e.g by consume_message()), + # or has been put back by put_message(). return - if not self._current_message_ready: - raise BufferError('no message to discard') - if self._current_message_len_unread: if ASYNCPG_DEBUG: mtype = chr(self._current_message_type) @@ -602,9 +616,9 @@ cdef class ReadBuffer: mtype, (discarded).as_bytes())) - self._discard_message() + self._finish_message() - cdef inline _discard_message(self): + cdef inline _finish_message(self): self._current_message_type = 0 self._current_message_len = 0 self._current_message_ready = 0 diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index 21498e7d..2927d4a9 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -22,8 +22,6 @@ cdef class CoreProtocol: self.xact_status = PQTRANS_IDLE self.encoding = 'utf-8' - self._skip_discard = False - # executemany support data self._execute_iter = None self._execute_portal_name = None @@ -36,7 +34,7 @@ cdef class CoreProtocol: char mtype ProtocolState state - while self.buffer.has_message() == 1: + while self.buffer.take_message() == 1: mtype = self.buffer.get_message_type() state = self.state @@ -150,10 +148,7 @@ cdef class CoreProtocol: self._push_result() finally: - if self._skip_discard: - self._skip_discard = False - else: - self.buffer.discard_message() + self.buffer.finish_message() cdef _process__auth(self, char mtype): if mtype == b'R': @@ -319,8 +314,6 @@ cdef class CoreProtocol: self.result = buf.consume_messages(b'd') - self._skip_discard = True - # By this point we have consumed all CopyData messages # in the inbound buffer. If there are no messages left # in the buffer, we need to push the accumulated data @@ -328,9 +321,13 @@ cdef class CoreProtocol: # batch. If there _are_ non-CopyData messages left, # we must not push the result here and let the # _process__copy_out_data subprotocol do the job. - if not buf.has_message(): + if not buf.take_message(): self._on_result() self.result = None + else: + # If there is a message in the buffer, put it back to + # be processed by the next protocol iteration. + buf.put_message() cdef _write_copy_data_msg(self, object data): cdef: @@ -385,11 +382,9 @@ cdef class CoreProtocol: '_parse_data_msgs: first message is not "D"') if self._discard_data: - while True: + while buf.take_message_type(b'D'): buf.consume_message() - if not buf.has_message() or buf.get_message_type() != b'D': - self._skip_discard = True - return + return if ASYNCPG_DEBUG: if type(self.result) is not list: @@ -398,7 +393,7 @@ cdef class CoreProtocol: format(self.result)) rows = self.result - while True: + while buf.take_message_type(b'D'): cbuf = buf.try_consume_message(&cbuf_len) if cbuf != NULL: row = decoder(self, cbuf, cbuf_len) @@ -408,10 +403,6 @@ cdef class CoreProtocol: cpython.PyList_Append(rows, row) - if not buf.has_message() or buf.get_message_type() != b'D': - self._skip_discard = True - return - cdef _parse_msg_backend_key_data(self): self.backend_pid = self.buffer.read_int32() self.backend_secret = self.buffer.read_int32()