diff --git a/airflow/providers/openlineage/conf.py b/airflow/providers/openlineage/conf.py index f79511a22d052..a9601a416bebf 100644 --- a/airflow/providers/openlineage/conf.py +++ b/airflow/providers/openlineage/conf.py @@ -14,6 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +This module provides functions for safely retrieving and handling OpenLineage configurations. + +To prevent errors caused by invalid user-provided configuration values, we use ``conf.get()`` +to fetch values as strings and perform safe conversions using custom functions. + +Any invalid configuration values should be treated as incorrect and replaced with default values. +For example, if the default for boolean ``custom_ol_var`` is False, any non-true value provided: +``"asdf"``, ``12345``, ``{"key": 1}`` or empty string, will result in False being used. + +By using default values for invalid configuration values, we ensure that the configurations are handled +safely, preventing potential runtime errors due to conversion issues. +""" from __future__ import annotations @@ -30,6 +43,13 @@ def _is_true(arg: Any) -> bool: return str(arg).lower().strip() in ("true", "1", "t") +def _safe_int_convert(arg: Any, default: int) -> int: + try: + return int(arg) + except (ValueError, TypeError): + return default + + @cache def config_path(check_legacy_env_var: bool = True) -> str: """[openlineage] config_path.""" @@ -108,5 +128,5 @@ def is_disabled() -> bool: @cache def dag_state_change_process_pool_size() -> int: """[openlineage] dag_state_change_process_pool_size.""" - option = conf.getint(_CONFIG_SECTION, "dag_state_change_process_pool_size", fallback=1) - return option + option = conf.get(_CONFIG_SECTION, "dag_state_change_process_pool_size", fallback="") + return _safe_int_convert(str(option).strip(), default=1) diff --git a/tests/providers/openlineage/test_conf.py b/tests/providers/openlineage/test_conf.py index 6271481689212..f52d8453acc2c 100644 --- a/tests/providers/openlineage/test_conf.py +++ b/tests/providers/openlineage/test_conf.py @@ -23,8 +23,10 @@ from airflow.providers.openlineage.conf import ( _is_true, + _safe_int_convert, config_path, custom_extractors, + dag_state_change_process_pool_size, disabled_operators, is_disabled, is_source_enabled, @@ -49,6 +51,7 @@ _CONFIG_OPTION_DISABLED = "disabled" _VAR_URL = "OPENLINEAGE_URL" _CONFIG_OPTION_SELECTIVE_ENABLE = "selective_enable" +_CONFIG_OPTION_DAG_STATE_CHANGE_PROCESS_POOL_SIZE = "dag_state_change_process_pool_size" _BOOL_PARAMS = ( ("1", True), @@ -76,6 +79,7 @@ def clear_cache(): transport.cache_clear() is_disabled.cache_clear() selective_enable.cache_clear() + dag_state_change_process_pool_size.cache_clear() try: yield finally: @@ -87,6 +91,7 @@ def clear_cache(): transport.cache_clear() is_disabled.cache_clear() selective_enable.cache_clear() + dag_state_change_process_pool_size.cache_clear() @pytest.mark.parametrize( @@ -103,6 +108,35 @@ def test_is_true(var_string, expected): assert _is_true(var_string) is expected +@pytest.mark.parametrize( + "input_value, expected", + [ + ("123", 123), + (456, 456), + ("789", 789), + (0, 0), + ("0", 0), + ], +) +def test_safe_int_convert(input_value, expected): + assert _safe_int_convert(input_value, default=1) == expected + + +@pytest.mark.parametrize( + "input_value, default", + [ + ("abc", 1), + ("", 2), + (None, 3), + ("123abc", 4), + ([], 5), + ("1.2", 6), + ], +) +def test_safe_int_convert_erroneous_values(input_value, default): + assert _safe_int_convert(input_value, default) == default + + @env_vars({_VAR_CONFIG_PATH: "env_var_path"}) @conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): None}) def test_config_path_legacy_env_var_is_used_when_no_conf_option_set(): @@ -456,3 +490,25 @@ def test_is_disabled_empty_conf_option(): ) def test_is_disabled_do_not_fail_if_conf_option_missing(): assert is_disabled() is True + + +@pytest.mark.parametrize( + ("var_string", "expected"), + ( + ("1", 1), + ("2 ", 2), + (" 3", 3), + ("4.56", 1), # default + ("asdf", 1), # default + ("true", 1), # default + ("false", 1), # default + ("None", 1), # default + ("", 1), # default + (" ", 1), # default + (None, 1), # default + ), +) +def test_dag_state_change_process_pool_size(var_string, expected): + with conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DAG_STATE_CHANGE_PROCESS_POOL_SIZE): var_string}): + result = dag_state_change_process_pool_size() + assert result == expected