Skip to content

Commit

Permalink
add removal of node to make dbt-server compile query work
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenyuLInx committed Apr 11, 2023
1 parent 78b1a0d commit 46d1781
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 216 deletions.
138 changes: 2 additions & 136 deletions core/dbt/lib.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import os
from dbt.config.project import Project
from dbt.contracts.results import RunningStatus, collect_timing_info
from dbt.events.functions import fire_event
from dbt.events.types import NodeCompiling, NodeExecuting
from dbt.exceptions import DbtRuntimeError
from dbt.task.sql import SqlCompileRunner
from dataclasses import dataclass
from dbt.cli.resolvers import default_profiles_dir
from dbt.config.runtime import load_profile, load_project
Expand All @@ -20,54 +15,6 @@ class RuntimeArgs:
target: str


class SqlCompileRunnerNoIntrospection(SqlCompileRunner):
def compile_and_execute(self, manifest, ctx):
"""
This version of this method does not connect to the data warehouse.
As a result, introspective queries at compilation will not be supported
and will throw an error.
TODO: This is a temporary solution to more complex permissions requirements
for the semantic layer, and thus largely duplicates the code in the parent class
method. Once conditional credential usage is enabled, this should be removed.
"""
result = None
ctx.node.update_event_status(node_status=RunningStatus.Compiling)
fire_event(
NodeCompiling(
node_info=ctx.node.node_info,
)
)
with collect_timing_info("compile") as timing_info:
# if we fail here, we still have a compiled node to return
# this has the benefit of showing a build path for the errant
# model
ctx.node = self.compile(manifest)
ctx.timing.append(timing_info)

# for ephemeral nodes, we only want to compile, not run
if not ctx.node.is_ephemeral_model:
ctx.node.update_event_status(node_status=RunningStatus.Executing)
fire_event(
NodeExecuting(
node_info=ctx.node.node_info,
)
)
with collect_timing_info("execute") as timing_info:
result = self.run(ctx.node, manifest)
ctx.node = result.node

ctx.timing.append(timing_info)

return result


def load_profile_project(project_dir, profile_name_override=None):
profile = load_profile(project_dir, {}, profile_name_override)
project = load_project(project_dir, False, profile, {})
return profile, project


def get_dbt_config(project_dir, args=None, single_threaded=False):
from dbt.config.runtime import RuntimeConfig
import dbt.adapters.factory
Expand All @@ -90,7 +37,8 @@ def get_dbt_config(project_dir, args=None, single_threaded=False):

# set global flags from arguments
set_from_args(runtime_args, None)
profile, project = load_profile_project(project_dir, profile_name)
profile = load_profile(project_dir, {}, profile_name)
project = load_project(project_dir, False, profile, {})
assert type(project) is Project

config = RuntimeConfig.from_parts(project, profile, runtime_args)
Expand All @@ -111,88 +59,6 @@ def get_dbt_config(project_dir, args=None, single_threaded=False):
return config


def get_task_by_type(type):
from dbt.task.run import RunTask
from dbt.task.list import ListTask
from dbt.task.seed import SeedTask
from dbt.task.test import TestTask
from dbt.task.build import BuildTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.run_operation import RunOperationTask

if type == "run":
return RunTask
elif type == "test":
return TestTask
elif type == "list":
return ListTask
elif type == "seed":
return SeedTask
elif type == "build":
return BuildTask
elif type == "snapshot":
return SnapshotTask
elif type == "run_operation":
return RunOperationTask

raise DbtRuntimeError("not a valid task")


def create_task(type, args, manifest, config):
task = get_task_by_type(type)

def no_op(*args, **kwargs):
pass

task = task(args, config)
task.load_manifest = no_op
task.manifest = manifest
return task


def _get_operation_node(manifest, project_path, sql, node_name):
from dbt.parser.manifest import process_node
from dbt.parser.sql import SqlBlockParser
import dbt.adapters.factory

config = get_dbt_config(project_path)
block_parser = SqlBlockParser(
project=config,
manifest=manifest,
root_project=config,
)

adapter = dbt.adapters.factory.get_adapter(config)
sql_node = block_parser.parse_remote(sql, node_name)
process_node(config, manifest, sql_node)
return config, sql_node, adapter


def compile_sql(manifest, project_path, sql, node_name="query"):
config, node, adapter = _get_operation_node(manifest, project_path, sql, node_name)
allow_introspection = str(os.environ.get("__DBT_ALLOW_INTROSPECTION", "1")).lower() in (
"true",
"1",
"on",
)

if allow_introspection:
runner = SqlCompileRunner(config, adapter, node, 1, 1)
else:
runner = SqlCompileRunnerNoIntrospection(config, adapter, node, 1, 1)
return runner.safe_run(manifest)


def execute_sql(manifest, project_path, sql, node_name="query"):
from dbt.task.sql import SqlExecuteRunner

config, node, adapter = _get_operation_node(manifest, project_path, sql, node_name)

runner = SqlExecuteRunner(config, adapter, node, 1, 1)

return runner.safe_run(manifest)


def parse_to_manifest(config):
from dbt.parser.manifest import ManifestLoader

Expand Down
19 changes: 1 addition & 18 deletions core/dbt/parser/sql.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import os
from dataclasses import dataclass
from typing import Iterable

from dbt.contracts.graph.manifest import SourceFile
from dbt.contracts.graph.nodes import SqlNode, Macro
from dbt.contracts.graph.unparsed import UnparsedMacro
from dbt.contracts.graph.nodes import SqlNode
from dbt.exceptions import DbtInternalError
from dbt.node_types import NodeType
from dbt.parser.base import SimpleSQLParser
from dbt.parser.macros import MacroParser
from dbt.parser.search import FileBlock


Expand Down Expand Up @@ -46,17 +43,3 @@ def parse_remote(self, sql: str, name: str) -> SqlNode:
source_file = SourceFile.remote(sql, self.project.project_name, "sql")
contents = SqlBlock(block_name=name, file=source_file)
return self.parse_node(contents)


class SqlMacroParser(MacroParser):
def parse_remote(self, contents) -> Iterable[Macro]:
base = UnparsedMacro(
path="from remote system",
original_file_path="from remote system",
package_name=self.project.project_name,
raw_code=contents,
language="sql",
resource_type=NodeType.Macro,
)
for node in self.parse_unparsed_macros(base):
yield node
20 changes: 20 additions & 0 deletions core/dbt/task/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def compile(self, manifest):


class CompileTask(GraphRunnableTask):
# We add a new inline node to the manifest during initialization
# it should be removed before the task is complete
_inline_node_id = None

# TODO remove when stu's PR about skip adapter cache is merged
def before_run(self, adapter, selected_uids: AbstractSet[str]):
if bool(getattr(self.args, "inline", None)):
# don't populate adapter cache when doing inline queries
pass
else:
super().before_run(adapter, selected_uids)

def raise_on_first_error(self):
return True

Expand Down Expand Up @@ -130,9 +142,17 @@ def _runtime_initialize(self):
)
sql_node = block_parser.parse_remote(self.args.inline, "inline_query")
process_node(self.config, self.manifest, sql_node)
# keep track of the node added to the manifest
self._inline_node_id = sql_node.unique_id

super()._runtime_initialize()

def after_run(self, adapter, results):
# remove inline node from manifest
if self._inline_node_id:
self.manifest.nodes.pop(self._inline_node_id)
super().after_run(adapter, results)

def _handle_result(self, result):
super()._handle_result(result)

Expand Down
62 changes: 0 additions & 62 deletions test/unit/test_lib.py

This file was deleted.

0 comments on commit 46d1781

Please sign in to comment.