Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Type Annotations #8536

Merged
merged 5 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230831-164435.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Added more type annotations.
time: 2023-08-31T16:44:35.737954-04:00
custom:
Author: peterallenwebb
Issue: "8537"
2 changes: 1 addition & 1 deletion core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):

TYPE: str = NotImplemented

def __init__(self, profile: AdapterRequiredConfig):
def __init__(self, profile: AdapterRequiredConfig) -> None:
self.profile = profile
self.thread_connections: Dict[Hashable, Connection] = {}
self.lock: RLock = flags.MP_CONTEXT.RLock()
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class BaseAdapter(metaclass=AdapterMeta):
ConstraintType.foreign_key: ConstraintSupport.ENFORCED,
}

def __init__(self, config):
def __init__(self, config) -> None:
self.config = config
self.cache = RelationsCache()
self.connections = self.ConnectionManager(config)
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/adapters/base/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
credentials: Type[Credentials],
include_path: str,
dependencies: Optional[List[str]] = None,
):
) -> None:

self.adapter: Type[AdapterProtocol] = adapter
self.credentials: Type[Credentials] = credentials
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/adapters/base/query_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class NodeWrapper:
def __init__(self, node):
def __init__(self, node) -> None:
self._inner_node = node

def __getattr__(self, name):
Expand Down Expand Up @@ -57,7 +57,7 @@ def set(self, comment: Optional[str], append: bool):


class MacroQueryStringSetter:
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest):
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest) -> None:
self.manifest = manifest
self.config = config

Expand Down
4 changes: 2 additions & 2 deletions core/dbt/adapters/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class _CachedRelation:
:attr BaseRelation inner: The underlying dbt relation.
"""

def __init__(self, inner):
self.referenced_by = {}
def __init__(self, inner) -> None:
self.referenced_by: Dict[_ReferenceKey, _CachedRelation] = {}
self.inner = inner

def __str__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/adapters/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class AdapterProtocol( # type: ignore[misc]
ConnectionManager: Type[ConnectionManager_T]
connections: ConnectionManager_T

def __init__(self, config: AdapterRequiredConfig):
def __init__(self, config: AdapterRequiredConfig) -> None:
...

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self,
manifest: Optional[Manifest] = None,
callbacks: Optional[List[Callable[[EventMsg], None]]] = None,
):
) -> None:
self.manifest = manifest

if callbacks is None:
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/cli/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# Implementation from: https://stackoverflow.com/a/48394004
# Note MultiOption options must be specified with type=tuple or type=ChoiceTuple (https://github.com/pallets/click/issues/2012)
class MultiOption(click.Option):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
self.save_other_options = kwargs.pop("save_other_options", True)
nargs = kwargs.pop("nargs", -1)
assert nargs == -1, "nargs, if set, must be -1 not {}".format(nargs)
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/clients/agate_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class _NullMarker:


class ColumnTypeBuilder(Dict[str, NullableAgateType]):
def __init__(self):
def __init__(self) -> None:
super().__init__()

def __setitem__(self, key, value):
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _get_tests_for_node(manifest: Manifest, unique_id: UniqueID) -> List[UniqueI


class Linker:
def __init__(self, data=None):
def __init__(self, data=None) -> None:
if data is None:
data = {}
self.graph = nx.DiGraph(**data)
Expand Down Expand Up @@ -274,7 +274,7 @@ def get_graph_summary(self, manifest: Manifest) -> Dict[int, Dict[str, Any]]:


class Compiler:
def __init__(self, config):
def __init__(self, config) -> None:
self.config = config

def initialize(self):
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/config/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
user_config: UserConfig,
threads: int,
credentials: Credentials,
):
) -> None:
"""Explicitly defining `__init__` to work around bug in Python 3.9.7
https://bugs.python.org/issue45081
"""
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/config/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _list_if_none_or_string(value):


class ProjectPostprocessor(Dict[Keypath, Callable[[Any], Any]]):
def __init__(self):
def __init__(self) -> None:
super().__init__()

self[("on-run-start",)] = _list_if_none_or_string
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def _get_project_directories(self) -> Iterator[Path]:


class UnsetCredentials(Credentials):
def __init__(self):
def __init__(self) -> None:
super().__init__("", "")

@property
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class LazyHandle:
connection, updating the handle on the Connection.
"""

def __init__(self, opener: Callable[[Connection], Connection]):
def __init__(self, opener: Callable[[Connection], Connection]) -> None:
self.opener = opener

def resolve(self, connection: Connection) -> Connection:
Expand Down
16 changes: 8 additions & 8 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def find_unique_id_for_package(storage, key, package: Optional[PackageName]):


class DocLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest"):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)

Expand Down Expand Up @@ -119,7 +119,7 @@ def perform_lookup(self, unique_id: UniqueID, manifest) -> Documentation:


class SourceLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest"):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)

Expand Down Expand Up @@ -156,7 +156,7 @@ class RefableLookup(dbtClassMixin):
_lookup_types: ClassVar[set] = set(NodeType.refable())
_versioned_types: ClassVar[set] = set(NodeType.versioned())

def __init__(self, manifest: "Manifest"):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)

Expand Down Expand Up @@ -267,7 +267,7 @@ def _find_unique_ids_for_package(self, key, package: Optional[PackageName]) -> L


class MetricLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest"):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)

