diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py index 1866599fa21c67..b39e05a8db4de1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py @@ -75,8 +75,37 @@ def set_metadata_endpoint(cls, values: dict) -> dict: def infer_metadata_endpoint(access_url: str) -> Optional[str]: - # See https://docs.getdbt.com/docs/cloud/about-cloud/access-regions-ip-addresses#api-access-urls - # and https://docs.getdbt.com/docs/dbt-cloud-apis/discovery-querying#discovery-api-endpoints + """Infer the dbt metadata endpoint from the access URL. + + See https://docs.getdbt.com/docs/cloud/about-cloud/access-regions-ip-addresses#api-access-urls + and https://docs.getdbt.com/docs/dbt-cloud-apis/discovery-querying#discovery-api-endpoints + for more information. + + Args: + access_url: The dbt Cloud access URL. This is the URL of the dbt Cloud UI. + + Returns: + The metadata endpoint, or None if it couldn't be inferred. + + Examples: + # Standard multi-tenant deployments. + >>> infer_metadata_endpoint("https://cloud.getdbt.com") + 'https://metadata.cloud.getdbt.com/graphql' + + >>> infer_metadata_endpoint("https://au.dbt.com") + 'https://metadata.au.dbt.com/graphql' + + >>> infer_metadata_endpoint("https://emea.dbt.com") + 'https://metadata.emea.dbt.com/graphql' + + # Cell-based deployment. + >>> infer_metadata_endpoint("https://prefix.us1.dbt.com") + 'https://prefix.metadata.us1.dbt.com/graphql' + + # Test with an "internal" URL. + >>> infer_metadata_endpoint("http://dbt.corp.internal") + 'http://metadata.dbt.corp.internal/graphql' + """ try: parsed_uri = urlparse(access_url) @@ -86,13 +115,18 @@ def infer_metadata_endpoint(access_url: str) -> Optional[str]: logger.debug(f"Unable to parse access URL {access_url}: {e}", exc_info=e) return None - if parsed_uri.hostname.endswith(".dbt.com"): + if parsed_uri.hostname.endswith(".getdbt.com") or parsed_uri.hostname in { + # Two special cases of multi-tenant deployments that use the dbt.com domain + # instead of getdbt.com. + "au.dbt.com", + "emea.dbt.com", + }: + return f"{parsed_uri.scheme}://metadata.{parsed_uri.netloc}/graphql" + elif parsed_uri.hostname.endswith(".dbt.com"): # For cell-based deployments. # prefix.region.dbt.com -> prefix.metadata.region.dbt.com hostname_parts = parsed_uri.hostname.split(".", maxsplit=1) return f"{parsed_uri.scheme}://{hostname_parts[0]}.metadata.{hostname_parts[1]}/graphql" - elif parsed_uri.hostname.endswith(".getdbt.com"): - return f"{parsed_uri.scheme}://metadata.{parsed_uri.netloc}/graphql" else: # The self-hosted variants also have the metadata. prefix. return f"{parsed_uri.scheme}://metadata.{parsed_uri.netloc}/graphql" @@ -403,10 +437,12 @@ def _parse_into_dbt_node(self, node: Dict) -> DBTNode: columns = [] if "columns" in node and node["columns"] is not None: # columns will be empty for ephemeral models - columns = [ - self._parse_into_dbt_column(column) - for column in sorted(node["columns"], key=lambda c: c["index"]) - ] + columns = list( + sorted( + [self._parse_into_dbt_column(column) for column in node["columns"]], + key=lambda c: c.index, + ) + ) test_info = None test_result = None @@ -494,7 +530,10 @@ def _parse_into_dbt_column( name=column["name"], comment=column.get("comment", ""), description=column["description"], - index=column["index"], + # For some reason, the index sometimes comes back as None from the dbt Cloud API. + # In that case, we just assume that the column is at the end of the table by + # assigning it a very large index. + index=column["index"] if column["index"] is not None else 10**6, data_type=column["type"], meta=column["meta"], tags=column["tags"], diff --git a/metadata-ingestion/tests/unit/test_dbt_source.py b/metadata-ingestion/tests/unit/test_dbt_source.py index 90ff78b16f652b..7d01ecd034523d 100644 --- a/metadata-ingestion/tests/unit/test_dbt_source.py +++ b/metadata-ingestion/tests/unit/test_dbt_source.py @@ -1,3 +1,4 @@ +import doctest from datetime import timedelta from typing import Dict, List, Union from unittest import mock @@ -7,10 +8,8 @@ from datahub.emitter import mce_builder from datahub.ingestion.api.common import PipelineContext -from datahub.ingestion.source.dbt.dbt_cloud import ( - DBTCloudConfig, - infer_metadata_endpoint, -) +from datahub.ingestion.source.dbt import dbt_cloud +from datahub.ingestion.source.dbt.dbt_cloud import DBTCloudConfig from datahub.ingestion.source.dbt.dbt_core import ( DBTCoreConfig, DBTCoreSource, @@ -401,17 +400,7 @@ def test_dbt_cloud_config_with_defined_metadata_endpoint(): def test_infer_metadata_endpoint() -> None: - assert ( - infer_metadata_endpoint("https://cloud.getdbt.com") - == "https://metadata.cloud.getdbt.com/graphql" - ) - assert ( - infer_metadata_endpoint("https://prefix.us1.dbt.com") - == "https://prefix.metadata.us1.dbt.com/graphql" - ) - assert ( - infer_metadata_endpoint("http://dbt.corp.internal") - ) == "http://metadata.dbt.corp.internal/graphql" + assert doctest.testmod(dbt_cloud, raise_on_error=True).attempted > 0 def test_dbt_time_parsing() -> None: