This repository has been archived by the owner on Aug 30, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
replace the CLI arguments with a config file
Rationale ========= The CLI is getting complex so it is worth loading the configuration from a file instead. Implementation details ====================== TOML ---- We decided to use TOML for the following reasons: - it is human friendly, ie easy to read and write - our configuration has a pretty flat structure which makes TOML quite adapted - it is well specified and has lots of implementation - it is well known The other options we considered: - INI: it is quite frequent in the Python ecosystem to use INI for config files, and the standard library even provides support for this. However, INI is not as powerful as TOML and does not have a specification - JSON: it is very popular but is not human friendly. For instance, it does not support comments, is very verbose, and breaks easily (if a trailing comma is forgotten at the end of a list for instance) - YAML: another popular choice, but is in my opinion more complex than TOML. Validation ---------- We use the third-party `schema` library to validate the configuration. It provides a convenient way to: - declare a schema to validate our config - leverage third-party libraries to validate some inputs (we use the `validators` library to validate IP addresses and URL for instance) - define our own validators - transform data after it has been validated: this can be useful to turn a relative path into an absolute one for example - provide user friendly error message when the configuration is invalid The `Config` class ------------------ By default, the `schema` library returns a dictionary containing a valid configuration, but that is not convenient to manipulate in Python. Therefore, we dynamically create a `Config` class from the configuration schema, and instantiate a `Config` object from the data returned by the `schema` validator. Package re-organization ----------------------- We moved the command line and config file logic into its own `config` sub-package, and moved the former `xain_fl.cli.main` entrypoint into the `xain_fl.__main__` module.
- Loading branch information
1 parent
ac8da29
commit 202bd14
Showing
10 changed files
with
630 additions
and
252 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
[server] | ||
|
||
# Address to listen on for incoming gRPC connections | ||
host = "localhost" | ||
# Port to listen on for incoming gRPC connections | ||
port = 50051 | ||
|
||
|
||
[ai] | ||
|
||
# Path to a file containing a numpy ndarray to use a initial model weights. | ||
initial_weights = "./test_array.npy" | ||
|
||
# Number of global rounds the model is going to be trained for. This | ||
# must be a positive integer. | ||
rounds = 1 | ||
|
||
# Number of local epochs per round | ||
epochs = 1 | ||
|
||
# Minimum number of participants to be selected for a round. | ||
min_participants = 1 | ||
|
||
# Fraction of total clients that participate in a training round. This | ||
# must be a float between 0 and 1. | ||
fraction_participants = 1.0 | ||
|
||
[storage] | ||
|
||
# URL to the storage service to use | ||
endpoint = "http://localhost:9000" | ||
|
||
# Name of the bucket for storing the aggregated models | ||
bucket = "aggregated_weights" | ||
|
||
# AWS secret access to use to authenticate to the storage service | ||
secret_access_key = "my-secret" | ||
|
||
# AWS access key ID to use to authenticate to the storage service | ||
access_key_id = "my-key-id" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# pylint: disable=missing-docstring,redefined-outer-name | ||
import re | ||
|
||
import pytest | ||
|
||
from xain_fl.config import Config, InvalidConfig | ||
|
||
|
||
@pytest.fixture | ||
def server_sample(): | ||
""" | ||
Return a valid "server" section | ||
""" | ||
return {"host": "localhost", "port": 50051} | ||
|
||
|
||
@pytest.fixture | ||
def ai_sample(): | ||
""" | ||
Return a valid "ai" section | ||
""" | ||
return { | ||
"initial_weights": "./test_array.npy", | ||
"rounds": 1, | ||
"epochs": 1, | ||
"min_participants": 1, | ||
"fraction_participants": 1.0, | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def storage_sample(): | ||
""" | ||
Return a valid "storage" section | ||
""" | ||
return { | ||
"endpoint": "http://localhost:9000", | ||
"bucket": "aggregated_weights", | ||
"secret_access_key": "my-secret", | ||
"access_key_id": "my-key-id", | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def config_sample(server_sample, ai_sample, storage_sample): | ||
""" | ||
Return a valid config | ||
""" | ||
return { | ||
"ai": ai_sample, | ||
"server": server_sample, | ||
"storage": storage_sample, | ||
} | ||
|
||
|
||
def test_load_valid_config(config_sample): | ||
""" | ||
Check that a valid config is loaded correctly | ||
""" | ||
config = Config.from_unchecked_dict(config_sample) | ||
|
||
assert config.server.host == "localhost" | ||
assert config.server.port == 50051 | ||
|
||
assert config.ai.initial_weights == "./test_array.npy" | ||
assert config.ai.rounds == 1 | ||
assert config.ai.epochs == 1 | ||
assert config.ai.min_participants == 1 | ||
assert config.ai.fraction_participants == 1.0 | ||
|
||
assert config.storage.endpoint == "http://localhost:9000" | ||
assert config.storage.bucket == "aggregated_weights" | ||
assert config.storage.secret_access_key == "my-secret" | ||
assert config.storage.access_key_id == "my-key-id" | ||
|
||
|
||
def test_server_config_ip_address(config_sample, server_sample): | ||
"""Check that the config is loaded correctly when the `server.host` | ||
key is an IP address | ||
""" | ||
# Ipv4 host | ||
server_sample["host"] = "1.2.3.4" | ||
config_sample["server"] = server_sample | ||
config = Config.from_unchecked_dict(config_sample) | ||
assert config.server.host == server_sample["host"] | ||
|
||
# Ipv6 host | ||
server_sample["host"] = "::1" | ||
config_sample["server"] = server_sample | ||
config = Config.from_unchecked_dict(config_sample) | ||
assert config.server.host == server_sample["host"] | ||
|
||
|
||
def test_server_config_extra_key(config_sample, server_sample): | ||
"""Check that the config is rejected when the server section contains | ||
an extra key | ||
""" | ||
server_sample["extra-key"] = "foo" | ||
config_sample["server"] = server_sample | ||
|
||
with AssertInvalid() as err: | ||
Config.from_unchecked_dict(config_sample) | ||
|
||
err.check_section("server") | ||
err.check_extra_key("extra-key") | ||
|
||
|
||
def test_server_config_invalid_host(config_sample, server_sample): | ||
"""Check that the config is rejected when the `server.host` key is | ||
invalid. | ||
""" | ||
server_sample["host"] = 1.0 | ||
config_sample["server"] = server_sample | ||
|
||
with AssertInvalid() as err: | ||
Config.from_unchecked_dict(config_sample) | ||
|
||
err.check_other( | ||
re.compile("Invalid `server.host`: value must be a valid domain name or IP address") | ||
) | ||
|
||
|
||
# FIXME: The library we use for validation rejects valid IPv6 | ||
# addresses | ||
@pytest.mark.xfail | ||
def test_server_config_valid_ipv6(config_sample, server_sample): | ||
"""Check some edge cases with IPv6 `server.host` key""" | ||
server_sample["host"] = "::" | ||
config_sample["server"] = server_sample | ||
config = Config.from_unchecked_dict(config_sample) | ||
assert config.server.host == server_sample["host"] | ||
|
||
server_sample["host"] = "fe80::" | ||
config_sample["server"] = server_sample | ||
config = Config.from_unchecked_dict(config_sample) | ||
assert config.server.host == server_sample["host"] | ||
|
||
|
||
# Adapted from unittest's assertRaises | ||
class AssertInvalid: | ||
"""A context manager that check that an `xainfl.config.InvalidConfig` | ||
exception is raised, and provides helpers to perform checks on the | ||
exception. | ||
""" | ||
|
||
def __init__(self): | ||
self.message = None | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, _tb): | ||
if exc_type is None: | ||
raise Exception("Did not get an exception") | ||
if not isinstance(exc_value, InvalidConfig): | ||
# let this un-expected exception be re-raised | ||
return False | ||
|
||
self.message = str(exc_value) | ||
|
||
return True | ||
|
||
def check_section(self, section): | ||
needle = re.compile(f"Key '{section}' error:") | ||
assert re.search(needle, self.message) | ||
|
||
def check_extra_key(self, key): | ||
needle = re.compile(f"Wrong keys '{key}' in") | ||
assert re.search(needle, self.message) | ||
|
||
def check_other(self, needle): | ||
assert re.search(needle, self.message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
"""This module is the entrypoint to start a new coordinator instance. | ||
""" | ||
import sys | ||
|
||
import numpy as np | ||
|
||
from xain_fl.config import Config, InvalidConfig, get_cmd_parameters | ||
from xain_fl.coordinator.coordinator import Coordinator | ||
from xain_fl.coordinator.store import Store | ||
from xain_fl.serve import serve | ||
|
||
|
||
def main(): | ||
"""Start a coordinator instance | ||
""" | ||
|
||
args = get_cmd_parameters() | ||
try: | ||
config = Config.load(args.config) | ||
except InvalidConfig as err: | ||
print(err, file=sys.stderr) | ||
sys.exit(1) | ||
|
||
coordinator = Coordinator( | ||
weights=list(np.load(config.ai.initial_weights, allow_pickle=True)), | ||
num_rounds=config.ai.rounds, | ||
epochs=config.ai.epochs, | ||
minimum_participants_in_round=config.ai.min_participants, | ||
fraction_of_participants=config.ai.fraction_participants, | ||
) | ||
|
||
store = Store(config.storage) | ||
|
||
serve(coordinator=coordinator, store=store, host=config.server.host, port=config.server.port) | ||
|
||
|
||
main() |
Oops, something went wrong.