diff --git a/setup.py b/setup.py index 8a3413a..f9d55d8 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ def run_tests(self): "Programming Language :: Python :: 3.8", "Topic :: Software Development :: Libraries :: Python Modules"], install_requires=[ - "cryptojwt>=1.5.0", + "cryptojwt>=1.6.0", "pyOpenSSL", "filelock>=3.0.12", 'pyyaml>=5.1.2' diff --git a/src/oidcmsg/__init__.py b/src/oidcmsg/__init__.py index 9b83b4d..724b7a7 100644 --- a/src/oidcmsg/__init__.py +++ b/src/oidcmsg/__init__.py @@ -1,5 +1,5 @@ __author__ = "Roland Hedberg" -__version__ = "1.5.0" +__version__ = "1.5.1" import os from typing import Dict @@ -34,37 +34,37 @@ def proper_path(path): return path -def add_base_path(conf: Dict[str, str], item_paths: dict, base_path: str): - """ - This is for adding a base path to path specified in a configuration - - :param conf: Configuration - :param item_paths: The relative item path - :param base_path: An absolute path to add to the relative - """ - for section, items in item_paths.items(): - if section == "": - part = conf - else: - part = conf.get(section) - - if part: - if isinstance(items, list): - for attr in items: - _path = part.get(attr) - if _path: - if _path.startswith("/"): - continue - elif _path == "": - part[attr] = "./" + _path - else: - part[attr] = os.path.join(base_path, _path) - elif items is None: - if part.startswith("/"): - continue - elif part == "": - conf[section] = "./" - else: - conf[section] = os.path.join(base_path, part) - else: # Assume items is dictionary like - add_base_path(part, items, base_path) +# def add_base_path(conf: Dict[str, str], item_paths: dict, base_path: str): +# """ +# This is for adding a base path to path specified in a configuration +# +# :param conf: Configuration +# :param item_paths: The relative item path +# :param base_path: An absolute path to add to the relative +# """ +# for section, items in item_paths.items(): +# if section == "": +# part = conf +# else: +# part = conf.get(section) +# +# if part: +# if isinstance(items, list): +# for attr in items: +# _path = part.get(attr) +# if _path: +# if _path.startswith("/"): +# continue +# elif _path == "": +# part[attr] = "./" + _path +# else: +# part[attr] = os.path.join(base_path, _path) +# elif items is None: +# if part.startswith("/"): +# continue +# elif part == "": +# conf[section] = "./" +# else: +# conf[section] = os.path.join(base_path, part) +# else: # Assume items is dictionary like +# add_base_path(part, items, base_path) diff --git a/src/oidcmsg/configure.py b/src/oidcmsg/configure.py index 715df32..abe408e 100644 --- a/src/oidcmsg/configure.py +++ b/src/oidcmsg/configure.py @@ -9,10 +9,10 @@ from oidcmsg.logging import configure_logging from oidcmsg.util import load_yaml_config -DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key', 'server_cert', 'filename', 'template_dir', - 'private_path', 'public_path', 'db_file'] +DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key', 'server_cert', 'filename', + 'private_path', 'public_path', 'db_file', 'jwks_file'] -URIS = ["redirect_uris", 'issuer', 'base_url'] +DEFAULT_DIR_ATTRIBUTE_NAMES = ['template_dir'] def lower_or_upper(config, param, default=None): @@ -22,17 +22,31 @@ def lower_or_upper(config, param, default=None): return res -def add_base_path(conf: dict, base_path: str, file_attributes: List[str]): +def add_path_to_filename(filename, base_path): + if filename == "" or filename.startswith("/"): + return filename + else: + return os.path.join(base_path, filename) + + +def add_path_to_directory_name(directory_name, base_path): + if directory_name.startswith("/"): + return directory_name + elif directory_name == "": + return "./" + directory_name + else: + return os.path.join(base_path, directory_name) + + +def add_base_path(conf: dict, base_path: str, attributes: List[str], attribute_type: str = "file"): for key, val in conf.items(): - if key in file_attributes: - if val.startswith("/"): - continue - elif val == "": - conf[key] = "./" + val + if key in attributes: + if attribute_type == "file": + conf[key] = add_path_to_filename(val, base_path) else: - conf[key] = os.path.join(base_path, val) + conf[key] = add_path_to_directory_name(val, base_path) if isinstance(val, dict): - conf[key] = add_base_path(val, base_path, file_attributes) + conf[key] = add_base_path(val, base_path, attributes, attribute_type) return conf @@ -53,41 +67,71 @@ def set_domain_and_port(conf: dict, uris: List[str], domain: str, port: int): return conf -class Base: +class Base(dict): """ Configuration base class """ + parameter = {} + uris = ["issuer", "base_url"] + def __init__(self, conf: Dict, base_path: str = '', file_attributes: Optional[List[str]] = None, + dir_attributes: Optional[List[str]] = None, + domain: Optional[str] = "", + port: Optional[int] = 0, ): + dict.__init__(self) + self._file_attributes = file_attributes or DEFAULT_FILE_ATTRIBUTE_NAMES + self._dir_attributes = dir_attributes or DEFAULT_DIR_ATTRIBUTE_NAMES - if file_attributes is None: - file_attributes = DEFAULT_FILE_ATTRIBUTE_NAMES - - if base_path and file_attributes: + if base_path: # this adds a base path to all paths in the configuration - add_base_path(conf, base_path, file_attributes) + if self._file_attributes: + add_base_path(conf, base_path, self._file_attributes, "file") + if self._dir_attributes: + add_base_path(conf, base_path, self._dir_attributes, "dir") - def __getitem__(self, item): - if item in self.__dict__: - return self.__dict__[item] + # entity info + self.domain = domain or conf.get("domain", "127.0.0.1") + self.port = port or conf.get("port", 80) + + self.conf = set_domain_and_port(conf, self.uris, self.domain, self.port) + + def __getattr__(self, item, default=None): + if item in self: + return self[item] else: - raise KeyError + return default - def get(self, item, default=None): - return getattr(self, item, default) + def __setattr__(self, key, value): + if key in self: + raise KeyError('{} has already been set'.format(key)) + super(Base, self).__setitem__(key, value) + + def __setitem__(self, key, value): + if key in self: + raise KeyError('{} has already been set'.format(key)) + super(Base, self).__setitem__(key, value) - def __contains__(self, item): - return item in self.__dict__ + def get(self, item, default=None): + return self.__getattr__(item, default) def items(self): - for key in self.__dict__: + for key in self.keys(): if key.startswith('__') and key.endswith('__'): continue yield key, getattr(self, key) - def extend(self, entity_conf, conf, base_path, file_attributes, domain, port): + def extend(self, + conf: Dict, + base_path: str, + domain: str, + port: int, + entity_conf: Optional[List[dict]] = None, + file_attributes: Optional[List[str]] = None, + dir_attributes: Optional[List[str]] = None, + ): for econf in entity_conf: _path = econf.get("path") _cnf = conf @@ -98,11 +142,49 @@ def extend(self, entity_conf, conf, base_path, file_attributes, domain, port): _cls = econf["class"] setattr(self, _attr, _cls(_cnf, base_path=base_path, file_attributes=file_attributes, - domain=domain, port=port)) + domain=domain, port=port, dir_attributes=dir_attributes)) + + def complete_paths(self, conf: Dict, keys: List[str], default_config: Dict, base_path: str): + for key in keys: + _val = conf.get(key) + if _val is None and key in default_config: + _val = default_config[key] + if key in self._file_attributes: + _val = add_path_to_filename(_val, base_path) + elif key in self._dir_attributes: + _val = add_path_to_directory_name(_val, base_path) + if not _val: + continue + + setattr(self, key, _val) + + def format(self, conf, base_path: str, domain: str, port: int, + file_attributes: Optional[List[str]] = None, + dir_attributes: Optional[List[str]] = None) -> None: + """ + Formats parts of the configuration. That includes replacing the strings {domain} and {port} + with the used domain and port and making references to files and directories absolute + rather then relative. The formatting is done in place. + + :param dir_attributes: + :param conf: The configuration part + :param base_path: The base path used to make file/directory refrences absolute + :param file_attributes: Attribute names that refer to files or directories. + :param domain: The domain name + :param port: The port used + """ + if isinstance(conf, dict): + if file_attributes: + add_base_path(conf, base_path, file_attributes, attribute_type="file") + if dir_attributes: + add_base_path(conf, base_path, dir_attributes, attribute_type="dir") + if isinstance(conf, dict): + set_domain_and_port(conf, self.uris, domain=domain, port=port) class Configuration(Base): - """Server Configuration""" + """Entity Configuration Base""" + uris = ["redirect_uris", 'issuer', 'base_url', 'server_name'] def __init__(self, conf: Dict, @@ -111,27 +193,24 @@ def __init__(self, file_attributes: Optional[List[str]] = None, domain: Optional[str] = "", port: Optional[int] = 0, + dir_attributes: Optional[List[str]] = None, ): - Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes) + Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes, + dir_attributes=dir_attributes, domain=domain, port=port) - log_conf = conf.get('logging') + log_conf = self.conf.get('logging') if log_conf: self.logger = configure_logging(config=log_conf).getChild(__name__) else: self.logger = logging.getLogger('oidcrp') - self.web_conf = lower_or_upper(conf, "webserver") - - # entity info - if not domain: - domain = conf.get("domain", "127.0.0.1") - - if not port: - port = conf.get("port", 80) + self.web_conf = lower_or_upper(self.conf, "webserver") if entity_conf: - self.extend(entity_conf=entity_conf, conf=conf, base_path=base_path, - file_attributes=file_attributes, domain=domain, port=port) + self.extend(conf=self.conf, base_path=base_path, + domain=self.domain, port=self.port, entity_conf=entity_conf, + file_attributes=self._file_attributes, + dir_attributes=self._dir_attributes) def create_from_config_file(cls, @@ -140,7 +219,9 @@ def create_from_config_file(cls, entity_conf: Optional[List[dict]] = None, file_attributes: Optional[List[str]] = None, domain: Optional[str] = "", - port: Optional[int] = 0): + port: Optional[int] = 0, + dir_attributes: Optional[List[str]] = None + ): if filename.endswith(".yaml"): """Load configuration as YAML""" _cnf = load_yaml_config(filename) @@ -158,4 +239,4 @@ def create_from_config_file(cls, return cls(_cnf, entity_conf=entity_conf, base_path=base_path, file_attributes=file_attributes, - domain=domain, port=port) + domain=domain, port=port, dir_attributes=dir_attributes) diff --git a/tests/server_conf.json b/tests/server_conf.json index d0e7858..8343229 100644 --- a/tests/server_conf.json +++ b/tests/server_conf.json @@ -38,6 +38,7 @@ "httpc_params": { "verify": false }, + "hash_seed": "MustangSally", "keys": { "private_path": "private/jwks.json", "key_defs": [ diff --git a/tests/test_03_time_util.py b/tests/test_03_time_util.py index a053676..f5041b4 100644 --- a/tests/test_03_time_util.py +++ b/tests/test_03_time_util.py @@ -258,9 +258,3 @@ def test_later_than_str(): b = in_a_while(seconds=20) assert later_than(b, a) assert later_than(a, b) is False - - -def test_utc_time(): - utc_now = utc_time_sans_frac() - expected_utc_now = int((datetime.utcnow() - datetime(1970, 1, 1)).total_seconds()) - assert utc_now == expected_utc_now diff --git a/tests/test_20_config.py b/tests/test_20_config.py index e5ac3a8..08a4454 100644 --- a/tests/test_20_config.py +++ b/tests/test_20_config.py @@ -10,7 +10,6 @@ from oidcmsg.configure import Configuration from oidcmsg.configure import create_from_config_file from oidcmsg.configure import lower_or_upper -from oidcmsg.configure import set_domain_and_port from oidcmsg.util import rndstr _dirname = os.path.dirname(os.path.abspath(__file__)) @@ -26,23 +25,14 @@ def __init__(self, domain: Optional[str] = "", port: Optional[int] = 0, file_attributes: Optional[List[str]] = None, - uris: Optional[List[str]] = None + uris: Optional[List[str]] = None, + dir_attributes: Optional[List[str]] = None ): - - Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes) + Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes, + dir_attributes=dir_attributes) self.keys = lower_or_upper(conf, 'keys') - if not domain: - domain = conf.get("domain", "127.0.0.1") - - if not port: - port = conf.get("port", 80) - - if uris is None: - uris = URIS - conf = set_domain_and_port(conf, uris, domain, port) - self.hash_seed = lower_or_upper(conf, 'hash_seed', rndstr(32)) self.base_url = conf.get("base_url") self.httpc_params = conf.get("httpc_params", {"verify": False}) @@ -74,5 +64,6 @@ def test_entity_config(filename): assert configuration.httpc_params == {"verify": False} assert configuration['keys'] ni = dict(configuration.items()) - assert len(ni) == 4 - assert set(ni.keys()) == {'keys', 'base_url', 'httpc_params', 'hash_seed'} + assert len(ni) == 9 + assert set(ni.keys()) == {'base_url', '_dir_attributes', '_file_attributes', 'hash_seed', + 'httpc_params', 'keys', 'conf', 'port', 'domain'}