diff --git a/CHANGELOG.md b/CHANGELOG.md index 21fd806..3216875 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `DbtCloudMetadataClient` and `get_metadata_client` method to `DbtCloudCredentials` to enable interaction with the dbt Cloud metadata API - [#109](https://github.com/PrefectHQ/prefect-dbt/pull/109) +- Added `get_client` method to `DbtCloudCredentials` - [#109](https://github.com/PrefectHQ/prefect-dbt/pull/109) + ### Changed ### Deprecated diff --git a/prefect_dbt/cloud/clients.py b/prefect_dbt/cloud/clients.py index d3d75d1..54a5ce4 100644 --- a/prefect_dbt/cloud/clients.py +++ b/prefect_dbt/cloud/clients.py @@ -3,6 +3,7 @@ import prefect from httpx import AsyncClient, Response +from sgqlc.endpoint.http import HTTPEndpoint from typing_extensions import Literal from prefect_dbt.cloud.models import TriggerJobRunOptions @@ -18,7 +19,7 @@ class DbtCloudAdministrativeClient: domain: Domain at which the dbt Cloud API is hosted. """ - def __init__(self, api_key: str, account_id: int, domain: str): + def __init__(self, api_key: str, account_id: int, domain: str = "cloud.getdbt.com"): self._closed = False self._started = False @@ -188,3 +189,46 @@ async def __aenter__(self): async def __aexit__(self, *exc): self._closed = True await self._admin_client.__aexit__() + + +class DbtCloudMetadataClient: + """ + Client for interacting with the dbt cloud Administrative API. + + Args: + api_key: API key to authenticate with the dbt Cloud administrative API. + account_id: ID of dbt Cloud account with which to interact. + domain: Domain at which the dbt Cloud API is hosted. + """ + + def __init__(self, api_key: str, domain: str = "metadata.cloud.getdbt.com"): + self._http_endpoint = HTTPEndpoint( + base_headers={ + "Authorization": f"Bearer {api_key}", + "user-agent": f"prefect-{prefect.__version__}", + "content-type": "application/json", + }, + url=f"https://{domain}/graphql", + ) + + def query( + self, + query: str, + variables: Optional[Dict] = None, + operation_name: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Run a GraphQL query against the dbt Cloud metadata API. + + Args: + query: The GraphQL query to run. + variables: The values of any variables defined in the GraphQL query. + operation_name: The name of the operation to run if multiple operations + are defined in the provided query. + + Returns: + The result of the GraphQL query. + """ + return self._http_endpoint( + query=query, variables=variables, operation_name=operation_name + ) diff --git a/prefect_dbt/cloud/credentials.py b/prefect_dbt/cloud/credentials.py index 5759144..f1708c1 100644 --- a/prefect_dbt/cloud/credentials.py +++ b/prefect_dbt/cloud/credentials.py @@ -1,11 +1,17 @@ """Module containing credentials for interacting with dbt Cloud""" -from prefect.blocks.core import Block -from pydantic import SecretStr +from typing import Union -from prefect_dbt.cloud.clients import DbtCloudAdministrativeClient +from prefect.blocks.abstract import CredentialsBlock +from pydantic import Field, SecretStr +from typing_extensions import Literal +from prefect_dbt.cloud.clients import ( + DbtCloudAdministrativeClient, + DbtCloudMetadataClient, +) -class DbtCloudCredentials(Block): + +class DbtCloudCredentials(CredentialsBlock): """ Credentials block for credential use across dbt Cloud tasks and flows. @@ -55,17 +61,123 @@ def trigger_dbt_cloud_job_run_flow(): _block_type_name = "dbt Cloud Credentials" _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/5zE9lxfzBHjw3tnEup4wWL/9a001902ed43a84c6c96d23b24622e19/dbt-bit_tm.png?h=250" # noqa - api_key: SecretStr - account_id: int - domain: str = "cloud.getdbt.com" - - def get_administrative_client(self): + api_key: SecretStr = Field( + default=..., + title="API Key", + description="A dbt Cloud API key to use for authentication.", + ) + account_id: int = Field( + default=..., title="Account ID", description="The ID of your dbt Cloud account." + ) + domain: str = Field( + default="cloud.getdbt.com", + description="The base domain of your dbt Cloud instance.", + ) + + def get_administrative_client(self) -> DbtCloudAdministrativeClient: """ Returns a newly instantiated client for working with the dbt Cloud administrative API. + + Returns: + An authenticated dbt Cloud administrative API client. """ return DbtCloudAdministrativeClient( api_key=self.api_key.get_secret_value(), account_id=self.account_id, domain=self.domain, ) + + def get_metadata_client(self) -> DbtCloudMetadataClient: + """ + Returns a newly instantiated client for working with the dbt Cloud + metadata API. + + Example: + Sending queries via the returned metadata client: + ```python + from prefect_dbt import DbtCloudCredentials + + credentials_block = DbtCloudCredentials.load("test-account") + metadata_client = credentials_block.get_metadata_client() + query = \"\"\" + { + metrics(jobId: 123) { + uniqueId + name + packageName + tags + label + runId + description + type + sql + timestamp + timeGrains + dimensions + meta + resourceType + filters { + field + operator + value + } + model { + name + } + } + } + \"\"\" + metadata_client.query(query) + # Result: + # { + # "data": { + # "metrics": [ + # { + # "uniqueId": "metric.tpch.total_revenue", + # "name": "total_revenue", + # "packageName": "tpch", + # "tags": [], + # "label": "Total Revenue ($)", + # "runId": 108952046, + # "description": "", + # "type": "sum", + # "sql": "net_item_sales_amount", + # "timestamp": "order_date", + # "timeGrains": ["day", "week", "month"], + # "dimensions": ["status_code", "priority_code"], + # "meta": {}, + # "resourceType": "metric", + # "filters": [], + # "model": { "name": "fct_orders" } + # } + # ] + # } + # } + ``` + + Returns: + An authenticated dbt Cloud metadata API client. + """ + return DbtCloudMetadataClient( + api_key=self.api_key.get_secret_value(), + domain=f"metadata.{self.domain}", + ) + + def get_client( + self, client_type: Literal["administrative", "metadata"] + ) -> Union[DbtCloudAdministrativeClient, DbtCloudMetadataClient]: + """ + Returns a newly instantiated client for working with the dbt Cloud API. + + Args: + client_type: Type of client to return. Accepts either 'administrative' + or 'metadata'. + + Returns: + The authenticated client of the requested type. + """ + get_client_method = getattr(self, f"get_{client_type}_client", None) + if get_client_method is None: + raise ValueError(f"'{client_type}' is not a supported client type.") + return get_client_method() diff --git a/requirements.txt b/requirements.txt index 72476df..b91a602 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ prefect>=2.7.2 prefect_shell>=0.1.0 +sgqlc>=16.0.0 diff --git a/tests/cloud/test_clients.py b/tests/cloud/test_clients.py new file mode 100644 index 0000000..6f72182 --- /dev/null +++ b/tests/cloud/test_clients.py @@ -0,0 +1,67 @@ +import json +from unittest.mock import MagicMock + +from prefect_dbt.cloud.clients import DbtCloudMetadataClient + + +def test_metadata_client_query(monkeypatch): + mock_response = { + "data": { + "metrics": [ + { + "uniqueId": "metric.tpch.total_revenue", + "name": "total_revenue", + "packageName": "tpch", + "tags": [], + "label": "Total Revenue ($)", + "runId": 108952046, + "description": "", + "type": "sum", + "sql": "net_item_sales_amount", + "timestamp": "order_date", + "timeGrains": ["day", "week", "month"], + "dimensions": ["status_code", "priority_code"], + "meta": {}, + "resourceType": "metric", + "filters": [], + "model": {"name": "fct_orders"}, + } + ] + } + } + urlopen_mock = MagicMock() + urlopen_mock.getcode.return_value = 200 + urlopen_mock.return_value = urlopen_mock + urlopen_mock.read.return_value = json.dumps(mock_response).encode() + urlopen_mock.__enter__.return_value = urlopen_mock + monkeypatch.setattr("urllib.request.urlopen", urlopen_mock) + dbt_cloud_metadata_client = DbtCloudMetadataClient(api_key="my_api_key") + mock_query = """ + { + metrics(jobId: 123) { + uniqueId + name + packageName + tags + label + runId + description + type + sql + timestamp + timeGrains + dimensions + meta + resourceType + filters { + field + operator + value + } + model { + name + } + } + } + """ + assert dbt_cloud_metadata_client.query(mock_query) == mock_response diff --git a/tests/cloud/test_cloud_credentials.py b/tests/cloud/test_cloud_credentials.py new file mode 100644 index 0000000..e25deaf --- /dev/null +++ b/tests/cloud/test_cloud_credentials.py @@ -0,0 +1,35 @@ +import pytest + +from prefect_dbt import DbtCloudCredentials +from prefect_dbt.cloud.clients import ( + DbtCloudAdministrativeClient, + DbtCloudMetadataClient, +) + + +@pytest.fixture +def dbt_cloud_credentials(): + return DbtCloudCredentials(api_key="my_api_key", account_id=123456789) + + +def test_get_administrative_client(dbt_cloud_credentials: DbtCloudCredentials): + assert isinstance( + dbt_cloud_credentials.get_administrative_client(), DbtCloudAdministrativeClient + ) + + +def test_get_metadata_client(dbt_cloud_credentials: DbtCloudCredentials): + assert isinstance( + dbt_cloud_credentials.get_metadata_client(), DbtCloudMetadataClient + ) + + +def test_get_client(dbt_cloud_credentials: DbtCloudCredentials): + assert isinstance( + dbt_cloud_credentials.get_client("administrative"), DbtCloudAdministrativeClient + ) + assert isinstance( + dbt_cloud_credentials.get_client("metadata"), DbtCloudMetadataClient + ) + with pytest.raises(ValueError, match="'blorp' is not a supported client type"): + dbt_cloud_credentials.get_client("blorp")