Skip to content

Commit

Permalink
Merge branch 'main' into revert-3140-fix/airflow-daemon-detect
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao authored Sep 17, 2024
2 parents a77806b + f7c7d0a commit 07e3fe4
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 24 deletions.
29 changes: 13 additions & 16 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from io import StringIO
from pathlib import Path
from shutil import rmtree
from types import MappingProxyType
from types import MappingProxyType, SimpleNamespace

import pandas as pd
from sqlglot import Dialect, exp
Expand Down Expand Up @@ -318,9 +318,7 @@ def __init__(
self.configs = (
config if isinstance(config, dict) else load_configs(config, self.CONFIG_TYPE, paths)
)
self._loaders: UniqueKeyDict[str, t.Dict[str, Loader | t.Dict[Path, C]]] = UniqueKeyDict(
"loaders"
)
self._loaders: UniqueKeyDict[str, SimpleNamespace] = UniqueKeyDict("loaders")
self.dag: DAG[str] = DAG()
self._models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
self._audits: UniqueKeyDict[str, ModelAudit] = UniqueKeyDict("audits")
Expand All @@ -337,11 +335,10 @@ def __init__(
for path, config in self.configs.items():
project_type = c.DBT if issubclass(config.loader, DbtLoader) else c.NATIVE
if project_type not in self._loaders:
self._loaders[project_type] = {
"loader": (loader or config.loader)(**config.loader_kwargs),
"configs": {},
}
self._loaders[project_type]["configs"][path] = config
self._loaders[project_type] = SimpleNamespace(
loader=(loader or config.loader)(**config.loader_kwargs), configs={}
)
self._loaders[project_type].configs[path] = config

self.project_type = c.HYBRID if len(self._loaders) > 1 else project_type
self._all_dialects: t.Set[str] = {self.config.dialect or ""}
Expand Down Expand Up @@ -528,17 +525,17 @@ def state_reader(self) -> StateReader:

def refresh(self) -> None:
"""Refresh all models that have been updated."""
if any(loader_dict["loader"].reload_needed() for loader_dict in self._loaders.values()):
if any(context_loader.loader.reload_needed() for context_loader in self._loaders.values()):
self.load()

def load(self, update_schemas: bool = True) -> GenericContext[C]:
"""Load all files in the context's path."""
load_start_ts = time.perf_counter()

projects = []
for loader_dict in self._loaders.values():
with sys_path(*loader_dict["configs"]):
projects.append(loader_dict["loader"].load(self, update_schemas))
for context_loader in self._loaders.values():
with sys_path(*context_loader.configs):
projects.append(context_loader.loader.load(self, update_schemas))

self._standalone_audits.clear()
self._audits.clear()
Expand Down Expand Up @@ -620,9 +617,9 @@ def run(

if not self._loaded:
# Signals should be loaded to run correctly.
for loader_dict in self._loaders.values():
with sys_path(*loader_dict["configs"]):
loader_dict["loader"].load_signals(self)
for context_loader in self._loaders.values():
with sys_path(*context_loader.configs):
context_loader.loader.load_signals(self)

success = False
try:
Expand Down
6 changes: 4 additions & 2 deletions sqlmesh/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,14 @@ def _load_sql_models(
) -> UniqueKeyDict[str, Model]:
"""Loads the sql models into a Dict"""
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
for context_path, config in self._context._loaders[c.NATIVE]["configs"].items():
for context_path, config in self._context._loaders[c.NATIVE].configs.items():
cache = SqlMeshLoader._Cache(self, context_path)
variables = self._variables(config)

for path in self._glob_paths(
context_path / c.MODELS, ignore_patterns=config.ignore_patterns, extension=".sql"
context_path / c.MODELS,
ignore_patterns=config.ignore_patterns,
extension=".sql",
):
if not os.path.getsize(path):
continue
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/dbt/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _load_projects(self) -> t.List[Project]:

self._projects = []

for path, config in self._context._loaders[c.DBT]["configs"].items():
for path, config in self._context._loaders[c.DBT].configs.items():
project = Project.load(
DbtContext(
project_root=path,
Expand Down
2 changes: 1 addition & 1 deletion tests/dbt/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@pytest.fixture()
def sushi_test_project(sushi_test_dbt_context: Context) -> Project:
return sushi_test_dbt_context._loaders[c.DBT]["loader"]._load_projects()[0] # type: ignore
return sushi_test_dbt_context._loaders[c.DBT].loader._load_projects()[0] # type: ignore


@pytest.fixture()
Expand Down
8 changes: 4 additions & 4 deletions web/server/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ async def watch_project() -> None:
if context:
in_paths = any(is_relative_to(path, p) for p in paths)
is_modified_new_file = change == Change.modified and any(
path not in loader_dict["loader"]._path_mtimes
for loader_dict in context._loaders.values()
path not in context_loader.loader._path_mtimes
for context_loader in context._loaders.values()
)
should_track_file = path.is_file() and in_paths
should_reset_mtime = Change.added or is_modified_new_file
if should_track_file and should_reset_mtime:
for loader_dict in context._loaders:
loader_dict["loader"]._path_mtimes[path] = 0
for context_loader in context._loaders.values():
context_loader.loader._path_mtimes[path] = 0

except Exception:
error = ApiException(
Expand Down

0 comments on commit 07e3fe4

Please sign in to comment.