Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #464

Closed
wants to merge 19 commits into from
Closed

Dev #464

Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
migrated work for fabric connection
cody-scott committed Dec 18, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit f3ded8bda1cd99f6173a34f1f4f2ef734b2fe265
201 changes: 41 additions & 160 deletions dbt/adapters/sqlserver/sql_server_adapter.py
Original file line number Diff line number Diff line change
@@ -1,175 +1,56 @@
from typing import List, Optional

import agate
from dbt.adapters.base.relation import BaseRelation
from dbt.adapters.cache import _make_ref_key_msg
from dbt.adapters.sql import SQLAdapter
from dbt.adapters.sql.impl import CREATE_SCHEMA_MACRO_NAME
from dbt.events.functions import fire_event
from dbt.events.types import SchemaCreation

from dbt.adapters.sqlserver.sql_server_column import SQLServerColumn
from dbt.adapters.sqlserver.sql_server_configs import SQLServerConfigs
from dbt.adapters.sqlserver.sql_server_connection_manager import SQLServerConnectionManager
# from dbt.adapters.capability import Capability, CapabilityDict, CapabilitySupport, Support

# https://github.com/microsoft/dbt-fabric/blob/main/dbt/adapters/fabric/fabric_adapter.py
from dbt.adapters.fabric import FabricAdapter

class SQLServerAdapter(SQLAdapter):

class SQLServerAdapter(FabricAdapter):
ConnectionManager = SQLServerConnectionManager
Column = SQLServerColumn
AdapterSpecificConfigs = SQLServerConfigs

def create_schema(self, relation: BaseRelation) -> None:
relation = relation.without_identifier()
fire_event(SchemaCreation(relation=_make_ref_key_msg(relation)))
macro_name = CREATE_SCHEMA_MACRO_NAME
kwargs = {
"relation": relation,
}

if self.config.credentials.schema_authorization:
kwargs["schema_authorization"] = self.config.credentials.schema_authorization
macro_name = "sqlserver__create_schema_with_authorization"

self.execute_macro(macro_name, kwargs=kwargs)
self.commit_if_has_connection()
# _capabilities: CapabilityDict = CapabilityDict(
# {
# Capability.SchemaMetadataByRelations: CapabilitySupport(support=Support.Full),
# Capability.TableLastModifiedMetadata: CapabilitySupport(support=Support.Full),
# }
# )

# region - these are implement in fabric but not in sqlserver
# _capabilities: CapabilityDict = CapabilityDict(
# {
# Capability.SchemaMetadataByRelations: CapabilitySupport(support=Support.Full),
# Capability.TableLastModifiedMetadata: CapabilitySupport(support=Support.Full),
# }
# )
# CONSTRAINT_SUPPORT = {
# ConstraintType.check: ConstraintSupport.NOT_SUPPORTED,
# ConstraintType.not_null: ConstraintSupport.ENFORCED,
# ConstraintType.unique: ConstraintSupport.ENFORCED,
# ConstraintType.primary_key: ConstraintSupport.ENFORCED,
# ConstraintType.foreign_key: ConstraintSupport.ENFORCED,
# }

# @available.parse(lambda *a, **k: [])
# def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]:
# """Get a list of the Columns with names and data types from the given sql."""
# _, cursor = self.connections.add_select_query(sql)

# columns = [
# self.Column.create(
# column_name, self.connections.data_type_code_to_name(column_type_code)
# )
# # https://peps.python.org/pep-0249/#description
# for column_name, column_type_code, *_ in cursor.description
# ]
# return columns
# endregion

@classmethod
def date_function(cls):
return "getdate()"

@classmethod
def convert_text_type(cls, agate_table, col_idx):
column = agate_table.columns[col_idx]
# see https://github.com/fishtown-analytics/dbt/pull/2255
lens = [len(d.encode("utf-8")) for d in column.values_without_nulls()]
max_len = max(lens) if lens else 64
length = max_len if max_len > 16 else 16
return "varchar({})".format(length)

@classmethod
def convert_datetime_type(cls, agate_table, col_idx):
return "datetime"

@classmethod
def convert_boolean_type(cls, agate_table, col_idx):
return "bit"

@classmethod
def convert_number_type(cls, agate_table, col_idx):
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "float" if decimals else "int"

@classmethod
def convert_time_type(cls, agate_table, col_idx):
return "datetime"

