Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MINOR - generic profiler optimization for sampling and BQ #14507

Merged
merged 11 commits into from
Dec 27, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def _get_struct_columns(self, columns: dict, parent: str):
for key, value in columns.items():
if not isinstance(value, STRUCT):
col = Column(f"{parent}.{key}", value)
# pylint: disable=protected-access
col._set_parent(self.table.__table__)
TeddyCr marked this conversation as resolved.
Show resolved Hide resolved
# pylint: enable=protected-access
columns_list.append(col)
else:
col = self._get_struct_columns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ def _create_thread_safe_runner(
partition_details=self.partition_details,
profile_sample_query=self.profile_query,
)
return thread_local.runner
thread_local.runner._sample = sample # pylint: disable=protected-access
return thread_local.runner

def compute_metrics_in_thread(
Expand All @@ -431,7 +433,7 @@ def compute_metrics_in_thread(
session,
metric_func.table,
)
sample = sampler.random_sample()
sample = sampler.random_sample(metric_func.column)
runner = self._create_thread_safe_runner(
session,
metric_func.table,
Expand Down Expand Up @@ -565,7 +567,7 @@ def get_hybrid_metrics(
dictionnary of results
"""
sampler = self._get_sampler(table=kwargs.get("table"))
sample = sampler.random_sample()
sample = sampler.random_sample(column)
try:
return metric(column).fn(sample, column_results, self.session)
except Exception as exc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,34 +173,64 @@ def bigquery_table_construct(runner: QueryRunner, **kwargs):
Args:
runner (QueryRunner): query runner object
"""
conn_config = kwargs.get("conn_config")
conn_config = cast(BigQueryConnection, conn_config)
try:
schema_name, table_name = _get_table_and_schema_name(runner.table)
project_id = conn_config.credentials.gcpConfig.projectId.__root__
except AttributeError:
raise AttributeError(ERROR_MSG)

conn_config = kwargs.get("conn_config")
conn_config = cast(BigQueryConnection, conn_config)

table_storage = _build_table(
"TABLE_STORAGE", f"region-{conn_config.usageLocation}.INFORMATION_SCHEMA"
)
col_names, col_count = _get_col_names_and_count(runner.table)
columns = [
Column("total_rows").label("rowCount"),
Column("total_logical_bytes").label("sizeInBytes"),
Column("creation_time").label("createDateTime"),
col_names,
col_count,
]

where_clause = [
Column("table_schema") == schema_name,
Column("table_name") == table_name,
]
def table_storage():
"""Fall back method if retrieving table metadata from`__TABLES__` fails"""
table_storage = _build_table(
"TABLE_STORAGE", f"region-{conn_config.usageLocation}.INFORMATION_SCHEMA"
)

query = _build_query(columns, table_storage, where_clause)
columns = [
Column("total_rows").label("rowCount"),
Column("total_logical_bytes").label("sizeInBytes"),
Column("creation_time").label("createDateTime"),
col_names,
col_count,
]

where_clause = [
Column("project_id") == project_id,
Column("table_schema") == schema_name,
Column("table_name") == table_name,
]

query = _build_query(columns, table_storage, where_clause)

return runner._session.execute(query).first()

def tables():
"""retrieve table metadata from `__TABLES__`"""
table_meta = _build_table("__TABLES__", f"{project_id}.{schema_name}")
columns = [
Column("row_count").label("rowCount"),
Column("size_bytes").label("sizeInBytes"),
Column("creation_time").label("createDateTime"),
col_names,
col_count,
]
where_clause = [
Column("project_id") == project_id,
Column("dataset_id") == schema_name,
Column("table_id") == table_name,
]

query = _build_query(columns, table_meta, where_clause)
return runner._session.execute(query).first()

return runner._session.execute(query).first()
try:
return tables()
except Exception as exc:
logger.debug(f"Error retrieving table metadata from `__TABLES__`: {exc}")
return table_storage()


def clickhouse_table_construct(runner: QueryRunner, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def handle_and_execute(_self, *args, **kwargs):
)
if self.build_sample:
return (
_self.client.query(
_self.table,
_self._base_sample_query(
kwargs.get("column"),
(ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL),
)
.filter(partition_filter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
from typing import Dict, Optional

from sqlalchemy import Column
from sqlalchemy.orm import Query

from metadata.generated.schema.entity.data.table import ProfileSampleType, TableType
Expand Down Expand Up @@ -50,20 +51,47 @@ def __init__(
)
self.table_type: TableType = table_type

def _base_sample_query(self, column: Optional[Column], label=None):
"""Base query for sampling

Args:
column (Optional[Column]): if computing a column metric only sample for the column
label (_type_, optional):

Returns:
"""
# pylint: disable=import-outside-toplevel
from sqlalchemy_bigquery import STRUCT

if column is not None:
column_parts = column.name.split(".")
if len(column_parts) > 1:
# for struct columns (e.g. `foo.bar`) we need to create a new column corresponding to
# the struct (e.g. `foo`) and then use that in the sample query as the column that
# will be query is `foo.bar`.
# e.g. WITH sample AS (SELECT `foo` FROM table) SELECT `foo.bar`
# FROM sample TABLESAMPLE SYSTEM (n PERCENT)
column = Column(column_parts[0], STRUCT)
# pylint: disable=protected-access
column._set_parent(self.table.__table__)
# pylint: enable=protected-access

return super()._base_sample_query(column, label=label)

@partition_filter_handler(build_sample=True)
def get_sample_query(self) -> Query:
def get_sample_query(self, *, column=None) -> Query:
"""get query for sample data"""
# TABLESAMPLE SYSTEM is not supported for views
if (
self.profile_sample_type == ProfileSampleType.PERCENTAGE
and self.table_type != TableType.View
):
return (
self._base_sample_query()
self._base_sample_query(column)
.suffix_with(
f"TABLESAMPLE SYSTEM ({self.profile_sample or 100} PERCENT)",
)
.cte(f"{self.table.__tablename__}_sample")
)

return super().get_sample_query()
return super().get_sample_query(column=column)
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,28 @@ class SQASampler(SamplerInterface):
run the query in the whole table.
"""

def _base_sample_query(self, label=None):
def _base_sample_query(self, column: Optional[Column], label=None):
"""Base query for sampling

Args:
column (Optional[Column]): if computing a column metric only sample for the column
label (_type_, optional):

Returns:
"""
# only sample the column if we are computing a column metric to limit the amount of data scaned
entity = self.table if column is None else column
if label is not None:
return self.client.query(self.table, label)
return self.client.query(self.table)
return self.client.query(entity, label)
return self.client.query(entity)

@partition_filter_handler(build_sample=True)
def get_sample_query(self) -> Query:
def get_sample_query(self, *, column=None) -> Query:
"""get query for sample data"""
if self.profile_sample_type == ProfileSampleType.PERCENTAGE:
rnd = (
self._base_sample_query(
column,
(ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL),
)
.suffix_with(
Expand All @@ -94,6 +105,7 @@ def get_sample_query(self) -> Query:

table_query = self.client.query(self.table)
session_query = self._base_sample_query(
column,
(ModuloFn(RandomNumFn(), table_query.count())).label(RANDOM_LABEL),
)
return (
Expand All @@ -102,7 +114,7 @@ def get_sample_query(self) -> Query:
.cte(f"{self.table.__tablename__}_rnd")
)

def random_sample(self) -> Union[DeclarativeMeta, AliasedClass]:
def random_sample(self, ccolumn=None) -> Union[DeclarativeMeta, AliasedClass]:
"""
Either return a sampled CTE of table, or
the full table if no sampling is required.
Expand All @@ -117,7 +129,7 @@ def random_sample(self) -> Union[DeclarativeMeta, AliasedClass]:
return self.table

# Add new RandomNumFn column
sampled = self.get_sample_query()
sampled = self.get_sample_query(column=ccolumn)

# Assign as an alias
return aliased(self.table, sampled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs)

def _base_sample_query(self, label=None):
def _base_sample_query(self, column, label=None):
sqa_columns = [col for col in inspect(self.table).c if col.name != RANDOM_LABEL]
return self.client.query(self.table, label).where(
entity = self.table if column is None else column
return self.client.query(entity, label).where(
or_(
*[
text(f"is_nan({cols.name}) = False")
Expand Down
15 changes: 10 additions & 5 deletions ingestion/src/metadata/utils/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from metadata.generated.schema.entity.data.table import (
IntervalType,
PartitionIntervalUnit,
PartitionProfilerConfig,
Table,
)
Expand Down Expand Up @@ -47,8 +48,10 @@ def get_partition_details(entity: Table) -> Optional[PartitionProfilerConfig]:
return PartitionProfilerConfig(
enablePartitioning=True,
partitionColumnName=entity.tablePartition.columns[0],
partitionIntervalUnit=entity.tablePartition.interval,
partitionInterval=30,
partitionIntervalUnit=PartitionIntervalUnit.DAY
if entity.tablePartition.interval != "HOUR"
else entity.tablePartition.interval,
partitionInterval=1,
TeddyCr marked this conversation as resolved.
Show resolved Hide resolved
partitionIntervalType=entity.tablePartition.intervalType.value,
partitionValues=None,
partitionIntegerRangeStart=None,
Expand All @@ -60,8 +63,10 @@ def get_partition_details(entity: Table) -> Optional[PartitionProfilerConfig]:
partitionColumnName="_PARTITIONDATE"
if entity.tablePartition.interval == "DAY"
else "_PARTITIONTIME",
partitionIntervalUnit=entity.tablePartition.interval,
partitionInterval=30,
partitionIntervalUnit=PartitionIntervalUnit.DAY
if entity.tablePartition.interval != "HOUR"
else entity.tablePartition.interval,
partitionInterval=1,
partitionIntervalType=entity.tablePartition.intervalType.value,
partitionValues=None,
partitionIntegerRangeStart=None,
Expand All @@ -72,7 +77,7 @@ def get_partition_details(entity: Table) -> Optional[PartitionProfilerConfig]:
enablePartitioning=True,
partitionColumnName=entity.tablePartition.columns[0],
partitionIntervalUnit=None,
partitionInterval=30,
partitionInterval=None,
partitionIntervalType=entity.tablePartition.intervalType.value,
partitionValues=None,
partitionIntegerRangeStart=1,
Expand Down
6 changes: 3 additions & 3 deletions ingestion/tests/unit/profiler/test_profiler_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_partition_details_time_unit(self):

if resp:
assert resp.partitionColumnName == "e"
assert resp.partitionInterval == 30
assert resp.partitionInterval == 1
assert not resp.partitionValues
else:
assert False
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_partition_details_ingestion_time_date(self):

if resp:
assert resp.partitionColumnName == "_PARTITIONDATE"
assert resp.partitionInterval == 30
assert resp.partitionInterval == 1
assert not resp.partitionValues
else:
assert False
Expand Down Expand Up @@ -221,7 +221,7 @@ def test_partition_details_ingestion_time_hour(self):

if resp:
assert resp.partitionColumnName == "_PARTITIONTIME"
assert resp.partitionInterval == 30
assert resp.partitionInterval == 1
assert not resp.partitionValues
else:
assert False
Expand Down
4 changes: 2 additions & 2 deletions ingestion/tests/unit/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_get_partition_details():
assert partition.enablePartitioning == True
assert partition.partitionColumnName == "_PARTITIONTIME"
assert partition.partitionIntervalType == PartitionIntervalType.INGESTION_TIME
assert partition.partitionInterval == 30
assert partition.partitionInterval == 1
assert partition.partitionIntervalUnit == PartitionIntervalUnit.HOUR

table_entity = MockTable(
Expand All @@ -97,5 +97,5 @@ def test_get_partition_details():
assert partition.enablePartitioning == True
assert partition.partitionColumnName == "_PARTITIONDATE"
assert partition.partitionIntervalType == PartitionIntervalType.INGESTION_TIME
assert partition.partitionInterval == 30
assert partition.partitionInterval == 1
assert partition.partitionIntervalUnit == PartitionIntervalUnit.DAY
Loading