|
| 1 | +from logging import getLogger |
1 | 2 | from typing import Any, Union
|
2 | 3 |
|
3 | 4 | from ..exceptions import ConnectionError, InvalidResponse, ResponseError
|
|
9 | 10 | class _RESP3Parser(_RESPBase):
|
10 | 11 | """RESP3 protocol implementation"""
|
11 | 12 |
|
12 |
| - def read_response(self, disable_decoding=False): |
| 13 | + def __init__(self, socket_read_size): |
| 14 | + super().__init__(socket_read_size) |
| 15 | + self.push_handler_func = self.handle_push_response |
| 16 | + |
| 17 | + def handle_push_response(self, response): |
| 18 | + logger = getLogger("push_response") |
| 19 | + logger.info("Push response: " + str(response)) |
| 20 | + return response |
| 21 | + |
| 22 | + def read_response(self, disable_decoding=False, push_request=False): |
13 | 23 | pos = self._buffer.get_pos()
|
14 | 24 | try:
|
15 |
| - result = self._read_response(disable_decoding=disable_decoding) |
| 25 | + result = self._read_response( |
| 26 | + disable_decoding=disable_decoding, push_request=push_request |
| 27 | + ) |
16 | 28 | except BaseException:
|
17 | 29 | self._buffer.rewind(pos)
|
18 | 30 | raise
|
19 | 31 | else:
|
20 | 32 | self._buffer.purge()
|
21 | 33 | return result
|
22 | 34 |
|
23 |
| - def _read_response(self, disable_decoding=False): |
| 35 | + def _read_response(self, disable_decoding=False, push_request=False): |
24 | 36 | raw = self._buffer.readline()
|
25 | 37 | if not raw:
|
26 | 38 | raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
@@ -77,31 +89,64 @@ def _read_response(self, disable_decoding=False):
|
77 | 89 | response = {
|
78 | 90 | self._read_response(
|
79 | 91 | disable_decoding=disable_decoding
|
80 |
| - ): self._read_response(disable_decoding=disable_decoding) |
| 92 | + ): self._read_response( |
| 93 | + disable_decoding=disable_decoding, push_request=push_request |
| 94 | + ) |
81 | 95 | for _ in range(int(response))
|
82 | 96 | }
|
| 97 | + # push response |
| 98 | + elif byte == b">": |
| 99 | + response = [ |
| 100 | + self._read_response( |
| 101 | + disable_decoding=disable_decoding, push_request=push_request |
| 102 | + ) |
| 103 | + for _ in range(int(response)) |
| 104 | + ] |
| 105 | + res = self.push_handler_func(response) |
| 106 | + if not push_request: |
| 107 | + return self._read_response( |
| 108 | + disable_decoding=disable_decoding, push_request=push_request |
| 109 | + ) |
| 110 | + else: |
| 111 | + return res |
83 | 112 | else:
|
84 | 113 | raise InvalidResponse(f"Protocol Error: {raw!r}")
|
85 | 114 |
|
86 | 115 | if isinstance(response, bytes) and disable_decoding is False:
|
87 | 116 | response = self.encoder.decode(response)
|
88 | 117 | return response
|
89 | 118 |
|
| 119 | + def set_push_handler(self, push_handler_func): |
| 120 | + self.push_handler_func = push_handler_func |
| 121 | + |
90 | 122 |
|
91 | 123 | class _AsyncRESP3Parser(_AsyncRESPBase):
|
92 |
| - async def read_response(self, disable_decoding: bool = False): |
| 124 | + def __init__(self, socket_read_size): |
| 125 | + super().__init__(socket_read_size) |
| 126 | + self.push_handler_func = self.handle_push_response |
| 127 | + |
| 128 | + def handle_push_response(self, response): |
| 129 | + logger = getLogger("push_response") |
| 130 | + logger.info("Push response: " + str(response)) |
| 131 | + return response |
| 132 | + |
| 133 | + async def read_response( |
| 134 | + self, disable_decoding: bool = False, push_request: bool = False |
| 135 | + ): |
93 | 136 | if self._chunks:
|
94 | 137 | # augment parsing buffer with previously read data
|
95 | 138 | self._buffer += b"".join(self._chunks)
|
96 | 139 | self._chunks.clear()
|
97 | 140 | self._pos = 0
|
98 |
| - response = await self._read_response(disable_decoding=disable_decoding) |
| 141 | + response = await self._read_response( |
| 142 | + disable_decoding=disable_decoding, push_request=push_request |
| 143 | + ) |
99 | 144 | # Successfully parsing a response allows us to clear our parsing buffer
|
100 | 145 | self._clear()
|
101 | 146 | return response
|
102 | 147 |
|
103 | 148 | async def _read_response(
|
104 |
| - self, disable_decoding: bool = False |
| 149 | + self, disable_decoding: bool = False, push_request: bool = False |
105 | 150 | ) -> Union[EncodableT, ResponseError, None]:
|
106 | 151 | if not self._stream or not self.encoder:
|
107 | 152 | raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
@@ -166,9 +211,31 @@ async def _read_response(
|
166 | 211 | )
|
167 | 212 | for _ in range(int(response))
|
168 | 213 | }
|
| 214 | + # push response |
| 215 | + elif byte == b">": |
| 216 | + response = [ |
| 217 | + ( |
| 218 | + await self._read_response( |
| 219 | + disable_decoding=disable_decoding, push_request=push_request |
| 220 | + ) |
| 221 | + ) |
| 222 | + for _ in range(int(response)) |
| 223 | + ] |
| 224 | + res = self.push_handler_func(response) |
| 225 | + if not push_request: |
| 226 | + return await ( |
| 227 | + self._read_response( |
| 228 | + disable_decoding=disable_decoding, push_request=push_request |
| 229 | + ) |
| 230 | + ) |
| 231 | + else: |
| 232 | + return res |
169 | 233 | else:
|
170 | 234 | raise InvalidResponse(f"Protocol Error: {raw!r}")
|
171 | 235 |
|
172 | 236 | if isinstance(response, bytes) and disable_decoding is False:
|
173 | 237 | response = self.encoder.decode(response)
|
174 | 238 | return response
|
| 239 | + |
| 240 | + def set_push_handler(self, push_handler_func): |
| 241 | + self.push_handler_func = push_handler_func |
0 commit comments