Skip to content

Commit

Permalink
Add a warning rudimentary system
Browse files Browse the repository at this point in the history
In this initial iteration of it, warnings are encoded in JSON as a
header for CommandDataDescription.

A `warning_handler` field is added to `Options`, and it is invoked if
any warnings occured. We provide two built-in functions:
`log_warnings` and `raise_warnings`. The default is `log_warnings`.
  • Loading branch information
msullivan committed Oct 1, 2024
1 parent 65d5cdc commit af8eb76
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 8 deletions.
20 changes: 20 additions & 0 deletions edgedb/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class QueryContext(typing.NamedTuple):
query_options: QueryOptions
retry_options: typing.Optional[options.RetryOptions]
state: typing.Optional[options.State]
warning_handler: options.WarningHandler

def lower(
self, *, allow_capabilities: enums.Capability
Expand All @@ -86,6 +87,7 @@ class ExecuteContext(typing.NamedTuple):
query: QueryWithArgs
cache: QueryCache
state: typing.Optional[options.State]
warning_handler: options.WarningHandler

def lower(
self, *, allow_capabilities: enums.Capability
Expand Down Expand Up @@ -181,6 +183,10 @@ def _get_retry_options(self) -> typing.Optional[options.RetryOptions]:
def _get_state(self) -> options.State:
...

@abc.abstractmethod
def _get_warning_handler(self) -> options.WarningHandler:
...


class ReadOnlyExecutor(BaseReadOnlyExecutor):
"""Subclasses can execute *at least* read-only queries"""
Expand All @@ -198,6 +204,7 @@ def query(self, query: str, *args, **kwargs) -> list:
query_options=_query_opts,
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handlerr=self._get_warning_handler(),
))

def query_single(
Expand All @@ -209,6 +216,7 @@ def query_single(
query_options=_query_single_opts,
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
))

def query_required_single(self, query: str, *args, **kwargs) -> typing.Any:
Expand All @@ -218,6 +226,7 @@ def query_required_single(self, query: str, *args, **kwargs) -> typing.Any:
query_options=_query_required_single_opts,
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
))

def query_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -227,6 +236,7 @@ def query_json(self, query: str, *args, **kwargs) -> str:
query_options=_query_json_opts,
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
))

def query_single_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -236,6 +246,7 @@ def query_single_json(self, query: str, *args, **kwargs) -> str:
query_options=_query_single_json_opts,
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
))

def query_required_single_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -245,6 +256,7 @@ def query_required_single_json(self, query: str, *args, **kwargs) -> str:
query_options=_query_required_single_json_opts,
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
))

@abc.abstractmethod
Expand All @@ -256,6 +268,7 @@ def execute(self, commands: str, *args, **kwargs) -> None:
query=QueryWithArgs(commands, args, kwargs),
cache=self._get_query_cache(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
))


Expand All @@ -281,6 +294,7 @@ async def query(self, query: str, *args, **kwargs) -> list:
query_options=_query_opts,
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
))

async def query_single(self, query: str, *args, **kwargs) -> typing.Any:
Expand All @@ -290,6 +304,7 @@ async def query_single(self, query: str, *args, **kwargs) -> typing.Any:
query_options=_query_single_opts,
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
))

async def query_required_single(
Expand All @@ -304,6 +319,7 @@ async def query_required_single(
query_options=_query_required_single_opts,
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
))

async def query_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -313,6 +329,7 @@ async def query_json(self, query: str, *args, **kwargs) -> str:
query_options=_query_json_opts,
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
))

async def query_single_json(self, query: str, *args, **kwargs) -> str:
Expand All @@ -322,6 +339,7 @@ async def query_single_json(self, query: str, *args, **kwargs) -> str:
query_options=_query_single_json_opts,
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
))

async def query_required_single_json(
Expand All @@ -336,6 +354,7 @@ async def query_required_single_json(
query_options=_query_required_single_json_opts,
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
))

@abc.abstractmethod
Expand All @@ -347,6 +366,7 @@ async def execute(self, commands: str, *args, **kwargs) -> None:
query=QueryWithArgs(commands, args, kwargs),
cache=self._get_query_cache(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
))


Expand Down
18 changes: 13 additions & 5 deletions edgedb/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,11 @@ async def raw_query(self, query_context: abstract.QueryContext):
if self._protocol.is_legacy:
return await self._protocol.legacy_execute_anonymous(ctx)
else:
return await self._protocol.query(ctx)
res = await self._protocol.query(ctx)
if ctx.warnings:
res = query_context.warning_handler(ctx.warnings, res)
return res

