Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

Commit

Permalink
replace the CLI arguments with a config file
Browse files Browse the repository at this point in the history
Rationale
=========

The CLI is getting complex so it is worth loading the configuration
from a file instead.

Implementation details
======================

TOML
----

We decided to use TOML for the following reasons:

  - it is human friendly, ie easy to read and write
  - our configuration has a pretty flat structure which makes TOML
    quite adapted
  - it is well specified and has lots of implementation
  - it is well known

The other options we considered:

  - INI: it is quite frequent in the Python ecosystem to use INI for
    config files, and the standard library even provides support for
    this. However, INI is not as powerful as TOML and does not have a
    specification
  - JSON: it is very popular but is not human friendly. For instance,
    it does not support comments, is very verbose, and breaks
    easily (if a trailing comma is forgotten at the end of a list for
    instance)
  - YAML: another popular choice, but is in my opinion more complex
    than TOML.

Validation
----------

We use the third-party `schema` library to validate the
configuration. It provides a convenient way to:

- declare a schema to validate our config
- leverage third-party libraries to validate some inputs (we use the
  `validators` library to validate IP addresses and URL for instance)
- define our own validators
- transform data after it has been validated: this can be useful to
  turn a relative path into an absolute one for example
- provide user friendly error message when the configuration is
  invalid

The `Config` class
------------------

By default, the `schema` library returns a dictionary containing a
valid configuration, but that is not convenient to manipulate in
Python. Therefore, we dynamically create a `Config` class from the
configuration schema, and instantiate a `Config` object from the data
returned by the `schema` validator.

Package re-organization
-----------------------

We moved the command line and config file logic into its own `config`
sub-package, and moved the former `xain_fl.cli.main` entrypoint into
the `xain_fl.__main__` module.
  • Loading branch information
little-dude committed Jan 20, 2020
1 parent ac8da29 commit 202bd14
Show file tree
Hide file tree
Showing 10 changed files with 630 additions and 252 deletions.
40 changes: 40 additions & 0 deletions config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[server]

# Address to listen on for incoming gRPC connections
host = "localhost"
# Port to listen on for incoming gRPC connections
port = 50051


[ai]

# Path to a file containing a numpy ndarray to use a initial model weights.
initial_weights = "./test_array.npy"

# Number of global rounds the model is going to be trained for. This
# must be a positive integer.
rounds = 1

# Number of local epochs per round
epochs = 1

# Minimum number of participants to be selected for a round.
min_participants = 1

# Fraction of total clients that participate in a training round. This
# must be a float between 0 and 1.
fraction_participants = 1.0

[storage]

# URL to the storage service to use
endpoint = "http://localhost:9000"

# Name of the bucket for storing the aggregated models
bucket = "aggregated_weights"

# AWS secret access to use to authenticate to the storage service
secret_access_key = "my-secret"

# AWS access key ID to use to authenticate to the storage service
access_key_id = "my-key-id"
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
# TODO: change xain-proto requirement to "xain-proto==0.2.0" once it is released
"xain-proto @ git+https://github.com/xainag/xain-proto.git@37fc05566da91d263c37d203c0ba70804960be9b#egg=xain_proto-0.1.0&subdirectory=python", # Apache License 2.0
"boto3==1.10.48", # Apache License 2.0
"toml==0.10.0", # MIT
"schema==0.6.8", # MIT
"validators==0.14.1", # MIT
]

dev_require = [
Expand Down Expand Up @@ -92,5 +95,5 @@
"docs": docs_require,
"dev": dev_require + tests_require + docs_require,
},
entry_points={"console_scripts": ["coordinator=xain_fl.cli:main"]},
entry_points={"console_scripts": ["coordinator=xain_fl.__main__"]},
)
10 changes: 8 additions & 2 deletions tests/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import numpy as np

from xain_fl.coordinator.store import Store, StoreConfig
from xain_fl.config import StorageConfig
from xain_fl.coordinator.store import Store


class FakeS3Resource:
Expand Down Expand Up @@ -71,7 +72,12 @@ class TestStore(Store):
#
# pylint: disable=super-init-not-called
def __init__(self):
self.config = StoreConfig("endpoint_url", "access_key_id", "secret_access_key", "bucket")
self.config = StorageConfig(
endpoint="endpoint",
access_key_id="access_key_id",
secret_access_key="secret_access_key",
bucket="bucket",
)
self.s3 = FakeS3Resource()

def assert_wrote(self, round: int, weights: np.ndarray):
Expand Down
176 changes: 176 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# pylint: disable=missing-docstring,redefined-outer-name
import re

import pytest

from xain_fl.config import Config, InvalidConfig


@pytest.fixture
def server_sample():
"""
Return a valid "server" section
"""
return {"host": "localhost", "port": 50051}


