Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #13149: Multiple Project Id for Datalake GCS #14846

Merged
merged 5 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def __init__(self, config, metadata):
# as per service connection config, which would result in an error.
self.test_connection = lambda: None
super().__init__(config, metadata)
self.temp_credentials = None
self.client = None
# Used to delete temp json file created while initializing bigquery client
self.temp_credentials_file_path = []
Expand Down Expand Up @@ -366,18 +365,18 @@ def yield_database_schema(
schema_name=schema_name,
),
)

dataset_obj = self.client.get_dataset(schema_name)
if dataset_obj.labels and self.source_config.includeTags:
database_schema_request_obj.tags = []
for label_classification, label_tag_name in dataset_obj.labels.items():
tag_label = get_tag_label(
metadata=self.metadata,
tag_name=label_tag_name,
classification_name=label_classification,
)
if tag_label:
database_schema_request_obj.tags.append(tag_label)
if self.source_config.includeTags:
dataset_obj = self.client.get_dataset(schema_name)
if dataset_obj.labels:
database_schema_request_obj.tags = []
for label_classification, label_tag_name in dataset_obj.labels.items():
tag_label = get_tag_label(
metadata=self.metadata,
tag_name=label_tag_name,
classification_name=label_classification,
)
if tag_label:
database_schema_request_obj.tags.append(tag_label)
yield Either(right=database_schema_request_obj)

def get_table_obj(self, table_name: str):
Expand Down Expand Up @@ -530,8 +529,6 @@ def clean_raw_data_type(self, raw_data_type):

def close(self):
super().close()
if self.temp_credentials:
os.unlink(self.temp_credentials)
os.environ.pop("GOOGLE_CLOUD_PROJECT", "")
if isinstance(
self.service_connection.credentials.gcpConfig, GcpCredentialsValues
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
"""
Source connection handler
"""
import os
from copy import deepcopy
from dataclasses import dataclass
from functools import partial, singledispatch
from typing import Optional

from google.cloud import storage

from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow,
)
Expand All @@ -31,9 +35,13 @@
from metadata.generated.schema.entity.services.connections.database.datalakeConnection import (
DatalakeConnection,
)
from metadata.generated.schema.security.credentials.gcpValues import (
MultipleProjectId,
SingleProjectId,
)
from metadata.ingestion.connections.test_connections import test_connection_steps
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils.credentials import set_google_credentials
from metadata.utils.credentials import GOOGLE_CREDENTIALS, set_google_credentials


# Only import specific datalake dependencies if necessary
Expand Down Expand Up @@ -65,9 +73,15 @@ def _(config: S3Config):

@get_datalake_client.register
def _(config: GCSConfig):
from google.cloud import storage

set_google_credentials(gcp_credentials=config.securityConfig)
gcs_config = deepcopy(config)
if hasattr(config.securityConfig, "gcpConfig") and isinstance(
config.securityConfig.gcpConfig.projectId, MultipleProjectId
):
gcs_config: GCSConfig = deepcopy(config)
gcs_config.securityConfig.gcpConfig.projectId = SingleProjectId.parse_obj(
gcs_config.securityConfig.gcpConfig.projectId.__root__[0]
)
set_google_credentials(gcp_credentials=gcs_config.securityConfig)
gcs_client = storage.Client()
return gcs_client

Expand Down Expand Up @@ -96,6 +110,15 @@ def _(config: AzureConfig):
)


def set_gcs_datalake_client(config: GCSConfig, project_id: str):
gcs_config = deepcopy(config)
if hasattr(gcs_config.securityConfig, "gcpConfig"):
gcs_config.securityConfig.gcpConfig.projectId = SingleProjectId.parse_obj(
project_id
)
return get_datalake_client(config=gcs_config)


def get_connection(connection: DatalakeConnection) -> DatalakeClient:
"""
Create connection.
Expand Down Expand Up @@ -125,6 +148,10 @@ def test_connection(
func = partial(connection.client.get_bucket, connection.config.bucketName)
else:
func = connection.client.list_buckets
os.environ.pop("GOOGLE_CLOUD_PROJECT", "")
if GOOGLE_CREDENTIALS in os.environ:
os.remove(os.environ[GOOGLE_CREDENTIALS])
del os.environ[GOOGLE_CREDENTIALS]

if isinstance(config, S3Config):
if connection.config.bucketName:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DataLake connector to fetch metadata from a files stored s3, gcs and Hdfs
"""
import json
import os
import traceback
from typing import Any, Iterable, Tuple, Union

Expand Down Expand Up @@ -53,12 +54,18 @@
from metadata.generated.schema.metadataIngestion.workflow import (
Source as WorkflowSource,
)
from metadata.generated.schema.security.credentials.gcpValues import (
GcpCredentialsValues,
)
from metadata.ingestion.api.models import Either
from metadata.ingestion.api.steps import InvalidSourceException
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection
from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.datalake.connection import (
set_gcs_datalake_client,
)
from metadata.ingestion.source.database.stored_procedures_mixin import QueryByProcedure
from metadata.ingestion.source.storage.storage_service import (
OPENMETADATA_TEMPLATE_FILE_NAME,
Expand All @@ -68,12 +75,13 @@
from metadata.readers.file.config_source_factory import get_reader
from metadata.utils import fqn
from metadata.utils.constants import DEFAULT_DATABASE
from metadata.utils.credentials import GOOGLE_CREDENTIALS
from metadata.utils.datalake.datalake_utils import (
fetch_dataframe,
get_columns,
get_file_format_type,
)
from metadata.utils.filters import filter_by_schema, filter_by_table
from metadata.utils.filters import filter_by_database, filter_by_schema, filter_by_table
from metadata.utils.logger import ingestion_logger
from metadata.utils.s3_utils import list_s3_objects

Expand All @@ -96,8 +104,10 @@ def __init__(self, config: WorkflowSource, metadata: OpenMetadata):
)
self.metadata = metadata
self.service_connection = self.config.serviceConnection.__root__.config
self.temp_credentials_file_path = []
self.connection = get_connection(self.service_connection)

