From 1e4286a62dddd49e251443e51529dda5ab1ff012 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Wed, 10 Jan 2024 09:06:06 -0800 Subject: [PATCH] Fix full-refresh and vars for retry (#9328) Co-authored-by: Peter Allen Webb --- .../unreleased/Fixes-20231213-220449.yaml | 6 ++ core/dbt/cli/flags.py | 8 +- core/dbt/cli/main.py | 2 +- core/dbt/cli/requires.py | 25 ++---- core/dbt/contracts/state.py | 17 ++-- core/dbt/parser/manifest.py | 26 ++++-- core/dbt/task/retry.py | 84 +++++++++++-------- tests/functional/retry/test_retry.py | 38 ++++++++- 8 files changed, 133 insertions(+), 73 deletions(-) create mode 100644 .changes/unreleased/Fixes-20231213-220449.yaml diff --git a/.changes/unreleased/Fixes-20231213-220449.yaml b/.changes/unreleased/Fixes-20231213-220449.yaml new file mode 100644 index 00000000000..6da9f7ddcaa --- /dev/null +++ b/.changes/unreleased/Fixes-20231213-220449.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Preserve the value of vars and the --full-refresh flags when using retry. +time: 2023-12-13T22:04:49.228294-05:00 +custom: + Author: peterallenwebb, ChenyuLInx + Issue: "9112" diff --git a/core/dbt/cli/flags.py b/core/dbt/cli/flags.py index 580bddcd372..ffc73323df8 100644 --- a/core/dbt/cli/flags.py +++ b/core/dbt/cli/flags.py @@ -28,6 +28,7 @@ FLAGS_DEFAULTS = { "INDIRECT_SELECTION": "eager", "TARGET_PATH": None, + "WARN_ERROR": None, # Cli args without project_flags or env var option. "FULL_REFRESH": False, "STRICT_MODE": False, @@ -84,7 +85,6 @@ class Flags: def __init__( self, ctx: Optional[Context] = None, project_flags: Optional[ProjectFlags] = None ) -> None: - # Set the default flags. for key, value in FLAGS_DEFAULTS.items(): object.__setattr__(self, key, value) @@ -126,7 +126,6 @@ def _assign_params( # respected over DBT_PRINT or --print. new_name: Union[str, None] = None if param_name in DEPRECATED_PARAMS: - # Deprecated env vars can only be set via env var. # We use the deprecated option in click to serialize the value # from the env var string. @@ -346,7 +345,6 @@ def command_params(command: CliCommand, args_dict: Dict[str, Any]) -> CommandPar default_args = set([x.lower() for x in FLAGS_DEFAULTS.keys()]) res = command.to_list() - for k, v in args_dict.items(): k = k.lower() # if a "which" value exists in the args dict, it should match the command provided @@ -358,7 +356,9 @@ def command_params(command: CliCommand, args_dict: Dict[str, Any]) -> CommandPar continue # param was assigned from defaults and should not be included - if k not in (cmd_args | prnt_args) - default_args: + if k not in (cmd_args | prnt_args) or ( + k in default_args and v == FLAGS_DEFAULTS[k.upper()] + ): continue # if the param is in parent args, it should come before the arg name diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index 748612acee9..43d0aa3501f 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -633,12 +633,12 @@ def run(ctx, **kwargs): @p.target @p.state @p.threads +@p.full_refresh @requires.postflight @requires.preflight @requires.profile @requires.project @requires.runtime_config -@requires.manifest def retry(ctx, **kwargs): """Retry the nodes that failed in the previous run.""" task = RetryTask( diff --git a/core/dbt/cli/requires.py b/core/dbt/cli/requires.py index a1ed39d3581..c00daddba9c 100644 --- a/core/dbt/cli/requires.py +++ b/core/dbt/cli/requires.py @@ -1,8 +1,7 @@ import dbt.tracking from dbt.common.invocation import reset_invocation_id -from dbt.mp_context import get_mp_context from dbt.version import installed as installed_version -from dbt.adapters.factory import adapter_management, register_adapter, get_adapter +from dbt.adapters.factory import adapter_management from dbt.flags import set_flags, get_flag_dict from dbt.cli.exceptions import ( ExceptionExit, @@ -11,7 +10,6 @@ from dbt.cli.flags import Flags from dbt.config import RuntimeConfig from dbt.config.runtime import load_project, load_profile, UnsetProfile -from dbt.context.providers import generate_runtime_macro_context from dbt.common.events.base_types import EventLevel from dbt.common.events.functions import ( @@ -28,11 +26,11 @@ from dbt.events.types import CommandCompleted, MainEncounteredError, MainStackTrace, ResourceReport from dbt.common.exceptions import DbtBaseException as DbtException from dbt.exceptions import DbtProjectError, FailFastError -from dbt.parser.manifest import ManifestLoader, write_manifest +from dbt.parser.manifest import parse_manifest from dbt.profiler import profiler from dbt.tracking import active_user, initialize_from_flags, track_run from dbt.common.utils import cast_dict_to_dict_of_strings -from dbt.plugins import set_up_plugin_manager, get_plugin_manager +from dbt.plugins import set_up_plugin_manager from click import Context from functools import update_wrapper @@ -273,25 +271,12 @@ def wrapper(*args, **kwargs): raise DbtProjectError("profile, project, and runtime_config required for manifest") runtime_config = ctx.obj["runtime_config"] - register_adapter(runtime_config, get_mp_context()) - adapter = get_adapter(runtime_config) - adapter.set_macro_context_generator(generate_runtime_macro_context) # a manifest has already been set on the context, so don't overwrite it if ctx.obj.get("manifest") is None: - manifest = ManifestLoader.get_full_manifest( - runtime_config, - write_perf_info=write_perf_info, + ctx.obj["manifest"] = parse_manifest( + runtime_config, write_perf_info, write, ctx.obj["flags"].write_json ) - - ctx.obj["manifest"] = manifest - if write and ctx.obj["flags"].write_json: - write_manifest(manifest, runtime_config.project_target_path) - pm = get_plugin_manager(runtime_config.project_name) - plugin_artifacts = pm.get_manifest_artifacts(manifest) - for path, plugin_artifact in plugin_artifacts.items(): - plugin_artifact.write(path) - return func(*args, **kwargs) return update_wrapper(wrapper, func) diff --git a/core/dbt/contracts/state.py b/core/dbt/contracts/state.py index 707109e7457..022f4833c35 100644 --- a/core/dbt/contracts/state.py +++ b/core/dbt/contracts/state.py @@ -9,6 +9,16 @@ from dbt.exceptions import IncompatibleSchemaError +def load_result_state(results_path) -> Optional[RunResultsArtifact]: + if results_path.exists() and results_path.is_file(): + try: + return RunResultsArtifact.read_and_check_versions(str(results_path)) + except IncompatibleSchemaError as exc: + exc.add_filename(str(results_path)) + raise + return None + + class PreviousState: def __init__(self, state_path: Path, target_path: Path, project_root: Path) -> None: self.state_path: Path = state_path @@ -32,12 +42,7 @@ def __init__(self, state_path: Path, target_path: Path, project_root: Path) -> N raise results_path = self.project_root / self.state_path / "run_results.json" - if results_path.exists() and results_path.is_file(): - try: - self.results = RunResultsArtifact.read_and_check_versions(str(results_path)) - except IncompatibleSchemaError as exc: - exc.add_filename(str(results_path)) - raise + self.results = load_result_state(results_path) sources_path = self.project_root / self.state_path / "sources.json" if sources_path.exists() and sources_path.is_file(): diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 02b6fb69a95..6c2041be084 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -24,6 +24,7 @@ from dbt.common.events.base_types import EventLevel import json import pprint +from dbt.mp_context import get_mp_context import msgpack import dbt.exceptions @@ -35,6 +36,7 @@ get_adapter, get_relation_class_by_name, get_adapter_package_names, + register_adapter, ) from dbt.constants import ( MANIFEST_FILE_NAME, @@ -75,7 +77,7 @@ from dbt.context.docs import generate_runtime_docs_context from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace from dbt.context.configured import generate_macro_context -from dbt.context.providers import ParseProvider +from dbt.context.providers import ParseProvider, generate_runtime_macro_context from dbt.contracts.files import FileHash, ParseFileType, SchemaSourceFile from dbt.parser.read_files import ( ReadFilesFromFileSystem, @@ -281,7 +283,6 @@ def get_full_manifest( reset: bool = False, write_perf_info=False, ) -> Manifest: - adapter = get_adapter(config) # type: ignore # reset is set in a TaskManager load_manifest call, since # the config and adapter may be persistent. @@ -593,7 +594,6 @@ def check_for_model_deprecations(self): node.depends_on for resolved_ref in resolved_model_refs: if resolved_ref.deprecation_date: - if resolved_ref.deprecation_date < datetime.datetime.now().astimezone(): event_cls = DeprecatedReference else: @@ -1738,7 +1738,6 @@ def _process_sources_for_metric(manifest: Manifest, current_project: str, metric def _process_sources_for_node(manifest: Manifest, current_project: str, node: ManifestNode): - if isinstance(node, SeedNode): return @@ -1780,7 +1779,6 @@ def process_macro(config: RuntimeConfig, manifest: Manifest, macro: Macro) -> No # This is called in task.rpc.sql_commands when a "dynamic" node is # created in the manifest, in 'add_refs' def process_node(config: RuntimeConfig, manifest: Manifest, node: ManifestNode): - _process_sources_for_node(manifest, config.project_name, node) _process_refs(manifest, config.project_name, node, config.dependencies) ctx = generate_runtime_docs_context(config, node, manifest, config.project_name) @@ -1798,3 +1796,21 @@ def write_manifest(manifest: Manifest, target_path: str): manifest.write(path) write_semantic_manifest(manifest=manifest, target_path=target_path) + + +def parse_manifest(runtime_config, write_perf_info, write, write_json): + register_adapter(runtime_config, get_mp_context()) + adapter = get_adapter(runtime_config) + adapter.set_macro_context_generator(generate_runtime_macro_context) + manifest = ManifestLoader.get_full_manifest( + runtime_config, + write_perf_info=write_perf_info, + ) + + if write and write_json: + write_manifest(manifest, runtime_config.project_target_path) + pm = plugins.get_plugin_manager(runtime_config.project_name) + plugin_artifacts = pm.get_manifest_artifacts(manifest) + for path, plugin_artifact in plugin_artifacts.items(): + plugin_artifact.write(path) + return manifest diff --git a/core/dbt/task/retry.py b/core/dbt/task/retry.py index f02381b789f..764b2dbf19e 100644 --- a/core/dbt/task/retry.py +++ b/core/dbt/task/retry.py @@ -1,10 +1,13 @@ from pathlib import Path +from click import get_current_context +from click.core import ParameterSource from dbt.cli.flags import Flags +from dbt.flags import set_flags, get_flags from dbt.cli.types import Command as CliCommand from dbt.config import RuntimeConfig from dbt.contracts.results import NodeStatus -from dbt.contracts.state import PreviousState +from dbt.contracts.state import load_result_state from dbt.common.exceptions import DbtRuntimeError from dbt.graph import GraphQueue from dbt.task.base import ConfiguredTask @@ -17,9 +20,10 @@ from dbt.task.seed import SeedTask from dbt.task.snapshot import SnapshotTask from dbt.task.test import TestTask +from dbt.parser.manifest import parse_manifest RETRYABLE_STATUSES = {NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped, NodeStatus.RuntimeErr} -OVERRIDE_PARENT_FLAGS = { +IGNORE_PARENT_FLAGS = { "log_path", "output_path", "profiles_dir", @@ -28,8 +32,11 @@ "defer_state", "deprecated_state", "target_path", + "warn_error", } +ALLOW_CLI_OVERRIDE_FLAGS = {"vars"} + TASK_DICT = { "build": BuildTask, "compile": CompileTask, @@ -57,59 +64,64 @@ class RetryTask(ConfiguredTask): def __init__(self, args, config, manifest) -> None: - super().__init__(args, config, manifest) - - state_path = self.args.state or self.config.target_path - - if self.args.warn_error: - RETRYABLE_STATUSES.add(NodeStatus.Warn) - - self.previous_state = PreviousState( - state_path=Path(state_path), - target_path=Path(self.config.target_path), - project_root=Path(self.config.project_root), + # load previous run results + state_path = args.state or config.target_path + self.previous_results = load_result_state( + Path(config.project_root) / Path(state_path) / "run_results.json" ) - - if not self.previous_state.results: + if not self.previous_results: raise DbtRuntimeError( f"Could not find previous run in '{state_path}' target directory" ) - - self.previous_args = self.previous_state.results.args + self.previous_args = self.previous_results.args self.previous_command_name = self.previous_args.get("which") - self.task_class = TASK_DICT.get(self.previous_command_name) # type: ignore - - def run(self): - unique_ids = set( - [ - result.unique_id - for result in self.previous_state.results.results - if result.status in RETRYABLE_STATUSES - ] - ) - cli_command = CMD_DICT.get(self.previous_command_name) + # Reslove flags and config + if args.warn_error: + RETRYABLE_STATUSES.add(NodeStatus.Warn) + cli_command = CMD_DICT.get(self.previous_command_name) # type: ignore # Remove these args when their default values are present, otherwise they'll raise an exception args_to_remove = { "show": lambda x: True, "resource_types": lambda x: x == [], "warn_error_options": lambda x: x == {"exclude": [], "include": []}, } - for k, v in args_to_remove.items(): if k in self.previous_args and v(self.previous_args[k]): del self.previous_args[k] - previous_args = { - k: v for k, v in self.previous_args.items() if k not in OVERRIDE_PARENT_FLAGS + k: v for k, v in self.previous_args.items() if k not in IGNORE_PARENT_FLAGS + } + click_context = get_current_context() + current_args = { + k: v + for k, v in args.__dict__.items() + if k in IGNORE_PARENT_FLAGS + or ( + click_context.get_parameter_source(k) == ParameterSource.COMMANDLINE + and k in ALLOW_CLI_OVERRIDE_FLAGS + ) } - current_args = {k: v for k, v in self.args.__dict__.items() if k in OVERRIDE_PARENT_FLAGS} combined_args = {**previous_args, **current_args} - - retry_flags = Flags.from_dict(cli_command, combined_args) + retry_flags = Flags.from_dict(cli_command, combined_args) # type: ignore + set_flags(retry_flags) retry_config = RuntimeConfig.from_args(args=retry_flags) + # Parse manifest using resolved config/flags + manifest = parse_manifest(retry_config, False, True, retry_flags.write_json) # type: ignore + super().__init__(args, retry_config, manifest) + self.task_class = TASK_DICT.get(self.previous_command_name) # type: ignore + + def run(self): + unique_ids = set( + [ + result.unique_id + for result in self.previous_results.results + if result.status in RETRYABLE_STATUSES + ] + ) + class TaskWrapper(self.task_class): def get_graph_queue(self): new_graph = self.graph.get_subset_graph(unique_ids) @@ -120,8 +132,8 @@ def get_graph_queue(self): ) task = TaskWrapper( - retry_flags, - retry_config, + get_flags(), + self.config, self.manifest, ) diff --git a/tests/functional/retry/test_retry.py b/tests/functional/retry/test_retry.py index c0a8cbc13e4..8890a99ac16 100644 --- a/tests/functional/retry/test_retry.py +++ b/tests/functional/retry/test_retry.py @@ -126,7 +126,10 @@ def test_previous_run(self, project): write_file(models__sample_model, "models", "sample_model.sql") def test_warn_error(self, project): - # Regular build + # Our test command should succeed when run normally... + results = run_dbt(["build", "--select", "second_model"]) + + # ...but it should fail when run with warn-error, due to a warning... results = run_dbt(["--warn-error", "build", "--select", "second_model"], expect_pass=False) expected_statuses = { @@ -291,3 +294,36 @@ def test_retry(self, project): run_dbt(["run", "--project-dir", "proj_location_1"], expect_pass=False) move(proj_location_1, proj_location_2) run_dbt(["retry", "--project-dir", "proj_location_2"], expect_pass=False) + + +class TestRetryVars: + @pytest.fixture(scope="class") + def models(self): + return { + "sample_model.sql": "select {{ var('myvar_a', '1') + var('myvar_b', '2') }} as mycol", + } + + def test_retry(self, project): + # pass because default vars works + run_dbt(["run"]) + run_dbt(["run", "--vars", '{"myvar_a": "12", "myvar_b": "3 4"}'], expect_pass=False) + # fail because vars are invalid, this shows that the last passed vars are being used + # instead of using the default vars + run_dbt(["retry"], expect_pass=False) + results = run_dbt(["retry", "--vars", '{"myvar_a": "12", "myvar_b": "34"}']) + assert len(results) == 1 + + +class TestRetryFullRefresh: + @pytest.fixture(scope="class") + def models(self): + return { + "sample_model.sql": "{% if flags.FULL_REFRESH %} this is invalid sql {% else %} select 1 as mycol {% endif %}", + } + + def test_retry(self, project): + # This run should fail with invalid sql... + run_dbt(["run", "--full-refresh"], expect_pass=False) + # ...and so should this one, since the effect of the full-refresh parameter should persist. + results = run_dbt(["retry"], expect_pass=False) + assert len(results) == 1