Skip to content

Commit

Permalink
Add RelationConfig Protocol for use in Relation.create_from (#9210)
Browse files Browse the repository at this point in the history
* move relation contract to dbt.adapters

* changelog entry

* first pass: clean up relation.create_from

* type ignores

* type ignore

* changelog entry

* update RelationConfig variable names
  • Loading branch information
MichelleArk authored Dec 6, 2023
1 parent ed8f5d3 commit eb96e3d
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 82 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20231205-170725.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Introduce RelationConfig Protocol, consolidate Relation.create_from
time: 2023-12-05T17:07:25.33861+09:00
custom:
Author: michelleark
Issue: "9215"
4 changes: 2 additions & 2 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def _get_cache_schemas(self, manifest: Manifest) -> Set[BaseRelation]:
"""
# the cache only cares about executable nodes
return {
self.Relation.create_from(self.config, node).without_identifier()
self.Relation.create_from(self.config, node).without_identifier() # type: ignore[arg-type]
for node in manifest.nodes.values()
if (node.is_relational and not node.is_ephemeral_model and not node.is_external_node)
}
Expand Down Expand Up @@ -470,7 +470,7 @@ def _get_catalog_relations(self, manifest: Manifest) -> List[BaseRelation]:
manifest.sources.values(),
)

relations = [self.Relation.create_from(self.config, n) for n in nodes]
relations = [self.Relation.create_from(self.config, n) for n in nodes] # type: ignore[arg-type]
return relations

def _relations_cache_for_schemas(
Expand Down
77 changes: 21 additions & 56 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@
from dataclasses import dataclass, field
from typing import Optional, TypeVar, Any, Type, Dict, Iterator, Tuple, Set, Union, FrozenSet

from dbt.contracts.graph.nodes import SourceDefinition, ManifestNode, ResultNode, ParsedNode
from dbt.adapters.contracts.relation import (
RelationConfig,
RelationType,
ComponentName,
HasQuoting,
FakeAPIObject,
Policy,
Path,
)
from dbt.common.exceptions import DbtInternalError
from dbt.adapters.exceptions import MultipleDatabasesNotAllowedError, ApproximateMatchError
from dbt.node_types import NodeType
from dbt.common.utils import filter_null_values, deep_merge
from dbt.adapters.utils import classproperty

Expand Down Expand Up @@ -198,83 +196,50 @@ def quoted(self, identifier):
identifier=identifier,
)

@classmethod
def create_from_source(cls: Type[Self], source: SourceDefinition, **kwargs: Any) -> Self:
source_quoting = source.quoting.to_dict(omit_none=True)
source_quoting.pop("column", None)
quote_policy = deep_merge(
cls.get_default_quote_policy().to_dict(omit_none=True),
source_quoting,
kwargs.get("quote_policy", {}),
)

return cls.create(
database=source.database,
schema=source.schema,
identifier=source.identifier,
quote_policy=quote_policy,
**kwargs,
)

@staticmethod
def add_ephemeral_prefix(name: str):
return f"__dbt__cte__{name}"

@classmethod
def create_ephemeral_from_node(
def create_ephemeral_from(
cls: Type[Self],
config: HasQuoting,
node: ManifestNode,
relation_config: RelationConfig,
) -> Self:
# Note that ephemeral models are based on the name.
identifier = cls.add_ephemeral_prefix(node.name)
identifier = cls.add_ephemeral_prefix(relation_config.name)
return cls.create(
type=cls.CTE,
identifier=identifier,
).quote(identifier=False)

@classmethod
def create_from_node(
def create_from(
cls: Type[Self],
config: HasQuoting,
node,
quote_policy: Optional[Dict[str, bool]] = None,
quoting: HasQuoting,
relation_config: RelationConfig,
**kwargs: Any,
) -> Self:
if quote_policy is None:
quote_policy = {}
quote_policy = kwargs.pop("quote_policy", {})

config_quoting = relation_config.quoting_dict
config_quoting.pop("column", None)

quote_policy = dbt.common.utils.merge(config.quoting, quote_policy)
# precedence: kwargs quoting > relation config quoting > base quoting > default quoting
quote_policy = deep_merge(
cls.get_default_quote_policy().to_dict(omit_none=True),
quoting.quoting,
config_quoting,
quote_policy,
)

return cls.create(
database=node.database,
schema=node.schema,
identifier=node.alias,
database=relation_config.database,
schema=relation_config.schema,
identifier=relation_config.identifier,
quote_policy=quote_policy,
**kwargs,
)

@classmethod
def create_from(
cls: Type[Self],
config: HasQuoting,
node: ResultNode,
**kwargs: Any,
) -> Self:
if node.resource_type == NodeType.Source:
if not isinstance(node, SourceDefinition):
raise DbtInternalError(
"type mismatch, expected SourceDefinition but got {}".format(type(node))
)
return cls.create_from_source(node, **kwargs)
else:
# Can't use ManifestNode here because of parameterized generics
if not isinstance(node, (ParsedNode)):
raise DbtInternalError(
f"type mismatch, expected ManifestNode but got {type(node)}"
)
return cls.create_from_node(config, node, **kwargs)

@classmethod
def create(
cls: Type[Self],
Expand Down
Empty file.
8 changes: 8 additions & 0 deletions core/dbt/adapters/contracts/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ class RelationType(StrEnum):
Ephemeral = "ephemeral"


class RelationConfig(Protocol):
name: str
database: str
schema: str
identifier: str
quoting_dict: Dict[str, bool]


class ComponentName(StrEnum):
Database = "database"
Schema = "schema"
Expand Down
18 changes: 5 additions & 13 deletions core/dbt/adapters/protocol.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
from dataclasses import dataclass
from typing import (
Type,
Hashable,
Optional,
ContextManager,
List,
Generic,
TypeVar,
Tuple,
)
from typing import Type, Hashable, Optional, ContextManager, List, Generic, TypeVar, Tuple, Any
from typing_extensions import Protocol

import agate

from dbt.adapters.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse
from dbt.adapters.contracts.relation import Policy, HasQuoting
from dbt.contracts.graph.nodes import ResultNode
from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig
from dbt.contracts.graph.model_config import BaseConfig
from dbt.contracts.graph.manifest import Manifest

Expand All @@ -42,7 +32,9 @@ def get_default_quote_policy(cls) -> Policy:
...

@classmethod
def create_from(cls: Type[Self], config: HasQuoting, node: ResultNode) -> Self:
def create_from(
cls: Type[Self], quoting: HasQuoting, relation_config: RelationConfig, **kwargs: Any
) -> Self:
...


Expand Down
11 changes: 3 additions & 8 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,6 @@ def __init__(self, adapter):
def __getattr__(self, key):
return getattr(self._relation_type, key)

def create_from_source(self, *args, **kwargs):
# bypass our create when creating from source so as not to mess up
# the source quoting
return self._relation_type.create_from_source(*args, **kwargs)

def create(self, *args, **kwargs):
kwargs["quote_policy"] = merge(self._quoting_config, kwargs.pop("quote_policy", {}))
return self._relation_type.create(*args, **kwargs)
Expand Down Expand Up @@ -529,7 +524,7 @@ def resolve(
def create_relation(self, target_model: ManifestNode) -> RelationProxy:
if target_model.is_ephemeral_model:
self.model.set_cte(target_model.unique_id, None)
return self.Relation.create_ephemeral_from_node(self.config, target_model)
return self.Relation.create_ephemeral_from(target_model)
else:
return self.Relation.create_from(self.config, target_model)

Expand Down Expand Up @@ -588,7 +583,7 @@ def resolve(self, source_name: str, table_name: str):
target_kind="source",
disabled=(isinstance(target_source, Disabled)),
)
return self.Relation.create_from_source(target_source)
return self.Relation.create_from(self.config, target_source)


# metric` implementations
Expand Down Expand Up @@ -1475,7 +1470,7 @@ def defer_relation(self) -> Optional[RelationProxy]:
object for that stateful other
"""
if getattr(self.model, "defer_relation", None):
return self.db_wrapper.Relation.create_from_node(
return self.db_wrapper.Relation.create_from(
self.config, self.model.defer_relation # type: ignore
)
else:
Expand Down
7 changes: 7 additions & 0 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ def __pre_deserialize__(cls, data):
data["database"] = None
return data

@property
def quoting_dict(self) -> Dict[str, bool]:
if hasattr(self, "quoting"):
return self.quoting.to_dict(omit_none=True)
else:
return {}


@dataclass
class MacroDependsOn(dbtClassMixin, Replaceable):
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,7 @@ def _check_resource_uniqueness(

# the full node name is really defined by the adapter's relation
relation_cls = get_relation_class_by_name(config.credentials.type)
relation = relation_cls.create_from(config=config, node=node)
relation = relation_cls.create_from(quoting=config, relation_config=node) # type: ignore[arg-type]
full_node_name = str(relation)

existing_alias = alias_resources.get(full_node_name)
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRe

# cache the 'other' schemas too!
if node.defer_relation: # type: ignore
other_relation = adapter.Relation.create_from_node(
other_relation = adapter.Relation.create_from(
self.config, node.defer_relation # type: ignore
)
result.add(other_relation.without_identifier())
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/freshness.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def from_run_result(self, result, start_time, timing_info):
return result

def execute(self, compiled_node, manifest):
relation = self.adapter.Relation.create_from_source(compiled_node)
relation = self.adapter.Relation.create_from(self.config, compiled_node)
# given a Source, calculate its freshness.
with self.adapter.connection_named(compiled_node.unique_id, compiled_node):
self.adapter.clear_transaction()
Expand Down

0 comments on commit eb96e3d

Please sign in to comment.