Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

Commit

Permalink
PB-159: remove weights from gRPC messages
Browse files Browse the repository at this point in the history
References:

https://xainag.atlassian.net/browse/PB-159

Needs to be merged along with:

- https://github.com/xainag/xain-proto/pull/25
- https://github.com/xainag/xain-sdk/pull/88
- #298

Summary:

Remove the weights from the gRPC messages. From now on, weights will
be exchanged via s3 buckets.

The sequence diagram below illustrate this new behavior.

At the beginning of a round (1) the selected participants send a
`StartTrainingRound` request, and the coordinator response with the
same `StartTrainingRoundResponse` that does not contain the global
weights anymore.

Instead, the participant fetches these weights from the store (2). S3
buckets are key-value stores, and the key for global weights is the
round number.

Then, the participant trains. Once done, it uploads its local weights
to the S3 bucket (3). The key is `<participant_id>/<round_number>`.

Finally (4), the participant sends it's `EndTrainingRequest`. Before
answering, the coordinator retrieves the local weights the participant
has uploaded.

_**Important note**: At the moment, the participants don't know their
ID, because the coordinator does send it to them. Thus, they currently
generate a random ID when they start, and send it to the coordinator
so that it can retrieve the participant's weights. This is why the
`EndTrainingRoundRequest` currently has a `participant_id` field._

```
    P                                C                      Store
1.  |   StartTrainingRoundRequest    |                        |
    | -----------------------------> |                        |
    |   StartTrainingRoundResponse   |                        |
    | <----------------------------- |                        |
    |                                |                        |
    |                Get global weights (key="round")         |
2.  | ------------------------------------------------------> |
    |                         Global weights                  |
    | <------------------------------------------------------ |
    |                                |                        |
    | [train...]                     |                        |
    |                                |                        |
3.  |       Set local weights (key="participant/round")       |
    | ------------------------------------------------------> |
    |                               Ok                        |
    | <------------------------------------------------------ |
    |                                |                        |
4.  |   EndTrainingRoundRequest      |                        |
    | -----------------------------> | Get local weights (key="participant/round")
    |                                | ---------------------> |
    |                                | Local weights          |
    |  EndTrainingRoundResponse      | <--------------------> |
    | <----------------------------- |                        |
```

At the end of the round, the coordinator writes the weights to the s3
bucket, using the next upcoming round number as key (see the sequence
diagram below).

```
P                                C                      Store
|   EndTrainingRoundRequest      |                        |
| -----------------------------> | Get local weights (key="participant/round")
|                                | ---------------------> |
|                                | Local weights          |
|  EndTrainingRoundResponse      | <--------------------> |
| <----------------------------- |                        |
|                                |                        |
|                                | Set global weights (key="round + 1")
|                                | ---------------------> |
|                                | Ok                     |
|                                | <--------------------> |
```

Implementation notes:

- Initially, we thought we would be using different buckets for the
  local and global weights. But for now, we use the same bucket for
  local and global weights for now

- We currently store the global weights under different keys. It turns
  out that this brings un-necessary complexity so we'll probably
  simplify this in the future

- For now, the coordinator doesn't send any storage information to the
  participants. Thus, the participants need to be configured with the
  storage information. In the future, the `StartTrainingRoundResponse`
  could contain the endpoint url, bucket name, etc.
  • Loading branch information
little-dude committed Feb 14, 2020
1 parent 7ff475b commit 4f17e1f
Show file tree
Hide file tree
Showing 9 changed files with 324 additions and 182 deletions.
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"numpy==1.15", # BSD
"grpcio==1.23", # Apache License 2.0
"structlog==19.2.0", # Apache License 2.0
"xain-proto==0.5.0", # Apache License 2.0
"xain-proto @ git+https://github.com/xainag/xain-proto.git@PB-159-use-s3-for-transfering-weights#egg=xain_proto-0.6.0&subdirectory=python", # Apache License 2.0
"boto3==1.10.48", # Apache License 2.0
"toml==0.10.0", # MIT
"schema~=0.7", # MIT
Expand All @@ -52,7 +52,8 @@
"pytest==5.3.2", # MIT license
"pytest-cov==2.8.1", # MIT
"pytest-watch==4.2.0", # MIT
"xain-sdk==0.5.0", # Apache License 2.0
"pytest-mock==2.0.0", # MIT
"xain-sdk@ git+https://github.com/xainag/xain-sdk.git@PB-159-use-s3-for-transfering-weights#egg=xain_sdk-0.6.0", # Apache License 2.0
]

