diff --git a/metadata-ingestion/src/datahub/cli/cli_utils.py b/metadata-ingestion/src/datahub/cli/cli_utils.py index 946ce0df74c3d..18f93b71a44f8 100644 --- a/metadata-ingestion/src/datahub/cli/cli_utils.py +++ b/metadata-ingestion/src/datahub/cli/cli_utils.py @@ -1,5 +1,6 @@ import json import logging +import os import os.path import sys import typing @@ -66,6 +67,8 @@ ENV_METADATA_HOST = "DATAHUB_GMS_HOST" ENV_METADATA_TOKEN = "DATAHUB_GMS_TOKEN" +config_override: Dict = {} + class GmsConfig(BaseModel): server: str @@ -76,6 +79,13 @@ class DatahubConfig(BaseModel): gms: GmsConfig +def set_env_variables_override_config(host: str, token: Optional[str]) -> None: + """Should be used to override the config when using rest emitter""" + config_override[ENV_METADATA_HOST] = host + if token is not None: + config_override[ENV_METADATA_TOKEN] = token + + def write_datahub_config(host: str, token: Optional[str]) -> None: config = { "gms": { @@ -137,29 +147,30 @@ def guess_entity_type(urn: str) -> str: return urn.split(":")[2] -def get_token(): - _, gms_token_env = get_details_from_env() - if should_skip_config(): +def get_host_and_token(): + gms_host_env, gms_token_env = get_details_from_env() + if len(config_override.keys()) > 0: + gms_host = config_override.get(ENV_METADATA_HOST) + gms_token = config_override.get(ENV_METADATA_TOKEN) + elif should_skip_config(): + gms_host = gms_host_env gms_token = gms_token_env else: ensure_datahub_config() - _, gms_token_conf = get_details_from_config() + gms_host_conf, gms_token_conf = get_details_from_config() + gms_host = first_non_null([gms_host_env, gms_host_conf]) gms_token = first_non_null([gms_token_env, gms_token_conf]) - return gms_token + return gms_host, gms_token + + +def get_token(): + return get_host_and_token()[1] def get_session_and_host(): session = requests.Session() - gms_host_env, gms_token_env = get_details_from_env() - if should_skip_config(): - gms_host = gms_host_env - gms_token = gms_token_env - else: - ensure_datahub_config() - gms_host_conf, gms_token_conf = get_details_from_config() - gms_host = first_non_null([gms_host_env, gms_host_conf]) - gms_token = first_non_null([gms_token_env, gms_token_conf]) + gms_host, gms_token = get_host_and_token() if gms_host is None or gms_host.strip() == "": log.error( diff --git a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py index 2755751856c68..2b7e51d0607dd 100644 --- a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py +++ b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Union, cast +from datahub.cli.cli_utils import set_env_variables_override_config from datahub.configuration.common import OperationalError from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.emitter.rest_emitter import DatahubRestEmitter @@ -57,7 +58,9 @@ def __init__(self, ctx: PipelineContext, config: DatahubRestSinkConfig): .get("linkedin/datahub", {}) .get("version", "") ) - logger.info("Setting gms config") + logger.debug("Setting env variables to override config") + set_env_variables_override_config(self.config.server, self.config.token) + logger.debug("Setting gms config") set_gms_config(gms_config) self.executor = concurrent.futures.ThreadPoolExecutor( max_workers=self.config.max_threads