diff --git a/ingestion/src/metadata/automations/runner.py b/ingestion/src/metadata/automations/runner.py index 96475206ca48..92a0dcb11efa 100644 --- a/ingestion/src/metadata/automations/runner.py +++ b/ingestion/src/metadata/automations/runner.py @@ -20,6 +20,9 @@ from metadata.generated.schema.entity.automations.workflow import ( Workflow as AutomationWorkflow, ) +from metadata.ingestion.connections.test_connections import ( + raise_test_connection_exception, +) from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.connections import get_connection, get_test_connection_fn from metadata.utils.ssl_manager import SSLManager, check_ssl_and_init @@ -60,19 +63,40 @@ def _( """ Run the test connection """ - ssl_manager = None ssl_manager: SSLManager = check_ssl_and_init(request.connection.config) if ssl_manager: request.connection.config = ssl_manager.setup_ssl(request.connection.config) - connection = get_connection(request.connection.config) - # Find the test_connection function in each /connection.py file test_connection_fn = get_test_connection_fn(request.connection.config) - test_connection_fn( - metadata, connection, request.connection.config, automation_workflow - ) + + try: + connection = get_connection(request.connection.config) + + host_port_str = str(request.connection.config.hostPort or "") + if "localhost" in host_port_str: + result = test_connection_fn(metadata, connection, request.connection.config) + raise_test_connection_exception(result) + + test_connection_fn( + metadata, connection, request.connection.config, automation_workflow + ) + except Exception as error: + host_port_str = str(getattr(request.connection.config, "hostPort", None) or "") + if "localhost" not in host_port_str: + raise error + + host_port_type = type(request.connection.config.hostPort) + docker_host_port_str = host_port_str.replace( + "localhost", "host.docker.internal" + ) + request.connection.config.hostPort = host_port_type(docker_host_port_str) + + connection = get_connection(request.connection.config) + test_connection_fn( + metadata, connection, request.connection.config, automation_workflow + ) if ssl_manager: ssl_manager.cleanup_temp_files() diff --git a/ingestion/src/metadata/ingestion/source/dashboard/dashboard_service.py b/ingestion/src/metadata/ingestion/source/dashboard/dashboard_service.py index 80a8895288f0..e98881988fa2 100644 --- a/ingestion/src/metadata/ingestion/source/dashboard/dashboard_service.py +++ b/ingestion/src/metadata/ingestion/source/dashboard/dashboard_service.py @@ -198,6 +198,9 @@ class DashboardServiceTopology(ServiceTopology): ) +from metadata.utils.helpers import retry_with_docker_host + + # pylint: disable=too-many-public-methods class DashboardServiceSource(TopologyRunnerMixin, Source, ABC): """ @@ -216,6 +219,7 @@ class DashboardServiceSource(TopologyRunnerMixin, Source, ABC): dashboard_source_state: Set = set() datamodel_source_state: Set = set() + @retry_with_docker_host() def __init__( self, config: WorkflowSource, diff --git a/ingestion/src/metadata/ingestion/source/database/bigquery/metadata.py b/ingestion/src/metadata/ingestion/source/database/bigquery/metadata.py index 078853aca599..b7f26e57be66 100644 --- a/ingestion/src/metadata/ingestion/source/database/bigquery/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/bigquery/metadata.py @@ -97,6 +97,7 @@ from metadata.utils import fqn from metadata.utils.credentials import GOOGLE_CREDENTIALS from metadata.utils.filters import filter_by_database, filter_by_schema +from metadata.utils.helpers import retry_with_docker_host from metadata.utils.logger import ingestion_logger from metadata.utils.sqlalchemy_utils import ( get_all_table_ddls, @@ -162,9 +163,11 @@ def get_columns(bq_schema): "precision": field.precision, "scale": field.scale, "max_length": field.max_length, - "system_data_type": _array_sys_data_type_repr(col_type) - if str(col_type) == "ARRAY" - else str(col_type), + "system_data_type": ( + _array_sys_data_type_repr(col_type) + if str(col_type) == "ARRAY" + else str(col_type) + ), "is_complex": is_complex_type(str(col_type)), "policy_tags": None, } @@ -223,6 +226,7 @@ class BigquerySource(LifeCycleQueryMixin, CommonDbSourceService, MultiDBSource): Database metadata from Bigquery Source """ + @retry_with_docker_host() def __init__(self, config, metadata, incremental_configuration: IncrementalConfig): # Check if the engine is established before setting project IDs # This ensures that we don't try to set project IDs when there is no engine @@ -685,9 +689,11 @@ def get_table_partition_details( return True, TablePartition( columns=[ PartitionColumnDetails( - columnName="_PARTITIONTIME" - if table.time_partitioning.type_ == "HOUR" - else "_PARTITIONDATE", + columnName=( + "_PARTITIONTIME" + if table.time_partitioning.type_ == "HOUR" + else "_PARTITIONDATE" + ), interval=str(table.time_partitioning.type_), intervalType=PartitionIntervalTypes.INGESTION_TIME, ) diff --git a/ingestion/src/metadata/ingestion/source/database/common_db_source.py b/ingestion/src/metadata/ingestion/source/database/common_db_source.py index 663b21628cbc..b773d621c74f 100644 --- a/ingestion/src/metadata/ingestion/source/database/common_db_source.py +++ b/ingestion/src/metadata/ingestion/source/database/common_db_source.py @@ -77,6 +77,7 @@ calculate_execution_time_generator, ) from metadata.utils.filters import filter_by_table +from metadata.utils.helpers import retry_with_docker_host from metadata.utils.logger import ingestion_logger from metadata.utils.ssl_manager import SSLManager, check_ssl_and_init @@ -108,6 +109,7 @@ class CommonDbSourceService( - fetch_column_tags implemented at SqlColumnHandler. Sources should override this when needed """ + @retry_with_docker_host() def __init__( self, config: WorkflowSource, diff --git a/ingestion/src/metadata/ingestion/source/database/common_nosql_source.py b/ingestion/src/metadata/ingestion/source/database/common_nosql_source.py index cdb1a03c42b3..a3503aaa2df5 100644 --- a/ingestion/src/metadata/ingestion/source/database/common_nosql_source.py +++ b/ingestion/src/metadata/ingestion/source/database/common_nosql_source.py @@ -54,6 +54,7 @@ from metadata.utils.constants import DEFAULT_DATABASE from metadata.utils.datalake.datalake_utils import DataFrameColumnParser from metadata.utils.filters import filter_by_schema, filter_by_table +from metadata.utils.helpers import retry_with_docker_host from metadata.utils.logger import ingestion_logger from metadata.utils.ssl_manager import check_ssl_and_init @@ -79,6 +80,7 @@ class CommonNoSQLSource(DatabaseServiceSource, ABC): Database metadata from NoSQL source """ + @retry_with_docker_host() def __init__(self, config: WorkflowSource, metadata: OpenMetadata): super().__init__() self.config = config diff --git a/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py b/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py index 20c3f0a57847..748b9edbfe6c 100644 --- a/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/datalake/metadata.py @@ -129,9 +129,11 @@ def get_database_names(self) -> Iterable[str]: ) if filter_by_database( self.source_config.databaseFilterPattern, - database_fqn - if self.source_config.useFqnForFiltering - else database_name, + ( + database_fqn + if self.source_config.useFqnForFiltering + else database_name + ), ): self.status.filter(database_fqn, "Database Filtered out") else: @@ -180,9 +182,11 @@ def get_database_schema_names(self) -> Iterable[str]: if filter_by_schema( self.config.sourceConfig.config.schemaFilterPattern, - schema_fqn - if self.config.sourceConfig.config.useFqnForFiltering - else schema_name, + ( + schema_fqn + if self.config.sourceConfig.config.useFqnForFiltering + else schema_name + ), ): self.status.filter(schema_fqn, "Bucket Filtered Out") continue @@ -352,9 +356,11 @@ def filter_dl_table(self, table_name: str): if filter_by_table( self.config.sourceConfig.config.tableFilterPattern, - table_fqn - if self.config.sourceConfig.config.useFqnForFiltering - else table_name, + ( + table_fqn + if self.config.sourceConfig.config.useFqnForFiltering + else table_name + ), ): self.status.filter( table_fqn, diff --git a/ingestion/src/metadata/ingestion/source/database/query_parser_source.py b/ingestion/src/metadata/ingestion/source/database/query_parser_source.py index 2c4477a6969a..e24396099c35 100644 --- a/ingestion/src/metadata/ingestion/source/database/query_parser_source.py +++ b/ingestion/src/metadata/ingestion/source/database/query_parser_source.py @@ -26,7 +26,7 @@ from metadata.ingestion.lineage.models import ConnectionTypeDialectMapper from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.connections import get_test_connection_fn -from metadata.utils.helpers import get_start_and_end +from metadata.utils.helpers import get_start_and_end, retry_with_docker_host from metadata.utils.logger import ingestion_logger from metadata.utils.ssl_manager import get_ssl_connection @@ -49,6 +49,7 @@ class QueryParserSource(Source, ABC): database_field: str schema_field: str + @retry_with_docker_host() def __init__( self, config: WorkflowSource, @@ -64,9 +65,11 @@ def __init__( self.dialect = ConnectionTypeDialectMapper.dialect_of(connection_type) self.source_config = self.config.sourceConfig.config self.start, self.end = get_start_and_end(self.source_config.queryLogDuration) - self.engine = ( - get_ssl_connection(self.service_connection) if get_engine else None - ) + + self.engine = None + if get_engine: + self.engine = get_ssl_connection(self.service_connection) + self.test_connection() @property def name(self) -> str: @@ -129,5 +132,5 @@ def close(self): def test_connection(self) -> None: test_connection_fn = get_test_connection_fn(self.service_connection) - result = test_connection_fn(self.engine) + result = test_connection_fn(self.metadata, self.engine, self.service_connection) raise_test_connection_exception(result) diff --git a/ingestion/src/metadata/ingestion/source/database/saphana/lineage.py b/ingestion/src/metadata/ingestion/source/database/saphana/lineage.py index c474c4bfea6f..a63b063a4ebc 100644 --- a/ingestion/src/metadata/ingestion/source/database/saphana/lineage.py +++ b/ingestion/src/metadata/ingestion/source/database/saphana/lineage.py @@ -159,5 +159,5 @@ def parse_cdata( def test_connection(self) -> None: test_connection_fn = get_test_connection_fn(self.service_connection) - result = test_connection_fn(self.engine) + result = test_connection_fn(self.metadata, self.engine, self.service_connection) raise_test_connection_exception(result) diff --git a/ingestion/src/metadata/ingestion/source/database/unitycatalog/lineage.py b/ingestion/src/metadata/ingestion/source/database/unitycatalog/lineage.py index 3189102eec3e..8f36b033fc10 100644 --- a/ingestion/src/metadata/ingestion/source/database/unitycatalog/lineage.py +++ b/ingestion/src/metadata/ingestion/source/database/unitycatalog/lineage.py @@ -40,6 +40,7 @@ from metadata.ingestion.source.database.unitycatalog.connection import get_connection from metadata.ingestion.source.database.unitycatalog.models import LineageTableStreams from metadata.utils import fqn +from metadata.utils.helpers import retry_with_docker_host from metadata.utils.logger import ingestion_logger logger = ingestion_logger() @@ -50,6 +51,7 @@ class UnitycatalogLineageSource(Source): Lineage Unity Catalog Source """ + @retry_with_docker_host() def __init__( self, config: WorkflowSource, diff --git a/ingestion/src/metadata/ingestion/source/database/unitycatalog/metadata.py b/ingestion/src/metadata/ingestion/source/database/unitycatalog/metadata.py index 6eb03388cf2c..7e55dac572b6 100644 --- a/ingestion/src/metadata/ingestion/source/database/unitycatalog/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/unitycatalog/metadata.py @@ -75,6 +75,7 @@ from metadata.ingestion.source.models import TableView from metadata.utils import fqn from metadata.utils.filters import filter_by_database, filter_by_schema, filter_by_table +from metadata.utils.helpers import retry_with_docker_host from metadata.utils.logger import ingestion_logger logger = ingestion_logger() @@ -89,6 +90,7 @@ class UnitycatalogSource( the unity catalog source """ + @retry_with_docker_host() def __init__(self, config: WorkflowSource, metadata: OpenMetadata): super().__init__() self.config = config diff --git a/ingestion/src/metadata/ingestion/source/messaging/messaging_service.py b/ingestion/src/metadata/ingestion/source/messaging/messaging_service.py index b6d5f0d8eb1f..06e1cd066260 100644 --- a/ingestion/src/metadata/ingestion/source/messaging/messaging_service.py +++ b/ingestion/src/metadata/ingestion/source/messaging/messaging_service.py @@ -48,6 +48,7 @@ from metadata.ingestion.source.connections import get_connection, get_test_connection_fn from metadata.utils import fqn from metadata.utils.filters import filter_by_topic +from metadata.utils.helpers import retry_with_docker_host from metadata.utils.logger import ingestion_logger logger = ingestion_logger() @@ -125,6 +126,7 @@ class MessagingServiceSource(TopologyRunnerMixin, Source, ABC): context = TopologyContextManager(topology) topic_source_state: Set = set() + @retry_with_docker_host() def __init__( self, config: WorkflowSource, diff --git a/ingestion/src/metadata/ingestion/source/metadata/alationsink/metadata.py b/ingestion/src/metadata/ingestion/source/metadata/alationsink/metadata.py index c29b81f7173b..82735a542528 100644 --- a/ingestion/src/metadata/ingestion/source/metadata/alationsink/metadata.py +++ b/ingestion/src/metadata/ingestion/source/metadata/alationsink/metadata.py @@ -56,6 +56,7 @@ ) from metadata.utils import fqn from metadata.utils.filters import filter_by_database, filter_by_schema, filter_by_table +from metadata.utils.helpers import retry_with_docker_host from metadata.utils.logger import ingestion_logger logger = ingestion_logger() @@ -71,6 +72,7 @@ class AlationsinkSource(Source): config: WorkflowSource alation_sink_client: AlationSinkClient + @retry_with_docker_host() def __init__( self, config: WorkflowSource, @@ -117,9 +119,11 @@ def create_datasource_request( ), ), db_username="Test", - title=om_database.displayName - if om_database.displayName - else model_str(om_database.name), + title=( + om_database.displayName + if om_database.displayName + else model_str(om_database.name) + ), description=model_str(om_database.description), ) except Exception as exc: @@ -140,9 +144,11 @@ def create_schema_request( key=fqn._build( # pylint: disable=protected-access str(alation_datasource_id), model_str(om_schema.name) ), - title=om_schema.displayName - if om_schema.displayName - else model_str(om_schema.name), + title=( + om_schema.displayName + if om_schema.displayName + else model_str(om_schema.name) + ), description=model_str(om_schema.description), ) except Exception as exc: @@ -163,9 +169,11 @@ def create_table_request( key=fqn._build( # pylint: disable=protected-access str(alation_datasource_id), schema_name, model_str(om_table.name) ), - title=om_table.displayName - if om_table.displayName - else model_str(om_table.name), + title=( + om_table.displayName + if om_table.displayName + else model_str(om_table.name) + ), description=model_str(om_table.description), table_type=TABLE_TYPE_MAPPER.get(om_table.tableType, "TABLE"), sql=om_table.schemaDefinition, @@ -273,16 +281,22 @@ def create_column_request( table_name, model_str(om_column.name), ), - column_type=om_column.dataTypeDisplay.lower() - if om_column.dataTypeDisplay - else om_column.dataType.value.lower(), - title=om_column.displayName - if om_column.displayName - else model_str(om_column.name), + column_type=( + om_column.dataTypeDisplay.lower() + if om_column.dataTypeDisplay + else om_column.dataType.value.lower() + ), + title=( + om_column.displayName + if om_column.displayName + else model_str(om_column.name) + ), description=model_str(om_column.description), - position=str(om_column.ordinalPosition) - if om_column.ordinalPosition - else None, + position=( + str(om_column.ordinalPosition) + if om_column.ordinalPosition + else None + ), index=self._get_column_index( alation_datasource_id, om_column, table_constraints ), diff --git a/ingestion/src/metadata/ingestion/source/metadata/amundsen/metadata.py b/ingestion/src/metadata/ingestion/source/metadata/amundsen/metadata.py index 8c9d077e434d..2d1e9e5702c1 100644 --- a/ingestion/src/metadata/ingestion/source/metadata/amundsen/metadata.py +++ b/ingestion/src/metadata/ingestion/source/metadata/amundsen/metadata.py @@ -70,7 +70,7 @@ NEO4J_AMUNDSEN_USER_QUERY, ) from metadata.utils import fqn -from metadata.utils.helpers import get_standard_chart_type +from metadata.utils.helpers import get_standard_chart_type, retry_with_docker_host from metadata.utils.logger import ingestion_logger from metadata.utils.metadata_service_helper import SERVICE_TYPE_MAPPER from metadata.utils.tag_utils import get_ometa_tag_and_classification, get_tag_labels @@ -116,6 +116,7 @@ class AmundsenSource(Source): dashboard_service: DashboardService + @retry_with_docker_host() def __init__(self, config: WorkflowSource, metadata: OpenMetadata): super().__init__() self.config = config @@ -248,9 +249,11 @@ def _yield_create_database(self, table) -> Iterable[Either[CreateDatabaseRequest table_name = "default" database_request = CreateDatabaseRequest( - name=table_name - if hasattr(service_entity.connection.config, "supportsDatabase") - else "default", + name=( + table_name + if hasattr(service_entity.connection.config, "supportsDatabase") + else "default" + ), service=service_entity.fullyQualifiedName, ) yield Either(right=database_request) diff --git a/ingestion/src/metadata/ingestion/source/metadata/atlas/metadata.py b/ingestion/src/metadata/ingestion/source/metadata/atlas/metadata.py index 2711c5ca1469..279bedf83373 100644 --- a/ingestion/src/metadata/ingestion/source/metadata/atlas/metadata.py +++ b/ingestion/src/metadata/ingestion/source/metadata/atlas/metadata.py @@ -50,7 +50,7 @@ from metadata.ingestion.source.database.column_type_parser import ColumnTypeParser from metadata.ingestion.source.metadata.atlas.client import AtlasClient from metadata.utils import fqn -from metadata.utils.helpers import get_database_name_for_lineage +from metadata.utils.helpers import get_database_name_for_lineage, retry_with_docker_host from metadata.utils.logger import ingestion_logger from metadata.utils.metadata_service_helper import SERVICE_TYPE_MAPPER from metadata.utils.tag_utils import get_ometa_tag_and_classification, get_tag_labels @@ -70,6 +70,7 @@ class AtlasSource(Source): tables: Dict[str, Any] topics: Dict[str, Any] + @retry_with_docker_host() def __init__( self, config: WorkflowSource, diff --git a/ingestion/src/metadata/ingestion/source/mlmodel/mlmodel_service.py b/ingestion/src/metadata/ingestion/source/mlmodel/mlmodel_service.py index e9517f819895..c4d68da137ef 100644 --- a/ingestion/src/metadata/ingestion/source/mlmodel/mlmodel_service.py +++ b/ingestion/src/metadata/ingestion/source/mlmodel/mlmodel_service.py @@ -51,6 +51,7 @@ from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.connections import get_connection, get_test_connection_fn from metadata.utils import fqn +from metadata.utils.helpers import retry_with_docker_host from metadata.utils.logger import ingestion_logger logger = ingestion_logger() @@ -113,6 +114,7 @@ class MlModelServiceSource(TopologyRunnerMixin, Source, ABC): context = TopologyContextManager(topology) mlmodel_source_state: Set = set() + @retry_with_docker_host() def __init__( self, config: WorkflowSource, diff --git a/ingestion/src/metadata/ingestion/source/pipeline/pipeline_service.py b/ingestion/src/metadata/ingestion/source/pipeline/pipeline_service.py index d47711d0a35b..6c0ed62ed4cc 100644 --- a/ingestion/src/metadata/ingestion/source/pipeline/pipeline_service.py +++ b/ingestion/src/metadata/ingestion/source/pipeline/pipeline_service.py @@ -54,6 +54,7 @@ from metadata.ingestion.source.pipeline.openlineage.utils import FQNNotFoundException from metadata.utils import fqn from metadata.utils.filters import filter_by_pipeline +from metadata.utils.helpers import retry_with_docker_host from metadata.utils.logger import ingestion_logger logger = ingestion_logger() @@ -133,6 +134,7 @@ class PipelineServiceSource(TopologyRunnerMixin, Source, ABC): context = TopologyContextManager(topology) pipeline_source_state: Set = set() + @retry_with_docker_host() def __init__( self, config: WorkflowSource, diff --git a/ingestion/src/metadata/ingestion/source/search/search_service.py b/ingestion/src/metadata/ingestion/source/search/search_service.py index 8ef0400748a9..10f0fc67e705 100644 --- a/ingestion/src/metadata/ingestion/source/search/search_service.py +++ b/ingestion/src/metadata/ingestion/source/search/search_service.py @@ -56,6 +56,7 @@ from metadata.ingestion.source.connections import get_connection, get_test_connection_fn from metadata.utils import fqn from metadata.utils.filters import filter_by_search_index +from metadata.utils.helpers import retry_with_docker_host from metadata.utils.logger import ingestion_logger logger = ingestion_logger() @@ -138,6 +139,7 @@ class SearchServiceSource(TopologyRunnerMixin, Source, ABC): context = TopologyContextManager(topology) index_source_state: Set = set() + @retry_with_docker_host() def __init__( self, config: WorkflowSource, diff --git a/ingestion/src/metadata/ingestion/source/storage/storage_service.py b/ingestion/src/metadata/ingestion/source/storage/storage_service.py index 9a57e5231253..47d96e06456f 100644 --- a/ingestion/src/metadata/ingestion/source/storage/storage_service.py +++ b/ingestion/src/metadata/ingestion/source/storage/storage_service.py @@ -62,6 +62,7 @@ DataFrameColumnParser, fetch_dataframe, ) +from metadata.utils.helpers import retry_with_docker_host from metadata.utils.logger import ingestion_logger from metadata.utils.storage_metadata_config import ( StorageMetadataConfigException, @@ -140,6 +141,7 @@ class StorageServiceSource(TopologyRunnerMixin, Source, ABC): global_manifest: Optional[ManifestMetadataConfig] + @retry_with_docker_host() def __init__( self, config: WorkflowSource, diff --git a/ingestion/src/metadata/utils/helpers.py b/ingestion/src/metadata/utils/helpers.py index 0b3438c7f280..79cb4221157b 100644 --- a/ingestion/src/metadata/utils/helpers.py +++ b/ingestion/src/metadata/utils/helpers.py @@ -32,6 +32,9 @@ from metadata.generated.schema.entity.data.table import Column, Table from metadata.generated.schema.entity.feed.suggestion import Suggestion, SuggestionType from metadata.generated.schema.entity.services.databaseService import DatabaseService +from metadata.generated.schema.metadataIngestion.workflow import ( + Source as WorkflowSource, +) from metadata.generated.schema.type.basic import EntityLink from metadata.generated.schema.type.tagLabel import TagLabel from metadata.utils.constants import DEFAULT_DATABASE @@ -476,3 +479,45 @@ def init_staging_dir(directory: str) -> None: location = Path(directory) logger.info(f"Creating the directory to store staging data in {location}") location.mkdir(parents=True, exist_ok=True) + + +def retry_with_docker_host(config: Optional[WorkflowSource] = None): + """ + Retries the function on exception, replacing "localhost" with "host.docker.internal" + in the `hostPort` config if applicable. Raises the original exception if no `config` is found. + """ + + def decorator(func): + def wrapper(*args, **kwargs): + nonlocal config + try: + func(*args, **kwargs) + except Exception as error: + config = config or kwargs.get("config") + if not config: + for argument in args: + if isinstance(argument, WorkflowSource): + config = argument + break + else: + raise error + + host_port_str = str( + getattr(config.serviceConnection.root.config, "hostPort", None) + or "" + ) + if "localhost" not in host_port_str: + raise error + + host_port_type = type(config.serviceConnection.root.config.hostPort) + docker_host_port_str = host_port_str.replace( + "localhost", "host.docker.internal" + ) + config.serviceConnection.root.config.hostPort = host_port_type( + docker_host_port_str + ) + func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/ingestion/src/metadata/workflow/classification.py b/ingestion/src/metadata/workflow/classification.py index 4f1417d9da4a..c9a8df2bae7e 100644 --- a/ingestion/src/metadata/workflow/classification.py +++ b/ingestion/src/metadata/workflow/classification.py @@ -29,6 +29,9 @@ class AutoClassificationWorkflow(ProfilerWorkflow): """Auto Classification workflow implementation. Based on the Profiler logic with different steps""" def set_steps(self): + # NOTE: Call test_connection to update host value before creating the source class + self.test_connection() + source_class = self._get_source_class() self.source = source_class.create(self.config.model_dump(), self.metadata) diff --git a/ingestion/src/metadata/workflow/profiler.py b/ingestion/src/metadata/workflow/profiler.py index 50b3cde10a06..5119b9f7bea2 100644 --- a/ingestion/src/metadata/workflow/profiler.py +++ b/ingestion/src/metadata/workflow/profiler.py @@ -23,6 +23,7 @@ from metadata.profiler.processor.processor import ProfilerProcessor from metadata.profiler.source.metadata import OpenMetadataSource from metadata.profiler.source.metadata_ext import OpenMetadataSourceExt +from metadata.utils.helpers import retry_with_docker_host from metadata.utils.importer import import_sink_class from metadata.utils.logger import profiler_logger from metadata.utils.ssl_manager import get_ssl_connection @@ -57,6 +58,9 @@ def _get_source_class(self): return OpenMetadataSourceExt def set_steps(self): + # NOTE: Call test_connection to update host value before creating the source class + self.test_connection() + source_class = self._get_source_class() self.source = source_class.create(self.config.model_dump(), self.metadata) @@ -66,12 +70,16 @@ def set_steps(self): self.steps = (profiler_processor, sink) def test_connection(self) -> None: - service_config = self.config.source.serviceConnection.root.config - conn = get_ssl_connection(service_config) + @retry_with_docker_host(config=self.config.source) + def main(self): + service_config = self.config.source.serviceConnection.root.config + conn = get_ssl_connection(service_config) + + test_connection_fn = get_test_connection_fn(service_config) + result = test_connection_fn(self.metadata, conn, service_config) + raise_test_connection_exception(result) - test_connection_fn = get_test_connection_fn(service_config) - result = test_connection_fn(self.metadata, conn, service_config) - raise_test_connection_exception(result) + return main(self) def _get_sink(self) -> Sink: sink_type = self.config.sink.type diff --git a/ingestion/tests/unit/test_databricks_lineage.py b/ingestion/tests/unit/test_databricks_lineage.py index 3f710316b669..be2d4feb4af4 100644 --- a/ingestion/tests/unit/test_databricks_lineage.py +++ b/ingestion/tests/unit/test_databricks_lineage.py @@ -128,10 +128,13 @@ def __init__(self, methodName) -> None: super().__init__(methodName) config = OpenMetadataWorkflowConfig.model_validate(mock_databricks_config) - self.databricks = DatabricksLineageSource.create( - mock_databricks_config["source"], - config.workflowConfig.openMetadataServerConfig, - ) + with patch( + "metadata.ingestion.source.database.databricks.lineage.DatabricksLineageSource.test_connection" + ): + self.databricks = DatabricksLineageSource.create( + mock_databricks_config["source"], + config.workflowConfig.openMetadataServerConfig, + ) @patch( "metadata.ingestion.source.database.databricks.client.DatabricksClient.list_query_history" diff --git a/ingestion/tests/unit/test_pgspider_lineage_unit.py b/ingestion/tests/unit/test_pgspider_lineage_unit.py index a9a25bc0c5e0..d87101f4d3b2 100644 --- a/ingestion/tests/unit/test_pgspider_lineage_unit.py +++ b/ingestion/tests/unit/test_pgspider_lineage_unit.py @@ -587,10 +587,13 @@ class PGSpiderLineageUnitTests(TestCase): def __init__(self, methodName) -> None: super().__init__(methodName) config = OpenMetadataWorkflowConfig.model_validate(mock_pgspider_config) - self.postgres = PostgresLineageSource.create( - mock_pgspider_config["source"], - config.workflowConfig.openMetadataServerConfig, - ) + with patch( + "metadata.ingestion.source.database.postgres.lineage.PostgresLineageSource.test_connection" + ): + self.postgres = PostgresLineageSource.create( + mock_pgspider_config["source"], + config.workflowConfig.openMetadataServerConfig, + ) print(type(self.postgres)) @patch( diff --git a/ingestion/tests/unit/test_usage_filter.py b/ingestion/tests/unit/test_usage_filter.py index d44286d6ab7a..543fc975e948 100644 --- a/ingestion/tests/unit/test_usage_filter.py +++ b/ingestion/tests/unit/test_usage_filter.py @@ -147,9 +147,14 @@ class UsageQueryFilterTests(TestCase): @patch.object(OpenMetadata, "list_all_entities", mock_list_entities) def test_prepare_clickhouse(self): config = OpenMetadataWorkflowConfig.model_validate(mock_clickhouse_config) - clickhouse_source = ClickhouseUsageSource.create( - mock_clickhouse_config["source"], - OpenMetadata(config.workflowConfig.openMetadataServerConfig), - ) + with patch( + "metadata.ingestion.source.database.query_parser_source.get_ssl_connection" + ), patch( + "metadata.ingestion.source.database.clickhouse.usage.ClickhouseUsageSource.test_connection" + ): + clickhouse_source = ClickhouseUsageSource.create( + mock_clickhouse_config["source"], + OpenMetadata(config.workflowConfig.openMetadataServerConfig), + ) clickhouse_source.prepare() assert clickhouse_source.filters == EXPECTED_CLICKHOUSE_FILTER diff --git a/ingestion/tests/unit/test_usage_log.py b/ingestion/tests/unit/test_usage_log.py index 9b05aeed1eea..c201be5f2642 100644 --- a/ingestion/tests/unit/test_usage_log.py +++ b/ingestion/tests/unit/test_usage_log.py @@ -16,6 +16,7 @@ from pathlib import Path from unittest import TestCase +from unittest.mock import patch from metadata.generated.schema.metadataIngestion.workflow import ( OpenMetadataWorkflowConfig, @@ -150,10 +151,13 @@ class QueryLogSourceTest(TestCase): def __init__(self, methodName) -> None: super().__init__(methodName) self.config = OpenMetadataWorkflowConfig.model_validate(mock_query_log_config) - self.source = QueryLogUsageSource.create( - mock_query_log_config["source"], - self.config.workflowConfig.openMetadataServerConfig, - ) + with patch( + "metadata.ingestion.source.database.query.usage.QueryLogUsageSource.test_connection" + ): + self.source = QueryLogUsageSource.create( + mock_query_log_config["source"], + self.config.workflowConfig.openMetadataServerConfig, + ) def test_queries(self): queries = list(self.source.get_table_query()) diff --git a/ingestion/tests/unit/topology/database/test_postgres.py b/ingestion/tests/unit/topology/database/test_postgres.py index 86da89043473..226d0fb46fac 100644 --- a/ingestion/tests/unit/topology/database/test_postgres.py +++ b/ingestion/tests/unit/topology/database/test_postgres.py @@ -299,10 +299,13 @@ def __init__(self, methodName, test_connection) -> None: self.usage_config = OpenMetadataWorkflowConfig.model_validate( mock_postgres_usage_config ) - self.postgres_usage_source = PostgresUsageSource.create( - mock_postgres_usage_config["source"], - self.usage_config.workflowConfig.openMetadataServerConfig, - ) + with patch( + "metadata.ingestion.source.database.postgres.usage.PostgresUsageSource.test_connection" + ): + self.postgres_usage_source = PostgresUsageSource.create( + mock_postgres_usage_config["source"], + self.usage_config.workflowConfig.openMetadataServerConfig, + ) def test_datatype(self): inspector = types.SimpleNamespace()