Skip to content

Commit

Permalink
fix(ingestion): use correct native data type in all SQLAlchemy source…
Browse files Browse the repository at this point in the history
…s by compiling data type using dialect
  • Loading branch information
Masterchen09 committed Jul 15, 2024
1 parent ea7d6a9 commit 9a7f61a
Show file tree
Hide file tree
Showing 8 changed files with 483 additions and 466 deletions.
2 changes: 2 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/source/sql/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,14 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: Dict,
inspector: Inspector,
pk_constraints: Optional[dict] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
fields = get_schema_fields_for_sqlalchemy_column(
column_name=column["name"],
column_type=column["type"],
inspector=inspector,
description=column.get("comment", None),
nullable=column.get("nullable", True),
is_part_of_key=True
Expand Down
3 changes: 2 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/sql/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,12 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: Dict[Any, Any],
inspector: Inspector,
pk_constraints: Optional[Dict[Any, Any]] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
fields = super().get_schema_fields_for_column(
dataset_name, column, pk_constraints
dataset_name, column, inspector, pk_constraints
)

if self._COMPLEX_TYPE.match(fields[0].nativeDataType) and isinstance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def loop_tables(
)

# add table schema fields
schema_fields = self.get_schema_fields(dataset_name, columns)
schema_fields = self.get_schema_fields(dataset_name, columns, inspector)

self._set_partition_key(columns, schema_fields)

Expand Down Expand Up @@ -754,7 +754,7 @@ def loop_views(

# add view schema fields
schema_fields = self.get_schema_fields(
dataset.dataset_name, dataset.columns
dataset.dataset_name, dataset.columns, inspector
)

schema_metadata = get_schema_metadata(
Expand Down Expand Up @@ -877,6 +877,7 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: Dict[Any, Any],
inspector: Inspector,
pk_constraints: Optional[Dict[Any, Any]] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
from datahub.utilities.lossy_collections import LossyList
from datahub.utilities.registries.domain_registry import DomainRegistry
from datahub.utilities.sqlalchemy_query_combiner import SQLAlchemyQueryCombinerReport
from datahub.utilities.sqlalchemy_type_converter import get_native_data_type_for_sqlalchemy_type

if TYPE_CHECKING:
from datahub.ingestion.source.ge_data_profiler import (
Expand Down Expand Up @@ -771,7 +772,7 @@ def _process_table(
pk_constraints: dict = inspector.get_pk_constraint(table, schema)
foreign_keys = self._get_foreign_keys(dataset_urn, inspector, schema, table)
schema_fields = self.get_schema_fields(
dataset_name, columns, pk_constraints, tags=extra_tags
dataset_name, columns, inspector, pk_constraints, tags=extra_tags
)
schema_metadata = get_schema_metadata(
self.report,
Expand Down Expand Up @@ -948,6 +949,7 @@ def get_schema_fields(
self,
dataset_name: str,
columns: List[dict],
inspector: Inspector,
pk_constraints: Optional[dict] = None,
tags: Optional[Dict[str, List[str]]] = None,
) -> List[SchemaField]:
Expand All @@ -957,7 +959,7 @@ def get_schema_fields(
if tags:
column_tags = tags.get(column["name"], [])
fields = self.get_schema_fields_for_column(
dataset_name, column, pk_constraints, tags=column_tags
dataset_name, column, inspector, pk_constraints, tags=column_tags
)
canonical_schema.extend(fields)
return canonical_schema
Expand All @@ -966,6 +968,7 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: dict,
inspector: Inspector,
pk_constraints: Optional[dict] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
Expand All @@ -974,10 +977,11 @@ def get_schema_fields_for_column(
tags_str = [make_tag_urn(t) for t in tags]
tags_tac = [TagAssociationClass(t) for t in tags_str]
gtc = GlobalTagsClass(tags_tac)
full_type = column.get("full_type")
field = SchemaField(
fieldPath=column["name"],
type=get_column_type(self.report, dataset_name, column["type"]),
nativeDataType=column.get("full_type", repr(column["type"])),
nativeDataType=full_type if full_type is not None else get_native_data_type_for_sqlalchemy_type(column["type"], inspector=inspector),
description=column.get("comment", None),
nullable=column["nullable"],
recursive=False,
Expand Down Expand Up @@ -1042,7 +1046,7 @@ def _process_view(
self.warn(logger, dataset_name, "unable to get schema for this view")
schema_metadata = None
else:
schema_fields = self.get_schema_fields(dataset_name, columns)
schema_fields = self.get_schema_fields(dataset_name, columns, inspector)
schema_metadata = get_schema_metadata(
self.report,
dataset_name,
Expand Down
3 changes: 2 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/sql/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,12 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: dict,
inspector: Inspector,
pk_constraints: Optional[dict] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
fields = super().get_schema_fields_for_column(
dataset_name, column, pk_constraints
dataset_name, column, inspector, pk_constraints
)

if isinstance(column["type"], (datatype.ROW, sqltypes.ARRAY, datatype.MAP)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def _process_projections(
foreign_keys = self._get_foreign_keys(
dataset_urn, inspector, schema, projection
)
schema_fields = self.get_schema_fields(dataset_name, columns, pk_constraints)
schema_fields = self.get_schema_fields(dataset_name, columns, inspector, pk_constraints)
schema_metadata = get_schema_metadata(
self.report,
dataset_name,
Expand Down Expand Up @@ -673,7 +673,7 @@ def _process_models(
)
dataset_snapshot.aspects.append(dataset_properties)

schema_fields = self.get_schema_fields(dataset_name, columns)
schema_fields = self.get_schema_fields(dataset_name, columns, inspector)

schema_metadata = get_schema_metadata(
self.report,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, List, Optional, Type, Union

from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector

from datahub.ingestion.extractor.schema_util import avro_schema_to_mce_fields
from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField
Expand Down Expand Up @@ -150,6 +151,7 @@ def get_avro_for_sqlalchemy_column(
def get_schema_fields_for_sqlalchemy_column(
column_name: str,
column_type: types.TypeEngine,
inspector: Inspector,
description: Optional[str] = None,
nullable: Optional[bool] = True,
is_part_of_key: Optional[bool] = False,
Expand Down Expand Up @@ -189,7 +191,7 @@ def get_schema_fields_for_sqlalchemy_column(
SchemaField(
fieldPath=column_name,
type=SchemaFieldDataTypeClass(type=NullTypeClass()),
nativeDataType=str(column_type),
nativeDataType=get_native_data_type_for_sqlalchemy_type(column_type, inspector),
)
]

Expand All @@ -209,3 +211,9 @@ def get_schema_fields_for_sqlalchemy_column(
)

return schema_fields

def get_native_data_type_for_sqlalchemy_type(column_type: types.TypeEngine, inspector: Inspector) -> str:
if instanceof(column_type, types.NullType):
return types.NullType.__visit_name__

return column_type.compile(dialect=inspector.dialect)
Loading

0 comments on commit 9a7f61a

Please sign in to comment.