Skip to content
This repository was archived by the owner on Jun 1, 2023. It is now read-only.

Configuration refactoring #51

Merged
merged 6 commits into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def run_tests(self):
"Programming Language :: Python :: 3.8",
"Topic :: Software Development :: Libraries :: Python Modules"],
install_requires=[
"cryptojwt>=1.5.0",
"cryptojwt>=1.6.0",
"pyOpenSSL",
"filelock>=3.0.12",
'pyyaml>=5.1.2'
Expand Down
70 changes: 35 additions & 35 deletions src/oidcmsg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__author__ = "Roland Hedberg"
__version__ = "1.5.0"
__version__ = "1.5.1"

import os
from typing import Dict
Expand Down Expand Up @@ -34,37 +34,37 @@ def proper_path(path):
return path


def add_base_path(conf: Dict[str, str], item_paths: dict, base_path: str):
"""
This is for adding a base path to path specified in a configuration

:param conf: Configuration
:param item_paths: The relative item path
:param base_path: An absolute path to add to the relative
"""
for section, items in item_paths.items():
if section == "":
part = conf
else:
part = conf.get(section)

if part:
if isinstance(items, list):
for attr in items:
_path = part.get(attr)
if _path:
if _path.startswith("/"):
continue
elif _path == "":
part[attr] = "./" + _path
else:
part[attr] = os.path.join(base_path, _path)
elif items is None:
if part.startswith("/"):
continue
elif part == "":
conf[section] = "./"
else:
conf[section] = os.path.join(base_path, part)
else: # Assume items is dictionary like
add_base_path(part, items, base_path)
# def add_base_path(conf: Dict[str, str], item_paths: dict, base_path: str):
# """
# This is for adding a base path to path specified in a configuration
#
# :param conf: Configuration
# :param item_paths: The relative item path
# :param base_path: An absolute path to add to the relative
# """
# for section, items in item_paths.items():
# if section == "":
# part = conf
# else:
# part = conf.get(section)
#
# if part:
# if isinstance(items, list):
# for attr in items:
# _path = part.get(attr)
# if _path:
# if _path.startswith("/"):
# continue
# elif _path == "":
# part[attr] = "./" + _path
# else:
# part[attr] = os.path.join(base_path, _path)
# elif items is None:
# if part.startswith("/"):
# continue
# elif part == "":
# conf[section] = "./"
# else:
# conf[section] = os.path.join(base_path, part)
# else: # Assume items is dictionary like
# add_base_path(part, items, base_path)
167 changes: 124 additions & 43 deletions src/oidcmsg/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from oidcmsg.logging import configure_logging
from oidcmsg.util import load_yaml_config

DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key', 'server_cert', 'filename', 'template_dir',
'private_path', 'public_path', 'db_file']
DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key', 'server_cert', 'filename',
'private_path', 'public_path', 'db_file', 'jwks_file']

URIS = ["redirect_uris", 'issuer', 'base_url']
DEFAULT_DIR_ATTRIBUTE_NAMES = ['template_dir']


def lower_or_upper(config, param, default=None):
Expand All @@ -22,17 +22,31 @@ def lower_or_upper(config, param, default=None):
return res


def add_base_path(conf: dict, base_path: str, file_attributes: List[str]):
def add_path_to_filename(filename, base_path):
if filename == "" or filename.startswith("/"):
return filename
else:
return os.path.join(base_path, filename)


def add_path_to_directory_name(directory_name, base_path):
if directory_name.startswith("/"):
return directory_name
elif directory_name == "":
return "./" + directory_name
else:
return os.path.join(base_path, directory_name)


def add_base_path(conf: dict, base_path: str, attributes: List[str], attribute_type: str = "file"):
for key, val in conf.items():
if key in file_attributes:
if val.startswith("/"):
continue
elif val == "":
conf[key] = "./" + val
if key in attributes:
if attribute_type == "file":
conf[key] = add_path_to_filename(val, base_path)
else:
conf[key] = os.path.join(base_path, val)
conf[key] = add_path_to_directory_name(val, base_path)
if isinstance(val, dict):
conf[key] = add_base_path(val, base_path, file_attributes)
conf[key] = add_base_path(val, base_path, attributes, attribute_type)

return conf

