Skip to content

Commit

Permalink
fix(ingestion): use correct native data type in all SQLAlchemy source…
Browse files Browse the repository at this point in the history
…s by compiling data type using dialect
  • Loading branch information
Masterchen09 committed Jun 28, 2024
1 parent 6f1f769 commit b5fbc10
Show file tree
Hide file tree
Showing 9 changed files with 478 additions and 468 deletions.
2 changes: 2 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/source/sql/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,14 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: Dict,
inspector: Inspector,
pk_constraints: Optional[dict] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
fields = get_schema_fields_for_sqlalchemy_column(
column_name=column["name"],
column_type=column["type"],
inspector=inspector,
description=column.get("comment", None),
nullable=column.get("nullable", True),
is_part_of_key=True
Expand Down
3 changes: 2 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/sql/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,12 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: Dict[Any, Any],
inspector: Inspector,
pk_constraints: Optional[Dict[Any, Any]] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
fields = super().get_schema_fields_for_column(
dataset_name, column, pk_constraints
dataset_name, column, inspector, pk_constraints
)

if self._COMPLEX_TYPE.match(fields[0].nativeDataType) and isinstance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def loop_tables(
)

# add table schema fields
schema_fields = self.get_schema_fields(dataset_name, columns)
schema_fields = self.get_schema_fields(dataset_name, columns, inspector)

self._set_partition_key(columns, schema_fields)

Expand Down Expand Up @@ -754,7 +754,7 @@ def loop_views(

# add view schema fields
schema_fields = self.get_schema_fields(
dataset.dataset_name, dataset.columns
dataset.dataset_name, dataset.columns, inspector
)

schema_metadata = get_schema_metadata(
Expand Down Expand Up @@ -877,6 +877,7 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: Dict[Any, Any],
inspector: Inspector,
pk_constraints: Optional[Dict[Any, Any]] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def _process_table(
pk_constraints: dict = inspector.get_pk_constraint(table, schema)
foreign_keys = self._get_foreign_keys(dataset_urn, inspector, schema, table)
schema_fields = self.get_schema_fields(
dataset_name, columns, pk_constraints, tags=extra_tags
dataset_name, columns, inspector, pk_constraints, tags=extra_tags
)
schema_metadata = get_schema_metadata(
self.report,
Expand Down Expand Up @@ -950,6 +950,7 @@ def get_schema_fields(
self,
dataset_name: str,
columns: List[dict],
inspector: Inspector,
pk_constraints: Optional[dict] = None,
tags: Optional[Dict[str, List[str]]] = None,
) -> List[SchemaField]:
Expand All @@ -959,7 +960,7 @@ def get_schema_fields(
if tags:
column_tags = tags.get(column["name"], [])
fields = self.get_schema_fields_for_column(
dataset_name, column, pk_constraints, tags=column_tags
dataset_name, column, inspector, pk_constraints, tags=column_tags
)
canonical_schema.extend(fields)
return canonical_schema
Expand All @@ -968,6 +969,7 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: dict,
inspector: Inspector,
pk_constraints: Optional[dict] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
Expand All @@ -979,7 +981,7 @@ def get_schema_fields_for_column(
field = SchemaField(
fieldPath=column["name"],
type=get_column_type(self.report, dataset_name, column["type"]),
nativeDataType=column.get("full_type", repr(column["type"])),
nativeDataType=column.get("full_type", column["type"].compile(dialect=inspector.dialect)),
description=column.get("comment", None),
nullable=column["nullable"],
recursive=False,
Expand Down Expand Up @@ -1044,7 +1046,7 @@ def _process_view(
self.warn(logger, dataset_name, "unable to get schema for this view")
schema_metadata = None
else:
schema_fields = self.get_schema_fields(dataset_name, columns)
schema_fields = self.get_schema_fields(dataset_name, columns, inspector)
schema_metadata = get_schema_metadata(
self.report,
dataset_name,
Expand Down
3 changes: 2 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/sql/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,12 @@ def get_schema_fields_for_column(
self,
dataset_name: str,
column: dict,
inspector: Inspector,
pk_constraints: Optional[dict] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
fields = super().get_schema_fields_for_column(
dataset_name, column, pk_constraints
dataset_name, column, inspector, pk_constraints
)

if isinstance(column["type"], (datatype.ROW, sqltypes.ARRAY, datatype.MAP)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def _process_projections(
foreign_keys = self._get_foreign_keys(
dataset_urn, inspector, schema, projection
)
schema_fields = self.get_schema_fields(dataset_name, columns, pk_constraints)
schema_fields = self.get_schema_fields(dataset_name, columns, inspector, pk_constraints)
schema_metadata = get_schema_metadata(
self.report,
dataset_name,
Expand Down Expand Up @@ -679,7 +679,7 @@ def _process_models(
)
dataset_snapshot.aspects.append(dataset_properties)

schema_fields = self.get_schema_fields(dataset_name, columns)
schema_fields = self.get_schema_fields(dataset_name, columns, inspector)

schema_metadata = get_schema_metadata(
self.report,
Expand Down
5 changes: 3 additions & 2 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def _select_statement_cll( # noqa: C901
# )

# Generate SELECT lineage.
direct_raw_col_upstreams = _get_direct_raw_col_upstreams(lineage_node)
direct_raw_col_upstreams = _get_direct_raw_col_upstreams(lineage_node, dialect)

# column_logic = lineage_node.source

Expand Down Expand Up @@ -621,6 +621,7 @@ def _column_level_lineage(

def _get_direct_raw_col_upstreams(
lineage_node: sqlglot.lineage.Node,
dialect: sqlglot.Dialect,
) -> Set[_ColumnRef]:
# Using a set here to deduplicate upstreams.
direct_raw_col_upstreams: Set[_ColumnRef] = set()
Expand All @@ -641,7 +642,7 @@ def _get_direct_raw_col_upstreams(

# Parse the column name out of the node name.
# Sqlglot calls .sql(), so we have to do the inverse.
normalized_col = sqlglot.parse_one(node.name).this.name
normalized_col = sqlglot.parse_one(node.name, dialect=dialect).this.name
if node.subfield:
normalized_col = f"{normalized_col}.{node.subfield}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, List, Optional, Type, Union

from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector

from datahub.ingestion.extractor.schema_util import avro_schema_to_mce_fields
from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField
Expand Down Expand Up @@ -150,6 +151,7 @@ def get_avro_for_sqlalchemy_column(
def get_schema_fields_for_sqlalchemy_column(
column_name: str,
column_type: types.TypeEngine,
inspector: Inspector,
description: Optional[str] = None,
nullable: Optional[bool] = True,
is_part_of_key: Optional[bool] = False,
Expand Down Expand Up @@ -189,7 +191,7 @@ def get_schema_fields_for_sqlalchemy_column(
SchemaField(
fieldPath=column_name,
type=SchemaFieldDataTypeClass(type=NullTypeClass()),
nativeDataType=str(column_type),
nativeDataType=column_type.compile(dialect=inspector.dialect),
)
]

Expand Down
Loading

0 comments on commit b5fbc10

Please sign in to comment.