Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Type Annotations #8536

Merged
merged 5 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _get_tests_for_node(manifest: Manifest, unique_id: UniqueID) -> List[UniqueI


class Linker:
def __init__(self, data=None):
def __init__(self, data=None) -> None:
if data is None:
data = {}
self.graph = nx.DiGraph(**data)
Expand Down Expand Up @@ -274,7 +274,7 @@ def get_graph_summary(self, manifest: Manifest) -> Dict[int, Dict[str, Any]]:


class Compiler:
def __init__(self, config):
def __init__(self, config) -> None:
self.config = config

def initialize(self):
Expand Down
12 changes: 6 additions & 6 deletions core/dbt/events/adapter_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,32 @@
class AdapterLogger:
name: str

def debug(self, msg, *args):
def debug(self, msg, *args) -> None:
event = AdapterEventDebug(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
fire_event(event)

def info(self, msg, *args):
def info(self, msg, *args) -> None:
event = AdapterEventInfo(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
fire_event(event)

def warning(self, msg, *args):
def warning(self, msg, *args) -> None:
event = AdapterEventWarning(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
fire_event(event)

def error(self, msg, *args):
def error(self, msg, *args) -> None:
event = AdapterEventError(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
fire_event(event)

# The default exc_info=True is what makes this method different
def exception(self, msg, *args):
def exception(self, msg, *args) -> None:
exc_info = str(traceback.format_exc())
event = AdapterEventError(
name=self.name,
Expand All @@ -51,7 +51,7 @@ def exception(self, msg, *args):
)
fire_event(event)

def critical(self, msg, *args):
def critical(self, msg, *args) -> None:
event = AdapterEventError(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/events/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_pid() -> int:
return os.getpid()


# in theory threads can change so we don't cache them.
# in theory threads can change, so we don't cache them.
def get_thread_name() -> str:
return threading.current_thread().name

Expand All @@ -55,7 +55,7 @@ class EventLevel(str, Enum):
class BaseEvent:
"""BaseEvent for proto message generated python events"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
class_name = type(self).__name__
msg_cls = getattr(types_pb2, class_name)
if class_name == "Formatting" and len(args) > 0:
Expand Down Expand Up @@ -100,9 +100,9 @@ def to_dict(self):
self.pb_msg, preserving_proto_field_name=True, including_default_value_fields=True
)

def to_json(self):
def to_json(self) -> str:
return MessageToJson(
self.pb_msg, preserving_proto_field_name=True, including_default_valud_fields=True
self.pb_msg, preserving_proto_field_name=True, including_default_value_fields=False
peterallenwebb marked this conversation as resolved.
Show resolved Hide resolved
)

def level_tag(self) -> EventLevel:
Expand Down
10 changes: 6 additions & 4 deletions core/dbt/events/eventmgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from logging.handlers import RotatingFileHandler
import threading
import traceback
from typing import Any, Callable, List, Optional, TextIO, Protocol
from typing import Any, Callable, List, Optional, Protocol, TextIO, Tuple
from uuid import uuid4
from dbt.events.format import timestamp_to_datetime_string

Expand Down Expand Up @@ -215,14 +215,15 @@ def add_logger(self, config: LoggerConfig) -> None:
logger.event_manager = self
self.loggers.append(logger)

def flush(self):
def flush(self) -> None:
for logger in self.loggers:
logger.flush()


class IEventManager(Protocol):
callbacks: List[Callable[[EventMsg], None]]
invocation_id: str
loggers: List[_Logger]

def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None:
...
Expand All @@ -232,8 +233,9 @@ def add_logger(self, config: LoggerConfig) -> None:


class TestEventManager(IEventManager):
def __init__(self):
self.event_history = []
def __init__(self) -> None:
self.event_history: List[Tuple[BaseEvent, Optional[EventLevel]]] = []
self.loggers = []

def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None:
self.event_history.append((e, level))
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/events/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def _pluralize(string: Union[str, NodeType]) -> str:
return convert.pluralize()


def pluralize(count, string: Union[str, NodeType]):
def pluralize(count, string: Union[str, NodeType]) -> str:
pluralized: str = str(string)
if count != 1:
pluralized = _pluralize(string)
return f"{count} {pluralized}"


def timestamp_to_datetime_string(ts):
def timestamp_to_datetime_string(ts) -> str:
timestamp_dt = datetime.fromtimestamp(ts.seconds + ts.nanos / 1e9)
return timestamp_dt.strftime("%H:%M:%S.%f")
10 changes: 5 additions & 5 deletions core/dbt/events/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def env_scrubber(msg: str) -> str:
return scrub_secrets(msg, env_secrets())


def cleanup_event_logger():
def cleanup_event_logger() -> None:
# Reset to a no-op manager to release streams associated with logs. This is
# especially important for tests, since pytest replaces the stdout stream
# during test runs, and closes the stream after the test is over.
Expand All @@ -192,12 +192,12 @@ def cleanup_event_logger():


# used for integration tests
def capture_stdout_logs(stream: TextIO):
def capture_stdout_logs(stream: TextIO) -> None:
global _CAPTURE_STREAM
_CAPTURE_STREAM = stream


def stop_capture_stdout_logs():
def stop_capture_stdout_logs() -> None:
global _CAPTURE_STREAM
_CAPTURE_STREAM = None

Expand Down Expand Up @@ -231,7 +231,7 @@ def msg_to_dict(msg: EventMsg) -> dict:
return msg_dict


def warn_or_error(event, node=None):
def warn_or_error(event, node=None) -> None:
flags = get_flags()
if flags.WARN_ERROR or flags.WARN_ERROR_OPTIONS.includes(type(event).__name__):

Expand Down Expand Up @@ -293,6 +293,6 @@ def set_invocation_id() -> None:
EVENT_MANAGER.invocation_id = str(uuid.uuid4())


def ctx_set_event_manager(event_manager: IEventManager):
def ctx_set_event_manager(event_manager: IEventManager) -> None:
global EVENT_MANAGER
EVENT_MANAGER = event_manager
Loading