Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to logging/context.py #6309

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/6309.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `logging/context.py`.
121 changes: 74 additions & 47 deletions synapse/logging/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@
import logging
import threading
import types
from typing import Any, List
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union

from typing_extensions import Literal

from twisted.internet import defer, threads

if TYPE_CHECKING:
from synapse.logging.scopecontextmanager import _LogContextScope
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm missing the point of this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This avoids the import at runtime, but type checkers become aware of the name. In sufficiently recent versions this uses typing.TYPE_CHECKING, so I've changed the commit to use this.


logger = logging.getLogger(__name__)

try:
Expand Down Expand Up @@ -91,7 +96,7 @@ class ContextResourceUsage(object):
"evt_db_fetch_count",
]

def __init__(self, copy_from=None):
def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None:
"""Create a new ContextResourceUsage
Args:
Expand All @@ -101,27 +106,28 @@ def __init__(self, copy_from=None):
if copy_from is None:
self.reset()
else:
self.ru_utime = copy_from.ru_utime
self.ru_stime = copy_from.ru_stime
self.db_txn_count = copy_from.db_txn_count
# FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now
self.ru_utime = copy_from.ru_utime # type: float
self.ru_stime = copy_from.ru_stime # type: float
self.db_txn_count = copy_from.db_txn_count # type: int

self.db_txn_duration_sec = copy_from.db_txn_duration_sec
self.db_sched_duration_sec = copy_from.db_sched_duration_sec
self.evt_db_fetch_count = copy_from.evt_db_fetch_count
self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float
self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float
self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int

def copy(self):
def copy(self) -> "ContextResourceUsage":
return ContextResourceUsage(copy_from=self)

def reset(self):
def reset(self) -> None:
self.ru_stime = 0.0
self.ru_utime = 0.0
self.db_txn_count = 0

self.db_txn_duration_sec = 0
self.db_sched_duration_sec = 0
self.db_txn_duration_sec = 0.0
self.db_sched_duration_sec = 0.0
self.evt_db_fetch_count = 0

def __repr__(self):
def __repr__(self) -> str:
return (
"<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
"db_txn_count='%r', db_txn_duration_sec='%r', "
Expand All @@ -135,7 +141,7 @@ def __repr__(self):
self.evt_db_fetch_count,
)

def __iadd__(self, other):
def __iadd__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
"""Add another ContextResourceUsage's stats to this one's.
Args:
Expand All @@ -149,7 +155,7 @@ def __iadd__(self, other):
self.evt_db_fetch_count += other.evt_db_fetch_count
return self

def __isub__(self, other):
def __isub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
self.ru_utime -= other.ru_utime
self.ru_stime -= other.ru_stime
self.db_txn_count -= other.db_txn_count
Expand All @@ -158,17 +164,20 @@ def __isub__(self, other):
self.evt_db_fetch_count -= other.evt_db_fetch_count
return self

def __add__(self, other):
def __add__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
res = ContextResourceUsage(copy_from=self)
res += other
return res

def __sub__(self, other):
def __sub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
res = ContextResourceUsage(copy_from=self)
res -= other
return res


LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]


class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
"with" block.
Expand Down Expand Up @@ -201,7 +210,14 @@ class LoggingContext(object):
class Sentinel(object):
"""Sentinel to represent the root context"""

__slots__ = [] # type: List[Any]
__slots__ = ["previous_context", "alive", "request", "scope"]

def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.alive = None
self.request = None
self.scope = None

def __str__(self):
return "sentinel"
Expand Down Expand Up @@ -235,7 +251,7 @@ def __nonzero__(self):

sentinel = Sentinel()

def __init__(self, name=None, parent_context=None, request=None):
def __init__(self, name=None, parent_context=None, request=None) -> None:
self.previous_context = LoggingContext.current_context()
self.name = name

Expand All @@ -250,7 +266,7 @@ def __init__(self, name=None, parent_context=None, request=None):
self.request = None
self.tag = ""
self.alive = True
self.scope = None
self.scope = None # type: Optional[_LogContextScope]

self.parent_context = parent_context

Expand All @@ -261,13 +277,13 @@ def __init__(self, name=None, parent_context=None, request=None):
# the request param overrides the request from the parent context
self.request = request

def __str__(self):
def __str__(self) -> str:
if self.request:
return str(self.request)
return "%s@%x" % (self.name, id(self))

@classmethod
def current_context(cls):
def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage
Returns:
Expand All @@ -276,7 +292,9 @@ def current_context(cls):
return getattr(cls.thread_local, "current_context", cls.sentinel)

@classmethod
def set_current_context(cls, context):
def set_current_context(
cls, context: LoggingContextOrSentinel
) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
Args:
context(LoggingContext): The context to activate.
Expand All @@ -291,7 +309,7 @@ def set_current_context(cls, context):
context.start()
return current

