Skip to content

Commit

Permalink
Fix #13149: Multiple Project Id for Datalake GCS (#14846)
Browse files Browse the repository at this point in the history
* Fix Multiple Project Id for datalake gcs

* Optimize logic

* Fix Tests

* Add Datalake GCS Tests

* Add multiple project id gcs test
  • Loading branch information
ayush-shah authored Jan 25, 2024
1 parent 951917b commit 1552aeb
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def __init__(self, config, metadata):
# as per service connection config, which would result in an error.
self.test_connection = lambda: None
super().__init__(config, metadata)
self.temp_credentials = None
self.client = None
# Used to delete temp json file created while initializing bigquery client
self.temp_credentials_file_path = []
Expand Down Expand Up @@ -366,18 +365,18 @@ def yield_database_schema(
schema_name=schema_name,
),
)

dataset_obj = self.client.get_dataset(schema_name)
if dataset_obj.labels and self.source_config.includeTags:
database_schema_request_obj.tags = []
for label_classification, label_tag_name in dataset_obj.labels.items():
tag_label = get_tag_label(
metadata=self.metadata,
tag_name=label_tag_name,
classification_name=label_classification,
)
if tag_label:
database_schema_request_obj.tags.append(tag_label)
if self.source_config.includeTags:
dataset_obj = self.client.get_dataset(schema_name)
if dataset_obj.labels:
database_schema_request_obj.tags = []
for label_classification, label_tag_name in dataset_obj.labels.items():
tag_label = get_tag_label(
metadata=self.metadata,
tag_name=label_tag_name,
classification_name=label_classification,
)
if tag_label:
database_schema_request_obj.tags.append(tag_label)
yield Either(right=database_schema_request_obj)

def get_table_obj(self, table_name: str):
Expand Down Expand Up @@ -530,8 +529,6 @@ def clean_raw_data_type(self, raw_data_type):

def close(self):
super().close()
if self.temp_credentials:
os.unlink(self.temp_credentials)
os.environ.pop("GOOGLE_CLOUD_PROJECT", "")
if isinstance(
self.service_connection.credentials.gcpConfig, GcpCredentialsValues
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
"""
Source connection handler
"""
import os
from copy import deepcopy
from dataclasses import dataclass
from functools import partial, singledispatch
from typing import Optional

from google.cloud import storage

from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow,
)
Expand All @@ -31,9 +35,13 @@
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection,
)
from metadata.generated.schema.security.credentials.gcpValues import (
MultipleProjectId,
SingleProjectId,
)
from metadata.ingestion.connections.test_connections import test_connection_steps
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils.credentials import set_google_credentials
from metadata.utils.credentials import GOOGLE_CREDENTIALS, set_google_credentials


# Only import specific datalake dependencies if necessary
Expand Down Expand Up @@ -65,9 +73,15 @@ def _(config: S3Config):

@get_datalake_client.register
def _(config: GCSConfig):
from google.cloud import storage

set_google_credentials(gcp_credentials=config.securityConfig)
gcs_config = deepcopy(config)
if hasattr(config.securityConfig, "gcpConfig") and isinstance(
config.securityConfig.gcpConfig.projectId, MultipleProjectId
):
gcs_config: GCSConfig = deepcopy(config)
gcs_config.securityConfig.gcpConfig.projectId = SingleProjectId.parse_obj(
gcs_config.securityConfig.gcpConfig.projectId.__root__[0]
)
set_google_credentials(gcp_credentials=gcs_config.securityConfig)
gcs_client = storage.Client()
return gcs_client

Expand Down Expand Up @@ -96,6 +110,15 @@ def _(config: AzureConfig):
)


def set_gcs_datalake_client(config: GCSConfig, project_id: str):
gcs_config = deepcopy(config)
if hasattr(gcs_config.securityConfig, "gcpConfig"):
gcs_config.securityConfig.gcpConfig.projectId = SingleProjectId.parse_obj(
project_id
)
return get_datalake_client(config=gcs_config)


def get_connection(connection: DatalakeConnection) -> DatalakeClient:
"""
Create connection.
Expand Down Expand Up @@ -125,6 +148,10 @@ def test_connection(
func = partial(connection.client.get_bucket, connection.config.bucketName)
else:
func = connection.client.list_buckets
os.environ.pop("GOOGLE_CLOUD_PROJECT", "")
if GOOGLE_CREDENTIALS in os.environ:
os.remove(os.environ[GOOGLE_CREDENTIALS])
del os.environ[GOOGLE_CREDENTIALS]

if isinstance(config, S3Config):
if connection.config.bucketName:
Expand Down
120 changes: 99 additions & 21 deletions ingestion/src/metadata/ingestion/source/database/datalake/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DataLake connector to fetch metadata from a files stored s3, gcs and Hdfs
"""
import json
import os
import traceback
from typing import Any, Iterable, Tuple, Union

