Skip to content

Commit

Permalink
fix(ingest): better correctness on the emitter -> graph conversion (d…
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored and llance committed Jan 13, 2025
1 parent 981570e commit 2c4b997
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 112 deletions.
11 changes: 10 additions & 1 deletion metadata-ingestion/src/datahub/cli/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import typing
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import click
import requests
Expand Down Expand Up @@ -33,6 +33,15 @@ def first_non_null(ls: List[Optional[str]]) -> Optional[str]:
return next((el for el in ls if el is not None and el.strip() != ""), None)


_T = TypeVar("_T")


def get_or_else(value: Optional[_T], default: _T) -> _T:
# Normally we'd use `value or default`. However, that runs into issues
# when value is falsey but not None.
return value if value is not None else default


def parse_run_restli_response(response: requests.Response) -> dict:
response_json = response.json()
if response.status_code != 200:
Expand Down
209 changes: 125 additions & 84 deletions metadata-ingestion/src/datahub/emitter/rest_emitter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from __future__ import annotations

import functools
import json
import logging
import os
from json.decoder import JSONDecodeError
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
)

import requests
from deprecated import deprecated
Expand All @@ -12,9 +24,13 @@

from datahub import nice_version_name
from datahub.cli import config_utils
from datahub.cli.cli_utils import ensure_has_system_metadata, fixup_gms_url
from datahub.cli.cli_utils import ensure_has_system_metadata, fixup_gms_url, get_or_else
from datahub.cli.env_utils import get_boolean_env_variable
from datahub.configuration.common import ConfigurationError, OperationalError
from datahub.configuration.common import (
ConfigModel,
ConfigurationError,
OperationalError,
)
from datahub.emitter.generic_emitter import Emitter
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.request_helper import make_curl_command
Expand All @@ -31,10 +47,8 @@

logger = logging.getLogger(__name__)

_DEFAULT_CONNECT_TIMEOUT_SEC = 30 # 30 seconds should be plenty to connect
_DEFAULT_READ_TIMEOUT_SEC = (
30 # Any ingest call taking longer than 30 seconds should be abandoned
)
_DEFAULT_TIMEOUT_SEC = 30 # 30 seconds should be plenty to connect
_TIMEOUT_LOWER_BOUND_SEC = 1 # if below this, we log a warning
_DEFAULT_RETRY_STATUS_CODES = [ # Additional status codes to retry on
429,
500,
Expand Down Expand Up @@ -63,15 +77,76 @@
)


class RequestsSessionConfig(ConfigModel):
timeout: Union[float, Tuple[float, float], None] = _DEFAULT_TIMEOUT_SEC

retry_status_codes: List[int] = _DEFAULT_RETRY_STATUS_CODES
retry_methods: List[str] = _DEFAULT_RETRY_METHODS
retry_max_times: int = _DEFAULT_RETRY_MAX_TIMES

extra_headers: Dict[str, str] = {}

ca_certificate_path: Optional[str] = None
client_certificate_path: Optional[str] = None
disable_ssl_verification: bool = False

def build_session(self) -> requests.Session:
session = requests.Session()

if self.extra_headers:
session.headers.update(self.extra_headers)

if self.client_certificate_path:
session.cert = self.client_certificate_path

if self.ca_certificate_path:
session.verify = self.ca_certificate_path

if self.disable_ssl_verification:
session.verify = False

try:
# Set raise_on_status to False to propagate errors:
# https://stackoverflow.com/questions/70189330/determine-status-code-from-python-retry-exception
# Must call `raise_for_status` after making a request, which we do
retry_strategy = Retry(
total=self.retry_max_times,
status_forcelist=self.retry_status_codes,
backoff_factor=2,
allowed_methods=self.retry_methods,
raise_on_status=False,
)
except TypeError:
# Prior to urllib3 1.26, the Retry class used `method_whitelist` instead of `allowed_methods`.
retry_strategy = Retry(
total=self.retry_max_times,
status_forcelist=self.retry_status_codes,
backoff_factor=2,
method_whitelist=self.retry_methods,
raise_on_status=False,
)

adapter = HTTPAdapter(
pool_connections=100, pool_maxsize=100, max_retries=retry_strategy
)
session.mount("http://", adapter)
session.mount("https://", adapter)

if self.timeout is not None:
# Shim session.request to apply default timeout values.
# Via https://stackoverflow.com/a/59317604.
session.request = functools.partial( # type: ignore
session.request,
timeout=self.timeout,
)

return session


class DataHubRestEmitter(Closeable, Emitter):
_gms_server: str
_token: Optional[str]
_session: requests.Session
_connect_timeout_sec: float = _DEFAULT_CONNECT_TIMEOUT_SEC
_read_timeout_sec: float = _DEFAULT_READ_TIMEOUT_SEC
_retry_status_codes: List[int] = _DEFAULT_RETRY_STATUS_CODES
_retry_methods: List[str] = _DEFAULT_RETRY_METHODS
_retry_max_times: int = _DEFAULT_RETRY_MAX_TIMES

