Skip to content

Commit

Permalink
Add type hints to logging/context.py (matrix-org#6309)
Browse files Browse the repository at this point in the history
* Add type hints to logging/context.py

Signed-off-by: neiljp (Neil Pilgrim) <github@kepier.clara.net>
  • Loading branch information
neiljp authored and phil-flex committed Mar 27, 2020
1 parent ffa9d7a commit 08d51b0
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 47 deletions.
1 change: 1 addition & 0 deletions changelog.d/6309.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `logging/context.py`.
121 changes: 74 additions & 47 deletions synapse/logging/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@
import logging
import threading
import types
from typing import Any, List
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union

from typing_extensions import Literal

from twisted.internet import defer, threads

if TYPE_CHECKING:
from synapse.logging.scopecontextmanager import _LogContextScope

logger = logging.getLogger(__name__)

try:
Expand Down Expand Up @@ -91,7 +96,7 @@ class ContextResourceUsage(object):
"evt_db_fetch_count",
]

def __init__(self, copy_from=None):
def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None:
"""Create a new ContextResourceUsage
Args:
Expand All @@ -101,27 +106,28 @@ def __init__(self, copy_from=None):
if copy_from is None:
self.reset()
else:
self.ru_utime = copy_from.ru_utime
self.ru_stime = copy_from.ru_stime
self.db_txn_count = copy_from.db_txn_count
# FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now
self.ru_utime = copy_from.ru_utime # type: float
self.ru_stime = copy_from.ru_stime # type: float
self.db_txn_count = copy_from.db_txn_count # type: int

self.db_txn_duration_sec = copy_from.db_txn_duration_sec
self.db_sched_duration_sec = copy_from.db_sched_duration_sec
self.evt_db_fetch_count = copy_from.evt_db_fetch_count
self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float
self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float
self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int

def copy(self):
def copy(self) -> "ContextResourceUsage":
return ContextResourceUsage(copy_from=self)

def reset(self):
def reset(self) -> None:
self.ru_stime = 0.0
self.ru_utime = 0.0
self.db_txn_count = 0

self.db_txn_duration_sec = 0
self.db_sched_duration_sec = 0
self.db_txn_duration_sec = 0.0
self.db_sched_duration_sec = 0.0
self.evt_db_fetch_count = 0

def __repr__(self):
def __repr__(self) -> str:
return (
"<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
"db_txn_count='%r', db_txn_duration_sec='%r', "
Expand All @@ -135,7 +141,7 @@ def __repr__(self):
self.evt_db_fetch_count,
)

def __iadd__(self, other):
def __iadd__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
"""Add another ContextResourceUsage's stats to this one's.
Args:
Expand All @@ -149,7 +155,7 @@ def __iadd__(self, other):
self.evt_db_fetch_count += other.evt_db_fetch_count
return self

def __isub__(self, other):
def __isub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
self.ru_utime -= other.ru_utime
self.ru_stime -= other.ru_stime
self.db_txn_count -= other.db_txn_count
Expand All @@ -158,17 +164,20 @@ def __isub__(self, other):
self.evt_db_fetch_count -= other.evt_db_fetch_count
return self

def __add__(self, other):
def __add__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
res = ContextResourceUsage(copy_from=self)
res += other
return res

def __sub__(self, other):
def __sub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
res = ContextResourceUsage(copy_from=self)
res -= other
return res


LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]


class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
"with" block.
Expand Down Expand Up @@ -201,7 +210,14 @@ class LoggingContext(object):
class Sentinel(object):
"""Sentinel to represent the root context"""

__slots__ = [] # type: List[Any]
__slots__ = ["previous_context", "alive", "request", "scope"]

def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.alive = None
self.request = None
self.scope = None

def __str__(self):
return "sentinel"
Expand Down Expand Up @@ -235,7 +251,7 @@ def __nonzero__(self):

sentinel = Sentinel()

def __init__(self, name=None, parent_context=None, request=None):
def __init__(self, name=None, parent_context=None, request=None) -> None:
self.previous_context = LoggingContext.current_context()
self.name = name

Expand All @@ -250,7 +266,7 @@ def __init__(self, name=None, parent_context=None, request=None):
self.request = None
self.tag = ""
self.alive = True
self.scope = None
self.scope = None # type: Optional[_LogContextScope]

self.parent_context = parent_context

Expand All @@ -261,13 +277,13 @@ def __init__(self, name=None, parent_context=None, request=None):
# the request param overrides the request from the parent context
self.request = request

def __str__(self):
def __str__(self) -> str:
if self.request:
return str(self.request)
return "%s@%x" % (self.name, id(self))

@classmethod
def current_context(cls):
def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage
Returns:
Expand All @@ -276,7 +292,9 @@ def current_context(cls):
return getattr(cls.thread_local, "current_context", cls.sentinel)

@classmethod
def set_current_context(cls, context):
def set_current_context(
cls, context: LoggingContextOrSentinel
) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
Args:
context(LoggingContext): The context to activate.
Expand All @@ -291,7 +309,7 @@ def set_current_context(cls, context):
context.start()
return current

