Skip to content

Commit

Permalink
Merge pull request #287 from ClickHouse/upgrade-to-dbt-core-1.8
Browse files Browse the repository at this point in the history
Upgrade to dbt core 1.8
  • Loading branch information
BentsiLeviav authored Jun 6, 2024
2 parents 9c212f2 + dc1534d commit 506bd18
Show file tree
Hide file tree
Showing 14 changed files with 111 additions and 91 deletions.
2 changes: 1 addition & 1 deletion dbt/adapters/clickhouse/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = '1.7.7'
version = '1.8.0'
20 changes: 9 additions & 11 deletions dbt/adapters/clickhouse/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

from dbt.events.functions import fire_event, fire_event_if
from dbt.events.types import CacheAction, CacheDumpGraph
from dbt.exceptions import (
from dbt.adapters.events.types import CacheAction, CacheDumpGraph
from dbt.adapters.exceptions import (
NewNameAlreadyInCacheError,
NoneRelationFoundError,
TruncatedModelNameCausedCollisionError,
)
from dbt.flags import get_flags
from dbt_common.events.functions import fire_event, fire_event_if

ReferenceKey = namedtuple("ReferenceKey", "schema identifier")

Expand Down Expand Up @@ -152,10 +151,11 @@ class ClickHouseRelationsCache:
:attr Set[str] schemas: The set of known/cached schemas
"""

def __init__(self) -> None:
def __init__(self, log_cache_events: bool = False) -> None:
self.relations: Dict[ReferenceKey, CachedRelation] = {}
self.lock = threading.RLock()
self.schemas: Set[Optional[str]] = set()
self.log_cache_events = log_cache_events

def add_schema(
self,
Expand Down Expand Up @@ -233,18 +233,17 @@ def add(self, relation):
:param BaseRelation relation: The underlying relation.
"""
flags = get_flags()
cached = CachedRelation(relation)
fire_event_if(
flags.LOG_CACHE_EVENTS,
self.log_cache_events,
lambda: CacheDumpGraph(before_after="before", action="adding", dump=self.dump_graph()),
)
fire_event(CacheAction(action="add_relation", ref_key=_make_ref_key_dict(cached)))

with self.lock:
self._setdefault(cached)
fire_event_if(
flags.LOG_CACHE_EVENTS,
self.log_cache_events,
lambda: CacheDumpGraph(before_after="after", action="adding", dump=self.dump_graph()),
)

Expand Down Expand Up @@ -368,9 +367,8 @@ def rename(self, old, new):
ref_key_2=new_key._asdict(),
)
)
flags = get_flags()
fire_event_if(
flags.LOG_CACHE_EVENTS,
self.log_cache_events,
lambda: CacheDumpGraph(before_after="before", action="rename", dump=self.dump_graph()),
)

Expand All @@ -381,7 +379,7 @@ def rename(self, old, new):
self._setdefault(CachedRelation(new))

fire_event_if(
flags.LOG_CACHE_EVENTS,
self.log_cache_events,
lambda: CacheDumpGraph(before_after="after", action="rename", dump=self.dump_graph()),
)

Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/clickhouse/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, TypeVar

from dbt.adapters.base.column import Column
from dbt.exceptions import DbtRuntimeError
from dbt_common.exceptions import DbtRuntimeError

Self = TypeVar('Self', bound='ClickHouseColumn')

Expand Down
6 changes: 3 additions & 3 deletions dbt/adapters/clickhouse/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union

import dbt.exceptions
from dbt.adapters.contracts.connection import AdapterResponse, Connection
from dbt.adapters.sql import SQLConnectionManager
from dbt.contracts.connection import AdapterResponse, Connection