def __init__(
self,
Expand Down Expand Up @@ -102,15 +177,13 @@ def __init__(

self._session = requests.Session()

self._session.headers.update(
{
"X-RestLi-Protocol-Version": "2.0.0",
"X-DataHub-Py-Cli-Version": nice_version_name(),
"Content-Type": "application/json",
}
)
headers = {
"X-RestLi-Protocol-Version": "2.0.0",
"X-DataHub-Py-Cli-Version": nice_version_name(),
"Content-Type": "application/json",
}
if token:
self._session.headers.update({"Authorization": f"Bearer {token}"})
headers["Authorization"] = f"Bearer {token}"
else:
# HACK: When no token is provided but system auth env variables are set, we use them.
# Ideally this should simply get passed in as config, instead of being sneakily injected
Expand All @@ -119,75 +192,43 @@ def __init__(
# rest emitter, and the rest sink uses the rest emitter under the hood.
system_auth = config_utils.get_system_auth()
if system_auth is not None:
self._session.headers.update({"Authorization": system_auth})

if extra_headers:
self._session.headers.update(extra_headers)

if client_certificate_path:
self._session.cert = client_certificate_path

if ca_certificate_path:
self._session.verify = ca_certificate_path

if disable_ssl_verification:
self._session.verify = False

self._connect_timeout_sec = (
connect_timeout_sec or timeout_sec or _DEFAULT_CONNECT_TIMEOUT_SEC
)
self._read_timeout_sec = (
read_timeout_sec or timeout_sec or _DEFAULT_READ_TIMEOUT_SEC
)

if self._connect_timeout_sec < 1 or self._read_timeout_sec < 1:
logger.warning(
f"Setting timeout values lower than 1 second is not recommended. Your configuration is connect_timeout:{self._connect_timeout_sec}s, read_timeout:{self._read_timeout_sec}s"
)

if retry_status_codes is not None: # Only if missing. Empty list is allowed
self._retry_status_codes = retry_status_codes

if retry_methods is not None:
self._retry_methods = retry_methods

if retry_max_times:
self._retry_max_times = retry_max_times
headers["Authorization"] = system_auth

try:
# Set raise_on_status to False to propagate errors:
# https://stackoverflow.com/questions/70189330/determine-status-code-from-python-retry-exception
# Must call `raise_for_status` after making a request, which we do
retry_strategy = Retry(
total=self._retry_max_times,
status_forcelist=self._retry_status_codes,
backoff_factor=2,
allowed_methods=self._retry_methods,
raise_on_status=False,
)
except TypeError:
# Prior to urllib3 1.26, the Retry class used `method_whitelist` instead of `allowed_methods`.
retry_strategy = Retry(
total=self._retry_max_times,
status_forcelist=self._retry_status_codes,
backoff_factor=2,
method_whitelist=self._retry_methods,
raise_on_status=False,
timeout: float | tuple[float, float]
if connect_timeout_sec is not None or read_timeout_sec is not None:
timeout = (
connect_timeout_sec or timeout_sec or _DEFAULT_TIMEOUT_SEC,
read_timeout_sec or timeout_sec or _DEFAULT_TIMEOUT_SEC,
)
if (
timeout[0] < _TIMEOUT_LOWER_BOUND_SEC
or timeout[1] < _TIMEOUT_LOWER_BOUND_SEC
):
logger.warning(
f"Setting timeout values lower than {_TIMEOUT_LOWER_BOUND_SEC} second is not recommended. Your configuration is (connect_timeout, read_timeout) = {timeout} seconds"
)
else:
timeout = get_or_else(timeout_sec, _DEFAULT_TIMEOUT_SEC)
if timeout < _TIMEOUT_LOWER_BOUND_SEC:
logger.warning(
f"Setting timeout values lower than {_TIMEOUT_LOWER_BOUND_SEC} second is not recommended. Your configuration is timeout = {timeout} seconds"
)

adapter = HTTPAdapter(
pool_connections=100, pool_maxsize=100, max_retries=retry_strategy
)
self._session.mount("http://", adapter)
self._session.mount("https://", adapter)

# Shim session.request to apply default timeout values.
# Via https://stackoverflow.com/a/59317604.
self._session.request = functools.partial( # type: ignore
self._session.request,
timeout=(self._connect_timeout_sec, self._read_timeout_sec),
self._session_config = RequestsSessionConfig(
timeout=timeout,
retry_status_codes=get_or_else(
retry_status_codes, _DEFAULT_RETRY_STATUS_CODES
),
retry_methods=get_or_else(retry_methods, _DEFAULT_RETRY_METHODS),
retry_max_times=get_or_else(retry_max_times, _DEFAULT_RETRY_MAX_TIMES),
extra_headers={**headers, **(extra_headers or {})},
ca_certificate_path=ca_certificate_path,
client_certificate_path=client_certificate_path,
disable_ssl_verification=disable_ssl_verification,
)

self._session = self._session_config.build_session()

def test_connection(self) -> None:
url = f"{self._gms_server}/config"
response = self._session.get(url)
Expand Down
25 changes: 14 additions & 11 deletions metadata-ingestion/src/datahub/ingestion/graph/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,21 +179,24 @@ def frontend_base_url(self) -> str:

@classmethod
def from_emitter(cls, emitter: DatahubRestEmitter) -> "DataHubGraph":
session_config = emitter._session_config
if isinstance(session_config.timeout, tuple):
# TODO: This is slightly lossy. Eventually, we want to modify the emitter
# to accept a tuple for timeout_sec, and then we'll be able to remove this.
timeout_sec: Optional[float] = session_config.timeout[0]
else:
timeout_sec = session_config.timeout
return cls(
DatahubClientConfig(
server=emitter._gms_server,
token=emitter._token,
timeout_sec=emitter._read_timeout_sec,
retry_status_codes=emitter._retry_status_codes,
retry_max_times=emitter._retry_max_times,
extra_headers=emitter._session.headers,
disable_ssl_verification=emitter._session.verify is False,
ca_certificate_path=(
emitter._session.verify
if isinstance(emitter._session.verify, str)
else None
),
client_certificate_path=emitter._session.cert,
timeout_sec=timeout_sec,
retry_status_codes=session_config.retry_status_codes,
retry_max_times=session_config.retry_max_times,
extra_headers=session_config.extra_headers,
disable_ssl_verification=session_config.disable_ssl_verification,
ca_certificate_path=session_config.ca_certificate_path,
client_certificate_path=session_config.client_certificate_path,
)
)

Expand Down
2 changes: 1 addition & 1 deletion metadata-ingestion/src/datahub/ingestion/graph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class DatahubClientConfig(ConfigModel):
# by callers / the CLI, but the actual client should not have any magic.
server: str
token: Optional[str] = None
timeout_sec: Optional[int] = None
timeout_sec: Optional[float] = None
retry_status_codes: Optional[List[int]] = None
retry_max_times: Optional[int] = None
extra_headers: Optional[Dict[str, str]] = None
Expand Down
32 changes: 17 additions & 15 deletions metadata-ingestion/tests/unit/sdk/test_rest_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,41 @@
MOCK_GMS_ENDPOINT = "http://fakegmshost:8080"


def test_datahub_rest_emitter_construction():
def test_datahub_rest_emitter_construction() -> None:
emitter = DatahubRestEmitter(MOCK_GMS_ENDPOINT)
assert emitter._connect_timeout_sec == rest_emitter._DEFAULT_CONNECT_TIMEOUT_SEC
assert emitter._read_timeout_sec == rest_emitter._DEFAULT_READ_TIMEOUT_SEC
assert emitter._retry_status_codes == rest_emitter._DEFAULT_RETRY_STATUS_CODES
assert emitter._retry_max_times == rest_emitter._DEFAULT_RETRY_MAX_TIMES
assert emitter._session_config.timeout == rest_emitter._DEFAULT_TIMEOUT_SEC
assert (
emitter._session_config.retry_status_codes
== rest_emitter._DEFAULT_RETRY_STATUS_CODES
)
assert (
emitter._session_config.retry_max_times == rest_emitter._DEFAULT_RETRY_MAX_TIMES
)


def test_datahub_rest_emitter_timeout_construction():
def test_datahub_rest_emitter_timeout_construction() -> None:
emitter = DatahubRestEmitter(
MOCK_GMS_ENDPOINT, connect_timeout_sec=2, read_timeout_sec=4
)
assert emitter._connect_timeout_sec == 2
assert emitter._read_timeout_sec == 4
assert emitter._session_config.timeout == (2, 4)


def test_datahub_rest_emitter_general_timeout_construction():
def test_datahub_rest_emitter_general_timeout_construction() -> None:
emitter = DatahubRestEmitter(MOCK_GMS_ENDPOINT, timeout_sec=2, read_timeout_sec=4)
assert emitter._connect_timeout_sec == 2
assert emitter._read_timeout_sec == 4
assert emitter._session_config.timeout == (2, 4)


def test_datahub_rest_emitter_retry_construction():
def test_datahub_rest_emitter_retry_construction() -> None:
emitter = DatahubRestEmitter(
MOCK_GMS_ENDPOINT,
retry_status_codes=[418],
retry_max_times=42,
)
assert emitter._retry_status_codes == [418]
assert emitter._retry_max_times == 42
assert emitter._session_config.retry_status_codes == [418]
assert emitter._session_config.retry_max_times == 42


def test_datahub_rest_emitter_extra_params():
def test_datahub_rest_emitter_extra_params() -> None:
emitter = DatahubRestEmitter(
MOCK_GMS_ENDPOINT, extra_headers={"key1": "value1", "key2": "value2"}
)
Expand Down

0 comments on commit 2c4b997

Please sign in to comment.