if GOOGLE_CREDENTIALS in os.environ:
self.temp_credentials_file_path.append(os.environ[GOOGLE_CREDENTIALS])
self.client = self.connection.client
self.table_constraints = None
self.database_source_state = set()
Expand Down Expand Up @@ -125,8 +135,47 @@ def get_database_names(self) -> Iterable[str]:
Sources with multiple databases should overwrite this and
apply the necessary filters.
"""
database_name = self.service_connection.databaseName or DEFAULT_DATABASE
yield database_name
if isinstance(self.config_source, GCSConfig):
project_id_list = (
self.service_connection.configSource.securityConfig.gcpConfig.projectId.__root__
)
if not isinstance(
project_id_list,
list,
):
project_id_list = [project_id_list]
for project_id in project_id_list:
database_fqn = fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.database_service,
database_name=project_id,
)
if filter_by_database(
self.source_config.databaseFilterPattern,
database_fqn
if self.source_config.useFqnForFiltering
else project_id,
):
self.status.filter(database_fqn, "Database Filtered out")
else:
try:
self.client = set_gcs_datalake_client(
config=self.config_source, project_id=project_id
)
if GOOGLE_CREDENTIALS in os.environ:
self.temp_credentials_file_path.append(
os.environ[GOOGLE_CREDENTIALS]
)
yield project_id
except Exception as exc:
logger.debug(traceback.format_exc())
logger.error(
f"Error trying to connect to database {project_id}: {exc}"
)
else:
database_name = self.service_connection.databaseName or DEFAULT_DATABASE
yield database_name

def yield_database(
self, database_name: str
Expand All @@ -135,6 +184,8 @@ def yield_database(
From topology.
Prepare a database request and pass it to the sink
"""
if isinstance(self.config_source, GCSConfig):
database_name = self.client.project
yield Either(
right=CreateDatabaseRequest(
name=database_name,
Expand All @@ -143,24 +194,42 @@ def yield_database(
)

def fetch_gcs_bucket_names(self):
for bucket in self.client.list_buckets():
schema_fqn = fqn.build(
self.metadata,
entity_type=DatabaseSchema,
service_name=self.context.database_service,
database_name=self.context.database,
schema_name=bucket.name,
)
if filter_by_schema(
self.config.sourceConfig.config.schemaFilterPattern,
schema_fqn
if self.config.sourceConfig.config.useFqnForFiltering
else bucket.name,
):
self.status.filter(schema_fqn, "Bucket Filtered Out")
continue
"""
Fetch Google cloud storage buckets
"""
try:
# List all the buckets in the project
for bucket in self.client.list_buckets():
# Build a fully qualified name (FQN) for each bucket
schema_fqn = fqn.build(
self.metadata,
entity_type=DatabaseSchema,
service_name=self.context.database_service,
database_name=self.context.database,
schema_name=bucket.name,
)

yield bucket.name
# Check if the bucket matches a certain filter pattern
if filter_by_schema(
self.config.sourceConfig.config.schemaFilterPattern,
schema_fqn
if self.config.sourceConfig.config.useFqnForFiltering
else bucket.name,
):
# If it does not match, the bucket is filtered out
self.status.filter(schema_fqn, "Bucket Filtered Out")
continue

# If it does match, the bucket name is yielded
yield bucket.name
except Exception as exc:
yield Either(
left=StackTraceError(
name="Bucket",
error=f"Unexpected exception to yield bucket: {exc}",
stackTrace=traceback.format_exc(),
)
)

def fetch_s3_bucket_names(self):
for bucket in self.client.list_buckets()["Buckets"]:
Expand Down Expand Up @@ -434,3 +503,12 @@ def filter_dl_table(self, table_name: str):
def close(self):
if isinstance(self.config_source, AzureConfig):
self.client.close()
if isinstance(self.config_source, GCSConfig):
os.environ.pop("GOOGLE_CLOUD_PROJECT", "")
if isinstance(self.service_connection, GcpCredentialsValues) and (
GOOGLE_CREDENTIALS in os.environ
):
del os.environ[GOOGLE_CREDENTIALS]
for temp_file_path in self.temp_credentials_file_path:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
7 changes: 7 additions & 0 deletions ingestion/tests/unit/test_entity_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def validate(self, fn_resp, check_split):
"<#E::table::随机的>",
["table", "随机的"],
),
EntityLinkTest(
'<#E::table::ExampleWithFolder.withfolder.examplewithfolder."folderpath/username.csv">',
[
"table",
'ExampleWithFolder.withfolder.examplewithfolder."folderpath/username.csv"',
],
),
]
for x in xs:
x.validate(entity_link.split(x.entitylink), x.split_list)
Loading
Loading