Skip to content

Commit

Permalink
Merge pull request #6 from RalfG/max-recursion-depth
Browse files Browse the repository at this point in the history
Add max_recursion_depth option
  • Loading branch information
RalfG authored Jul 3, 2023
2 parents 7108e2a + 4b82c94 commit 0f2505b
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 45 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to
[Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.4.0] - 03/07/2023

### Added

- Added `max_recursion_depth` argument to `CascadeConfig` to limit the depth of
hierarchically updating nested dictionaries. When the maximum nesting depth is
exceeded, the new dictionary will be used as-is, overwriting any previous
values under that dictionary tree.

## [0.3.1] - 03/07/2023

### Fixed
Expand Down
51 changes: 31 additions & 20 deletions cascade_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Cascading configuration from the CLI and config files."""

__version__ = "0.3.1"
__version__ = "0.4.0"

import json
import os
Expand All @@ -14,7 +14,12 @@
class CascadeConfig:
"""Cascading configuration."""

def __init__(self, validation_schema=None, none_overrides_value=False):
def __init__(
self,
validation_schema=None,
none_overrides_value=False,
max_recursion_depth=None,
):
"""
Cascading configuration.
Expand All @@ -25,6 +30,10 @@ def __init__(self, validation_schema=None, none_overrides_value=False):
none_overrides_value: bool
If True, a None value overrides a not-None value from the previous configuration.
If False, None values will never override not-None values.
max_recursion_depth: int, optional
Maximum depth of nested dictionaries to recurse into. When the maximum recursion depth
is reached, the nested dictionary will be replaced by the newer nested dictionary. If
None, recurse into all nested dictionaries.
Examples
--------
Expand All @@ -36,6 +45,7 @@ def __init__(self, validation_schema=None, none_overrides_value=False):
"""
self.validation_schema = validation_schema
self.none_overrides_value = none_overrides_value
self.max_recursion_depth = max_recursion_depth
self.sources = []

@property
Expand All @@ -51,21 +61,28 @@ def validation_schema(self, value):
else:
self._validation_schema = None

def _update_dict_recursively(self, original: Dict, updater: Dict) -> Dict:
def _update_dict_recursively(self, original: Dict, updater: Dict, depth: int) -> Dict:
"""Update dictionary recursively."""
depth = depth + 1
for k, v in updater.items():
if isinstance(v, dict):
if not v: # v is not None, v is empty dictionary
if not v:
# v is an empty dictionary
original[k] = dict()
elif self.max_recursion_depth and depth > self.max_recursion_depth:
# v is a populated dictionary, exceeds max depth
original[k] = v
else:
original[k] = self._update_dict_recursively(original.get(k, {}), v)
# v is a populated dictionary, can be further recursed
original[k] = self._update_dict_recursively(original.get(k, {}), v, depth)
elif isinstance(v, bool):
original[k] = v # v is True or False
elif v or k not in original: # v is not None, or key does not exist yet
# v is True or False
original[k] = v
elif (
self.none_overrides_value
): # v is None, but can override previous value
elif v or k not in original:
# v is thruthy (therefore not None), or key does not exist yet
original[k] = v
elif self.none_overrides_value:
# v is None, but can override previous value
original[k] = v
return original

Expand Down Expand Up @@ -114,7 +131,7 @@ def parse(self) -> Dict:
"""Parse all sources, cascade, validate, and return cascaded configuration."""
config = dict()
for source in self.sources:
config = self._update_dict_recursively(config, source.load())
config = self._update_dict_recursively(config, source.load(), depth=0)

if self.validation_schema:
jsonschema.validate(config, self.validation_schema.load())
Expand Down Expand Up @@ -196,9 +213,7 @@ class JSONConfigSource(_ConfigSource):

def _read(self) -> Dict:
if not isinstance(self.source, (str, os.PathLike)):
raise TypeError(
"JSONConfigSource `source` must be a string or path-like object"
)
raise TypeError("JSONConfigSource `source` must be a string or path-like object")
with open(self.source, "rt") as json_file:
config = json.load(json_file)
return config
Expand All @@ -221,9 +236,7 @@ class NamespaceConfigSource(_ConfigSource):

def _read(self) -> Dict:
if not isinstance(self.source, Namespace):
raise TypeError(
"NamespaceConfigSource `source` must be an argparse.Namespace object"
)
raise TypeError("NamespaceConfigSource `source` must be an argparse.Namespace object")
config = vars(self.source)
return config

Expand Down Expand Up @@ -256,7 +269,5 @@ def load(self) -> Dict:
elif isinstance(self.source, Dict):
schema = self.source
else:
raise TypeError(
"ValidationSchema `source` must be of type string, path-like, or dict"
)
raise TypeError("ValidationSchema `source` must be of type string, path-like, or dict")
return schema
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,12 @@ docs = [

[tool.flit.module]
name = "cascade_config"

[tool.black]
line-length = 99
target-version = ['py38']

[tool.ruff]
line-length = 99
target-version = 'py38'

59 changes: 34 additions & 25 deletions tests/test_cascade_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@

import argparse
import json
from os import stat
import tempfile
from typing import Type

import pytest
import jsonschema
import pytest

import cascade_config
from cascade_config import ValidationSchema

TEST_SCHEMA = {
"$schema": "http://json-schema.org/draft-07/schema#",
Expand All @@ -26,29 +23,29 @@
},
"log_level": {
"type": "string",
"enum": ["debug", "info", "warning", "error", "critical"]
}
}
"enum": ["debug", "info", "warning", "error", "critical"],
},
},
}
}
},
}

TEST_SAMPLE = {
"config_example": {"num_cpu": 1, "log_level": "info"}
}
TEST_SAMPLE_2 = {
"config_example": {"log_level": "debug"}, "test": True
}
TEST_SAMPLE_CASC = {
"config_example": {"num_cpu": 1, "log_level": "debug"}, "test": True
TEST_SAMPLE = {"config_example": {"num_cpu": 1, "log_level": "info"}}
TEST_SAMPLE_2 = {"config_example": {"log_level": "debug"}, "test": True}
TEST_SAMPLE_CASC = {"config_example": {"num_cpu": 1, "log_level": "debug"}, "test": True}
TEST_SAMPLE_INVALID = {"config_example": {"num_cpu": "not_a_number", "log_level": "info"}}
TEST_SAMPLE_NESTED_A = {"config_example": {"num_cpu": 1, "depth_2": {"depth_3a": True}}}
TEST_SAMPLE_NESTED_B = {"config_example": {"num_cpu": 1, "depth_2": {"depth_3b": False}}}
TEST_SAMPLE_NESTED_RESULT_NOMAX = {
"config_example": {"num_cpu": 1, "depth_2": {"depth_3a": True, "depth_3b": False}}
}
TEST_SAMPLE_NESTED_RESULT_MAX1 = {"config_example": {"num_cpu": 1, "depth_2": {"depth_3b": False}}}

TEST_SAMPLE_INVALID = {
"config_example": {"num_cpu": "not_a_number", "log_level": "info"}
}

def get_sample_namespace(test_sample):
flatten = lambda l: [item for sublist in l for item in sublist]
def flatten(lst):
return [item for sublist in lst for item in sublist]

test_args = flatten([[f"--{i[0]}", f"{i[1]}"] for i in test_sample.items()])
parser = argparse.ArgumentParser()
parser.add_argument("--num_cpu", type=int)
Expand All @@ -61,15 +58,17 @@ class TestCascadeConfig:

@staticmethod
def get_json_file(test_sample):
with tempfile.NamedTemporaryFile(mode='wt', delete=False) as json_file:
with tempfile.NamedTemporaryFile(mode="wt", delete=False) as json_file:
json.dump(test_sample, json_file)
json_file.seek(0)
json_file_name = json_file.name
return json_file_name

@staticmethod
def get_sample_namespace(test_sample):
flatten = lambda l: [item for sublist in l for item in sublist]
def flatten(lst):
return [item for sublist in lst for item in sublist]

test_args = flatten([[f"--{i[0]}", f"{i[1]}"] for i in test_sample.items()])
parser = argparse.ArgumentParser()
parser.add_argument("--num_cpu", type=int)
Expand Down Expand Up @@ -105,9 +104,7 @@ def test_single_config_namespace(self):
subkey = "config_example"
cc = cascade_config.CascadeConfig()
cc.add_namespace(
get_sample_namespace(TEST_SAMPLE[subkey]),
subkey=subkey,
validation_schema=TEST_SCHEMA
get_sample_namespace(TEST_SAMPLE[subkey]), subkey=subkey, validation_schema=TEST_SCHEMA
)
assert cc.parse() == TEST_SAMPLE

Expand Down Expand Up @@ -143,6 +140,18 @@ def test_multiple_configs(self):
cc.add_json(self.get_json_file(TEST_SAMPLE_2))
assert cc.parse() == TEST_SAMPLE_CASC

def test_max_recursion(self):
"""Test max_recursion_depth argument."""
cc = cascade_config.CascadeConfig(max_recursion_depth=None)
cc.add_dict(TEST_SAMPLE_NESTED_A)
cc.add_dict(TEST_SAMPLE_NESTED_B)
assert cc.parse() == TEST_SAMPLE_NESTED_RESULT_NOMAX

cc = cascade_config.CascadeConfig(max_recursion_depth=1)
cc.add_dict(TEST_SAMPLE_NESTED_A)
cc.add_dict(TEST_SAMPLE_NESTED_B)
assert cc.parse() == TEST_SAMPLE_NESTED_RESULT_MAX1

def test_validation_schema_from_object(self):
with pytest.raises(TypeError):
cascade_config.ValidationSchema.from_object(42)
Expand Down

0 comments on commit 0f2505b

Please sign in to comment.