Skip to content

Commit

Permalink
Merge branch 'main' into freshness_artifact_move
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank committed Jan 10, 2024
2 parents 77d0685 + 1e4286a commit 6df056d
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 73 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20231213-220449.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Preserve the value of vars and the --full-refresh flags when using retry.
time: 2023-12-13T22:04:49.228294-05:00
custom:
Author: peterallenwebb, ChenyuLInx
Issue: "9112"
8 changes: 4 additions & 4 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
FLAGS_DEFAULTS = {
"INDIRECT_SELECTION": "eager",
"TARGET_PATH": None,
"WARN_ERROR": None,
# Cli args without project_flags or env var option.
"FULL_REFRESH": False,
"STRICT_MODE": False,
Expand Down Expand Up @@ -84,7 +85,6 @@ class Flags:
def __init__(
self, ctx: Optional[Context] = None, project_flags: Optional[ProjectFlags] = None
) -> None:

# Set the default flags.
for key, value in FLAGS_DEFAULTS.items():
object.__setattr__(self, key, value)
Expand Down Expand Up @@ -126,7 +126,6 @@ def _assign_params(
# respected over DBT_PRINT or --print.
new_name: Union[str, None] = None
if param_name in DEPRECATED_PARAMS:

# Deprecated env vars can only be set via env var.
# We use the deprecated option in click to serialize the value
# from the env var string.
Expand Down Expand Up @@ -346,7 +345,6 @@ def command_params(command: CliCommand, args_dict: Dict[str, Any]) -> CommandPar
default_args = set([x.lower() for x in FLAGS_DEFAULTS.keys()])

res = command.to_list()

for k, v in args_dict.items():
k = k.lower()
# if a "which" value exists in the args dict, it should match the command provided
Expand All @@ -358,7 +356,9 @@ def command_params(command: CliCommand, args_dict: Dict[str, Any]) -> CommandPar
continue

# param was assigned from defaults and should not be included
if k not in (cmd_args | prnt_args) - default_args:
if k not in (cmd_args | prnt_args) or (
k in default_args and v == FLAGS_DEFAULTS[k.upper()]
):
continue

# if the param is in parent args, it should come before the arg name
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,12 +631,12 @@ def run(ctx, **kwargs):
@p.target
@p.state
@p.threads
@p.full_refresh
@requires.postflight
@requires.preflight
@requires.profile
@requires.project
@requires.runtime_config
@requires.manifest
def retry(ctx, **kwargs):
"""Retry the nodes that failed in the previous run."""
task = RetryTask(
Expand Down
25 changes: 5 additions & 20 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import dbt.tracking
from dbt.common.invocation import reset_invocation_id
from dbt.mp_context import get_mp_context
from dbt.version import installed as installed_version
from dbt.adapters.factory import adapter_management, register_adapter, get_adapter
from dbt.adapters.factory import adapter_management
from dbt.flags import set_flags, get_flag_dict
from dbt.cli.exceptions import (
ExceptionExit,
Expand All @@ -11,7 +10,6 @@
from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig
from dbt.config.runtime import load_project, load_profile, UnsetProfile
from dbt.context.providers import generate_runtime_macro_context

from dbt.common.events.base_types import EventLevel
from dbt.common.events.functions import (
Expand All @@ -28,11 +26,11 @@
from dbt.events.types import CommandCompleted, MainEncounteredError, MainStackTrace, ResourceReport
from dbt.common.exceptions import DbtBaseException as DbtException
from dbt.exceptions import DbtProjectError, FailFastError
from dbt.parser.manifest import ManifestLoader, write_manifest
from dbt.parser.manifest import parse_manifest
from dbt.profiler import profiler
from dbt.tracking import active_user, initialize_from_flags, track_run
from dbt.common.utils import cast_dict_to_dict_of_strings
from dbt.plugins import set_up_plugin_manager, get_plugin_manager
from dbt.plugins import set_up_plugin_manager

from click import Context
from functools import update_wrapper
Expand Down Expand Up @@ -273,25 +271,12 @@ def wrapper(*args, **kwargs):
raise DbtProjectError("profile, project, and runtime_config required for manifest")

runtime_config = ctx.obj["runtime_config"]
register_adapter(runtime_config, get_mp_context())
adapter = get_adapter(runtime_config)
adapter.set_macro_context_generator(generate_runtime_macro_context)

# a manifest has already been set on the context, so don't overwrite it
if ctx.obj.get("manifest") is None:
manifest = ManifestLoader.get_full_manifest(
runtime_config,
write_perf_info=write_perf_info,
ctx.obj["manifest"] = parse_manifest(
runtime_config, write_perf_info, write, ctx.obj["flags"].write_json
)

ctx.obj["manifest"] = manifest
if write and ctx.obj["flags"].write_json:
write_manifest(manifest, runtime_config.project_target_path)
pm = get_plugin_manager(runtime_config.project_name)
plugin_artifacts = pm.get_manifest_artifacts(manifest)
for path, plugin_artifact in plugin_artifacts.items():
plugin_artifact.write(path)

return func(*args, **kwargs)

return update_wrapper(wrapper, func)
Expand Down
17 changes: 11 additions & 6 deletions core/dbt/contracts/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
from dbt.exceptions import IncompatibleSchemaError


def load_result_state(results_path) -> Optional[RunResultsArtifact]:
if results_path.exists() and results_path.is_file():
try:
return RunResultsArtifact.read_and_check_versions(str(results_path))
except IncompatibleSchemaError as exc:
exc.add_filename(str(results_path))
raise
return None


class PreviousState:
def __init__(self, state_path: Path, target_path: Path, project_root: Path) -> None:
self.state_path: Path = state_path
Expand All @@ -32,12 +42,7 @@ def __init__(self, state_path: Path, target_path: Path, project_root: Path) -> N
raise

results_path = self.project_root / self.state_path / "run_results.json"
if results_path.exists() and results_path.is_file():
try:
self.results = RunResultsArtifact.read_and_check_versions(str(results_path))
except IncompatibleSchemaError as exc:
exc.add_filename(str(results_path))
raise
self.results = load_result_state(results_path)

sources_path = self.project_root / self.state_path / "sources.json"
if sources_path.exists() and sources_path.is_file():
Expand Down
26 changes: 21 additions & 5 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from dbt.common.events.base_types import EventLevel
import json
import pprint
from dbt.mp_context import get_mp_context
import msgpack

import dbt.exceptions
Expand All @@ -35,6 +36,7 @@
get_adapter,
get_relation_class_by_name,
get_adapter_package_names,
register_adapter,
)
from dbt.constants import (
MANIFEST_FILE_NAME,
Expand Down Expand Up @@ -75,7 +77,7 @@
from dbt.context.docs import generate_runtime_docs_context
from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
from dbt.context.configured import generate_macro_context
from dbt.context.providers import ParseProvider
from dbt.context.providers import ParseProvider, generate_runtime_macro_context
from dbt.contracts.files import FileHash, ParseFileType, SchemaSourceFile
from dbt.parser.read_files import (
ReadFilesFromFileSystem,
Expand Down Expand Up @@ -281,7 +283,6 @@ def get_full_manifest(
reset: bool = False,
write_perf_info=False,
) -> Manifest:

adapter = get_adapter(config) # type: ignore
# reset is set in a TaskManager load_manifest call, since
# the config and adapter may be persistent.
Expand Down Expand Up @@ -593,7 +594,6 @@ def check_for_model_deprecations(self):
node.depends_on
for resolved_ref in resolved_model_refs:
if resolved_ref.deprecation_date:

if resolved_ref.deprecation_date < datetime.datetime.now().astimezone():
event_cls = DeprecatedReference
else:
Expand Down Expand Up @@ -1738,7 +1738,6 @@ def _process_sources_for_metric(manifest: Manifest, current_project: str, metric


def _process_sources_for_node(manifest: Manifest, current_project: str, node: ManifestNode):

if isinstance(node, SeedNode):
return

Expand Down Expand Up @@ -1780,7 +1779,6 @@ def process_macro(config: RuntimeConfig, manifest: Manifest, macro: Macro) -> No
# This is called in task.rpc.sql_commands when a "dynamic" node is
# created in the manifest, in 'add_refs'
def process_node(config: RuntimeConfig, manifest: Manifest, node: ManifestNode):

_process_sources_for_node(manifest, config.project_name, node)
_process_refs(manifest, config.project_name, node, config.dependencies)
ctx = generate_runtime_docs_context(config, node, manifest, config.project_name)
Expand All @@ -1798,3 +1796,21 @@ def write_manifest(manifest: Manifest, target_path: str):
manifest.write(path)

write_semantic_manifest(manifest=manifest, target_path=target_path)


def parse_manifest(runtime_config, write_perf_info, write, write_json):
register_adapter(runtime_config, get_mp_context())
adapter = get_adapter(runtime_config)
adapter.set_macro_context_generator(generate_runtime_macro_context)
manifest = ManifestLoader.get_full_manifest(
runtime_config,
write_perf_info=write_perf_info,
)

if write and write_json:
write_manifest(manifest, runtime_config.project_target_path)
pm = plugins.get_plugin_manager(runtime_config.project_name)
plugin_artifacts = pm.get_manifest_artifacts(manifest)
for path, plugin_artifact in plugin_artifacts.items():
plugin_artifact.write(path)
return manifest
84 changes: 48 additions & 36 deletions core/dbt/task/retry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from pathlib import Path
from click import get_current_context
from click.core import ParameterSource

from dbt.cli.flags import Flags
from dbt.flags import set_flags, get_flags
from dbt.cli.types import Command as CliCommand
from dbt.config import RuntimeConfig
from dbt.artifacts.results import NodeStatus
from dbt.contracts.state import PreviousState
from dbt.contracts.state import load_result_state
from dbt.common.exceptions import DbtRuntimeError
from dbt.graph import GraphQueue
from dbt.task.base import ConfiguredTask
Expand All @@ -17,9 +20,10 @@
from dbt.task.seed import SeedTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask
from dbt.parser.manifest import parse_manifest

RETRYABLE_STATUSES = {NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped, NodeStatus.RuntimeErr}
OVERRIDE_PARENT_FLAGS = {
IGNORE_PARENT_FLAGS = {
"log_path",
"output_path",
"profiles_dir",
Expand All @@ -28,8 +32,11 @@
"defer_state",
"deprecated_state",
"target_path",
"warn_error",
}

ALLOW_CLI_OVERRIDE_FLAGS = {"vars"}

TASK_DICT = {
"build": BuildTask,
"compile": CompileTask,
Expand Down Expand Up @@ -57,59 +64,64 @@

class RetryTask(ConfiguredTask):
def __init__(self, args, config, manifest) -> None:
super().__init__(args, config, manifest)

state_path = self.args.state or self.config.target_path

if self.args.warn_error:
RETRYABLE_STATUSES.add(NodeStatus.Warn)

self.previous_state = PreviousState(
state_path=Path(state_path),
target_path=Path(self.config.target_path),
project_root=Path(self.config.project_root),
# load previous run results
state_path = args.state or config.target_path
self.previous_results = load_result_state(
Path(config.project_root) / Path(state_path) / "run_results.json"
)

if not self.previous_state.results:
if not self.previous_results:
raise DbtRuntimeError(
f"Could not find previous run in '{state_path}' target directory"
)

self.previous_args = self.previous_state.results.args
self.previous_args = self.previous_results.args
self.previous_command_name = self.previous_args.get("which")
self.task_class = TASK_DICT.get(self.previous_command_name) # type: ignore

def run(self):
unique_ids = set(
[
result.unique_id
for result in self.previous_state.results.results
if result.status in RETRYABLE_STATUSES
]
)

cli_command = CMD_DICT.get(self.previous_command_name)
# Reslove flags and config
if args.warn_error:
RETRYABLE_STATUSES.add(NodeStatus.Warn)

cli_command = CMD_DICT.get(self.previous_command_name) # type: ignore
# Remove these args when their default values are present, otherwise they'll raise an exception
args_to_remove = {
"show": lambda x: True,
"resource_types": lambda x: x == [],
"warn_error_options": lambda x: x == {"exclude": [], "include": []},
}

for k, v in args_to_remove.items():
if k in self.previous_args and v(self.previous_args[k]):
del self.previous_args[k]

previous_args = {
k: v for k, v in self.previous_args.items() if k not in OVERRIDE_PARENT_FLAGS
k: v for k, v in self.previous_args.items() if k not in IGNORE_PARENT_FLAGS
}
click_context = get_current_context()
current_args = {
k: v
for k, v in args.__dict__.items()
if k in IGNORE_PARENT_FLAGS
or (
click_context.get_parameter_source(k) == ParameterSource.COMMANDLINE
and k in ALLOW_CLI_OVERRIDE_FLAGS
)
}
current_args = {k: v for k, v in self.args.__dict__.items() if k in OVERRIDE_PARENT_FLAGS}
combined_args = {**previous_args, **current_args}

retry_flags = Flags.from_dict(cli_command, combined_args)
retry_flags = Flags.from_dict(cli_command, combined_args) # type: ignore
set_flags(retry_flags)
retry_config = RuntimeConfig.from_args(args=retry_flags)

# Parse manifest using resolved config/flags
manifest = parse_manifest(retry_config, False, True, retry_flags.write_json) # type: ignore
super().__init__(args, retry_config, manifest)
self.task_class = TASK_DICT.get(self.previous_command_name) # type: ignore

def run(self):
unique_ids = set(
[
result.unique_id
for result in self.previous_results.results
if result.status in RETRYABLE_STATUSES
]
)

class TaskWrapper(self.task_class):
def get_graph_queue(self):
new_graph = self.graph.get_subset_graph(unique_ids)
Expand All @@ -120,8 +132,8 @@ def get_graph_queue(self):
)

task = TaskWrapper(
retry_flags,
retry_config,
get_flags(),
self.config,
self.manifest,
)

Expand Down
Loading

0 comments on commit 6df056d

Please sign in to comment.