Expand Down Expand Up @@ -306,7 +306,7 @@ class SemanticModelByMeasureLookup(dbtClassMixin):
the semantic models in a manifest.
"""

def __init__(self, manifest: "Manifest"):
def __init__(self, manifest: "Manifest") -> None:
self.storage: DefaultDict[str, Dict[PackageName, UniqueID]] = defaultdict(dict)
self.populate(manifest)

Expand Down Expand Up @@ -355,7 +355,7 @@ def perform_lookup(self, unique_id: UniqueID, manifest: "Manifest") -> SemanticM

# This handles both models/seeds/snapshots and sources/metrics/exposures/semantic_models
class DisabledLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest"):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, List[Any]]] = {}
self.populate(manifest)

Expand Down Expand Up @@ -1427,12 +1427,12 @@ def __reduce_ex__(self, protocol):


class MacroManifest(MacroMethods):
def __init__(self, macros):
def __init__(self, macros) -> None:
self.macros = macros
self.metadata = ManifestMetadata()
# This is returned by the 'graph' context property
# in the ProviderContext class.
self.flat_graph = {}
self.flat_graph: Dict[str, Any] = {}


AnyManifest = Union[Manifest, MacroManifest]
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/graph/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class MetricReference(object):
def __init__(self, metric_name, package_name=None):
def __init__(self, metric_name, package_name=None) -> None:
self.metric_name = metric_name
self.package_name = package_name

Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/graph/semantic_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


class SemanticManifest:
def __init__(self, manifest):
def __init__(self, manifest) -> None:
self.manifest = manifest

def validate(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def to_msg_dict(self):

# This is a context manager
class collect_timing_info:
def __init__(self, name: str, callback: Callable[[TimingInfo], None]):
def __init__(self, name: str, callback: Callable[[TimingInfo], None]) -> None:
self.timing_info = TimingInfo(name=name)
self.callback = callback

Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class PreviousState:
def __init__(self, state_path: Path, target_path: Path, project_root: Path):
def __init__(self, state_path: Path, target_path: Path, project_root: Path) -> None:
self.state_path: Path = state_path
self.target_path: Path = target_path
self.project_root: Path = project_root
Expand Down
12 changes: 6 additions & 6 deletions core/dbt/events/adapter_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,32 @@
class AdapterLogger:
name: str

def debug(self, msg, *args):
def debug(self, msg, *args) -> None:
event = AdapterEventDebug(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
fire_event(event)

def info(self, msg, *args):
def info(self, msg, *args) -> None:
event = AdapterEventInfo(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
fire_event(event)

def warning(self, msg, *args):
def warning(self, msg, *args) -> None:
event = AdapterEventWarning(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
fire_event(event)

def error(self, msg, *args):
def error(self, msg, *args) -> None:
event = AdapterEventError(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
fire_event(event)

# The default exc_info=True is what makes this method different
def exception(self, msg, *args):
def exception(self, msg, *args) -> None:
exc_info = str(traceback.format_exc())
event = AdapterEventError(
name=self.name,
Expand All @@ -51,7 +51,7 @@ def exception(self, msg, *args):
)
fire_event(event)

def critical(self, msg, *args):
def critical(self, msg, *args) -> None:
event = AdapterEventError(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/events/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_pid() -> int:
return os.getpid()


# in theory threads can change so we don't cache them.
# in theory threads can change, so we don't cache them.
def get_thread_name() -> str:
return threading.current_thread().name

Expand All @@ -55,7 +55,7 @@ class EventLevel(str, Enum):
class BaseEvent:
"""BaseEvent for proto message generated python events"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
class_name = type(self).__name__
msg_cls = getattr(types_pb2, class_name)
if class_name == "Formatting" and len(args) > 0:
Expand Down Expand Up @@ -100,9 +100,9 @@ def to_dict(self):
self.pb_msg, preserving_proto_field_name=True, including_default_value_fields=True
)

def to_json(self):
def to_json(self) -> str:
return MessageToJson(
self.pb_msg, preserving_proto_field_name=True, including_default_valud_fields=True
self.pb_msg, preserving_proto_field_name=True, including_default_value_fields=True
)

def level_tag(self) -> EventLevel:
Expand Down
10 changes: 6 additions & 4 deletions core/dbt/events/eventmgr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import traceback
from typing import Callable, List, Optional, Protocol
from typing import Callable, List, Optional, Protocol, Tuple
from uuid import uuid4

from dbt.events.base_types import BaseEvent, EventLevel, msg_from_base_event, EventMsg
Expand Down Expand Up @@ -38,14 +38,15 @@ def add_logger(self, config: LoggerConfig) -> None:
)
self.loggers.append(logger)

def flush(self):
def flush(self) -> None:
for logger in self.loggers:
logger.flush()


class IEventManager(Protocol):
callbacks: List[Callable[[EventMsg], None]]
invocation_id: str
loggers: List[_Logger]

def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None:
...
Expand All @@ -55,8 +56,9 @@ def add_logger(self, config: LoggerConfig) -> None:


class TestEventManager(IEventManager):
def __init__(self):
self.event_history = []
def __init__(self) -> None:
self.event_history: List[Tuple[BaseEvent, Optional[EventLevel]]] = []
self.loggers = []

def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None:
self.event_history.append((e, level))
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/events/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def _pluralize(string: Union[str, NodeType]) -> str:
return convert.pluralize()


def pluralize(count, string: Union[str, NodeType]):
def pluralize(count, string: Union[str, NodeType]) -> str:
pluralized: str = str(string)
if count != 1:
pluralized = _pluralize(string)
return f"{count} {pluralized}"


def timestamp_to_datetime_string(ts):
def timestamp_to_datetime_string(ts) -> str:
timestamp_dt = datetime.fromtimestamp(ts.seconds + ts.nanos / 1e9)
return timestamp_dt.strftime("%H:%M:%S.%f")
Loading