except errors.EdgeDBError as e:
if query_context.retry_options is None:
raise
Expand Down Expand Up @@ -246,11 +250,12 @@ async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
execute_context.query.query, enums.Capability.LEGACY_EXECUTE
)
else:
await self._protocol.execute(
execute_context.lower(
allow_capabilities=enums.Capability.EXECUTE
)
ctx = execute_context.lower(
allow_capabilities=enums.Capability.EXECUTE
)
res = await self._protocol.execute(ctx)
if ctx.warnings:
res = execute_context.warning_handler(ctx.warnings, res)

async def describe(
self, describe_context: abstract.DescribeContext
Expand Down Expand Up @@ -684,6 +689,9 @@ def _get_retry_options(self) -> typing.Optional[_options.RetryOptions]:
def _get_state(self) -> _options.State:
return self._options.state

def _get_warning_handler(self) -> _options.WarningHandler:
return self._options.warning_handler

@property
def max_concurrency(self) -> int:
"""Max number of connections in the pool."""
Expand Down
25 changes: 24 additions & 1 deletion edgedb/errors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ def _details(self):

def _read_str_field(self, key, default=None):
val = self._attrs.get(key)
if val:
if isinstance(val, bytes):
return val.decode('utf-8')
elif val is not None:
return val
return default

def get_code(self):
Expand All @@ -149,6 +151,16 @@ def _from_code(code, *args, **kwargs):
exc._code = code
return exc

@staticmethod
def _from_json(data):
exc = EdgeDBError._from_code(data['code'], data['message'])
exc._attrs = {
field: data[name]
for name, field in _JSON_FIELDS.items()
if name in data
}
return exc

def __str__(self):
msg = super().__str__()
if SHOW_HINT and self._query and self._position_start >= 0:
Expand Down Expand Up @@ -323,6 +335,17 @@ def _unicode_width(text):
EDGE_SEVERITY_PANIC = 255


# Fields to include in the json dump of the type
_JSON_FIELDS = {
'hint': FIELD_HINT,
'details': FIELD_DETAILS,
'start': FIELD_CHARACTER_START,
'end': FIELD_CHARACTER_END,
'line': FIELD_LINE_START,
'col': FIELD_COLUMN_START,
}


LINESEP = os.linesep

try:
Expand Down
70 changes: 69 additions & 1 deletion edgedb/options.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,47 @@
import abc
import enum
import logging
import random
import typing
import sys
from collections import namedtuple

from . import errors


logger = logging.getLogger('edgedb')


_RetryRule = namedtuple("_RetryRule", ["attempts", "backoff"])


def default_backoff(attempt):
return (2 ** attempt) * 0.1 + random.randrange(100) * 0.001


WarningHandler = typing.Callable[
[typing.Tuple[errors.EdgeDBError, ...], typing.Any],
typing.Any,
]

def raise_warnings(warnings, res):
if (
len(warnings) > 1
and sys.version_info >= (3, 11)
):
raise ExceptionGroup(
"Query produced warnings", warnings
)
else:
raise warnings[0]


def log_warnings(warnings, res):
for w in warnings:
logger.warning("EdgeDB warning: %s", str(w))
return res


class RetryCondition:
"""Specific condition to retry on for fine-grained control"""
TransactionConflict = enum.auto()
Expand Down Expand Up @@ -311,6 +339,25 @@ def with_retry_options(self, options: RetryOptions=None):
result._options = self._options.with_retry_options(options)
return result

def with_warning_handler(self, warning_handler: WarningHandler=None):
"""Returns object with adjusted options for handling warnings.
:param warning_handler WarningHandler:
Function for handling warnings. It is passed a tuple of warnings
and the query result and returns a potentially updated query
result.
This method returns a "shallow copy" of the current object
with modified retry options.
Both ``self`` and returned object can be used after, but when using
them retry options applied will be different.
"""

result = self._shallow_clone()
result._options = self._options.with_warning_handler(warning_handler)
return result

def with_state(self, state: State):
result = self._shallow_clone()
result._options = self._options.with_state(state)
Expand Down Expand Up @@ -369,17 +416,22 @@ def without_globals(self, *global_names):
class _Options:
"""Internal class for storing connection options"""

__slots__ = ['_retry_options', '_transaction_options', '_state']
__slots__ = [
'_retry_options', '_transaction_options', '_state',
'_warning_handler'
]

def __init__(
self,
retry_options: RetryOptions,
transaction_options: TransactionOptions,
state: State,
warning_handler: WarningHandler,
):
self._retry_options = retry_options
self._transaction_options = transaction_options
self._state = state
self._warning_handler = warning_handler

@property
def retry_options(self):
Expand All @@ -393,25 +445,40 @@ def transaction_options(self):
def state(self):
return self._state

@property
def warning_handler(self):
return self._warning_handler

def with_retry_options(self, options: RetryOptions):
return _Options(
options,
self._transaction_options,
self._state,
self._warning_handler,
)

def with_transaction_options(self, options: TransactionOptions):
return _Options(
self._retry_options,
options,
self._state,
self._warning_handler,
)

def with_state(self, state: State):
return _Options(
self._retry_options,
self._transaction_options,
state,
self._warning_handler,
)

def with_warning_handler(self, warning_handler: WarningHandler):
return _Options(
self._retry_options,
self._transaction_options,
self._state,
warning_handler,
)

@classmethod
Expand All @@ -420,4 +487,5 @@ def defaults(cls):
RetryOptions.defaults(),
TransactionOptions.defaults(),
State.defaults(),
log_warnings,
)
2 changes: 2 additions & 0 deletions edgedb/protocol/protocol.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ cdef class ExecuteContext:
readonly BaseCodec in_dc
readonly BaseCodec out_dc
readonly uint64_t capabilities
readonly tuple warnings

cdef inline bint has_na_cardinality(self)
cdef bint load_from_cache(self)
Expand Down Expand Up @@ -142,6 +143,7 @@ cdef class SansIOProtocol:
)

cdef inline ignore_headers(self)
cdef inline dict read_headers(self)
cdef dict parse_error_headers(self)

cdef parse_error_message(self)
Expand Down
Loading

0 comments on commit af8eb76

Please sign in to comment.