Skip to content

Commit

Permalink
add macro_context_generator on adapter (#9251)
Browse files Browse the repository at this point in the history
* moving types_pb2.py to common/events

* remove manifest from adapter.execute_macro, replace with MacroResolver + remove lazy loading

* rename to MacroResolverProtocol

* pass MacroResolverProtcol in adapter.calculate_freshness_from_metadata

* changelog entry

* fix adapter.calculate_freshness call

* add macro_context_generator on adapter

* fix adapter test setup

* changelog entry

* Update parser to support conversion metrics (#9173)

* added ConversionTypeParams classes

* updated parser for ConversionTypeParams

* added step to populate input_measure for conversion metrics

* version bump on DSI

* comment back manifest generating line

* updated v12 schemas

* added tests

* added changelog

* Add typing for macro_context_generator, fix query_header_context

---------

Co-authored-by: Colin <colin.rogers@dbtlabs.com>
Co-authored-by: William Deng <33618746+WilliamDee@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 11, 2023
1 parent c7b9b1a commit a68e427
Show file tree
Hide file tree
Showing 20 changed files with 767 additions and 57 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20231206-181458.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Adds support for parsing conversion metric related properties for the semantic
layer.
time: 2023-12-06T18:14:58.688221-05:00
custom:
Author: WilliamDee
Issue: "9203"
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20231208-004854.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: add macro_context_generator on adapter
time: 2023-12-08T00:48:54.506911+09:00
custom:
Author: michelleark
Issue: "9247"
36 changes: 22 additions & 14 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@
MacroResultError,
)

from dbt.adapters.protocol import AdapterConfig
from dbt.adapters.protocol import (
AdapterConfig,
MacroContextGeneratorCallable,
)
from dbt.common.clients.agate_helper import (
empty_table,
get_column_value_uncased,
Expand All @@ -76,7 +79,11 @@
from dbt.common.utils import filter_null_values, executor, cast_to_str, AttrDict

from dbt.adapters.contracts.relation import RelationConfig
from dbt.adapters.base.connections import Connection, AdapterResponse, BaseConnectionManager
from dbt.adapters.base.connections import (
Connection,
AdapterResponse,
BaseConnectionManager,
)
from dbt.adapters.base.meta import AdapterMeta, available
from dbt.adapters.base.relation import (
ComponentName,
Expand Down Expand Up @@ -258,6 +265,7 @@ def __init__(self, config, mp_context: SpawnContext) -> None:
self.cache = RelationsCache(log_cache_events=config.log_cache_events)
self.connections = self.ConnectionManager(config, mp_context)
self._macro_resolver: Optional[MacroResolverProtocol] = None
self._macro_context_generator: Optional[MacroContextGeneratorCallable] = None

###
# Methods to set / access a macro resolver
Expand All @@ -272,6 +280,12 @@ def clear_macro_resolver(self) -> None:
if self._macro_resolver is not None:
self._macro_resolver = None

def set_macro_context_generator(
self,
macro_context_generator: MacroContextGeneratorCallable,
) -> None:
self._macro_context_generator = macro_context_generator

###
# Methods that pass through to the connection manager
###
Expand Down Expand Up @@ -1057,7 +1071,10 @@ def execute_macro(

resolver = macro_resolver or self._macro_resolver
if resolver is None:
raise DbtInternalError("macro resolver was None when calling execute_macro!")
raise DbtInternalError("Macro resolver was None when calling execute_macro!")

if self._macro_context_generator is None:
raise DbtInternalError("Macro context generator was None when calling execute_macro!")

macro = resolver.find_macro_by_name(macro_name, self.config.project_name, project)
if macro is None:
Expand All @@ -1071,17 +1088,8 @@ def execute_macro(
macro_name, package_name
)
)
# This causes a reference cycle, as generate_runtime_macro_context()
# ends up calling get_adapter, so the import has to be here.
from dbt.context.providers import generate_runtime_macro_context

macro_context = generate_runtime_macro_context(
# TODO CT-211
macro=macro,
config=self.config,
manifest=resolver, # type: ignore[arg-type]
package_name=project,
)

macro_context = self._macro_context_generator(macro, self.config, resolver, project)
macro_context.update(context_override)

macro_function = CallableMacroGenerator(macro, macro_context)
Expand Down
18 changes: 18 additions & 0 deletions core/dbt/adapters/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dbt.adapters.contracts.macros import MacroResolverProtocol
from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig
from dbt.common.contracts.config.base import BaseConfig
from dbt.common.clients.jinja import MacroProtocol


@dataclass
Expand Down Expand Up @@ -55,6 +56,17 @@ def create_from(
Column_T = TypeVar("Column_T", bound=ColumnProtocol)


class MacroContextGeneratorCallable(Protocol):
def __call__(
self,
macro_protocol: MacroProtocol,
config: AdapterRequiredConfig,
macro_resolver: MacroResolverProtocol,
package_name: Optional[str],
) -> Dict[str, Any]:
...


# TODO CT-211
class AdapterProtocol( # type: ignore[misc]
Protocol,
Expand Down Expand Up @@ -86,6 +98,12 @@ def get_macro_resolver(self) -> Optional[MacroResolverProtocol]:
def clear_macro_resolver(self) -> None:
...

def set_macro_context_generator(
self,
macro_context_generator: MacroContextGeneratorCallable,
) -> None:
...

@classmethod
def type(cls) -> str:
pass
Expand Down
5 changes: 4 additions & 1 deletion core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dbt.tracking
from dbt.common.invocation import reset_invocation_id
from dbt.version import installed as installed_version
from dbt.adapters.factory import adapter_management, register_adapter
from dbt.adapters.factory import adapter_management, register_adapter, get_adapter
from dbt.flags import set_flags, get_flag_dict
from dbt.cli.exceptions import (
ExceptionExit,
Expand All @@ -10,6 +10,7 @@
from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig
from dbt.config.runtime import load_project, load_profile, UnsetProfile
from dbt.context.providers import generate_runtime_macro_context

from dbt.common.events.base_types import EventLevel
from dbt.common.events.functions import (
Expand Down Expand Up @@ -274,6 +275,8 @@ def wrapper(*args, **kwargs):

runtime_config = ctx.obj["runtime_config"]
register_adapter(runtime_config)
adapter = get_adapter(runtime_config)
adapter.set_macro_context_generator(generate_runtime_macro_context)

# a manifest has already been set on the context, so don't overwrite it
if ctx.obj.get("manifest") is None:
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/graph/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


DERIVED_METRICS = [MetricType.DERIVED, MetricType.RATIO]
BASE_METRICS = [MetricType.SIMPLE, MetricType.CUMULATIVE]
BASE_METRICS = [MetricType.SIMPLE, MetricType.CUMULATIVE, MetricType.CONVERSION]


class MetricReference(object):
Expand Down
18 changes: 17 additions & 1 deletion core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SourceFileMetadata,
)
from dbt.contracts.graph.unparsed import (
ConstantPropertyInput,
Docs,
ExposureType,
ExternalTable,
Expand Down Expand Up @@ -64,7 +65,11 @@
TimeDimensionReference,
)
from dbt_semantic_interfaces.references import MetricReference as DSIMetricReference
from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity
from dbt_semantic_interfaces.type_enums import (
ConversionCalculationType,
MetricType,
TimeGranularity,
)

from .model_config import (
NodeConfig,
Expand Down Expand Up @@ -1409,6 +1414,16 @@ def post_aggregation_reference(self) -> DSIMetricReference:
return DSIMetricReference(element_name=self.alias or self.name)


@dataclass
class ConversionTypeParams(dbtClassMixin):
base_measure: MetricInputMeasure
conversion_measure: MetricInputMeasure
entity: str
calculation: ConversionCalculationType = ConversionCalculationType.CONVERSION_RATE
window: Optional[MetricTimeWindow] = None
constant_properties: Optional[List[ConstantPropertyInput]] = None


@dataclass
class MetricTypeParams(dbtClassMixin):
measure: Optional[MetricInputMeasure] = None
Expand All @@ -1419,6 +1434,7 @@ class MetricTypeParams(dbtClassMixin):
window: Optional[MetricTimeWindow] = None
grain_to_date: Optional[TimeGranularity] = None
metrics: Optional[List[MetricInput]] = None
conversion_type_params: Optional[ConversionTypeParams] = None


@dataclass
Expand Down
21 changes: 21 additions & 0 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
ValidationError,
)

from dbt_semantic_interfaces.type_enums import ConversionCalculationType

from dataclasses import dataclass, field
from datetime import timedelta
from pathlib import Path
Expand Down Expand Up @@ -600,6 +602,24 @@ class UnparsedMetricInput(dbtClassMixin):
offset_to_grain: Optional[str] = None # str is really a TimeGranularity Enum


@dataclass
class ConstantPropertyInput(dbtClassMixin):
base_property: str
conversion_property: str


@dataclass
class UnparsedConversionTypeParams(dbtClassMixin):
base_measure: Union[UnparsedMetricInputMeasure, str]
conversion_measure: Union[UnparsedMetricInputMeasure, str]
entity: str
calculation: str = (
ConversionCalculationType.CONVERSION_RATE.value
) # ConversionCalculationType Enum
window: Optional[str] = None
constant_properties: Optional[List[ConstantPropertyInput]] = None


@dataclass
class UnparsedMetricTypeParams(dbtClassMixin):
measure: Optional[Union[UnparsedMetricInputMeasure, str]] = None
Expand All @@ -609,6 +629,7 @@ class UnparsedMetricTypeParams(dbtClassMixin):
window: Optional[str] = None
grain_to_date: Optional[str] = None # str is really a TimeGranularity Enum
metrics: Optional[List[Union[UnparsedMetricInput, str]]] = None
conversion_type_params: Optional[UnparsedConversionTypeParams] = None


@dataclass
Expand Down
69 changes: 47 additions & 22 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from itertools import chain
import time

from dbt.context.manifest import generate_query_header_context
from dbt.contracts.graph.semantic_manifest import SemanticManifest
from dbt.common.events.base_types import EventLevel
import json
Expand Down Expand Up @@ -72,7 +73,6 @@
)
from dbt.config import Project, RuntimeConfig
from dbt.context.docs import generate_runtime_docs_context
from dbt.context.manifest import generate_query_header_context
from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
from dbt.context.configured import generate_macro_context
from dbt.context.providers import ParseProvider
Expand Down Expand Up @@ -237,7 +237,7 @@ def __init__(
self,
root_project: RuntimeConfig,
all_projects: Mapping[str, Project],
macro_hook: Optional[Callable[[Dict[str, Any]], Any]] = None,
macro_hook: Optional[Callable[[Manifest], Any]] = None,
file_diff: Optional[FileDiff] = None,
) -> None:
self.root_project: RuntimeConfig = root_project
Expand All @@ -251,9 +251,9 @@ def __init__(
# This is a MacroQueryStringSetter callable, which is called
# later after we set the MacroManifest in the adapter. It sets
# up the query headers.
self.macro_hook: Callable[[Dict[str, Any]], Any]
self.macro_hook: Callable[[Manifest], Any]
if macro_hook is None:
self.macro_hook = lambda c: None
self.macro_hook = lambda m: None
else:
self.macro_hook = macro_hook

Expand Down Expand Up @@ -1002,6 +1002,8 @@ def build_manifest_state_check(self):
def save_macros_to_adapter(self, adapter):
macro_manifest = MacroManifest(self.manifest.macros)
adapter.set_macro_resolver(macro_manifest)
# This executes the callable macro_hook and sets the
# query headers
# This executes the callable macro_hook and sets the query headers
query_header_context = generate_query_header_context(adapter.config, macro_manifest)
self.macro_hook(query_header_context)
Expand Down Expand Up @@ -1033,7 +1035,7 @@ def create_macro_manifest(self):
def load_macros(
cls,
root_config: RuntimeConfig,
macro_hook: Callable[[Dict[str, Any]], Any],
macro_hook: Callable[[Manifest], Any],
base_macros_only=False,
) -> Manifest:
with PARSING_STATE:
Expand Down Expand Up @@ -1530,43 +1532,66 @@ def _process_refs(
node.depends_on.add_node(target_model_id)


def _process_metric_node(
def _process_metric_depends_on(
manifest: Manifest,
current_project: str,
metric: Metric,
) -> None:
"""Sets a metric's `input_measures` and `depends_on` properties"""

# This ensures that if this metrics input_measures have already been set
# we skip the work. This could happen either due to recursion or if multiple
# metrics derive from another given metric.
# NOTE: This does not protect against infinite loops
if len(metric.type_params.input_measures) > 0:
return
"""For a given metric, set the `depends_on` property"""

if metric.type is MetricType.SIMPLE or metric.type is MetricType.CUMULATIVE:
assert (
metric.type_params.measure is not None
), f"{metric} should have a measure defined, but it does not."
metric.type_params.input_measures.append(metric.type_params.measure)
assert len(metric.type_params.input_measures) > 0
for input_measure in metric.type_params.input_measures:
target_semantic_model = manifest.resolve_semantic_model_for_measure(
target_measure_name=metric.type_params.measure.name,
target_measure_name=input_measure.name,
current_project=current_project,
node_package=metric.package_name,
)
if target_semantic_model is None:
raise dbt.exceptions.ParsingError(
f"A semantic model having a measure `{metric.type_params.measure.name}` does not exist but was referenced.",
f"A semantic model having a measure `{input_measure.name}` does not exist but was referenced.",
node=metric,
)
if target_semantic_model.config.enabled is False:
raise dbt.exceptions.ParsingError(
f"The measure `{metric.type_params.measure.name}` is referenced on disabled semantic model `{target_semantic_model.name}`.",
f"The measure `{input_measure.name}` is referenced on disabled semantic model `{target_semantic_model.name}`.",
node=metric,
)

metric.depends_on.add_node(target_semantic_model.unique_id)


def _process_metric_node(
manifest: Manifest,
current_project: str,
metric: Metric,
) -> None:
"""Sets a metric's `input_measures` and `depends_on` properties"""

# This ensures that if this metrics input_measures have already been set
# we skip the work. This could happen either due to recursion or if multiple
# metrics derive from another given metric.
# NOTE: This does not protect against infinite loops
if len(metric.type_params.input_measures) > 0:
return

if metric.type is MetricType.SIMPLE or metric.type is MetricType.CUMULATIVE:
assert (
metric.type_params.measure is not None
), f"{metric} should have a measure defined, but it does not."
metric.type_params.input_measures.append(metric.type_params.measure)
_process_metric_depends_on(
manifest=manifest, current_project=current_project, metric=metric
)
elif metric.type is MetricType.CONVERSION:
conversion_type_params = metric.type_params.conversion_type_params
assert (
conversion_type_params
), f"{metric.name} is a conversion metric and must have conversion_type_params defined."
metric.type_params.input_measures.append(conversion_type_params.base_measure)
metric.type_params.input_measures.append(conversion_type_params.conversion_measure)
_process_metric_depends_on(
manifest=manifest, current_project=current_project, metric=metric
)
elif metric.type is MetricType.DERIVED or metric.type is MetricType.RATIO:
input_metrics = metric.input_metrics
if metric.type is MetricType.RATIO:
Expand Down
Loading

0 comments on commit a68e427

Please sign in to comment.