Expand Down Expand Up @@ -53,12 +54,18 @@
from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.generated.schema.security.credentials.gcpValues import (
GcpCredentialsValues,
)
from metadata.ingestion.api.models import Either
from metadata.ingestion.api.steps import InvalidSourceException
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection
from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.datalake.connection import (
set_gcs_datalake_client,
)
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.ingestion.source.storage.storage_service import (
OPENMETADATA_TEMPLATE_FILE_NAME,
Expand All @@ -68,12 +75,13 @@
from metadata.readers.file.config_source_factory import get_reader
from metadata.utils import fqn
from metadata.utils.constants import DEFAULT_DATABASE
from metadata.utils.credentials import GOOGLE_CREDENTIALS
from metadata.utils.datalake.datalake_utils import (
fetch_dataframe,
get_columns,
get_file_format_type,
)
from metadata.utils.filters import filter_by_schema, filter_by_table
from metadata.utils.filters import filter_by_database, filter_by_schema, filter_by_table
from metadata.utils.logger import ingestion_logger
from metadata.utils.s3_utils import list_s3_objects

Expand All @@ -96,8 +104,10 @@ def __init__(self, config: WorkflowSource, metadata: OpenMetadata):
)
self.metadata = metadata
self.service_connection = self.config.serviceConnection.__root__.config
self.temp_credentials_file_path = []
self.connection = get_connection(self.service_connection)

if GOOGLE_CREDENTIALS in os.environ:
self.temp_credentials_file_path.append(os.environ[GOOGLE_CREDENTIALS])
self.client = self.connection.client
self.table_constraints = None
self.database_source_state = set()
Expand Down Expand Up @@ -125,8 +135,47 @@ def get_database_names(self) -> Iterable[str]:
Sources with multiple databases should overwrite this and
apply the necessary filters.
"""
database_name = self.service_connection.databaseName or DEFAULT_DATABASE
yield database_name
if isinstance(self.config_source, GCSConfig):
project_id_list = (
self.service_connection.configSource.securityConfig.gcpConfig.projectId.__root__
)
if not isinstance(
project_id_list,
list,
):
project_id_list = [project_id_list]
for project_id in project_id_list:
database_fqn = fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.database_service,
database_name=project_id,
)
if filter_by_database(
self.source_config.databaseFilterPattern,
database_fqn
if self.source_config.useFqnForFiltering
else project_id,
):
self.status.filter(database_fqn, "Database Filtered out")
else:
try:
self.client = set_gcs_datalake_client(
config=self.config_source, project_id=project_id
)
if GOOGLE_CREDENTIALS in os.environ:
self.temp_credentials_file_path.append(
os.environ[GOOGLE_CREDENTIALS]
)
yield project_id
except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(
f"Error trying to connect to database {project_id}: {exc}"
)
else:
database_name = self.service_connection.databaseName or DEFAULT_DATABASE
yield database_name

def yield_database(
self, database_name: str
Expand All @@ -135,6 +184,8 @@ def yield_database(
From topology.
Prepare a database request and pass it to the sink
"""
if isinstance(self.config_source, GCSConfig):
database_name = self.client.project
yield Either(
right=CreateDatabaseRequest(
name=database_name,
Expand All @@ -143,24 +194,42 @@ def yield_database(
)

def fetch_gcs_bucket_names(self):
for bucket in self.client.list_buckets():
schema_fqn = fqn.build(
self.metadata,
entity_type=DatabaseSchema,
service_name=self.context.database_service,
database_name=self.context.database,
schema_name=bucket.name,
)
if filter_by_schema(
self.config.sourceConfig.config.schemaFilterPattern,
schema_fqn
if self.config.sourceConfig.config.useFqnForFiltering
else bucket.name,
):
self.status.filter(schema_fqn, "Bucket Filtered Out")
continue
"""
Fetch Google cloud storage buckets
"""
try:
# List all the buckets in the project
for bucket in self.client.list_buckets():
# Build a fully qualified name (FQN) for each bucket
schema_fqn = fqn.build(
self.metadata,
entity_type=DatabaseSchema,
service_name=self.context.database_service,
database_name=self.context.database,
schema_name=bucket.name,
)

yield bucket.name
# Check if the bucket matches a certain filter pattern
if filter_by_schema(
self.config.sourceConfig.config.schemaFilterPattern,
schema_fqn
if self.config.sourceConfig.config.useFqnForFiltering
else bucket.name,
):
# If it does not match, the bucket is filtered out
self.status.filter(schema_fqn, "Bucket Filtered Out")
continue

# If it does match, the bucket name is yielded
yield bucket.name
except Exception as exc:
yield Either(
left=StackTraceError(
name="Bucket",
error=f"Unexpected exception to yield bucket: {exc}",
stackTrace=traceback.format_exc(),
)
)

def fetch_s3_bucket_names(self):
for bucket in self.client.list_buckets()["Buckets"]:
Expand Down Expand Up @@ -434,3 +503,12 @@ def filter_dl_table(self, table_name: str):
def close(self):
if isinstance(self.config_source, AzureConfig):
self.client.close()
if isinstance(self.config_source, GCSConfig):
os.environ.pop("GOOGLE_CLOUD_PROJECT", "")
if isinstance(self.service_connection, GcpCredentialsValues) and (
GOOGLE_CREDENTIALS in os.environ
):
del os.environ[GOOGLE_CREDENTIALS]
for temp_file_path in self.temp_credentials_file_path:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
7 changes: 7 additions & 0 deletions ingestion/tests/unit/test_entity_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def validate(self, fn_resp, check_split):
"<#E::table::随机的>",
["table", "随机的"],
),
EntityLinkTest(
'<#E::table::ExampleWithFolder.withfolder.examplewithfolder."folderpath/username.csv">',
[
"table",
'ExampleWithFolder.withfolder.examplewithfolder."folderpath/username.csv"',
],
),
]
for x in xs:
x.validate(entity_link.split(x.entitylink), x.split_list)
Loading

0 comments on commit 1552aeb

Please sign in to comment.