Skip to content

Commit

Permalink
Merge branch 'main' into feature/issue_#573
Browse files Browse the repository at this point in the history
  • Loading branch information
aminmovahed-db authored May 21, 2024
2 parents 8881919 + ab03d5c commit 91dd430
Show file tree
Hide file tree
Showing 18 changed files with 834 additions and 19 deletions.
4 changes: 2 additions & 2 deletions docs/assessment.md
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,13 @@ The Databricks ML Runtime is not supported on Shared Compute mode clusters. Reco

### AF301.1 - spark.catalog.x

The `spark.catalog.` pattern was found. Commonly used functions in spark.catalog, such as tableExists, listTables, setDefault catalog are not allowed/whitelisted on shared clusters due to security reasons. `spark.sql("<sql command>)` may be a better alternative.
The `spark.catalog.` pattern was found. Commonly used functions in spark.catalog, such as tableExists, listTables, setDefault catalog are not allowed/whitelisted on shared clusters due to security reasons. `spark.sql("<sql command>)` may be a better alternative. DBR 14.1 and above have made these commands available. Upgrade your DBR version.

[[back to top](#migration-assessment-report)]

### AF301.2 - spark.catalog.x (spark._jsparkSession.catalog)

The `spark._jsparkSession.catalog` pattern was found. Commonly used functions in spark.catalog, such as tableExists, listTables, setDefault catalog are not allowed/whitelisted on shared clusters due to security reasons. `spark.sql("<sql command>)` may be a better alternative.
The `spark._jsparkSession.catalog` pattern was found. Commonly used functions in spark.catalog, such as tableExists, listTables, setDefault catalog are not allowed/whitelisted on shared clusters due to security reasons. `spark.sql("<sql command>)` may be a better alternative. The corresponding `spark.catalog.x` methods may work on DBR 14.1 and above.

[[back to top](#migration-assessment-report)]

Expand Down
22 changes: 17 additions & 5 deletions src/databricks/labs/ucx/assessment/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class AWSRoleAction:

@property
def role_name(self):
role_match = re.match(AWSInstanceProfile.ROLE_NAME_REGEX, self.role_arn)
role_match = re.match(AWSResources.ROLE_NAME_REGEX, self.role_arn)
return role_match.group(1)


Expand All @@ -55,15 +55,13 @@ class AWSInstanceProfile:
instance_profile_arn: str
iam_role_arn: str | None = None

ROLE_NAME_REGEX = r"arn:aws:iam::[0-9]+:(?:instance-profile|role)\/([a-zA-Z0-9+=,.@_-]*)$"

@property
def role_name(self) -> str | None:
if self.iam_role_arn:
arn = self.iam_role_arn
else:
arn = self.instance_profile_arn
role_match = re.match(self.ROLE_NAME_REGEX, arn)
role_match = re.match(AWSResources.ROLE_NAME_REGEX, arn)
if not role_match:
logger.error(f"Role ARN is mismatched {self.iam_role_arn}")
return None
Expand All @@ -88,6 +86,7 @@ class AWSResources:
"arn:aws:iam::414351767826:role/unity-catalog-prod-UCMasterRole-14S5ZJVKOTYTL",
"arn:aws:iam::707343435239:role/unity-catalog-dev-UCMasterRole-G3MMN8SP21FO",
]
ROLE_NAME_REGEX = r"arn:aws:iam::[0-9]+:(?:instance-profile|role)\/([a-zA-Z0-9+=,.@_-]*)$"

def __init__(self, profile: str, command_runner: Callable[[str], tuple[int, str, str]] = run_command):
self._profile = profile
Expand Down Expand Up @@ -350,7 +349,7 @@ def create_migration_role(self, role_name: str) -> str | None:
assume_role_json = self._get_json_for_cli(aws_role_trust_doc)
return self._create_role(role_name, assume_role_json)

def get_instance_profile(self, instance_profile_name: str) -> str | None:
def get_instance_profile_arn(self, instance_profile_name: str) -> str | None:
instance_profile = self._run_json_command(
f"iam get-instance-profile --instance-profile-name {instance_profile_name}"
)
Expand All @@ -360,6 +359,19 @@ def get_instance_profile(self, instance_profile_name: str) -> str | None:

return instance_profile["InstanceProfile"]["Arn"]

def get_instance_profile_role_arn(self, instance_profile_name: str) -> str | None:
instance_profile = self._run_json_command(
f"iam get-instance-profile --instance-profile-name {instance_profile_name}"
)

if not instance_profile:
return None

try:
return instance_profile["InstanceProfile"]["Roles"][0]["Arn"]
except (KeyError, IndexError):
return None

def create_instance_profile(self, instance_profile_name: str) -> str | None:
instance_profile = self._run_json_command(
f"iam create-instance-profile --instance-profile-name {instance_profile_name}"
Expand Down
15 changes: 7 additions & 8 deletions src/databricks/labs/ucx/aws/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,12 @@ def _get_instance_profiles(self) -> Iterable[AWSInstanceProfile]:
instance_profiles = self._ws.instance_profiles.list()
result_instance_profiles = []
for instance_profile in instance_profiles:
if not instance_profile.iam_role_arn:
instance_profile.iam_role_arn = instance_profile.instance_profile_arn.replace(
"instance-profile", "role"
)
result_instance_profiles.append(
AWSInstanceProfile(instance_profile.instance_profile_arn, instance_profile.iam_role_arn)
)
iam_role_arn = instance_profile.iam_role_arn
role_match = re.match(AWSResources.ROLE_NAME_REGEX, instance_profile.instance_profile_arn)
if role_match is not None:
instance_profile_name = role_match.group(1)
iam_role_arn = self._aws_resources.get_instance_profile_role_arn(instance_profile_name)
result_instance_profiles.append(AWSInstanceProfile(instance_profile.instance_profile_arn, iam_role_arn))

return result_instance_profiles

Expand Down Expand Up @@ -230,7 +229,7 @@ def _update_sql_dac_with_instance_profile(self, iam_instance_profile: AWSInstanc
)

def get_instance_profile(self, instance_profile_name: str) -> AWSInstanceProfile | None:
instance_profile_arn = self._aws_resources.get_instance_profile(instance_profile_name)
instance_profile_arn = self._aws_resources.get_instance_profile_arn(instance_profile_name)

if not instance_profile_arn:
return None
Expand Down
13 changes: 12 additions & 1 deletion src/databricks/labs/ucx/mixins/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import subprocess
import sys
from collections.abc import Callable, Generator, MutableMapping
from datetime import timedelta
from datetime import timedelta, datetime
from pathlib import Path
from typing import BinaryIO

Expand Down Expand Up @@ -55,6 +55,7 @@
# pylint: disable=redefined-outer-name,too-many-try-statements,import-outside-toplevel,unnecessary-lambda,too-complex,invalid-name

logger = logging.getLogger(__name__)
JOBS_PURGE_TIMEOUT = timedelta(days=1)


def factory(name, create, remove):
Expand Down Expand Up @@ -798,6 +799,16 @@ def create(notebook_path: str | Path | None = None, **kwargs):
timeout_seconds=0,
)
]

# add RemoveAfter tag for test job cleanup
date_to_remove = (datetime.now() + JOBS_PURGE_TIMEOUT).strftime("%Y-%m-%d")
remove_after_tag = {"key": "RemoveAfter", "value": date_to_remove}

if 'tags' not in kwargs:
kwargs["tags"] = [remove_after_tag]
else:
kwargs["tags"].append(remove_after_tag)

job = ws.jobs.create(**kwargs)
logger.info(f"Job: {ws.config.host}#job/{job.job_id}")
return job
Expand Down
Empty file.
105 changes: 105 additions & 0 deletions src/databricks/labs/ucx/recon/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass


@dataclass
class TableIdentifier:
catalog: str
schema: str
table: str

@property
def catalog_escaped(self):
return f"`{self.catalog}`"

@property
def schema_escaped(self):
return f"`{self.schema}`"

@property
def table_escaped(self):
return f"`{self.table}`"

@property
def fqn_escaped(self):
return f"{self.catalog_escaped}.{self.schema_escaped}.{self.table_escaped}"


@dataclass(frozen=True)
class ColumnMetadata:
name: str
data_type: str


@dataclass
class TableMetadata:
identifier: TableIdentifier
columns: list[ColumnMetadata]

def get_column_metadata(self, column_name: str) -> ColumnMetadata | None:
for column in self.columns:
if column.name == column_name:
return column
return None


@dataclass
class DataProfilingResult:
row_count: int
table_metadata: TableMetadata


@dataclass
class SchemaComparisonEntry:
source_column: str | None
source_datatype: str | None
target_column: str | None
target_datatype: str | None
is_matching: bool
notes: str | None


@dataclass
class SchemaComparisonResult:
is_matching: bool
data: list[SchemaComparisonEntry]


@dataclass
class DataComparisonResult:
source_row_count: int
target_row_count: int
num_missing_records_in_target: int
num_missing_records_in_source: int


class TableMetadataRetriever(ABC):
@abstractmethod
def get_metadata(self, entity: TableIdentifier) -> TableMetadata:
"""
Get metadata for a given table
"""


class DataProfiler(ABC):
@abstractmethod
def profile_data(self, entity: TableIdentifier) -> DataProfilingResult:
"""
Profile data for a given table
"""


class SchemaComparator(ABC):
@abstractmethod
def compare_schema(self, source: TableIdentifier, target: TableIdentifier) -> SchemaComparisonResult:
"""
Compare schema for two tables
"""


class DataComparator(ABC):
@abstractmethod
def compare_data(self, source: TableIdentifier, target: TableIdentifier) -> DataComparisonResult:
"""
Compare data for two tables
"""
114 changes: 114 additions & 0 deletions src/databricks/labs/ucx/recon/data_comparator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from collections.abc import Iterator

from databricks.labs.lsql.backends import SqlBackend
from databricks.labs.lsql.core import Row

from .base import (
DataComparator,
DataComparisonResult,
TableIdentifier,
DataProfiler,
DataProfilingResult,
)


class StandardDataComparator(DataComparator):
DATA_COMPARISON_QUERY_TEMPLATE = """
WITH compare_results AS (
SELECT
CASE
WHEN source.hash_value IS NULL AND target.hash_value IS NULL THEN TRUE
WHEN source.hash_value IS NULL OR target.hash_value IS NULL THEN FALSE
WHEN source.hash_value = target.hash_value THEN TRUE
ELSE FALSE
END AS is_match,
CASE
WHEN target.hash_value IS NULL THEN 1
ELSE 0
END AS num_missing_records_in_target,
CASE
WHEN source.hash_value IS NULL THEN 1
ELSE 0
END AS num_missing_records_in_source
FROM (
SELECT {source_hash_expr} AS hash_value
FROM {source_table_fqn}
) AS source
FULL OUTER JOIN (
SELECT {target_hash_expr} AS hash_value
FROM {target_table_fqn}
) AS target
ON source.hash_value = target.hash_value
)
SELECT
COUNT(*) AS total_mismatches,
COALESCE(SUM(num_missing_records_in_target), 0) AS num_missing_records_in_target,
COALESCE(SUM(num_missing_records_in_source), 0) AS num_missing_records_in_source
FROM compare_results
WHERE is_match IS FALSE;
"""

def __init__(self, sql_backend: SqlBackend, data_profiler: DataProfiler):
self._sql_backend = sql_backend
self._data_profiler = data_profiler

def compare_data(self, source: TableIdentifier, target: TableIdentifier) -> DataComparisonResult:
"""
This method compares the data of two tables. It takes two TableIdentifier objects as input, which represent
the source and target tables for which the data are to be compared.
Note: This method does not handle exceptions raised during the execution of the SQL query or
the retrieval of the table metadata. These exceptions are expected to be handled by the caller in a manner
appropriate for their context.
"""
source_data_profile = self._data_profiler.profile_data(source)
target_data_profile = self._data_profiler.profile_data(target)
comparison_query = StandardDataComparator.build_data_comparison_query(
source_data_profile,
target_data_profile,
)
query_result: Iterator[Row] = self._sql_backend.fetch(comparison_query)
count_row = next(query_result)
num_missing_records_in_target = int(count_row["num_missing_records_in_target"])
num_missing_records_in_source = int(count_row["num_missing_records_in_source"])
return DataComparisonResult(
source_row_count=source_data_profile.row_count,
target_row_count=target_data_profile.row_count,
num_missing_records_in_target=num_missing_records_in_target,
num_missing_records_in_source=num_missing_records_in_source,
)

@classmethod
def build_data_comparison_query(
cls,
source_data_profile: DataProfilingResult,
target_data_profile: DataProfilingResult,
) -> str:
source_table = source_data_profile.table_metadata.identifier
target_table = target_data_profile.table_metadata.identifier
source_hash_inputs = _build_data_comparison_hash_inputs(source_data_profile)
target_hash_inputs = _build_data_comparison_hash_inputs(target_data_profile)
comparison_query = StandardDataComparator.DATA_COMPARISON_QUERY_TEMPLATE.format(
source_hash_expr=f"SHA2(CONCAT_WS('|', {', '.join(source_hash_inputs)}), 256)",
target_hash_expr=f"SHA2(CONCAT_WS('|', {', '.join(target_hash_inputs)}), 256)",
source_table_fqn=source_table.fqn_escaped,
target_table_fqn=target_table.fqn_escaped,
)

return comparison_query


def _build_data_comparison_hash_inputs(data_profile: DataProfilingResult) -> list[str]:
source_metadata = data_profile.table_metadata
inputs = []
for column in source_metadata.columns:
data_type = column.data_type.lower()
transformed_column = column.name

if data_type.startswith("array"):
transformed_column = f"TO_JSON(SORT_ARRAY({column.name}))"
elif data_type.startswith("map") or data_type.startswith("struct"):
transformed_column = f"TO_JSON({column.name})"

inputs.append(f"COALESCE(TRIM({transformed_column}), '')")
return inputs
35 changes: 35 additions & 0 deletions src/databricks/labs/ucx/recon/data_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from collections.abc import Iterator

from databricks.labs.lsql.backends import SqlBackend
from databricks.labs.lsql.core import Row

from .base import DataProfiler, DataProfilingResult, TableIdentifier, TableMetadataRetriever


class StandardDataProfiler(DataProfiler):
def __init__(self, sql_backend: SqlBackend, metadata_retriever: TableMetadataRetriever):
self._sql_backend = sql_backend
self._metadata_retriever = metadata_retriever

def profile_data(self, entity: TableIdentifier) -> DataProfilingResult:
"""
This method profiles the data in the given table. It takes a TableIdentifier object as input, which represents
the table to be profiled. The method performs two main operations:
1. It retrieves the row count of the table.
2. It retrieves the metadata of the table using a TableMetadataRetriever instance.
Note: This method does not handle exceptions raised during the execution of the SQL query or the retrieval
of the table metadata. These exceptions are expected to be handled by the caller
in a manner appropriate for their context.
"""
row_count = self._get_table_row_count(entity)
return DataProfilingResult(
row_count,
self._metadata_retriever.get_metadata(entity),
)

def _get_table_row_count(self, entity: TableIdentifier) -> int:
query_result: Iterator[Row] = self._sql_backend.fetch(f"SELECT COUNT(*) as row_count FROM {entity.fqn_escaped}")
count_row = next(query_result)
return int(count_row[0])
Loading

0 comments on commit 91dd430

Please sign in to comment.