Expand All @@ -53,41 +67,71 @@ def set_domain_and_port(conf: dict, uris: List[str], domain: str, port: int):
return conf


class Base:
class Base(dict):
""" Configuration base class """

parameter = {}
uris = ["issuer", "base_url"]

def __init__(self,
conf: Dict,
base_path: str = '',
file_attributes: Optional[List[str]] = None,
dir_attributes: Optional[List[str]] = None,
domain: Optional[str] = "",
port: Optional[int] = 0,
):
dict.__init__(self)
self._file_attributes = file_attributes or DEFAULT_FILE_ATTRIBUTE_NAMES
self._dir_attributes = dir_attributes or DEFAULT_DIR_ATTRIBUTE_NAMES

if file_attributes is None:
file_attributes = DEFAULT_FILE_ATTRIBUTE_NAMES

if base_path and file_attributes:
if base_path:
# this adds a base path to all paths in the configuration
add_base_path(conf, base_path, file_attributes)
if self._file_attributes:
add_base_path(conf, base_path, self._file_attributes, "file")
if self._dir_attributes:
add_base_path(conf, base_path, self._dir_attributes, "dir")

def __getitem__(self, item):
if item in self.__dict__:
return self.__dict__[item]
# entity info
self.domain = domain or conf.get("domain", "127.0.0.1")
self.port = port or conf.get("port", 80)

self.conf = set_domain_and_port(conf, self.uris, self.domain, self.port)

def __getattr__(self, item, default=None):
if item in self:
return self[item]
else:
raise KeyError
return default

def get(self, item, default=None):
return getattr(self, item, default)
def __setattr__(self, key, value):
if key in self:
raise KeyError('{} has already been set'.format(key))
super(Base, self).__setitem__(key, value)

def __setitem__(self, key, value):
if key in self:
raise KeyError('{} has already been set'.format(key))
super(Base, self).__setitem__(key, value)

def __contains__(self, item):
return item in self.__dict__
def get(self, item, default=None):
return self.__getattr__(item, default)

def items(self):
for key in self.__dict__:
for key in self.keys():
if key.startswith('__') and key.endswith('__'):
continue
yield key, getattr(self, key)

def extend(self, entity_conf, conf, base_path, file_attributes, domain, port):
def extend(self,
conf: Dict,
base_path: str,
domain: str,
port: int,
entity_conf: Optional[List[dict]] = None,
file_attributes: Optional[List[str]] = None,
dir_attributes: Optional[List[str]] = None,
):
for econf in entity_conf:
_path = econf.get("path")
_cnf = conf
Expand All @@ -98,11 +142,49 @@ def extend(self, entity_conf, conf, base_path, file_attributes, domain, port):
_cls = econf["class"]
setattr(self, _attr,
_cls(_cnf, base_path=base_path, file_attributes=file_attributes,
domain=domain, port=port))
domain=domain, port=port, dir_attributes=dir_attributes))

def complete_paths(self, conf: Dict, keys: List[str], default_config: Dict, base_path: str):
for key in keys:
_val = conf.get(key)
if _val is None and key in default_config:
_val = default_config[key]
if key in self._file_attributes:
_val = add_path_to_filename(_val, base_path)
elif key in self._dir_attributes:
_val = add_path_to_directory_name(_val, base_path)
if not _val:
continue

setattr(self, key, _val)

def format(self, conf, base_path: str, domain: str, port: int,
file_attributes: Optional[List[str]] = None,
dir_attributes: Optional[List[str]] = None) -> None:
"""
Formats parts of the configuration. That includes replacing the strings {domain} and {port}
with the used domain and port and making references to files and directories absolute
rather then relative. The formatting is done in place.

:param dir_attributes:
:param conf: The configuration part
:param base_path: The base path used to make file/directory refrences absolute
:param file_attributes: Attribute names that refer to files or directories.
:param domain: The domain name
:param port: The port used
"""
if isinstance(conf, dict):
if file_attributes:
add_base_path(conf, base_path, file_attributes, attribute_type="file")
if dir_attributes:
add_base_path(conf, base_path, dir_attributes, attribute_type="dir")
if isinstance(conf, dict):
set_domain_and_port(conf, self.uris, domain=domain, port=port)


class Configuration(Base):
"""Server Configuration"""
"""Entity Configuration Base"""
uris = ["redirect_uris", 'issuer', 'base_url', 'server_name']

