Skip to content

Commit

Permalink
refactor(ingest): streamline two-tier db config validation (#5986)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Sep 21, 2022
1 parent b638bcf commit 68db859
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import warnings
from typing import Callable, Type, TypeVar

import pydantic

_T = TypeVar("_T")


def _default_rename_transform(value: _T) -> _T:
return value


def pydantic_renamed_field(
old_name: str,
new_name: str,
transform: Callable[[_T], _T] = _default_rename_transform,
) -> classmethod:
def _validate_field_rename(cls: Type, values: dict) -> dict:
if old_name in values:
if new_name in values:
raise ValueError(
f"Cannot specify both {old_name} and {new_name} in the same config. Note that {old_name} has been deprecated in favor of {new_name}."
)
else:
warnings.warn(
f"The {old_name} is deprecated, please use {new_name} instead.",
UserWarning,
)
values[new_name] = transform(values.pop(old_name))
return values

# Why aren't we using pydantic.validator here?
# The `values` argument that is passed to field validators only contains items
# that have already been validated in the pre-process phase, which happens if
# they have an associated field and a pre=True validator. However, the root
# validator with pre=True gets all the values that were passed in.
# Given that a renamed field doesn't show up in the fields list, we can't use
# the field-level validator, even with a different field name.
return pydantic.root_validator(pre=True, allow_reuse=True)(_validate_field_rename)
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import typing
from typing import Any, Dict

import pydantic
from pydantic.fields import Field
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.reflection import Inspector

from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.validate_field_rename import pydantic_renamed_field
from datahub.emitter.mcp_builder import PlatformKey
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.sql.sql_common import (
Expand All @@ -24,40 +23,26 @@ class TwoTierSQLAlchemyConfig(BasicSQLAlchemyConfig):
description="Regex patterns for databases to filter in ingestion.",
)
schema_pattern: AllowDenyPattern = Field(
# The superclass contains a `schema_pattern` field, so we need this here
# to override the documentation.
default=AllowDenyPattern.allow_all(),
description="Deprecated in favour of database_pattern. Regex patterns for schemas to filter in ingestion. "
"Specify regex to only match the schema name. e.g. to match all tables in schema analytics, "
"use the regex 'analytics'",
description="Deprecated in favour of database_pattern.",
)

@pydantic.root_validator()
def ensure_profiling_pattern_is_passed_to_profiling(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
allow_all_pattern = AllowDenyPattern.allow_all()
schema_pattern = values.get("schema_pattern")
database_pattern = values.get("database_pattern")
if (
database_pattern == allow_all_pattern
and schema_pattern != allow_all_pattern
):
logger.warning(
"Updating 'database_pattern' to 'schema_pattern'. Please stop using deprecated "
"'schema_pattern'. Use 'database_pattern' instead. "
)
values["database_pattern"] = schema_pattern
return values
_schema_pattern_deprecated = pydantic_renamed_field(
"schema_pattern", "database_pattern"
)

def get_sql_alchemy_url(
self,
uri_opts: typing.Optional[typing.Dict[str, typing.Any]] = None,
current_db: typing.Optional[str] = None,
) -> str:
return self.sqlalchemy_uri or make_sqlalchemy_uri(
self.scheme, # type: ignore
self.scheme,
self.username,
self.password.get_secret_value() if self.password else None,
self.host_port, # type: ignore
self.host_port,
current_db if current_db else self.database,
uri_opts=uri_opts,
)
Expand All @@ -70,6 +55,8 @@ def __init__(self, config, ctx, platform):
self.config: TwoTierSQLAlchemyConfig = config

def get_parent_container_key(self, db_name: str, schema: str) -> PlatformKey:
# Because our overridden get_allowed_schemas method returns db_name as the schema name,
# the db_name and schema here will be the same. Hence, we just ignore the schema parameter.
return self.gen_database_key(db_name)

def get_allowed_schemas(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
JobId,
JobStateKey,
)
from datahub.ingestion.source.sql.mysql import MySQLConfig
from datahub.ingestion.source.sql.postgres import PostgresConfig
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState,
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_provider(self):
pipeline_name=self.pipeline_name,
platform_instance_id=self.platform_instance_id,
run_id=self.run_id,
config=MySQLConfig(),
config=PostgresConfig(host_port="localhost:5432"),
state=job1_state_obj,
)
# Job2 - Checkpoint with a BaseUsageCheckpointState state
Expand All @@ -136,22 +136,18 @@ def test_provider(self):
pipeline_name=self.pipeline_name,
platform_instance_id=self.platform_instance_id,
run_id=self.run_id,
config=MySQLConfig(),
config=PostgresConfig(host_port="localhost:5432"),
state=job2_state_obj,
)

# 2. Set the provider's state_to_commit.
self.provider.state_to_commit = {
# NOTE: state_to_commit accepts only the aspect version of the checkpoint.
self.job_names[0]: job1_checkpoint.to_checkpoint_aspect(
# fmt: off
max_allowed_state_size=2**20
# fmt: on
),
self.job_names[1]: job2_checkpoint.to_checkpoint_aspect(
# fmt: off
max_allowed_state_size=2**20
# fmt: on
),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from datahub.emitter.mce_builder import make_dataset_urn
from datahub.ingestion.source.sql.mysql import MySQLConfig
from datahub.ingestion.source.sql.postgres import PostgresConfig
from datahub.ingestion.source.sql.sql_common import BasicSQLAlchemyConfig
from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase
from datahub.ingestion.source.state.sql_common_state import (
Expand All @@ -21,7 +21,7 @@
test_platform_instance_id: str = "test_platform_instance_1"
test_job_name: str = "test_job_1"
test_run_id: str = "test_run_1"
test_source_config: BasicSQLAlchemyConfig = MySQLConfig()
test_source_config: BasicSQLAlchemyConfig = PostgresConfig(host_port="test_host:1234")

# 2. Create the params for parametrized tests.

Expand Down Expand Up @@ -79,7 +79,7 @@ def test_create_from_checkpoint_aspect(state_obj):
job_name=test_job_name,
checkpoint_aspect=checkpoint_aspect,
state_class=type(state_obj),
config_class=MySQLConfig,
config_class=PostgresConfig,
)

expected_checkpoint_obj = Checkpoint(
Expand Down Expand Up @@ -125,6 +125,6 @@ def test_serde_idempotence(state_obj):
job_name=test_job_name,
checkpoint_aspect=checkpoint_aspect,
state_class=type(state_obj),
config_class=MySQLConfig,
config_class=PostgresConfig,
)
assert orig_checkpoint_obj == serde_checkpoint_obj

0 comments on commit 68db859

Please sign in to comment.