# Methods used in adapter tests
def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str:
# note: 'interval' is not supported for T-SQL
# for backwards compatibility, we're compelled to set some sort of
# default. A lot of searching has lead me to believe that the
# '+ interval' syntax used in postgres/redshift is relatively common
# and might even be the SQL standard's intention.
return f"DATEADD({interval},{number},{add_to})"

def string_add_sql(
self,
add_to: str,
value: str,
location="append",
) -> str:
"""
`+` is T-SQL's string concatenation operator
"""
if location == "append":
return f"{add_to} + '{value}'"
elif location == "prepend":
return f"'{value}' + {add_to}"
else:
raise ValueError(f'Got an unexpected location value of "{location}"')

def get_rows_different_sql(
self,
relation_a: BaseRelation,
relation_b: BaseRelation,
column_names: Optional[List[str]] = None,
except_operator: str = "EXCEPT",
) -> str:
"""
note: using is not supported on Synapse so COLUMNS_EQUAL_SQL is adjsuted
Generate SQL for a query that returns a single row with a two
columns: the number of rows that are different between the two
relations and the number of mismatched rows.
"""
# This method only really exists for test reasons.
names: List[str]
if column_names is None:
columns = self.get_columns_in_relation(relation_a)
names = sorted((self.quote(c.name) for c in columns))
else:
names = sorted((self.quote(n) for n in column_names))
columns_csv = ", ".join(names)

sql = COLUMNS_EQUAL_SQL.format(
columns=columns_csv,
relation_a=str(relation_a),
relation_b=str(relation_b),
except_op=except_operator,
)

return sql

def valid_incremental_strategies(self):
"""The set of standard builtin strategies which this adapter supports out-of-the-box.
Not used to validate custom strategies defined by end users.
"""
return ["append", "delete+insert", "merge", "insert_overwrite"]

# This is for use in the test suite
def run_sql_for_tests(self, sql, fetch, conn):
cursor = conn.handle.cursor()
try:
cursor.execute(sql)
if not fetch:
conn.handle.commit()
if fetch == "one":
return cursor.fetchone()
elif fetch == "all":
return cursor.fetchall()
else:
return
except BaseException:
if conn.handle and not getattr(conn.handle, "closed", True):
conn.handle.rollback()
raise
finally:
conn.transaction_open = False


COLUMNS_EQUAL_SQL = """
with diff_count as (
SELECT
1 as id,
COUNT(*) as num_missing FROM (
(SELECT {columns} FROM {relation_a} {except_op}
SELECT {columns} FROM {relation_b})
UNION ALL
(SELECT {columns} FROM {relation_b} {except_op}
SELECT {columns} FROM {relation_a})
) as a
), table_a as (
SELECT COUNT(*) as num_rows FROM {relation_a}
), table_b as (
SELECT COUNT(*) as num_rows FROM {relation_b}
), row_count_diff as (
select
1 as id,
table_a.num_rows - table_b.num_rows as difference
from table_a, table_b
)
select
row_count_diff.difference as row_count_difference,
diff_count.num_missing as num_mismatched
from row_count_diff
join diff_count on row_count_diff.id = diff_count.id
""".strip()
21 changes: 3 additions & 18 deletions dbt/adapters/sqlserver/sql_server_column.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,5 @@
from typing import Any, ClassVar, Dict

from dbt.adapters.base import Column
from dbt.adapters.fabric import FabricColumn


class SQLServerColumn(Column):
TYPE_LABELS: ClassVar[Dict[str, str]] = {
"STRING": "VARCHAR(MAX)",
"TIMESTAMP": "DATETIMEOFFSET",
"FLOAT": "FLOAT",
"INTEGER": "INT",
"BOOLEAN": "BIT",
}

@classmethod
def string_type(cls, size: int) -> str:
return f"varchar({size if size > 0 else 'MAX'})"

def literal(self, value: Any) -> str:
return "cast('{}' as {})".format(value, self.data_type)
class SQLServerColumn(FabricColumn):
...
8 changes: 3 additions & 5 deletions dbt/adapters/sqlserver/sql_server_configs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from dataclasses import dataclass
from typing import Optional

from dbt.adapters.protocol import AdapterConfig

from dbt.adapters.fabric import FabricConfigs

@dataclass
class SQLServerConfigs(AdapterConfig):
auto_provision_aad_principals: Optional[bool] = False
class SQLServerConfigs(FabricConfigs):
...
Loading