@pytest.fixture
def ai_sample():
"""
Return a valid "ai" section
"""
return {
"initial_weights": "./test_array.npy",
"rounds": 1,
"epochs": 1,
"min_participants": 1,
"fraction_participants": 1.0,
}


@pytest.fixture
def storage_sample():
"""
Return a valid "storage" section
"""
return {
"endpoint": "http://localhost:9000",
"bucket": "aggregated_weights",
"secret_access_key": "my-secret",
"access_key_id": "my-key-id",
}


@pytest.fixture
def config_sample(server_sample, ai_sample, storage_sample):
"""
Return a valid config
"""
return {
"ai": ai_sample,
"server": server_sample,
"storage": storage_sample,
}


def test_load_valid_config(config_sample):
"""
Check that a valid config is loaded correctly
"""
config = Config.from_unchecked_dict(config_sample)

assert config.server.host == "localhost"
assert config.server.port == 50051

assert config.ai.initial_weights == "./test_array.npy"
assert config.ai.rounds == 1
assert config.ai.epochs == 1
assert config.ai.min_participants == 1
assert config.ai.fraction_participants == 1.0

assert config.storage.endpoint == "http://localhost:9000"
assert config.storage.bucket == "aggregated_weights"
assert config.storage.secret_access_key == "my-secret"
assert config.storage.access_key_id == "my-key-id"


def test_server_config_ip_address(config_sample, server_sample):
"""Check that the config is loaded correctly when the `server.host`
key is an IP address
"""
# Ipv4 host
server_sample["host"] = "1.2.3.4"
config_sample["server"] = server_sample
config = Config.from_unchecked_dict(config_sample)
assert config.server.host == server_sample["host"]

# Ipv6 host
server_sample["host"] = "::1"
config_sample["server"] = server_sample
config = Config.from_unchecked_dict(config_sample)
assert config.server.host == server_sample["host"]


def test_server_config_extra_key(config_sample, server_sample):
"""Check that the config is rejected when the server section contains
an extra key
"""
server_sample["extra-key"] = "foo"
config_sample["server"] = server_sample

with AssertInvalid() as err:
Config.from_unchecked_dict(config_sample)

err.check_section("server")
err.check_extra_key("extra-key")


def test_server_config_invalid_host(config_sample, server_sample):
"""Check that the config is rejected when the `server.host` key is
invalid.
"""
server_sample["host"] = 1.0
config_sample["server"] = server_sample

with AssertInvalid() as err:
Config.from_unchecked_dict(config_sample)

err.check_other(
re.compile("Invalid `server.host`: value must be a valid domain name or IP address")
)


# FIXME: The library we use for validation rejects valid IPv6
# addresses
@pytest.mark.xfail
def test_server_config_valid_ipv6(config_sample, server_sample):
"""Check some edge cases with IPv6 `server.host` key"""
server_sample["host"] = "::"
config_sample["server"] = server_sample
config = Config.from_unchecked_dict(config_sample)
assert config.server.host == server_sample["host"]

server_sample["host"] = "fe80::"
config_sample["server"] = server_sample
config = Config.from_unchecked_dict(config_sample)
assert config.server.host == server_sample["host"]


# Adapted from unittest's assertRaises
class AssertInvalid:
"""A context manager that check that an `xainfl.config.InvalidConfig`
exception is raised, and provides helpers to perform checks on the
exception.
"""

def __init__(self):
self.message = None

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, _tb):
if exc_type is None:
raise Exception("Did not get an exception")
if not isinstance(exc_value, InvalidConfig):
# let this un-expected exception be re-raised
return False

self.message = str(exc_value)

return True

def check_section(self, section):
needle = re.compile(f"Key '{section}' error:")
assert re.search(needle, self.message)

def check_extra_key(self, key):
needle = re.compile(f"Wrong keys '{key}' in")
assert re.search(needle, self.message)

def check_other(self, needle):
assert re.search(needle, self.message)
38 changes: 38 additions & 0 deletions xain_fl/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""This module is the entrypoint to start a new coordinator instance.
"""
import sys

import numpy as np

from xain_fl.config import Config, InvalidConfig, get_cmd_parameters
from xain_fl.coordinator.coordinator import Coordinator
from xain_fl.coordinator.store import Store
from xain_fl.serve import serve


def main():
"""Start a coordinator instance
"""

args = get_cmd_parameters()
try:
config = Config.load(args.config)
except InvalidConfig as err:
print(err, file=sys.stderr)
sys.exit(1)

coordinator = Coordinator(
weights=list(np.load(config.ai.initial_weights, allow_pickle=True)),
num_rounds=config.ai.rounds,
epochs=config.ai.epochs,
minimum_participants_in_round=config.ai.min_participants,
fraction_of_participants=config.ai.fraction_participants,
)

store = Store(config.storage)

serve(coordinator=coordinator, store=store, host=config.server.host, port=config.server.port)


main()
Loading

0 comments on commit 202bd14

Please sign in to comment.