Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ def parse_table_schemas(
else None,
)

@staticmethod
def get_metadata_from_parser(
self,
inputs: list[DbTableMeta],
outputs: list[DbTableMeta],
database_info: DatabaseInfo,
Expand Down Expand Up @@ -315,6 +315,7 @@ def generate_openlineage_metadata_from_sql(
:param database_info: database specific information
:param database: when passed it takes precedence over parsed database name
:param sqlalchemy_engine: when passed, engine's dialect is used to compile SQL queries
:param use_connection: if call to db should be performed to enrich datasets (e.g., with schema)
"""
job_facets: dict[str, JobFacet] = {"sql": sql_job.SQLJobFacet(query=self.normalize_sql(sql))}
parse_result = self.parse(sql=self.split_sql_string(sql))
Expand All @@ -338,17 +339,28 @@ def generate_openlineage_metadata_from_sql(
)

namespace = self.create_namespace(database_info=database_info)
inputs: list[Dataset] = []
outputs: list[Dataset] = []
if use_connection:
inputs, outputs = self.parse_table_schemas(
hook=hook,
inputs=parse_result.in_tables,
outputs=parse_result.out_tables,
namespace=namespace,
database=database,
database_info=database_info,
sqlalchemy_engine=sqlalchemy_engine,
)
else:
try:
inputs, outputs = self.parse_table_schemas(
hook=hook,
inputs=parse_result.in_tables,
outputs=parse_result.out_tables,
namespace=namespace,
database=database,
database_info=database_info,
sqlalchemy_engine=sqlalchemy_engine,
)
except Exception as e:
self.log.warning(
"OpenLineage method failed to enrich datasets using db metadata. Exception: `%s`",
e,
)
self.log.debug("OpenLineage failure details:", exc_info=True)

# If call to db failed or was not performed, use datasets from sql parsing alone
if not inputs and not outputs:
inputs, outputs = self.get_metadata_from_parser(
inputs=parse_result.in_tables,
outputs=parse_result.out_tables,
Expand Down
50 changes: 50 additions & 0 deletions providers/openlineage/tests/unit/openlineage/test_sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,53 @@ def test_generate_openlineage_metadata_from_sql(self, mock_parse, parser_returns
}
)
assert metadata.job_facets["sql"].query.replace(" ", "") == formatted_sql.replace(" ", "")

def test_generate_openlineage_metadata_from_sql_with_db_error(self):
parser = SQLParser(default_schema="ANOTHER_SCHEMA")
db_info = DatabaseInfo(scheme="myscheme", authority="host:port")

hook = MagicMock()

sql = """INSERT INTO popular_orders_day_of_week (order_day_of_week)
SELECT EXTRACT(ISODOW FROM order_placed_on) AS order_day_of_week
FROM top_delivery_times -- irrelevant comment"""

hook.get_conn.side_effect = RuntimeError("Simulated DB error")

formatted_sql = """INSERT INTO popular_orders_day_of_week (order_day_of_week)
SELECT EXTRACT(ISODOW FROM order_placed_on) AS order_day_of_week
FROM top_delivery_times"""
expected_schema = "ANOTHER_SCHEMA"
metadata = parser.generate_openlineage_metadata_from_sql(
sql=sql, hook=hook, database_info=db_info, use_connection=True
)

assert metadata.inputs == [
Dataset(
namespace="myscheme://host:port",
name=f"{expected_schema}.top_delivery_times",
facets={},
)
]
assert len(metadata.outputs) == 1
assert metadata.outputs[0].namespace == "myscheme://host:port"
assert metadata.outputs[0].name == f"{expected_schema}.popular_orders_day_of_week"
assert len(metadata.outputs[0].facets) == 1
assert metadata.outputs[0].facets[
"columnLineage"
] == column_lineage_dataset.ColumnLineageDatasetFacet(
fields={
"order_day_of_week": column_lineage_dataset.Fields(
inputFields=[
column_lineage_dataset.InputField(
namespace="myscheme://host:port",
name=f"{expected_schema}.top_delivery_times",
field="order_placed_on",
)
],
transformationDescription="",
transformationType="",
)
}
)
assert metadata.job_facets["sql"].query.replace(" ", "") == formatted_sql.replace(" ", "")