-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
resolver.py
134 lines (110 loc) · 4.49 KB
/
resolver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from dataclasses import dataclass, field
from typing import Dict, List, NoReturn, Union, Type, Iterator, Set
from dbt.exceptions import (
DuplicateDependencyToRootError,
DuplicateProjectDependencyError,
MismatchedDependencyTypeError,
DbtInternalError,
)
from dbt.config import Project, RuntimeConfig
from dbt.config.renderer import DbtProjectYamlRenderer
from dbt.deps.base import BasePackage, PinnedPackage, UnpinnedPackage
from dbt.deps.local import LocalUnpinnedPackage
from dbt.deps.tarball import TarballUnpinnedPackage
from dbt.deps.git import GitUnpinnedPackage
from dbt.deps.registry import RegistryUnpinnedPackage
from dbt.contracts.project import (
LocalPackage,
TarballPackage,
GitPackage,
RegistryPackage,
)
PackageContract = Union[LocalPackage, TarballPackage, GitPackage, RegistryPackage]
@dataclass
class PackageListing:
packages: Dict[str, UnpinnedPackage] = field(default_factory=dict)
def __len__(self):
return len(self.packages)
def __bool__(self):
return bool(self.packages)
def _pick_key(self, key: BasePackage) -> str:
for name in key.all_names():
if name in self.packages:
return name
return key.name
def __contains__(self, key: BasePackage):
for name in key.all_names():
if name in self.packages:
return True
def __getitem__(self, key: BasePackage):
key_str: str = self._pick_key(key)
return self.packages[key_str]
def __setitem__(self, key: BasePackage, value):
key_str: str = self._pick_key(key)
self.packages[key_str] = value
def _mismatched_types(self, old: UnpinnedPackage, new: UnpinnedPackage) -> NoReturn:
raise MismatchedDependencyTypeError(new, old)
def incorporate(self, package: UnpinnedPackage):
key: str = self._pick_key(package)
if key in self.packages:
existing: UnpinnedPackage = self.packages[key]
if not isinstance(existing, type(package)):
self._mismatched_types(existing, package)
self.packages[key] = existing.incorporate(package)
else:
self.packages[key] = package
def update_from(self, src: List[PackageContract]) -> None:
pkg: UnpinnedPackage
for contract in src:
if isinstance(contract, LocalPackage):
pkg = LocalUnpinnedPackage.from_contract(contract)
elif isinstance(contract, TarballPackage):
pkg = TarballUnpinnedPackage.from_contract(contract)
elif isinstance(contract, GitPackage):
pkg = GitUnpinnedPackage.from_contract(contract)
elif isinstance(contract, RegistryPackage):
pkg = RegistryUnpinnedPackage.from_contract(contract)
else:
raise DbtInternalError("Invalid package type {}".format(type(contract)))
self.incorporate(pkg)
@classmethod
def from_contracts(
cls: Type["PackageListing"], src: List[PackageContract]
) -> "PackageListing":
self = cls({})
self.update_from(src)
return self
def resolved(self) -> List[PinnedPackage]:
return [p.resolved() for p in self.packages.values()]
def __iter__(self) -> Iterator[UnpinnedPackage]:
return iter(self.packages.values())
def _check_for_duplicate_project_names(
final_deps: List[PinnedPackage],
config: Project,
renderer: DbtProjectYamlRenderer,
):
seen: Set[str] = set()
for package in final_deps:
project_name = package.get_project_name(config, renderer)
if project_name in seen:
raise DuplicateProjectDependencyError(project_name)
elif project_name == config.project_name:
raise DuplicateDependencyToRootError(project_name)
seen.add(project_name)
def resolve_packages(
packages: List[PackageContract], config: RuntimeConfig
) -> List[PinnedPackage]:
pending = PackageListing.from_contracts(packages)
final = PackageListing()
renderer = DbtProjectYamlRenderer(config, config.cli_vars)
while pending:
next_pending = PackageListing()
# resolve the dependency in question
for package in pending:
final.incorporate(package)
target = final[package].resolved().fetch_metadata(config, renderer)
next_pending.update_from(target.packages)
pending = next_pending
resolved = final.resolved()
_check_for_duplicate_project_names(resolved, config, renderer)
return resolved