diff --git a/ingestion/setup.py b/ingestion/setup.py index ff54d8406ba7..b117351d6a0f 100644 --- a/ingestion/setup.py +++ b/ingestion/setup.py @@ -137,7 +137,7 @@ "requests>=2.23", "requests-aws4auth~=1.1", # Only depends on requests as external package. Leaving as base. "sqlalchemy>=1.4.0,<2", - "collate-sqllineage~=1.4.0", + "collate-sqllineage~=1.5.0", "tabulate==0.9.0", "typing-inspect", "packaging", # For version parsing diff --git a/ingestion/src/metadata/ingestion/lineage/parser.py b/ingestion/src/metadata/ingestion/lineage/parser.py index 2423f0dda3b4..f7fad5fe812a 100644 --- a/ingestion/src/metadata/ingestion/lineage/parser.py +++ b/ingestion/src/metadata/ingestion/lineage/parser.py @@ -15,12 +15,12 @@ from collections import defaultdict from copy import deepcopy from logging.config import DictConfigurator -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import sqlparse from cached_property import cached_property from collate_sqllineage import SQLPARSE_DIALECT -from collate_sqllineage.core.models import Column, Table +from collate_sqllineage.core.models import Column, DataFunction, Table from collate_sqllineage.exceptions import SQLLineageException from collate_sqllineage.runner import LineageRunner from sqlparse.sql import Comparison, Identifier, Parenthesis, Statement @@ -110,7 +110,7 @@ def intermediate_tables(self) -> List[Table]: return [] @cached_property - def source_tables(self) -> List[Table]: + def source_tables(self) -> List[Union[Table, DataFunction]]: """ Get a list of source tables """ @@ -374,7 +374,9 @@ def retrieve_tables(self, tables: List[Any]) -> List[Table]: if not self._clean_query: return [] return [ - self.clean_table_name(table) for table in tables if isinstance(table, Table) + self.clean_table_name(table) + for table in tables + if isinstance(table, (Table, DataFunction)) ] @classmethod diff --git a/ingestion/src/metadata/ingestion/lineage/sql_lineage.py b/ingestion/src/metadata/ingestion/lineage/sql_lineage.py index 3655bbc9251b..c8954fad6f0f 100644 --- a/ingestion/src/metadata/ingestion/lineage/sql_lineage.py +++ b/ingestion/src/metadata/ingestion/lineage/sql_lineage.py @@ -13,9 +13,18 @@ """ import itertools import traceback -from typing import Any, Iterable, List, Optional, Tuple +from collections import defaultdict +from typing import Any, Iterable, List, Optional, Tuple, Union + +from collate_sqllineage.core.models import Column, DataFunction +from collate_sqllineage.core.models import Table as LineageTable from metadata.generated.schema.api.lineage.addLineage import AddLineageRequest +from metadata.generated.schema.entity.data.storedProcedure import ( + Language, + StoredProcedure, + StoredProcedureType, +) from metadata.generated.schema.entity.data.table import Table from metadata.generated.schema.entity.services.ingestionPipelines.status import ( StackTraceError, @@ -163,6 +172,187 @@ def get_table_fqn_from_query_name( return database_query, schema_query, table +def __process_intermediate_column_lineage( + intermediate_column_lineage: dict, + result: dict, + source_table: str, + intermediate_table: str, + intermediate_column: str, + source_column: str, +): + # Check intermediate dictionary for mappings + for ( + target_table, + target_mappings, + ) in intermediate_column_lineage[intermediate_table].items(): + for inter_col, target_col in target_mappings: + if intermediate_column == inter_col: + # Append to the result dictionary + if target_table not in result[source_table]: + result[source_table][target_table] = [] + result[source_table][target_table].append((source_column, target_col)) + + +def __process_column_mappings( + mappings: dict, result: dict, source_table: str, intermediate_column_lineage: dict +): + for intermediate_table, column_pairs in mappings.items(): + # Iterate through each column mapping in the original dictionary + for source_column, intermediate_column in column_pairs: + if intermediate_table in intermediate_column_lineage: + __process_intermediate_column_lineage( + intermediate_column_lineage, + result, + source_table, + intermediate_table, + intermediate_column, + source_column, + ) + + +def handle_udf_column_lineage( + column_lineage_original: dict, + column_lineage_generated: List[Tuple[Column, Column]], +): + """ + Handle UDF column lineage + """ + try: + result = defaultdict(dict) + intermediate_column_lineage = populate_column_lineage_map( + column_lineage_generated + ) + # Iterate through the original dictionary + for source_table, mappings in column_lineage_original.items(): + __process_column_mappings( + mappings, result, source_table, intermediate_column_lineage + ) + column_lineage_original.update(result) + except Exception as exc: + logger.debug(traceback.format_exc()) + logger.error(f"Error handling UDF column lineage: {exc}") + + +# pylint: disable=too-many-arguments +def __process_udf_es_results( + metadata: OpenMetadata, + dialect: Dialect, + source_table: Union[DataFunction, LineageTable], + database_name: Optional[str], + schema_name: Optional[str], + service_name: Optional[str], + timeout_seconds: int, + column_lineage: dict, + es_result_entities: List[StoredProcedure], + procedure: Optional[StoredProcedure] = None, +): + for entity in es_result_entities: + if ( + entity.storedProcedureType == StoredProcedureType.UDF + and entity.storedProcedureCode + and entity.storedProcedureCode.language == Language.SQL + ): + expected_table_name = str(source_table).replace( + f"{DEFAULT_SCHEMA_NAME}.", "" + ) + lineage_parser = LineageParser( + f"create table {str(expected_table_name)} as {entity.storedProcedureCode.code}", + dialect=dialect, + timeout_seconds=timeout_seconds, + ) + handle_udf_column_lineage(column_lineage, lineage_parser.column_lineage) + for source in lineage_parser.source_tables or []: + yield from get_source_table_names( + metadata, + dialect, + source, + database_name, + schema_name, + service_name, + timeout_seconds, + column_lineage, + procedure or entity, + ) + + +def __process_udf_table_names( + metadata: OpenMetadata, + dialect: Dialect, + source_table: Union[DataFunction, LineageTable], + database_name: Optional[str], + schema_name: Optional[str], + service_name: Optional[str], + timeout_seconds: int, + column_lineage: dict, + procedure: Optional[StoredProcedure] = None, +): + database_query, schema_query, table = get_table_fqn_from_query_name( + str(source_table) + ) + function_fqn_string = build_es_fqn_search_string( + database_query or database_name, + schema_query or schema_name, + service_name, + table, + ) + es_result_entities: Optional[List[StoredProcedure]] = metadata.es_search_from_fqn( + entity_type=StoredProcedure, + fqn_search_string=function_fqn_string, + ) + if es_result_entities: + yield from __process_udf_es_results( + metadata, + dialect, + source_table, + database_name, + schema_name, + service_name, + timeout_seconds, + column_lineage, + es_result_entities, + procedure, + ) + + +def get_source_table_names( + metadata: OpenMetadata, + dialect: Dialect, + source_table: Union[DataFunction, LineageTable], + database_name: Optional[str], + schema_name: Optional[str], + service_name: Optional[str], + timeout_seconds: int, + column_lineage: dict, + procedure: Optional[StoredProcedure] = None, +) -> Iterable[Tuple[Optional[EntityReference], str]]: + """ + Get source table names from DataFunction + """ + try: + if not isinstance(source_table, DataFunction): + yield EntityReference( + id=procedure.id.root, type="storedProcedure" + ) if procedure else None, str(source_table) + else: + yield from __process_udf_table_names( + metadata, + dialect, + source_table, + database_name, + schema_name, + service_name, + timeout_seconds, + column_lineage, + procedure, + ) + + except Exception as exc: + logger.debug(traceback.format_exc()) + logger.error( + f"Error getting source table names for table [{source_table}]: {exc}" + ) + + def get_table_entities_from_query( metadata: OpenMetadata, service_name: str, @@ -244,6 +434,7 @@ def get_column_lineage( return column_lineage +# pylint: disable=too-many-arguments def _build_table_lineage( from_entity: Table, to_entity: Table, @@ -252,6 +443,7 @@ def _build_table_lineage( masked_query: str, column_lineage_map: dict, lineage_source: LineageSource = LineageSource.QueryLineage, + procedure: Optional[EntityReference] = None, ) -> Either[AddLineageRequest]: """ Prepare the lineage request generator @@ -276,7 +468,9 @@ def _build_table_lineage( from_table_raw_name=str(from_table_raw_name), column_lineage_map=column_lineage_map, ) - lineage_details = LineageDetails(sqlQuery=masked_query, source=lineage_source) + lineage_details = LineageDetails( + sqlQuery=masked_query, source=lineage_source, pipeline=procedure + ) if col_lineage: lineage_details.columnsLineage = col_lineage lineage = AddLineageRequest( @@ -315,6 +509,7 @@ def _create_lineage_by_table_name( masked_query: str, column_lineage_map: dict, lineage_source: LineageSource = LineageSource.QueryLineage, + procedure: Optional[EntityReference] = None, ) -> Iterable[Either[AddLineageRequest]]: """ This method is to create a lineage between two tables @@ -358,6 +553,7 @@ def _create_lineage_by_table_name( masked_query=masked_query, column_lineage_map=column_lineage_map, lineage_source=lineage_source, + procedure=procedure, ) except Exception as exc: @@ -429,17 +625,28 @@ def get_lineage_by_query( for intermediate_table in lineage_parser.intermediate_tables: for source_table in lineage_parser.source_tables: - yield from _create_lineage_by_table_name( - metadata, - from_table=str(source_table), - to_table=str(intermediate_table), - service_name=service_name, + for procedure, from_table_name in get_source_table_names( + metadata=metadata, + dialect=dialect, + source_table=source_table, database_name=database_name, schema_name=schema_name, - masked_query=masked_query, - column_lineage_map=column_lineage, - lineage_source=lineage_source, - ) + service_name=service_name, + timeout_seconds=timeout_seconds, + column_lineage=column_lineage, + ): + yield from _create_lineage_by_table_name( + metadata, + from_table=str(from_table_name), + to_table=str(intermediate_table), + service_name=service_name, + database_name=database_name, + schema_name=schema_name, + masked_query=masked_query, + column_lineage_map=column_lineage, + lineage_source=lineage_source, + procedure=procedure, + ) for target_table in lineage_parser.target_tables: yield from _create_lineage_by_table_name( metadata, @@ -455,17 +662,28 @@ def get_lineage_by_query( if not lineage_parser.intermediate_tables: for target_table in lineage_parser.target_tables: for source_table in lineage_parser.source_tables: - yield from _create_lineage_by_table_name( - metadata, - from_table=str(source_table), - to_table=str(target_table), - service_name=service_name, + for procedure, from_table_name in get_source_table_names( + metadata=metadata, + dialect=dialect, + source_table=source_table, database_name=database_name, schema_name=schema_name, - masked_query=masked_query, - column_lineage_map=column_lineage, - lineage_source=lineage_source, - ) + service_name=service_name, + timeout_seconds=timeout_seconds, + column_lineage=column_lineage, + ): + yield from _create_lineage_by_table_name( + metadata, + from_table=str(from_table_name), + to_table=str(target_table), + service_name=service_name, + database_name=database_name, + schema_name=schema_name, + masked_query=masked_query, + column_lineage_map=column_lineage, + lineage_source=lineage_source, + procedure=procedure, + ) if not lineage_parser.query_parsing_success: query_parsing_failures.add( QueryParsingError( @@ -505,17 +723,28 @@ def get_lineage_via_table_entity( to_table_name = table_entity.name.root for from_table_name in lineage_parser.source_tables: - yield from _create_lineage_by_table_name( - metadata, - from_table=str(from_table_name), - to_table=f"{schema_name}.{to_table_name}", - service_name=service_name, + for procedure, source_table in get_source_table_names( + metadata=metadata, + dialect=dialect, + source_table=from_table_name, database_name=database_name, schema_name=schema_name, - masked_query=masked_query, - column_lineage_map=column_lineage, - lineage_source=lineage_source, - ) or [] + service_name=service_name, + timeout_seconds=timeout_seconds, + column_lineage=column_lineage, + ): + yield from _create_lineage_by_table_name( + metadata, + from_table=str(source_table), + to_table=f"{schema_name}.{to_table_name}", + service_name=service_name, + database_name=database_name, + schema_name=schema_name, + masked_query=masked_query, + column_lineage_map=column_lineage, + lineage_source=lineage_source, + procedure=procedure, + ) or [] if not lineage_parser.query_parsing_success: query_parsing_failures.add( QueryParsingError( diff --git a/ingestion/src/metadata/ingestion/models/patch_request.py b/ingestion/src/metadata/ingestion/models/patch_request.py index f7fc2236e856..dd8789ecae66 100644 --- a/ingestion/src/metadata/ingestion/models/patch_request.py +++ b/ingestion/src/metadata/ingestion/models/patch_request.py @@ -100,6 +100,7 @@ class PatchedEntity(BaseModel): "fileFormat": True, # Stored Procedure Fields "storedProcedureCode": True, + "storedProcedureType": True, "code": True, # Dashboard Entity Fields "chartType": True, diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/lineage.py b/ingestion/src/metadata/ingestion/source/database/snowflake/lineage.py index 80ab7eeef5e6..2d85ed8772e2 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/lineage.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/lineage.py @@ -47,7 +47,6 @@ class SnowflakeLineageSource( AND ( QUERY_TYPE IN ('MERGE', 'UPDATE','CREATE_TABLE_AS_SELECT') OR (QUERY_TYPE = 'INSERT' and query_text ILIKE '%%insert%%into%%select%%') - OR (QUERY_TYPE = 'ALTER' and query_text ILIKE '%%alter%%table%%swap%%') ) """ diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py b/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py index 7ce7483cf098..da5a7034a26d 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py @@ -27,7 +27,10 @@ ) from metadata.generated.schema.entity.data.database import Database from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema -from metadata.generated.schema.entity.data.storedProcedure import StoredProcedureCode +from metadata.generated.schema.entity.data.storedProcedure import ( + StoredProcedureCode, + StoredProcedureType, +) from metadata.generated.schema.entity.data.table import ( PartitionColumnDetails, PartitionIntervalTypes, @@ -71,6 +74,7 @@ SnowflakeStoredProcedure, ) from metadata.ingestion.source.database.snowflake.queries import ( + SNOWFLAKE_DESC_FUNCTION, SNOWFLAKE_DESC_STORED_PROCEDURE, SNOWFLAKE_FETCH_ALL_TAGS, SNOWFLAKE_GET_CLUSTER_KEY, @@ -78,6 +82,7 @@ SNOWFLAKE_GET_DATABASE_COMMENTS, SNOWFLAKE_GET_DATABASES, SNOWFLAKE_GET_EXTERNAL_LOCATIONS, + SNOWFLAKE_GET_FUNCTIONS, SNOWFLAKE_GET_ORGANIZATION_NAME, SNOWFLAKE_GET_SCHEMA_COMMENTS, SNOWFLAKE_GET_STORED_PROCEDURES, @@ -623,26 +628,34 @@ def query_view_names_and_types( return views + def _get_stored_procedures_internal( + self, query: str + ) -> Iterable[SnowflakeStoredProcedure]: + results = self.engine.execute( + query.format( + database_name=self.context.get().database, + schema_name=self.context.get().database_schema, + ) + ).all() + for row in results: + stored_procedure = SnowflakeStoredProcedure.model_validate(dict(row)) + if stored_procedure.definition is None: + logger.debug( + f"Missing ownership permissions on procedure {stored_procedure.name}." + " Trying to fetch description via DESCRIBE." + ) + stored_procedure.definition = self.describe_procedure_definition( + stored_procedure + ) + yield stored_procedure + def get_stored_procedures(self) -> Iterable[SnowflakeStoredProcedure]: """List Snowflake stored procedures""" if self.source_config.includeStoredProcedures: - results = self.engine.execute( - SNOWFLAKE_GET_STORED_PROCEDURES.format( - database_name=self.context.get().database, - schema_name=self.context.get().database_schema, - ) - ).all() - for row in results: - stored_procedure = SnowflakeStoredProcedure.model_validate(dict(row)) - if stored_procedure.definition is None: - logger.debug( - f"Missing ownership permissions on procedure {stored_procedure.name}." - " Trying to fetch description via DESCRIBE." - ) - stored_procedure.definition = self.describe_procedure_definition( - stored_procedure - ) - yield stored_procedure + yield from self._get_stored_procedures_internal( + SNOWFLAKE_GET_STORED_PROCEDURES + ) + yield from self._get_stored_procedures_internal(SNOWFLAKE_GET_FUNCTIONS) def describe_procedure_definition( self, stored_procedure: SnowflakeStoredProcedure @@ -654,8 +667,12 @@ def describe_procedure_definition( Then, if the procedure is created with `EXECUTE AS CALLER`, we can still try to get the definition with a DESCRIBE. """ + if stored_procedure.procedure_type == StoredProcedureType.StoredProcedure.value: + query = SNOWFLAKE_DESC_STORED_PROCEDURE + else: + query = SNOWFLAKE_DESC_FUNCTION res = self.engine.execute( - SNOWFLAKE_DESC_STORED_PROCEDURE.format( + query.format( database_name=self.context.get().database, schema_name=self.context.get().database_schema, procedure_name=stored_procedure.name, @@ -677,6 +694,8 @@ def yield_stored_procedure( language=STORED_PROC_LANGUAGE_MAP.get(stored_procedure.language), code=stored_procedure.definition, ), + storedProcedureType=stored_procedure.procedure_type + or StoredProcedureType.StoredProcedure.value, databaseSchema=fqn.build( metadata=self.metadata, entity_type=DatabaseSchema, diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/models.py b/ingestion/src/metadata/ingestion/source/database/snowflake/models.py index 91a4bc3da8d9..594adb1149fd 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/models.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/models.py @@ -50,6 +50,7 @@ class SnowflakeStoredProcedure(BaseModel): None, alias="SIGNATURE", description="Used to build the source URL" ) comment: Optional[str] = Field(None, alias="COMMENT") + procedure_type: Optional[str] = Field(None, alias="PROCEDURE_TYPE") # Update the signature to clean it up on read @field_validator("signature") diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py b/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py index bf541214959d..671e2ebfa8a4 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/queries.py @@ -294,17 +294,38 @@ PROCEDURE_LANGUAGE AS language, PROCEDURE_DEFINITION AS definition, ARGUMENT_SIGNATURE AS signature, - COMMENT as comment + COMMENT as comment, + 'StoredProcedure' as procedure_type FROM INFORMATION_SCHEMA.PROCEDURES WHERE PROCEDURE_CATALOG = '{database_name}' AND PROCEDURE_SCHEMA = '{schema_name}' """ ) +SNOWFLAKE_GET_FUNCTIONS = textwrap.dedent( + """ +SELECT + FUNCTION_NAME AS name, + FUNCTION_OWNER AS owner, + FUNCTION_LANGUAGE AS language, + FUNCTION_DEFINITION AS definition, + ARGUMENT_SIGNATURE AS signature, + COMMENT as comment, + 'UDF' as procedure_type +FROM INFORMATION_SCHEMA.FUNCTIONS +WHERE FUNCTION_CATALOG = '{database_name}' + AND FUNCTION_SCHEMA = '{schema_name}' + """ +) + SNOWFLAKE_DESC_STORED_PROCEDURE = ( "DESC PROCEDURE {database_name}.{schema_name}.{procedure_name}{procedure_signature}" ) +SNOWFLAKE_DESC_FUNCTION = ( + "DESC FUNCTION {database_name}.{schema_name}.{procedure_name}{procedure_signature}" +) + SNOWFLAKE_GET_STORED_PROCEDURE_QUERIES = textwrap.dedent( """ WITH SP_HISTORY AS (