Skip to content

Commit

Permalink
fix(ingestion): use correct native data type in all SQLAlchemy source…
Browse files Browse the repository at this point in the history
…s by compiling data type using dialect
  • Loading branch information
Masterchen09 committed Jul 20, 2024
1 parent 20574cf commit f5eaa01
Show file tree
Hide file tree
Showing 8 changed files with 509 additions and 464 deletions.
2 changes: 2 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/source/sql/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,13 +499,15 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: Dict,
inspector: Inspector,
pk_constraints: Optional[dict] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
fields = get_schema_fields_for_sqlalchemy_column(
column_name=column["name"],
column_type=column["type"],
inspector=inspector,
description=column.get("comment", None),
nullable=column.get("nullable", True),
is_part_of_key=(
Expand Down
6 changes: 5 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/sql/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,16 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: Dict[Any, Any],
inspector: Inspector,
pk_constraints: Optional[Dict[Any, Any]] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
fields = super().get_schema_fields_for_column(
dataset_name, column, pk_constraints
dataset_name,
column,
inspector,
pk_constraints,
)

if self._COMPLEX_TYPE.match(fields[0].nativeDataType) and isinstance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def loop_tables(
)

# add table schema fields
schema_fields = self.get_schema_fields(dataset_name, columns)
schema_fields = self.get_schema_fields(dataset_name, columns, inspector)

self._set_partition_key(columns, schema_fields)

Expand Down Expand Up @@ -754,7 +754,9 @@ def loop_views(

# add view schema fields
schema_fields = self.get_schema_fields(
dataset.dataset_name, dataset.columns
dataset.dataset_name,
dataset.columns,
inspector,
)

schema_metadata = get_schema_metadata(
Expand Down Expand Up @@ -877,6 +879,7 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: Dict[Any, Any],
inspector: Inspector,
pk_constraints: Optional[Dict[Any, Any]] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
Expand Down
17 changes: 15 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@
from datahub.utilities.lossy_collections import LossyList
from datahub.utilities.registries.domain_registry import DomainRegistry
from datahub.utilities.sqlalchemy_query_combiner import SQLAlchemyQueryCombinerReport
from datahub.utilities.sqlalchemy_type_converter import (
get_native_data_type_for_sqlalchemy_type,
)

if TYPE_CHECKING:
from datahub.ingestion.source.ge_data_profiler import (
Expand Down Expand Up @@ -788,6 +791,7 @@ def _process_table(
schema_fields = self.get_schema_fields(
dataset_name,
columns,
inspector,
pk_constraints,
tags=extra_tags,
partition_keys=partitions,
Expand Down Expand Up @@ -968,6 +972,7 @@ def get_schema_fields(
self,
dataset_name: str,
columns: List[dict],
inspector: Inspector,
pk_constraints: Optional[dict] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[Dict[str, List[str]]] = None,
Expand All @@ -980,6 +985,7 @@ def get_schema_fields(
fields = self.get_schema_fields_for_column(
dataset_name,
column,
inspector,
pk_constraints,
tags=column_tags,
partition_keys=partition_keys,
Expand All @@ -991,6 +997,7 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: dict,
inspector: Inspector,
pk_constraints: Optional[dict] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
Expand All @@ -1000,10 +1007,16 @@ def get_schema_fields_for_column(
tags_str = [make_tag_urn(t) for t in tags]
tags_tac = [TagAssociationClass(t) for t in tags_str]
gtc = GlobalTagsClass(tags_tac)
full_type = column.get("full_type")
field = SchemaField(
fieldPath=column["name"],
type=get_column_type(self.report, dataset_name, column["type"]),
nativeDataType=column.get("full_type", repr(column["type"])),
nativeDataType=full_type
if full_type is not None
else get_native_data_type_for_sqlalchemy_type(
column["type"],
inspector=inspector,
),
description=column.get("comment", None),
nullable=column["nullable"],
recursive=False,
Expand Down Expand Up @@ -1076,7 +1089,7 @@ def _process_view(
self.warn(logger, dataset_name, "unable to get schema for this view")
schema_metadata = None
else:
schema_fields = self.get_schema_fields(dataset_name, columns)
schema_fields = self.get_schema_fields(dataset_name, columns, inspector)
schema_metadata = get_schema_metadata(
self.report,
dataset_name,
Expand Down
6 changes: 5 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/sql/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,16 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: dict,
inspector: Inspector,
pk_constraints: Optional[dict] = None,
partition_keys: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
fields = super().get_schema_fields_for_column(
dataset_name, column, pk_constraints
dataset_name,
column,
inspector,
pk_constraints,
)

if isinstance(column["type"], (datatype.ROW, sqltypes.ARRAY, datatype.MAP)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,12 @@ def _process_projections(
foreign_keys = self._get_foreign_keys(
dataset_urn, inspector, schema, projection
)
schema_fields = self.get_schema_fields(dataset_name, columns, pk_constraints)
schema_fields = self.get_schema_fields(
dataset_name,
columns,
inspector,
pk_constraints,
)
schema_metadata = get_schema_metadata(
self.report,
dataset_name,
Expand Down Expand Up @@ -673,7 +678,7 @@ def _process_models(
)
dataset_snapshot.aspects.append(dataset_properties)

schema_fields = self.get_schema_fields(dataset_name, columns)
schema_fields = self.get_schema_fields(dataset_name, columns, inspector)

schema_metadata = get_schema_metadata(
self.report,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Dict, List, Optional, Type, Union

from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector

from datahub.ingestion.extractor.schema_util import avro_schema_to_mce_fields
from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField
Expand Down Expand Up @@ -176,6 +177,7 @@ def get_avro_for_sqlalchemy_column(
def get_schema_fields_for_sqlalchemy_column(
column_name: str,
column_type: types.TypeEngine,
inspector: Inspector,
description: Optional[str] = None,
nullable: Optional[bool] = True,
is_part_of_key: Optional[bool] = False,
Expand Down Expand Up @@ -216,7 +218,10 @@ def get_schema_fields_for_sqlalchemy_column(
SchemaField(
fieldPath=column_name,
type=SchemaFieldDataTypeClass(type=NullTypeClass()),
nativeDataType=str(column_type),
nativeDataType=get_native_data_type_for_sqlalchemy_type(
column_type,
inspector,
),
)
]

Expand All @@ -240,3 +245,12 @@ def get_schema_fields_for_sqlalchemy_column(
)

return schema_fields


def get_native_data_type_for_sqlalchemy_type(
column_type: types.TypeEngine, inspector: Inspector
) -> str:
if instanceof(column_type, types.NullType):
return types.NullType.__visit_name__

return column_type.compile(dialect=inspector.dialect)
Loading

0 comments on commit f5eaa01

Please sign in to comment.