docs_require = [
Expand Down
110 changes: 104 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,114 @@
import threading

import grpc
import numpy as np
from numpy import ndarray
import pytest
from xain_proto.fl import coordinator_pb2_grpc
from xain_proto.fl.coordinator_pb2 import EndTrainingRoundRequest

from xain_fl.coordinator.coordinator import Coordinator
from xain_fl.coordinator.coordinator_grpc import CoordinatorGrpc
from xain_fl.coordinator.heartbeat import monitor_heartbeats
from xain_fl.fl.coordinator.aggregate import ModelSumAggregator
from xain_fl.fl.coordinator.controller import IdController
from xain_fl.coordinator.metrics_store import (
AbstractMetricsStore,
NullObjectMetricsStore,
)
from xain_fl.coordinator.store import (
AbstractGlobalWeightsWriter,
AbstractLocalWeightsReader,
)
from xain_fl.fl.coordinator.aggregate import (
Aggregator,
ModelSumAggregator,
WeightedAverageAggregator,
)
from xain_fl.fl.coordinator.controller import Controller, IdController, RandomController

from .port_forwarding import ConnectionManager
from .store import MockS3Coordinator, MockS3Participant, MockS3Resource

# pylint: disable=redefined-outer-name


@pytest.fixture(scope="function")
def s3_mock_stores():
"""
Create a fake S3 store
"""

s3_resource = MockS3Resource()
participant_store = MockS3Participant(s3_resource)
coordinator_store = MockS3Coordinator(s3_resource)
return (coordinator_store, participant_store)


@pytest.fixture(scope="function")
def store(s3_mock_stores):
return s3_mock_stores[0].s3.fake_store


@pytest.fixture(scope="function")
def participant_store(s3_mock_stores):
return s3_mock_stores[1]


@pytest.fixture(scope="function")
def end_training_request(s3_mock_stores):
participant_store = s3_mock_stores[1]

def wrapped(
coordinator: Coordinator,
participant_id: str,
round: int = 0,
weights: ndarray = ndarray([]),
):
participant_store.write_weights(participant_id, round, weights)
coordinator.on_message(
EndTrainingRoundRequest(participant_id=participant_id), participant_id
)

return wrapped


@pytest.fixture(scope="function")
def coordinator(s3_mock_stores):
"""
A function that instantiate a new coordinator.
"""
store: MockS3Coordinator = s3_mock_stores[0]
default_global_weights_writer: AbstractGlobalWeightsWriter = store
default_local_weights_reader: AbstractLocalWeightsReader = store

# pylint: disable=too-many-arguments
def wrapped(
global_weights_writer=default_global_weights_writer,
local_weights_reader=default_local_weights_reader,
metrics_store: AbstractMetricsStore = NullObjectMetricsStore(),
num_rounds: int = 1,
minimum_participants_in_round: int = 1,
fraction_of_participants: float = 1.0,
weights: ndarray = np.empty(shape=(0,)),
epochs: int = 1,
epoch_base: int = 0,
aggregator: Aggregator = WeightedAverageAggregator(),
controller: Controller = RandomController(),
):
return Coordinator(
global_weights_writer,
local_weights_reader,
metrics_store=metrics_store,
num_rounds=num_rounds,
minimum_participants_in_round=minimum_participants_in_round,
fraction_of_participants=fraction_of_participants,
weights=weights,
epochs=epochs,
epoch_base=epoch_base,
aggregator=aggregator,
controller=controller,
)

return wrapped


@pytest.fixture()
Expand All @@ -39,14 +137,14 @@ def metrics_sample():


@pytest.fixture
def coordinator_service():
def coordinator_service(coordinator):
"""[summary]
.. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425)
"""

server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
coordinator = Coordinator(
coordinator = coordinator(
minimum_participants_in_round=10, fraction_of_participants=1.0
)
coordinator_grpc = CoordinatorGrpc(coordinator)
Expand All @@ -58,7 +156,7 @@ def coordinator_service():


@pytest.fixture
def mock_coordinator_service():
def mock_coordinator_service(coordinator):
"""[summary]
.. todo:: Advance docstrings (https://xainag.atlassian.net/browse/XP-425)
Expand All @@ -67,7 +165,7 @@ def mock_coordinator_service():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
agg = ModelSumAggregator()
ctrl = IdController()
coordinator = Coordinator(
coordinator = coordinator(
num_rounds=2,
minimum_participants_in_round=1,
fraction_of_participants=1.0,
Expand Down
81 changes: 65 additions & 16 deletions tests/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import typing

import numpy as np
from xain_sdk.store import S3GlobalWeightsReader, S3LocalWeightsWriter

from xain_fl.config import StorageConfig
from xain_fl.coordinator.store import S3GlobalWeightsWriter
from xain_fl.coordinator.store import S3GlobalWeightsWriter, S3LocalWeightsReader


class MockS3Resource:
Expand Down Expand Up @@ -62,43 +63,44 @@ def download_fileobj(self, key: str, buf: typing.IO):
self.reads[key] += 1


class MockS3Writer(S3GlobalWeightsWriter):
class MockS3Coordinator(S3GlobalWeightsWriter, S3LocalWeightsReader):
"""A partial mock of the
``xain-fl.coordinator.store.S3GlobalWeightsWriter`` class that
does not perform any IO. Instead, data is stored in memory.
``xain-fl.coordinator.store.S3GlobalWeightsWriter`` and
``xain-fl.coordinator.store.S3LocalWeightsReader`` class that does
not perform any IO. Instead, data is stored in memory.
"""

# We DO NOT want to call the parent class __init__, since it tries
# to initialize a connection to a non-existent external resource
#
# pylint: disable=super-init-not-called
def __init__(self):
def __init__(self, mock_s3_resource):
self.config = StorageConfig(
endpoint="endpoint",
access_key_id="access_key_id",
secret_access_key="secret_access_key",
global_weights_bucket="bucket",
local_weights_bucket="bucket",
bucket="bucket",
)
self.s3 = MockS3Resource()
self.s3 = mock_s3_resource

def assert_read(self, participant_id: str, round: int):
key = f"{participant_id}/{round}"
reads = self.s3.reads[key]
assert reads == 1, f"got {reads} reads for round {key}, expected 1"

def assert_wrote(self, round: int, weights: np.ndarray):
"""Check that the given weights have been written to the store for the
given round.
given round.
Args:
weights (np.ndarray): weights to store
round (int): round to which the weights belong
weights: weights to store
round: round to which the weights belong
"""
writes = self.s3.writes[str(round)]
# Under normal conditions, we should write data exactly once
assert writes == 1, f"got {writes} writes for round {round}, expected 1"
# If the arrays contains `NaN` we cannot compare them, so we
# replace them by zeros to do the comparison
stored_array = np.nan_to_num(self.s3.fake_store[str(round)])
expected_array = np.nan_to_num(weights)
assert np.array_equal(stored_array, expected_array)
assert_ndarray_eq(self.s3.fake_store[str(round)], weights)

def assert_didnt_write(self, round: int):
"""Check that the weights for the given round have NOT been written to the store.
Expand All @@ -108,3 +110,50 @@ def assert_didnt_write(self, round: int):
"""
assert self.s3.writes[str(round)] == 0


class MockS3Participant(S3LocalWeightsWriter, S3GlobalWeightsReader):
"""A partial mock of the ``xain_sdk.store.S3GlobalWeightsReader`` and
``xain_sdk.store.S3LocalWeightsWriter`` class that does not
perform any IO. Instead, data is stored in memory.
"""

def __init__(self, mock_s3_resource):
self.config = StorageConfig(
endpoint="endpoint",
access_key_id="access_key_id",
secret_access_key="secret_access_key",
bucket="bucket",
)
self.s3 = mock_s3_resource

def assert_wrote(self, participant_id: str, round: int, weights: np.ndarray):
"""Check that the given weights have been written to the store for the
given round.
Args:
weights: weights to store
participant_id: ID of the participant
round: round to which the weights belong
"""
key = f"{participant_id}/{round}"
writes = self.s3.writes[key]
assert writes == 1, f"got {writes} writes for {key}, expected 1"
assert_ndarray_eq(self.s3.fake_store[key], weights)

def assert_didnt_write(self, participant_id: str, round: int):
"""Check that the weights for the given round have NOT been written to
the store.
Args:
participant_id: ID of the participant
round: round to which the weights belong
"""
key = f"{participant_id}/{round}"
assert self.s3.writes[key] == 0


def assert_ndarray_eq(nd_array1, ndarray2):
assert np.array_equal(np.nan_to_num(nd_array1), np.nan_to_num(nd_array1))
6 changes: 2 additions & 4 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def storage_sample():
"""
return {
"endpoint": "http://localhost:9000",
"global_weights_bucket": "aggregated_weights",
"local_weights_bucket": "participants_weights",
"bucket": "bucket",
"secret_access_key": "my-secret",
"access_key_id": "my-key-id",
}
Expand Down Expand Up @@ -135,8 +134,7 @@ def test_load_valid_config(config_sample): # pylint: disable=redefined-outer-na
assert config.ai.fraction_participants == 1.0

assert config.storage.endpoint == "http://localhost:9000"
assert config.storage.global_weights_bucket == "aggregated_weights"
assert config.storage.local_weights_bucket == "participants_weights"
assert config.storage.bucket == "bucket"
assert config.storage.secret_access_key == "my-secret"
assert config.storage.access_key_id == "my-key-id"

Expand Down
Loading

0 comments on commit 4f17e1f

Please sign in to comment.