14
14
15
15
import logging
16
16
from functools import wraps
17
- from typing import Any , Callable , Optional , TypeVar , cast
17
+ from types import TracebackType
18
+ from typing import Any , Callable , Optional , Protocol , Type , TypeVar , cast
18
19
19
20
from prometheus_client import Counter
20
21
24
25
current_context ,
25
26
)
26
27
from synapse .metrics import InFlightGauge
28
+ from synapse .util import Clock
27
29
28
30
logger = logging .getLogger (__name__ )
29
31
61
63
sub_metrics = ["real_time_max" , "real_time_sum" ],
62
64
)
63
65
64
- T = TypeVar ("T" , bound = Callable [..., Any ])
66
+ R = TypeVar ("R" )
67
+ F = Callable [..., R ]
65
68
66
69
67
- def measure_func (name : Optional [str ] = None ) -> Callable [[T ], T ]:
70
+ class HasClock (Protocol ):
71
+ clock : Clock
72
+
73
+
74
+ def measure_func (name : Optional [str ] = None ) -> Callable [[F ], F ]:
68
75
"""
69
76
Used to decorate an async function with a `Measure` context manager.
70
77
@@ -82,16 +89,16 @@ async def foo(...):
82
89
83
90
"""
84
91
85
- def wrapper (func : T ) -> T :
92
+ def wrapper (func : F ) -> F :
86
93
block_name = func .__name__ if name is None else name
87
94
88
95
@wraps (func )
89
- async def measured_func (self , * args , ** kwargs ) :
96
+ async def measured_func (self : HasClock , * args : Any , ** kwargs : Any ) -> R :
90
97
with Measure (self .clock , block_name ):
91
98
r = await func (self , * args , ** kwargs )
92
99
return r
93
100
94
- return cast (T , measured_func )
101
+ return cast (F , measured_func )
95
102
96
103
return wrapper
97
104
@@ -104,10 +111,10 @@ class Measure:
104
111
"start" ,
105
112
]
106
113
107
- def __init__ (self , clock , name : str ):
114
+ def __init__ (self , clock : Clock , name : str ) -> None :
108
115
"""
109
116
Args:
110
- clock: A n object with a "time()" method, which returns the current
117
+ clock: An object with a "time()" method, which returns the current
111
118
time in seconds.
112
119
name: The name of the metric to report.
113
120
"""
@@ -124,7 +131,7 @@ def __init__(self, clock, name: str):
124
131
assert isinstance (curr_context , LoggingContext )
125
132
parent_context = curr_context
126
133
self ._logging_context = LoggingContext (str (curr_context ), parent_context )
127
- self .start : Optional [int ] = None
134
+ self .start : Optional [float ] = None
128
135
129
136
def __enter__ (self ) -> "Measure" :
130
137
if self .start is not None :
@@ -138,7 +145,12 @@ def __enter__(self) -> "Measure":
138
145
139
146
return self
140
147
141
- def __exit__ (self , exc_type , exc_val , exc_tb ):
148
+ def __exit__ (
149
+ self ,
150
+ exc_type : Optional [Type [BaseException ]],
151
+ exc_val : Optional [BaseException ],
152
+ exc_tb : Optional [TracebackType ],
153
+ ) -> None :
142
154
if self .start is None :
143
155
raise RuntimeError ("Measure() block exited without being entered" )
144
156
@@ -168,8 +180,9 @@ def get_resource_usage(self) -> ContextResourceUsage:
168
180
"""
169
181
return self ._logging_context .get_resource_usage ()
170
182
171
- def _update_in_flight (self , metrics ):
183
+ def _update_in_flight (self , metrics ) -> None :
172
184
"""Gets called when processing in flight metrics"""
185
+ assert self .start is not None
173
186
duration = self .clock .time () - self .start
174
187
175
188
metrics .real_time_max = max (metrics .real_time_max , duration )
0 commit comments