Skip to content

Commit

Permalink
feat(ingest): improve config loading helpers (#9477)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Jan 3, 2024
1 parent 186b6f9 commit f06b5c7
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 20 deletions.
48 changes: 29 additions & 19 deletions metadata-ingestion/src/datahub/configuration/config_loader.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,59 @@
import io
import os
import pathlib
import re
import sys
import tempfile
import unittest.mock
from typing import Any, Dict, Set, Union
from typing import Any, Dict, Mapping, Optional, Set, Union
from urllib import parse

import requests
from expandvars import UnboundVariable, expandvars
from expandvars import UnboundVariable, expand

from datahub.configuration.common import ConfigurationError, ConfigurationMechanism
from datahub.configuration.json_loader import JsonConfigurationMechanism
from datahub.configuration.toml import TomlConfigurationMechanism
from datahub.configuration.yaml import YamlConfigurationMechanism

Environ = Mapping[str, str]

def _resolve_element(element: str) -> str:

def _resolve_element(element: str, environ: Environ) -> str:
if re.search(r"(\$\{).+(\})", element):
return expandvars(element, nounset=True)
return expand(element, nounset=True, environ=environ)
elif element.startswith("$"):
try:
return expandvars(element, nounset=True)
return expand(element, nounset=True, environ=environ)
except UnboundVariable:
return element
else:
return element


def _resolve_list(ele_list: list) -> list:
def _resolve_list(ele_list: list, environ: Environ) -> list:
new_v: list = []
for ele in ele_list:
if isinstance(ele, str):
new_v.append(_resolve_element(ele))
new_v.append(_resolve_element(ele, environ=environ))
elif isinstance(ele, list):
new_v.append(_resolve_list(ele))
new_v.append(_resolve_list(ele, environ=environ))
elif isinstance(ele, dict):
new_v.append(resolve_env_variables(ele))
new_v.append(resolve_env_variables(ele, environ=environ))
else:
new_v.append(ele)
return new_v


def resolve_env_variables(config: dict) -> dict:
def resolve_env_variables(config: dict, environ: Environ) -> dict:
new_dict: Dict[Any, Any] = {}
for k, v in config.items():
if isinstance(v, dict):
new_dict[k] = resolve_env_variables(v)
new_dict[k] = resolve_env_variables(v, environ=environ)
elif isinstance(v, list):
new_dict[k] = _resolve_list(v)
new_dict[k] = _resolve_list(v, environ=environ)
elif isinstance(v, str):
new_dict[k] = _resolve_element(v)
new_dict[k] = _resolve_element(v, environ=environ)
else:
new_dict[k] = v
return new_dict
Expand All @@ -60,13 +63,20 @@ def list_referenced_env_variables(config: dict) -> Set[str]:
# This is a bit of a hack, but expandvars does a bunch of escaping
# and other logic that we don't want to duplicate here.

with unittest.mock.patch("expandvars.getenv") as mock_getenv:
mock_getenv.return_value = "mocked_value"
vars = set()

def mock_get_env(key: str, default: Optional[str] = None) -> str:
vars.add(key)
if default is not None:
return default
return "mocked_value"

mock = unittest.mock.MagicMock()
mock.get.side_effect = mock_get_env

resolve_env_variables(config)
resolve_env_variables(config, environ=mock)

calls = mock_getenv.mock_calls
return set([call[1][0] for call in calls])
return vars


WRITE_TO_FILE_DIRECTIVE_PREFIX = "__DATAHUB_TO_FILE_"
Expand Down Expand Up @@ -147,7 +157,7 @@ def load_config_file(

config = raw_config.copy()
if resolve_env_vars:
config = resolve_env_variables(config)
config = resolve_env_variables(config, environ=os.environ)
if process_directives:
config = _process_directives(config)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import logging
import os
import uuid
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -112,7 +113,7 @@ def default_sink_is_datahub_rest(cls, values: Dict[str, Any]) -> Any:
}
# resolve env variables if present
default_sink_config = config_loader.resolve_env_variables(
default_sink_config
default_sink_config, environ=os.environ
)
values["sink"] = default_sink_config

Expand Down
Empty file.
66 changes: 66 additions & 0 deletions metadata-ingestion/src/datahub/secret/datahub_secret_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import logging
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, validator

from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from datahub.secret.datahub_secrets_client import DataHubSecretsClient
from datahub.secret.secret_store import SecretStore

logger = logging.getLogger(__name__)


class DataHubSecretStoreConfig(BaseModel):
graph_client: Optional[DataHubGraph] = None
graph_client_config: Optional[DatahubClientConfig] = None

class Config:
arbitrary_types_allowed = True

@validator("graph_client")
def check_graph_connection(cls, v: DataHubGraph) -> DataHubGraph:
if v is not None:
v.test_connection()
return v


# An implementation of SecretStore that fetches secrets from DataHub
class DataHubSecretStore(SecretStore):
# Client for fetching secrets from DataHub GraphQL API
client: DataHubSecretsClient

def __init__(self, config: DataHubSecretStoreConfig):
# Attempt to establish an outbound connection to DataHub and create a client.
if config.graph_client is not None:
self.client = DataHubSecretsClient(graph=config.graph_client)
elif config.graph_client_config is not None:
graph = DataHubGraph(config.graph_client_config)
self.client = DataHubSecretsClient(graph)
else:
raise Exception(
"Invalid configuration provided: unable to construct DataHub Graph Client."
)

def get_secret_values(self, secret_names: List[str]) -> Dict[str, Union[str, None]]:
# Fetch the secret from DataHub, using the credentials provided in the configuration.
# Use the GraphQL API.
try:
return self.client.get_secret_values(secret_names)
except Exception:
# Failed to resolve secrets, return empty.
logger.exception(
f"Caught exception while attempting to fetch secrets from DataHub. Secret names: {secret_names}"
)
return {}

def get_secret_value(self, secret_name: str) -> Union[str, None]:
secret_value_dict = self.get_secret_values([secret_name])
return secret_value_dict.get(secret_name)

def get_id(self) -> str:
return "datahub"

@classmethod
def create(cls, config: Any) -> "DataHubSecretStore":
config = DataHubSecretStoreConfig.parse_obj(config)
return cls(config)
45 changes: 45 additions & 0 deletions metadata-ingestion/src/datahub/secret/datahub_secrets_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Dict, List, Optional

from datahub.ingestion.graph.client import DataHubGraph


class DataHubSecretsClient:
"""Class used to fetch secrets from DataHub."""

graph: DataHubGraph

def __init__(self, graph: DataHubGraph):
self.graph = graph

def get_secret_values(self, secret_names: List[str]) -> Dict[str, Optional[str]]:
if len(secret_names) == 0:
return {}

request_json = {
"query": """query getSecretValues($input: GetSecretValuesInput!) {\n
getSecretValues(input: $input) {\n
name\n
value\n
}\n
}""",
"variables": {"input": {"secrets": secret_names}},
}
# TODO: Use graph.execute_graphql() instead.

# Fetch secrets using GraphQL API f
response = self.graph._session.post(
f"{self.graph.config.server}/api/graphql", json=request_json
)
response.raise_for_status()

# Verify response
res_data = response.json()
if "errors" in res_data:
raise Exception("Failed to retrieve secrets from DataHub.")

# Convert list of name, value secret pairs into a dict and return
secret_value_list = res_data["data"]["getSecretValues"]
secret_value_dict = dict()
for secret_value in secret_value_list:
secret_value_dict[secret_value["name"]] = secret_value["value"]
return secret_value_dict
59 changes: 59 additions & 0 deletions metadata-ingestion/src/datahub/secret/secret_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import json
import logging
from typing import List

from datahub.configuration.config_loader import (
list_referenced_env_variables,
resolve_env_variables,
)
from datahub.secret.secret_store import SecretStore

logger = logging.getLogger(__name__)


def resolve_secrets(secret_names: List[str], secret_stores: List[SecretStore]) -> dict:
# Attempt to resolve secret using by checking each configured secret store.
final_secret_values = dict({})

for secret_store in secret_stores:
try:
# Retrieve secret values from the store.
secret_values_dict = secret_store.get_secret_values(secret_names)
# Overlay secret values from each store, if not None.
for secret_name, secret_value in secret_values_dict.items():
if secret_value is not None:
# HACK: We previously, incorrectly replaced newline characters with
# a r'\n' string. This was a lossy conversion, since we can no longer
# distinguish between a newline character and the literal '\n' in
# the secret value. For now, we assume that all r'\n' strings are
# actually newline characters. This will break if a secret value
# genuinely contains the string r'\n'.
# Once this PR https://github.com/datahub-project/datahub/pull/9484
# has baked for a while, we should be able to remove this hack.
# TODO: This logic should live in the DataHub secret client/store,
# not the general secret resolution logic.
secret_value = secret_value.replace(r"\n", "\n")

final_secret_values[secret_name] = secret_value
except Exception:
logger.exception(
f"Failed to fetch secret values from secret store with id {secret_store.get_id()}"
)
return final_secret_values


def resolve_recipe(recipe: str, secret_stores: List[SecretStore]) -> dict:
json_recipe_raw = json.loads(recipe)

# 1. Extract all secrets needing resolved.
secrets_to_resolve = list_referenced_env_variables(json_recipe_raw)

# 2. Resolve secret values
secret_values_dict = resolve_secrets(list(secrets_to_resolve), secret_stores)

# 3. Substitute secrets into recipe file
json_recipe_resolved = resolve_env_variables(
json_recipe_raw, environ=secret_values_dict
)

return json_recipe_resolved
43 changes: 43 additions & 0 deletions metadata-ingestion/src/datahub/secret/secret_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from abc import abstractmethod
from typing import Dict, List, Optional

from datahub.configuration.common import ConfigModel


class SecretStoreConfig(ConfigModel):
type: str
config: Dict


class SecretStore:
"""
Abstract base class for a Secret Store, or a class that resolves "secret" values by name.
"""

@classmethod
@abstractmethod
def create(cls, configs: dict) -> "SecretStore":
pass

@abstractmethod
def get_secret_values(self, secret_names: List[str]) -> Dict[str, Optional[str]]:
"""
Attempt to fetch a group of secrets, returning a Dictionary of the secret of None if one
cannot be resolved by the store.
"""

def get_secret_value(self, secret_name: str) -> Optional[str]:
secret_value_dict = self.get_secret_values([secret_name])
return secret_value_dict.get(secret_name)

@abstractmethod
def get_id(self) -> str:
"""
Get a unique name or id associated with the Secret Store.
"""

@abstractmethod
def close(self) -> None:
"""
Wraps up the task
"""

0 comments on commit f06b5c7

Please sign in to comment.