def __enter__(self):
def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage"""
old_context = self.set_current_context(self)
if self.previous_context != old_context:
Expand All @@ -304,7 +322,7 @@ def __enter__(self):

return self

def __exit__(self, type, value, traceback):
def __exit__(self, type, value, traceback) -> None:
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
Expand All @@ -318,7 +336,6 @@ def __exit__(self, type, value, traceback):
logger.warning(
"Expected logging context %s but found %s", self, current
)
self.previous_context = None
self.alive = False

# if we have a parent, pass our CPU usage stats on
Expand All @@ -330,7 +347,7 @@ def __exit__(self, type, value, traceback):
# reset them in case we get entered again
self._resource_usage.reset()

def copy_to(self, record):
def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or
another LoggingContext
"""
Expand All @@ -341,14 +358,14 @@ def copy_to(self, record):
# we also track the current scope:
record.scope = self.scope

def copy_to_twisted_log_entry(self, record):
def copy_to_twisted_log_entry(self, record) -> None:
"""
Copy logging fields from this context to a Twisted log record.
"""
record["request"] = self.request
record["scope"] = self.scope

def start(self):
def start(self) -> None:
if get_thread_id() != self.main_thread:
logger.warning("Started logcontext %s on different thread", self)
return
Expand All @@ -358,7 +375,7 @@ def start(self):
if not self.usage_start:
self.usage_start = get_thread_resource_usage()

def stop(self):
def stop(self) -> None:
if get_thread_id() != self.main_thread:
logger.warning("Stopped logcontext %s on different thread", self)
return
Expand All @@ -378,7 +395,7 @@ def stop(self):

self.usage_start = None

def get_resource_usage(self):
def get_resource_usage(self) -> ContextResourceUsage:
"""Get resources used by this logcontext so far.
Returns:
Expand All @@ -398,11 +415,13 @@ def get_resource_usage(self):

return res

def _get_cputime(self):
def _get_cputime(self) -> Tuple[float, float]:
"""Get the cpu usage time so far
Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
"""
assert self.usage_start is not None

current = get_thread_resource_usage()

# Indicate to mypy that we know that self.usage_start is None.
Expand Down Expand Up @@ -430,13 +449,13 @@ def _get_cputime(self):

return utime_delta, stime_delta

def add_database_transaction(self, duration_sec):
def add_database_transaction(self, duration_sec: float) -> None:
if duration_sec < 0:
raise ValueError("DB txn time can only be non-negative")
self._resource_usage.db_txn_count += 1
self._resource_usage.db_txn_duration_sec += duration_sec

def add_database_scheduled(self, sched_sec):
def add_database_scheduled(self, sched_sec: float) -> None:
"""Record a use of the database pool
Args:
Expand All @@ -447,7 +466,7 @@ def add_database_scheduled(self, sched_sec):
raise ValueError("DB scheduling time can only be non-negative")
self._resource_usage.db_sched_duration_sec += sched_sec

def record_event_fetch(self, event_count):
def record_event_fetch(self, event_count: int) -> None:
"""Record a number of events being fetched from the db
Args:
Expand All @@ -464,10 +483,10 @@ class LoggingContextFilter(logging.Filter):
missing fields
"""

def __init__(self, **defaults):
def __init__(self, **defaults) -> None:
self.defaults = defaults

def filter(self, record):
def filter(self, record) -> Literal[True]:
"""Add each fields from the logging contexts to the record.
Returns:
True to include the record in the log output.
Expand All @@ -492,12 +511,13 @@ class PreserveLoggingContext(object):

__slots__ = ["current_context", "new_context", "has_parent"]

def __init__(self, new_context=None):
def __init__(self, new_context: Optional[LoggingContext] = None) -> None:
if new_context is None:
new_context = LoggingContext.sentinel
self.new_context = new_context
self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
else:
self.new_context = new_context

def __enter__(self):
def __enter__(self) -> None:
"""Captures the current logging context"""
self.current_context = LoggingContext.set_current_context(self.new_context)

Expand All @@ -506,7 +526,7 @@ def __enter__(self):
if not self.current_context.alive:
logger.debug("Entering dead context: %s", self.current_context)

def __exit__(self, type, value, traceback):
def __exit__(self, type, value, traceback) -> None:
"""Restores the current logging context"""
context = LoggingContext.set_current_context(self.current_context)

Expand All @@ -525,7 +545,9 @@ def __exit__(self, type, value, traceback):
logger.debug("Restoring dead context: %s", self.current_context)


def nested_logging_context(suffix, parent_context=None):
def nested_logging_context(
suffix: str, parent_context: Optional[LoggingContext] = None
) -> LoggingContext:
"""Creates a new logging context as a child of another.
The nested logging context will have a 'request' made up of the parent context's
Expand All @@ -546,10 +568,12 @@ def nested_logging_context(suffix, parent_context=None):
Returns:
LoggingContext: new logging context.
"""
if parent_context is None:
parent_context = LoggingContext.current_context()
if parent_context is not None:
context = parent_context # type: LoggingContextOrSentinel
else:
context = LoggingContext.current_context()
return LoggingContext(
parent_context=parent_context, request=parent_context.request + "-" + suffix
parent_context=context, request=str(context.request) + "-" + suffix
)


Expand Down Expand Up @@ -654,7 +678,10 @@ def make_deferred_yieldable(deferred):
return deferred


def _set_context_cb(result, context):
ResultT = TypeVar("ResultT")


def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
"""A callback function which just sets the logging context"""
LoggingContext.set_current_context(context)
return result
Expand Down