Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 8554128

Browse files
author
David Robertson
committed
Use Concatenate to annotate do_execute
I'm not sure this gives us a huge amount of type safety, see this comment: #12312 (comment) In any case, it's a nice bit of practice with `ParamSpec`.
1 parent 0ce2201 commit 8554128

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ netaddr = ">=0.7.18"
142142
# add a lower bound to the Jinja2 dependency.
143143
Jinja2 = ">=3.0"
144144
bleach = ">=1.4.3"
145-
# We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0.
145+
# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
146146
typing-extensions = ">=3.10.0"
147147
# We enforce that we have a `cryptography` version that bundles an `openssl`
148148
# with the latest security patches.

synapse/storage/database.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
import attr
4040
from prometheus_client import Histogram
41-
from typing_extensions import Literal
41+
from typing_extensions import Concatenate, Literal, ParamSpec
4242

4343
from twisted.enterprise import adbapi
4444

@@ -194,7 +194,7 @@ def __getattr__(self, name):
194194
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
195195
_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
196196

197-
197+
P = ParamSpec("P")
198198
R = TypeVar("R")
199199

200200

@@ -339,7 +339,13 @@ def _make_sql_one_line(self, sql: str) -> str:
339339
"Strip newlines out of SQL so that the loggers in the DB are on one line"
340340
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
341341

342-
def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
342+
def _do_execute(
343+
self,
344+
func: Callable[Concatenate[str, P], R],
345+
sql: str,
346+
*args: P.args,
347+
**kwargs: P.kwargs,
348+
) -> R:
343349
sql = self._make_sql_one_line(sql)
344350

345351
# TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -348,7 +354,10 @@ def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
348354
sql = self.database_engine.convert_param_style(sql)
349355
if args:
350356
try:
351-
sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
357+
# The type-ignore should be redundant once mypy releases a version with
358+
# https://github.com/python/mypy/pull/12668. (`args` might be empty,
359+
# (but we'll catch the index error if so.)
360+
sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) # type: ignore[index]
352361
except Exception:
353362
# Don't let logging failures stop SQL from working
354363
pass
@@ -363,7 +372,7 @@ def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
363372
opentracing.tags.DATABASE_STATEMENT: sql,
364373
},
365374
):
366-
return func(sql, *args)
375+
return func(sql, *args, **kwargs)
367376
except Exception as e:
368377
sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
369378
raise

0 commit comments

Comments
 (0)