from dbt.adapters.clickhouse.dbclient import ChRetryableException, get_db_client
from dbt.adapters.clickhouse.logger import logger
Expand Down Expand Up @@ -68,7 +68,7 @@ def get_table_from_response(cls, response, column_names) -> "agate.Table":
:param response: ClickHouse query result
:param column_names: Table column names
"""
from dbt.clients.agate_helper import table_from_data_flat
from dbt_common.clients.agate_helper import table_from_data_flat

data = []
for row in response:
Expand Down Expand Up @@ -101,7 +101,7 @@ def execute(
query_result.result_set, query_result.column_names
)
else:
from dbt.clients.agate_helper import empty_table
from dbt_common.clients.agate_helper import empty_table

table = empty_table()
return AdapterResponse(_message=status), table
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/clickhouse/credentials.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional

from dbt.contracts.connection import Credentials
from dbt.exceptions import DbtRuntimeError
from dbt.adapters.contracts.connection import Credentials
from dbt_common.exceptions import DbtRuntimeError


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion dbt/adapters/clickhouse/dbclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from abc import ABC, abstractmethod
from typing import Dict

from dbt.exceptions import DbtConfigError, DbtDatabaseError, FailedToConnectError
from dbt.adapters.exceptions import FailedToConnectError
from dbt_common.exceptions import DbtConfigError, DbtDatabaseError

from dbt.adapters.clickhouse.credentials import ClickHouseCredentials
from dbt.adapters.clickhouse.errors import (
Expand Down
6 changes: 3 additions & 3 deletions dbt/adapters/clickhouse/httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import clickhouse_connect
from clickhouse_connect.driver.exceptions import DatabaseError, OperationalError
from dbt.exceptions import DbtDatabaseError
from dbt.version import __version__ as dbt_version
from dbt.adapters.__about__ import version as dbt_adapters_version
from dbt_common.exceptions import DbtDatabaseError

from dbt.adapters.clickhouse import ClickHouseColumn
from dbt.adapters.clickhouse.__version__ import version as dbt_clickhouse_version
Expand Down Expand Up @@ -60,7 +60,7 @@ def _create_client(self, credentials):
compress=False if credentials.compression == '' else bool(credentials.compression),
connect_timeout=credentials.connect_timeout,
send_receive_timeout=credentials.send_receive_timeout,
client_name=f'dbt/{dbt_version} dbt-clickhouse/{dbt_clickhouse_version}',
client_name=f'dbt-adapters/{dbt_adapters_version} dbt-clickhouse/{dbt_clickhouse_version}',
verify=credentials.verify,
query_limit=0,
settings=self._conn_settings,
Expand Down
63 changes: 43 additions & 20 deletions dbt/adapters/clickhouse/impl.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
import csv
import io
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
from multiprocessing.context import SpawnContext
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)

from dbt.adapters.base import AdapterConfig, available
from dbt.adapters.base.impl import BaseAdapter, ConstraintSupport
from dbt.adapters.base.relation import BaseRelation, InformationSchema
from dbt.adapters.capability import Capability, CapabilityDict, CapabilitySupport, Support
from dbt.adapters.contracts.relation import Path, RelationConfig
from dbt.adapters.events.types import ConstraintNotSupported
from dbt.adapters.sql import SQLAdapter
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ConstraintType, ModelLevelConstraint
from dbt.contracts.relation import Path
from dbt.events.functions import warn_or_error
from dbt.events.types import ConstraintNotSupported
from dbt.exceptions import DbtInternalError, DbtRuntimeError, NotImplementedError
from dbt.utils import filter_null_values

import dbt
from dbt_common.contracts.constraints import ConstraintType, ModelLevelConstraint
from dbt_common.events.functions import warn_or_error
from dbt_common.exceptions import DbtInternalError, DbtRuntimeError, NotImplementedError
from dbt_common.utils import filter_null_values

from dbt.adapters.clickhouse.cache import ClickHouseRelationsCache
from dbt.adapters.clickhouse.column import ClickHouseColumn
from dbt.adapters.clickhouse.connections import ClickHouseConnectionManager
Expand Down Expand Up @@ -67,8 +78,8 @@ class ClickHouseAdapter(SQLAdapter):
}
)

def __init__(self, config):
BaseAdapter.__init__(self, config)
def __init__(self, config, mp_context: SpawnContext):
BaseAdapter.__init__(self, config, mp_context)
self.cache = ClickHouseRelationsCache()

@classmethod
Expand Down Expand Up @@ -313,21 +324,28 @@ def get_ch_database(self, schema: str):
except DbtRuntimeError:
return None

def get_catalog(self, manifest) -> Tuple["agate.Table", List[Exception]]:
from dbt.clients.agate_helper import empty_table
def get_catalog(
self,
relation_configs: Iterable[RelationConfig],
used_schemas: FrozenSet[Tuple[str, str]],
) -> Tuple["agate.Table", List[Exception]]:
from dbt_common.clients.agate_helper import empty_table

relations = self._get_catalog_relations(manifest)
relations = self._get_catalog_relations(relation_configs)
schemas = set(relation.schema for relation in relations)
if schemas:
catalog = self._get_one_catalog(InformationSchema(Path()), schemas, manifest)
catalog = self._get_one_catalog(InformationSchema(Path()), schemas, used_schemas)
else:
catalog = empty_table()
return catalog, []

def get_filtered_catalog(
self, manifest: Manifest, relations: Optional[Set[BaseRelation]] = None
self,
relation_configs: Iterable[RelationConfig],
used_schemas: FrozenSet[Tuple[str, str]],
relations: Optional[Set[BaseRelation]] = None,
):
catalog, exceptions = self.get_catalog(manifest)
catalog, exceptions = self.get_catalog(relation_configs, used_schemas)
if relations and catalog:
relation_map = {(r.schema, r.identifier) for r in relations}

Expand Down Expand Up @@ -512,8 +530,13 @@ def _expect_row_value(key: str, row: "agate.Row"):
return row[key]


def _catalog_filter_schemas(manifest: Manifest) -> Callable[["agate.Row"], bool]:
schemas = frozenset((None, s) for d, s in manifest.get_used_schemas())
def _catalog_filter_schemas(
used_schemas: FrozenSet[Tuple[str, str]]
) -> Callable[["agate.Row"], bool]:
"""Return a function that takes a row and decides if the row should be
included in the catalog output.
"""
schemas = frozenset((d.lower(), s.lower()) for d, s in used_schemas)

def test(row: "agate.Row") -> bool:
table_database = _expect_row_value('table_database', row)
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/clickhouse/logger.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from dbt.events import AdapterLogger
from dbt.adapters.events.logging import AdapterLogger

logger = AdapterLogger('dbt_clickhouse')
6 changes: 3 additions & 3 deletions dbt/adapters/clickhouse/nativeclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import clickhouse_driver
import pkg_resources
from clickhouse_driver.errors import NetworkError, SocketTimeoutError
from dbt.exceptions import DbtDatabaseError
from dbt.version import __version__ as dbt_version
from dbt.adapters.__about__ import version as dbt_adapters_version
from dbt_common.exceptions import DbtDatabaseError

from dbt.adapters.clickhouse import ClickHouseColumn, ClickHouseCredentials
from dbt.adapters.clickhouse.__version__ import version as dbt_clickhouse_version
Expand Down Expand Up @@ -61,7 +61,7 @@ def _create_client(self, credentials: ClickHouseCredentials):
port=credentials.port,
user=credentials.user,
password=credentials.password,
client_name=f'dbt/{dbt_version} dbt-clickhouse/{dbt_clickhouse_version} clickhouse-driver/{driver_version}',
client_name=f'dbt-adapters/{dbt_adapters_version} dbt-clickhouse/{dbt_clickhouse_version} clickhouse-driver/{driver_version}',
secure=credentials.secure,
verify=credentials.verify,
connect_timeout=credentials.connect_timeout,
Expand Down
79 changes: 38 additions & 41 deletions dbt/adapters/clickhouse/relation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Type

from dbt.adapters.base.relation import BaseRelation, Policy, Self
from dbt.contracts.graph.nodes import ManifestNode, SourceDefinition
from dbt.contracts.relation import HasQuoting, Path
from dbt.dataclass_schema import StrEnum
from dbt.exceptions import DbtRuntimeError
from dbt.utils import deep_merge, merge
from dbt.adapters.base.relation import BaseRelation, Path, Policy, Self
from dbt.adapters.contracts.relation import HasQuoting, RelationConfig
from dbt_common.dataclass_schema import StrEnum
from dbt_common.exceptions import DbtRuntimeError
from dbt_common.utils import deep_merge

from dbt.adapters.clickhouse.query import quote_identifier

NODE_TYPE_SOURCE = 'source'


@dataclass
class ClickHouseQuotePolicy(Policy):
Expand Down Expand Up @@ -85,51 +86,47 @@ def get_on_cluster(
return False

@classmethod
def create_from_source(cls: Type[Self], source: SourceDefinition, **kwargs: Any) -> Self:
source_quoting = source.quoting.to_dict(omit_none=True)
source_quoting.pop("column", None)
def create_from(
cls: Type[Self],
quoting: HasQuoting,
relation_config: RelationConfig,
**kwargs: Any,
) -> Self:
quote_policy = kwargs.pop("quote_policy", {})

config_quoting = relation_config.quoting_dict
config_quoting.pop("column", None)
# precedence: kwargs quoting > relation config quoting > base quoting > default quoting
quote_policy = deep_merge(
cls.get_default_quote_policy().to_dict(omit_none=True),
source_quoting,
kwargs.get("quote_policy", {}),
quoting.quoting,
config_quoting,
quote_policy,
)

# If the database is set, and the source schema is "defaulted" to the source.name, override the
# schema with the database instead, since that's presumably what's intended for clickhouse
schema = source.schema
if schema == source.source_name and source.database:
schema = source.database

return cls.create(
database='',
schema=schema,
identifier=source.identifier,
quote_policy=quote_policy,
**kwargs,
)

@classmethod
def create_from_node(
cls: Type[Self],
config: HasQuoting,
node: ManifestNode,
quote_policy: Optional[Dict[str, bool]] = None,
**kwargs: Any,
) -> Self:
if quote_policy is None:
quote_policy = {}

quote_policy = merge(config.quoting, quote_policy)
schema = relation_config.schema
can_on_cluster = None
# We placed a hardcoded const (instead of importing it from dbt-core) in order to decouple the packages
if relation_config.resource_type == NODE_TYPE_SOURCE:
if schema == relation_config.source_name and relation_config.database:
schema = relation_config.database

cluster = config.credentials.cluster if config.credentials.cluster else ''
materialized = node.get_materialization() if node.get_materialization() else ''
engine = node.config.get('engine') if node.config.get('engine') else ''
can_on_cluster = cls.get_on_cluster(cluster, materialized, engine)
else:
cluster = quoting.credentials.cluster if quoting.credentials.cluster else ''
materialized = (
relation_config.config.materialized if relation_config.config.materialized else ''
)
engine = (
relation_config.config.get('engine') if relation_config.config.get('engine') else ''
)
can_on_cluster = cls.get_on_cluster(cluster, materialized, engine)

return cls.create(
database='',
schema=node.schema,
identifier=node.alias,
schema=schema,
identifier=relation_config.identifier,
quote_policy=quote_policy,
can_on_cluster=can_on_cluster,
**kwargs,
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/clickhouse/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass

from dbt.exceptions import DbtRuntimeError
from dbt_common.exceptions import DbtRuntimeError


def compare_versions(v1: str, v2: str) -> int:
Expand Down
Loading

0 comments on commit 506bd18

Please sign in to comment.