Skip to content

Commit

Permalink
add test connection for tableu, dbt core & cloud source
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamjagtap639 committed Nov 28, 2023
1 parent 057617a commit 0364fec
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 44 deletions.
89 changes: 60 additions & 29 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.api.source import (
CapabilityReport,
SourceCapability,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.source.dbt.dbt_common import (
DBTColumn,
DBTCommonConfig,
Expand Down Expand Up @@ -177,7 +182,7 @@ class DBTCloudConfig(DBTCommonConfig):
@support_status(SupportStatus.INCUBATING)
@capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion")
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
class DBTCloudSource(DBTSourceBase):
class DBTCloudSource(DBTSourceBase, TestableSource):
"""
This source pulls dbt metadata directly from the dbt Cloud APIs.
Expand All @@ -199,6 +204,57 @@ def create(cls, config_dict, ctx):
config = DBTCloudConfig.parse_obj(config_dict)
return cls(config, ctx, "dbt")

@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
source_config = DBTCloudConfig.parse_obj_allow_extras(config_dict)
DBTCloudSource._send_graphql_query(
metadata_endpoint=source_config.metadata_endpoint,
token=source_config.token,
query=_DBT_GRAPHQL_QUERY.format(type="tests", fields="jobId"),
variables={
"jobId": source_config.job_id,
"runId": source_config.run_id,
},
)
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report

@staticmethod
def _send_graphql_query(
metadata_endpoint: str, token: str, query: str, variables: Dict
) -> Dict:
logger.debug(f"Sending GraphQL query to dbt Cloud: {query}")
response = requests.post(
metadata_endpoint,
json={
"query": query,
"variables": variables,
},
headers={
"Authorization": f"Bearer {token}",
"X-dbt-partner-source": "acryldatahub",
},
)

try:
res = response.json()
if "errors" in res:
raise ValueError(
f'Unable to fetch metadata from dbt Cloud: {res["errors"]}'
)
data = res["data"]
except JSONDecodeError as e:
response.raise_for_status()
raise e

return data

def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:
# TODO: In dbt Cloud, commands are scheduled as part of jobs, where
# each job can have multiple runs. We currently only fully support
Expand All @@ -213,6 +269,8 @@ def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:
for node_type, fields in _DBT_FIELDS_BY_TYPE.items():
logger.info(f"Fetching {node_type} from dbt Cloud")
data = self._send_graphql_query(
metadata_endpoint=self.config.metadata_endpoint,
token=self.config.token,
query=_DBT_GRAPHQL_QUERY.format(type=node_type, fields=fields),
variables={
"jobId": self.config.job_id,
Expand All @@ -232,33 +290,6 @@ def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:

return nodes, additional_metadata

def _send_graphql_query(self, query: str, variables: Dict) -> Dict:
logger.debug(f"Sending GraphQL query to dbt Cloud: {query}")
response = requests.post(
self.config.metadata_endpoint,
json={
"query": query,
"variables": variables,
},
headers={
"Authorization": f"Bearer {self.config.token}",
"X-dbt-partner-source": "acryldatahub",
},
)

try:
res = response.json()
if "errors" in res:
raise ValueError(
f'Unable to fetch metadata from dbt Cloud: {res["errors"]}'
)
data = res["data"]
except JSONDecodeError as e:
response.raise_for_status()
raise e

return data

def _parse_into_dbt_node(self, node: Dict) -> DBTNode:
key = node["uniqueId"]

Expand Down
56 changes: 43 additions & 13 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.api.source import (
CapabilityReport,
SourceCapability,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig
from datahub.ingestion.source.dbt.dbt_common import (
DBTColumn,
Expand Down Expand Up @@ -60,11 +65,6 @@ class DBTCoreConfig(DBTCommonConfig):

_github_info_deprecated = pydantic_renamed_field("github_info", "git_info")

@property
def s3_client(self):
assert self.aws_connection
return self.aws_connection.get_s3_client()

@validator("aws_connection")
def aws_connection_needed_if_s3_uris_present(
cls, aws_connection: Optional[AwsConnectionConfig], values: Dict, **kwargs: Any
Expand Down Expand Up @@ -363,7 +363,7 @@ def load_test_results(
@support_status(SupportStatus.CERTIFIED)
@capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion")
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
class DBTCoreSource(DBTSourceBase):
class DBTCoreSource(DBTSourceBase, TestableSource):
"""
The artifacts used by this source are:
- [dbt manifest file](https://docs.getdbt.com/reference/artifacts/manifest-json)
Expand All @@ -387,12 +387,34 @@ def create(cls, config_dict, ctx):
config = DBTCoreConfig.parse_obj(config_dict)
return cls(config, ctx, "dbt")

def load_file_as_json(self, uri: str) -> Any:
@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
source_config = DBTCoreConfig.parse_obj_allow_extras(config_dict)
DBTCoreSource.load_file_as_json(
source_config.manifest_path, source_config.aws_connection
)
DBTCoreSource.load_file_as_json(
source_config.catalog_path, source_config.aws_connection
)
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report

@staticmethod
def load_file_as_json(
uri: str, aws_connection: Optional[AwsConnectionConfig]
) -> Dict:
if re.match("^https?://", uri):
return json.loads(requests.get(uri).text)
elif re.match("^s3://", uri):
u = urlparse(uri)
response = self.config.s3_client.get_object(
assert aws_connection
response = aws_connection.get_s3_client().get_object(
Bucket=u.netloc, Key=u.path.lstrip("/")
)
return json.loads(response["Body"].read().decode("utf-8"))
Expand All @@ -410,12 +432,18 @@ def loadManifestAndCatalog(
Optional[str],
Optional[str],
]:
dbt_manifest_json = self.load_file_as_json(self.config.manifest_path)
dbt_manifest_json = self.load_file_as_json(
self.config.manifest_path, self.config.aws_connection
)

dbt_catalog_json = self.load_file_as_json(self.config.catalog_path)
dbt_catalog_json = self.load_file_as_json(
self.config.catalog_path, self.config.aws_connection
)

if self.config.sources_path is not None:
dbt_sources_json = self.load_file_as_json(self.config.sources_path)
dbt_sources_json = self.load_file_as_json(
self.config.sources_path, self.config.aws_connection
)
sources_results = dbt_sources_json["results"]
else:
sources_results = {}
Expand Down Expand Up @@ -478,7 +506,9 @@ def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:
# This will populate the test_results field on each test node.
all_nodes = load_test_results(
self.config,
self.load_file_as_json(self.config.test_results_path),
self.load_file_as_json(
self.config.test_results_path, self.config.aws_connection
),
all_nodes,
)

Expand Down
23 changes: 21 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import MetadataWorkUnitProcessor, Source
from datahub.ingestion.api.source import (
CapabilityReport,
MetadataWorkUnitProcessor,
Source,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source import tableau_constant as c
from datahub.ingestion.source.common.subtypes import (
Expand Down Expand Up @@ -456,7 +462,7 @@ class TableauSourceReport(StaleEntityRemovalSourceReport):
SourceCapability.LINEAGE_FINE,
"Enabled by default, configure using `extract_column_level_lineage`",
)
class TableauSource(StatefulIngestionSourceBase):
class TableauSource(StatefulIngestionSourceBase, TestableSource):
platform = "tableau"

def __hash__(self):
Expand Down Expand Up @@ -496,6 +502,19 @@ def __init__(

self._authenticate()

@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
source_config = TableauConfig.parse_obj_allow_extras(config_dict)
source_config.make_tableau_client()
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report

def close(self) -> None:
try:
if self.server is not None:
Expand Down

0 comments on commit 0364fec

Please sign in to comment.