def __init__(self,
conf: Dict,
Expand All @@ -111,27 +193,24 @@ def __init__(self,
file_attributes: Optional[List[str]] = None,
domain: Optional[str] = "",
port: Optional[int] = 0,
dir_attributes: Optional[List[str]] = None,
):
Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes)
Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes,
dir_attributes=dir_attributes, domain=domain, port=port)

log_conf = conf.get('logging')
log_conf = self.conf.get('logging')
if log_conf:
self.logger = configure_logging(config=log_conf).getChild(__name__)
else:
self.logger = logging.getLogger('oidcrp')

self.web_conf = lower_or_upper(conf, "webserver")

# entity info
if not domain:
domain = conf.get("domain", "127.0.0.1")

if not port:
port = conf.get("port", 80)
self.web_conf = lower_or_upper(self.conf, "webserver")

if entity_conf:
self.extend(entity_conf=entity_conf, conf=conf, base_path=base_path,
file_attributes=file_attributes, domain=domain, port=port)
self.extend(conf=self.conf, base_path=base_path,
domain=self.domain, port=self.port, entity_conf=entity_conf,
file_attributes=self._file_attributes,
dir_attributes=self._dir_attributes)


def create_from_config_file(cls,
Expand All @@ -140,7 +219,9 @@ def create_from_config_file(cls,
entity_conf: Optional[List[dict]] = None,
file_attributes: Optional[List[str]] = None,
domain: Optional[str] = "",
port: Optional[int] = 0):
port: Optional[int] = 0,
dir_attributes: Optional[List[str]] = None
):
if filename.endswith(".yaml"):
"""Load configuration as YAML"""
_cnf = load_yaml_config(filename)
Expand All @@ -158,4 +239,4 @@ def create_from_config_file(cls,
return cls(_cnf,
entity_conf=entity_conf,
base_path=base_path, file_attributes=file_attributes,
domain=domain, port=port)
domain=domain, port=port, dir_attributes=dir_attributes)
1 change: 1 addition & 0 deletions tests/server_conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"httpc_params": {
"verify": false
},
"hash_seed": "MustangSally",
"keys": {
"private_path": "private/jwks.json",
"key_defs": [
Expand Down
6 changes: 0 additions & 6 deletions tests/test_03_time_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,3 @@ def test_later_than_str():
b = in_a_while(seconds=20)
assert later_than(b, a)
assert later_than(a, b) is False


def test_utc_time():
utc_now = utc_time_sans_frac()
expected_utc_now = int((datetime.utcnow() - datetime(1970, 1, 1)).total_seconds())
assert utc_now == expected_utc_now
23 changes: 7 additions & 16 deletions tests/test_20_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from oidcmsg.configure import Configuration
from oidcmsg.configure import create_from_config_file
from oidcmsg.configure import lower_or_upper
from oidcmsg.configure import set_domain_and_port
from oidcmsg.util import rndstr

_dirname = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -26,23 +25,14 @@ def __init__(self,
domain: Optional[str] = "",
port: Optional[int] = 0,
file_attributes: Optional[List[str]] = None,
uris: Optional[List[str]] = None
uris: Optional[List[str]] = None,
dir_attributes: Optional[List[str]] = None
):

Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes)
Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes,
dir_attributes=dir_attributes)

self.keys = lower_or_upper(conf, 'keys')

if not domain:
domain = conf.get("domain", "127.0.0.1")

if not port:
port = conf.get("port", 80)

if uris is None:
uris = URIS
conf = set_domain_and_port(conf, uris, domain, port)

self.hash_seed = lower_or_upper(conf, 'hash_seed', rndstr(32))
self.base_url = conf.get("base_url")
self.httpc_params = conf.get("httpc_params", {"verify": False})
Expand Down Expand Up @@ -74,5 +64,6 @@ def test_entity_config(filename):
assert configuration.httpc_params == {"verify": False}
assert configuration['keys']
ni = dict(configuration.items())
assert len(ni) == 4
assert set(ni.keys()) == {'keys', 'base_url', 'httpc_params', 'hash_seed'}
assert len(ni) == 9
assert set(ni.keys()) == {'base_url', '_dir_attributes', '_file_attributes', 'hash_seed',
'httpc_params', 'keys', 'conf', 'port', 'domain'}