def __enter__(self):
def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage"""
old_context = self.set_current_context(self)
if self.previous_context != old_context:
Expand All @@ -304,7 +322,7 @@ def __enter__(self):

return self

def __exit__(self, type, value, traceback):
def __exit__(self, type, value, traceback) -> None:
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
Expand All @@ -318,7 +336,6 @@ def __exit__(self, type, value, traceback):
logger.warning(
"Expected logging context %s but found %s", self, current
)
self.previous_context = None
self.alive = False

# if we have a parent, pass our CPU usage stats on
Expand All @@ -330,7 +347,7 @@ def __exit__(self, type, value, traceback):
# reset them in case we get entered again
self._resource_usage.reset()

def copy_to(self, record):
def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or
another LoggingContext
"""
Expand All @@ -341,14 +358,14 @@ def copy_to(self, record):
# we also track the current scope:
record.scope = self.scope

def copy_to_twisted_log_entry(self, record):
def copy_to_twisted_log_entry(self, record) -> None:
"""
Copy logging fields from this context to a Twisted log record.
"""
record["request"] = self.request
record["scope"] = self.scope

def start(self):
def start(self) -> None:
if get_thread_id() != self.main_thread:
logger.warning("Started logcontext %s on different thread", self)
return
Expand All @@ -358,7 +375,7 @@ def start(self):
if not self.usage_start:
self.usage_start = get_thread_resource_usage()

def stop(self):
def stop(self) -> None:
if get_thread_id() != self.main_thread:
logger.warning("Stopped logcontext %s on different thread", self)
return
Expand All @@ -378,7 +395,7 @@ def stop(self):

self.usage_start = None

def get_resource_usage(self):
def get_resource_usage(self) -> ContextResourceUsage:
"""Get resources used by this logcontext so far.
Returns:
Expand All @@ -398,11 +415,13 @@ def get_resource_usage(self):

return res

def _get_cputime(self):
def _get_cputime(self) -> Tuple[float, float]:
"""Get the cpu usage time so far
Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
"""
assert self.usage_start is not None

current = get_thread_resource_usage()

# Indicate to mypy that we know that self.usage_start is None.
Expand Down Expand Up @@ -430,13 +449,13 @@ def _get_cputime(self):

return utime_delta, stime_delta

def add_database_transaction(self, duration_sec):
def add_database_transaction(self, duration_sec: float) -> None:
if duration_sec < 0:
raise ValueError("DB txn time can only be non-negative")
self._resource_usage.db_txn_count += 1
self._resource_usage.db_txn_duration_sec += duration_sec

def add_database_scheduled(self, sched_sec):
def add_database_scheduled(self, sched_sec: float) -> None:
"""Record a use of the database pool
Args:
Expand All @@ -447,7 +466,7 @@ def add_database_scheduled(self, sched_sec):
raise ValueError("DB scheduling time can only be non-negative")
self._resource_usage.db_sched_duration_sec += sched_sec

def record_event_fetch(self, event_count):
def record_event_fetch(self, event_count: int) -> None:
"""Record a number of events being fetched from the db
Args:
Expand All @@ -464,10 +483,10 @@ class LoggingContextFilter(logging.Filter):
missing fields
"""

def __init__(self, **defaults):
def __init__(self, **defaults) -> None:
self.defaults = defaults

def filter(self, record):
def filter(self, record) -> Literal[True]:
"""Add each fields from the logging contexts to the record.
Returns:
True to include the record in the log output.
Expand All @@ -492,12 +511,13 @@ class PreserveLoggingContext(object):

__slots__ = ["current_context", "new_context", "has_parent"]

def __init__(self, new_context=None):
def __init__(self, new_context: Optional[LoggingContext] = None) -> None:
if new_context is None:
new_context = LoggingContext.sentinel
self.new_context = new_context
self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
else:
self.new_context = new_context

def __enter__(self):
def __enter__(self) -> None:
"""Captures the current logging context"""
self.current_context = LoggingContext.set_current_context(self.new_context)

Expand All @@ -506,7 +526,7 @@ def __enter__(self):
if not self.current_context.alive:
logger.debug("Entering dead context: %s", self.current_context)

def __exit__(self, type, value, traceback):
def __exit__(self, type, value, traceback) -> None:
"""Restores the current logging context"""
context = LoggingContext.set_current_context(self.current_context)

Expand All @@ -525,7 +545,9 @@ def __exit__(self, type, value, traceback):
logger.debug("Restoring dead context: %s", self.current_context)


def nested_logging_context(suffix, parent_context=None):
def nested_logging_context(
suffix: str, parent_context: Optional[LoggingContext] = None
) -> LoggingContext:
"""Creates a new logging context as a child of another.
The nested logging context will have a 'request' made up of the parent context's
Expand All @@ -546,10 +568,12 @@ def nested_logging_context(suffix, parent_context=None):
Returns:
LoggingContext: new logging context.
"""
if parent_context is None:
parent_context = LoggingContext.current_context()
if parent_context is not None:
context = parent_context # type: LoggingContextOrSentinel
else:
context = LoggingContext.current_context()
return LoggingContext(
parent_context=parent_context, request=parent_context.request + "-" + suffix
parent_context=context, request=str(context.request) + "-" + suffix
)


Expand Down Expand Up @@ -654,7 +678,10 @@ def make_deferred_yieldable(deferred):
return deferred


def _set_context_cb(result, context):
ResultT = TypeVar("ResultT")


def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
"""A callback function which just sets the logging context"""
LoggingContext.set_current_context(context)
return result
Expand Down

0 comments on commit 08d51b0

Please sign in to comment.