From af8eb76661abb5e9bca3d410ba0553d5ab126bee Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 1 Oct 2024 16:42:38 -0700 Subject: [PATCH] Add a warning rudimentary system In this initial iteration of it, warnings are encoded in JSON as a header for CommandDataDescription. A `warning_handler` field is added to `Options`, and it is invoked if any warnings occured. We provide two built-in functions: `log_warnings` and `raise_warnings`. The default is `log_warnings`. --- edgedb/abstract.py | 20 +++++++++++ edgedb/base_client.py | 18 +++++++--- edgedb/errors/_base.py | 25 ++++++++++++- edgedb/options.py | 70 +++++++++++++++++++++++++++++++++++- edgedb/protocol/protocol.pxd | 2 ++ edgedb/protocol/protocol.pyx | 29 ++++++++++++++- edgedb/transaction.py | 3 ++ 7 files changed, 159 insertions(+), 8 deletions(-) diff --git a/edgedb/abstract.py b/edgedb/abstract.py index dbac4503..4575f22a 100644 --- a/edgedb/abstract.py +++ b/edgedb/abstract.py @@ -64,6 +64,7 @@ class QueryContext(typing.NamedTuple): query_options: QueryOptions retry_options: typing.Optional[options.RetryOptions] state: typing.Optional[options.State] + warning_handler: options.WarningHandler def lower( self, *, allow_capabilities: enums.Capability @@ -86,6 +87,7 @@ class ExecuteContext(typing.NamedTuple): query: QueryWithArgs cache: QueryCache state: typing.Optional[options.State] + warning_handler: options.WarningHandler def lower( self, *, allow_capabilities: enums.Capability @@ -181,6 +183,10 @@ def _get_retry_options(self) -> typing.Optional[options.RetryOptions]: def _get_state(self) -> options.State: ... + @abc.abstractmethod + def _get_warning_handler(self) -> options.WarningHandler: + ... + class ReadOnlyExecutor(BaseReadOnlyExecutor): """Subclasses can execute *at least* read-only queries""" @@ -198,6 +204,7 @@ def query(self, query: str, *args, **kwargs) -> list: query_options=_query_opts, retry_options=self._get_retry_options(), state=self._get_state(), + warning_handlerr=self._get_warning_handler(), )) def query_single( @@ -209,6 +216,7 @@ def query_single( query_options=_query_single_opts, retry_options=self._get_retry_options(), state=self._get_state(), + warning_handler=self._get_warning_handler(), )) def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: @@ -218,6 +226,7 @@ def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: query_options=_query_required_single_opts, retry_options=self._get_retry_options(), state=self._get_state(), + warning_handler=self._get_warning_handler(), )) def query_json(self, query: str, *args, **kwargs) -> str: @@ -227,6 +236,7 @@ def query_json(self, query: str, *args, **kwargs) -> str: query_options=_query_json_opts, retry_options=self._get_retry_options(), state=self._get_state(), + warning_handler=self._get_warning_handler(), )) def query_single_json(self, query: str, *args, **kwargs) -> str: @@ -236,6 +246,7 @@ def query_single_json(self, query: str, *args, **kwargs) -> str: query_options=_query_single_json_opts, retry_options=self._get_retry_options(), state=self._get_state(), + warning_handler=self._get_warning_handler(), )) def query_required_single_json(self, query: str, *args, **kwargs) -> str: @@ -245,6 +256,7 @@ def query_required_single_json(self, query: str, *args, **kwargs) -> str: query_options=_query_required_single_json_opts, retry_options=self._get_retry_options(), state=self._get_state(), + warning_handler=self._get_warning_handler(), )) @abc.abstractmethod @@ -256,6 +268,7 @@ def execute(self, commands: str, *args, **kwargs) -> None: query=QueryWithArgs(commands, args, kwargs), cache=self._get_query_cache(), state=self._get_state(), + warning_handler=self._get_warning_handler(), )) @@ -281,6 +294,7 @@ async def query(self, query: str, *args, **kwargs) -> list: query_options=_query_opts, retry_options=self._get_retry_options(), state=self._get_state(), + warning_handler=self._get_warning_handler(), )) async def query_single(self, query: str, *args, **kwargs) -> typing.Any: @@ -290,6 +304,7 @@ async def query_single(self, query: str, *args, **kwargs) -> typing.Any: query_options=_query_single_opts, retry_options=self._get_retry_options(), state=self._get_state(), + warning_handler=self._get_warning_handler(), )) async def query_required_single( @@ -304,6 +319,7 @@ async def query_required_single( query_options=_query_required_single_opts, retry_options=self._get_retry_options(), state=self._get_state(), + warning_handler=self._get_warning_handler(), )) async def query_json(self, query: str, *args, **kwargs) -> str: @@ -313,6 +329,7 @@ async def query_json(self, query: str, *args, **kwargs) -> str: query_options=_query_json_opts, retry_options=self._get_retry_options(), state=self._get_state(), + warning_handler=self._get_warning_handler(), )) async def query_single_json(self, query: str, *args, **kwargs) -> str: @@ -322,6 +339,7 @@ async def query_single_json(self, query: str, *args, **kwargs) -> str: query_options=_query_single_json_opts, retry_options=self._get_retry_options(), state=self._get_state(), + warning_handler=self._get_warning_handler(), )) async def query_required_single_json( @@ -336,6 +354,7 @@ async def query_required_single_json( query_options=_query_required_single_json_opts, retry_options=self._get_retry_options(), state=self._get_state(), + warning_handler=self._get_warning_handler(), )) @abc.abstractmethod @@ -347,6 +366,7 @@ async def execute(self, commands: str, *args, **kwargs) -> None: query=QueryWithArgs(commands, args, kwargs), cache=self._get_query_cache(), state=self._get_state(), + warning_handler=self._get_warning_handler(), )) diff --git a/edgedb/base_client.py b/edgedb/base_client.py index 382da052..7ab2d76c 100644 --- a/edgedb/base_client.py +++ b/edgedb/base_client.py @@ -216,7 +216,11 @@ async def raw_query(self, query_context: abstract.QueryContext): if self._protocol.is_legacy: return await self._protocol.legacy_execute_anonymous(ctx) else: - return await self._protocol.query(ctx) + res = await self._protocol.query(ctx) + if ctx.warnings: + res = query_context.warning_handler(ctx.warnings, res) + return res + except errors.EdgeDBError as e: if query_context.retry_options is None: raise @@ -246,11 +250,12 @@ async def _execute(self, execute_context: abstract.ExecuteContext) -> None: execute_context.query.query, enums.Capability.LEGACY_EXECUTE ) else: - await self._protocol.execute( - execute_context.lower( - allow_capabilities=enums.Capability.EXECUTE - ) + ctx = execute_context.lower( + allow_capabilities=enums.Capability.EXECUTE ) + res = await self._protocol.execute(ctx) + if ctx.warnings: + res = execute_context.warning_handler(ctx.warnings, res) async def describe( self, describe_context: abstract.DescribeContext @@ -684,6 +689,9 @@ def _get_retry_options(self) -> typing.Optional[_options.RetryOptions]: def _get_state(self) -> _options.State: return self._options.state + def _get_warning_handler(self) -> _options.WarningHandler: + return self._options.warning_handler + @property def max_concurrency(self) -> int: """Max number of connections in the pool.""" diff --git a/edgedb/errors/_base.py b/edgedb/errors/_base.py index 11f5c26d..03ce547c 100644 --- a/edgedb/errors/_base.py +++ b/edgedb/errors/_base.py @@ -132,8 +132,10 @@ def _details(self): def _read_str_field(self, key, default=None): val = self._attrs.get(key) - if val: + if isinstance(val, bytes): return val.decode('utf-8') + elif val is not None: + return val return default def get_code(self): @@ -149,6 +151,16 @@ def _from_code(code, *args, **kwargs): exc._code = code return exc + @staticmethod + def _from_json(data): + exc = EdgeDBError._from_code(data['code'], data['message']) + exc._attrs = { + field: data[name] + for name, field in _JSON_FIELDS.items() + if name in data + } + return exc + def __str__(self): msg = super().__str__() if SHOW_HINT and self._query and self._position_start >= 0: @@ -323,6 +335,17 @@ def _unicode_width(text): EDGE_SEVERITY_PANIC = 255 +# Fields to include in the json dump of the type +_JSON_FIELDS = { + 'hint': FIELD_HINT, + 'details': FIELD_DETAILS, + 'start': FIELD_CHARACTER_START, + 'end': FIELD_CHARACTER_END, + 'line': FIELD_LINE_START, + 'col': FIELD_COLUMN_START, +} + + LINESEP = os.linesep try: diff --git a/edgedb/options.py b/edgedb/options.py index f325033c..274eca4f 100644 --- a/edgedb/options.py +++ b/edgedb/options.py @@ -1,12 +1,17 @@ import abc import enum +import logging import random import typing +import sys from collections import namedtuple from . import errors +logger = logging.getLogger('edgedb') + + _RetryRule = namedtuple("_RetryRule", ["attempts", "backoff"]) @@ -14,6 +19,29 @@ def default_backoff(attempt): return (2 ** attempt) * 0.1 + random.randrange(100) * 0.001 +WarningHandler = typing.Callable[ + [typing.Tuple[errors.EdgeDBError, ...], typing.Any], + typing.Any, +] + +def raise_warnings(warnings, res): + if ( + len(warnings) > 1 + and sys.version_info >= (3, 11) + ): + raise ExceptionGroup( + "Query produced warnings", warnings + ) + else: + raise warnings[0] + + +def log_warnings(warnings, res): + for w in warnings: + logger.warning("EdgeDB warning: %s", str(w)) + return res + + class RetryCondition: """Specific condition to retry on for fine-grained control""" TransactionConflict = enum.auto() @@ -311,6 +339,25 @@ def with_retry_options(self, options: RetryOptions=None): result._options = self._options.with_retry_options(options) return result + def with_warning_handler(self, warning_handler: WarningHandler=None): + """Returns object with adjusted options for handling warnings. + + :param warning_handler WarningHandler: + Function for handling warnings. It is passed a tuple of warnings + and the query result and returns a potentially updated query + result. + + This method returns a "shallow copy" of the current object + with modified retry options. + + Both ``self`` and returned object can be used after, but when using + them retry options applied will be different. + """ + + result = self._shallow_clone() + result._options = self._options.with_warning_handler(warning_handler) + return result + def with_state(self, state: State): result = self._shallow_clone() result._options = self._options.with_state(state) @@ -369,17 +416,22 @@ def without_globals(self, *global_names): class _Options: """Internal class for storing connection options""" - __slots__ = ['_retry_options', '_transaction_options', '_state'] + __slots__ = [ + '_retry_options', '_transaction_options', '_state', + '_warning_handler' + ] def __init__( self, retry_options: RetryOptions, transaction_options: TransactionOptions, state: State, + warning_handler: WarningHandler, ): self._retry_options = retry_options self._transaction_options = transaction_options self._state = state + self._warning_handler = warning_handler @property def retry_options(self): @@ -393,11 +445,16 @@ def transaction_options(self): def state(self): return self._state + @property + def warning_handler(self): + return self._warning_handler + def with_retry_options(self, options: RetryOptions): return _Options( options, self._transaction_options, self._state, + self._warning_handler, ) def with_transaction_options(self, options: TransactionOptions): @@ -405,6 +462,7 @@ def with_transaction_options(self, options: TransactionOptions): self._retry_options, options, self._state, + self._warning_handler, ) def with_state(self, state: State): @@ -412,6 +470,15 @@ def with_state(self, state: State): self._retry_options, self._transaction_options, state, + self._warning_handler, + ) + + def with_warning_handler(self, warning_handler: WarningHandler): + return _Options( + self._retry_options, + self._transaction_options, + self._state, + warning_handler, ) @classmethod @@ -420,4 +487,5 @@ def defaults(cls): RetryOptions.defaults(), TransactionOptions.defaults(), State.defaults(), + log_warnings, ) diff --git a/edgedb/protocol/protocol.pxd b/edgedb/protocol/protocol.pxd index 7140ae4e..bc1f553e 100644 --- a/edgedb/protocol/protocol.pxd +++ b/edgedb/protocol/protocol.pxd @@ -89,6 +89,7 @@ cdef class ExecuteContext: readonly BaseCodec in_dc readonly BaseCodec out_dc readonly uint64_t capabilities + readonly tuple warnings cdef inline bint has_na_cardinality(self) cdef bint load_from_cache(self) @@ -142,6 +143,7 @@ cdef class SansIOProtocol: ) cdef inline ignore_headers(self) + cdef inline dict read_headers(self) cdef dict parse_error_headers(self) cdef parse_error_message(self) diff --git a/edgedb/protocol/protocol.pyx b/edgedb/protocol/protocol.pyx index 3e0643d3..11421043 100644 --- a/edgedb/protocol/protocol.pyx +++ b/edgedb/protocol/protocol.pyx @@ -126,6 +126,7 @@ cdef class ExecuteContext: self.cardinality = None self.in_dc = self.out_dc = None self.capabilities = 0 + self.warnings = () cdef inline bint has_na_cardinality(self): return self.cardinality == CARDINALITY_NOT_APPLICABLE @@ -230,6 +231,23 @@ cdef class SansIOProtocol: self.buffer.read_len_prefixed_bytes() # value num_fields -= 1 + cdef inline dict read_headers(self): + cdef uint16_t num_fields = self.buffer.read_int16() + headers = {} + if self.is_legacy: + while num_fields: + self.buffer.read_int16() # key + self.buffer.read_len_prefixed_bytes() # value + num_fields -= 1 + else: + while num_fields: + key = self.buffer.read_len_prefixed_utf8() + value = self.buffer.read_len_prefixed_utf8() + headers[key] = value + num_fields -= 1 + + return headers + cdef ensure_connected(self): if self.cancelled: raise errors.ClientConnectionClosedError( @@ -987,7 +1005,16 @@ cdef class SansIOProtocol: assert self.buffer.get_message_type() == COMMAND_DATA_DESC_MSG try: - self.ignore_headers() + headers = self.read_headers() + if headers and 'warnings' in headers: + warnings = tuple([ + errors.EdgeDBError._from_json(w) + for w in json.loads(headers['warnings']) + ]) + for w in warnings: + w._query = ctx.query + ctx.warnings = warnings + ctx.capabilities = self.buffer.read_int64() ctx.cardinality = self.buffer.read_byte() ctx.in_dc, ctx.out_dc = self.parse_type_data(ctx.reg) diff --git a/edgedb/transaction.py b/edgedb/transaction.py index 511b8f42..17f3d2ff 100644 --- a/edgedb/transaction.py +++ b/edgedb/transaction.py @@ -185,6 +185,9 @@ def _get_query_cache(self) -> abstract.QueryCache: def _get_state(self) -> options.State: return self._client._get_state() + def _get_warning_handler(self) -> options.WarningHandler: + return self._client._get_warning_handler() + async def _query(self, query_context: abstract.QueryContext): await self._ensure_transaction() return await self._connection.raw_query(query_context)