Skip to content

Commit

Permalink
test: update the tests, Client(execution_mode) (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
Esadruhn authored Sep 21, 2022
1 parent 033f139 commit cb254d7
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 49 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

- Update the Client, it takes a backend type instead of debug=True + env variable to set the spawner - (#210)

## [0.33.0] - 2022-09-19

### Removed
Expand Down
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ test-remote-sdk: pyclean
pytest tests -rs -v --durations=0 -m "not workflows" -n $(PARALLELISM)

test-remote-workflows: pyclean
pytest tests -v --durations=0 -m "workflows"
pytest tests -v --durations=0 -m "workflows"

test-minimal: pyclean
pytest tests -rs -v --durations=0 -m "not slow and not workflows" -n $(PARALLELISM)

test-local: test-subprocess test-docker test-subprocess-workflows

test-docker: pyclean
DEBUG_SPAWNER=docker pytest tests -rs -v --durations=0 -m "not workflows" --local
pytest tests -rs -v --durations=0 -m "not workflows" --mode=docker

test-subprocess: pyclean
DEBUG_SPAWNER=subprocess pytest tests -rs -v --durations=0 -m "not workflows and not subprocess_skip" --local
pytest tests -rs -v --durations=0 -m "not workflows and not subprocess_skip" --mode=subprocess

test-subprocess-workflows: pyclean
DEBUG_SPAWNER=subprocess pytest tests -v --durations=0 -m "workflows" --local
pytest tests -v --durations=0 -m "workflows" --mode=subprocess

test-all: test-local test-remote

Expand Down
4 changes: 2 additions & 2 deletions substratest/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Client:

def __init__(
self,
debug: bool,
backend_type: substra.BackendType,
organization_id: str,
address: str,
user: str,
Expand All @@ -51,7 +51,7 @@ def __init__(
super().__init__()

self.organization_id = organization_id
self._client = substra.Client(debug=debug, url=address, insecure=False, token=token)
self._client = substra.Client(backend_type=backend_type, url=address, insecure=False, token=token)
if not token:
token = self._client.login(user, password)
self._api_client = _APIClient(address, token)
Expand Down
53 changes: 30 additions & 23 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid

import pytest
import substra
from substra.sdk.schemas import AlgoCategory

import substratest as sbt
Expand Down Expand Up @@ -55,9 +56,10 @@ def pytest_configure(config):
def pytest_addoption(parser):
"""Command line arguments to configure the network to be local or remote"""
parser.addoption(
"--local",
action="store_true",
help="Run the tests on the local backend only. Otherwise run the tests only on the remote backend.",
"--mode",
choices=["subprocess", "docker", "remote"],
default="remote",
help="Choose the mode on which to run the tests",
)
parser.addoption(
"--nb-train-datasamples",
Expand All @@ -77,8 +79,8 @@ def pytest_collection_modifyitems(config, items):
"""Skip the remote tests if local backend and local tests if remote backend.
By default, run only on the remote backend.
"""
local = config.getoption("--local")
if local:
mode = substra.BackendType(config.getoption("--mode"))
if mode != substra.BackendType.REMOTE:
skip_marker = pytest.mark.skip(reason="remove the --local option to run")
keyword = "remote_only"
else:
Expand All @@ -90,11 +92,9 @@ def pytest_collection_modifyitems(config, items):


@pytest.fixture(scope="session")
def client_debug_local(request):
local = request.config.getoption("--local")
if local:
return True
return False
def client_mode(request):
mode = request.config.getoption("--mode")
return substra.BackendType(mode)


class _DataEnv:
Expand Down Expand Up @@ -129,23 +129,26 @@ class Network:


@pytest.fixture
def factory(request, cfg, client_debug_local):
def factory(request, cfg, client_mode):
"""Factory fixture.
Provide class methods to simply create asset specification in order to add them
to the substra framework.
"""
name = f"{TESTS_RUN_UUID}_{request.node.name}"
with sbt.AssetsFactory(name=name, cfg=cfg, client_debug_local=client_debug_local) as f:
with sbt.AssetsFactory(
name=name,
cfg=cfg,
client_debug_local=(client_mode in [substra.BackendType.LOCAL_SUBPROCESS, substra.BackendType.LOCAL_DOCKER]),
) as f:
yield f


@pytest.fixture(scope="session")
def cfg(client_debug_local):
if not client_debug_local:
def cfg(client_mode):
if client_mode == substra.BackendType.REMOTE:
return settings.Settings.load()
else:
# TODO check what enable_intermediate_model_removal does
return settings.Settings.load_local_backend()


Expand All @@ -162,7 +165,7 @@ def debug_factory(request, cfg):


@pytest.fixture(scope="session")
def network(cfg, client_debug_local):
def network(cfg, client_mode):
"""Network fixture.
Network must be started outside of the tests environment and the network is kept
Expand All @@ -174,7 +177,7 @@ def network(cfg, client_debug_local):
"""
clients = [
sbt.Client(
debug=client_debug_local,
backend_type=client_mode,
organization_id=n.msp_id,
address=n.address,
user=n.user,
Expand All @@ -191,7 +194,7 @@ def network(cfg, client_debug_local):


@pytest.fixture(scope="session")
def default_data_env(cfg, network, client_debug_local):
def default_data_env(cfg, network, client_mode):
"""Fixture with pre-existing assets in all organizations.
The following assets will be created for each organization:
Expand All @@ -207,7 +210,11 @@ def default_data_env(cfg, network, client_debug_local):
"""
factory_name = f"{TESTS_RUN_UUID}_global"

with sbt.AssetsFactory(name=factory_name, cfg=cfg, client_debug_local=client_debug_local) as f:
with sbt.AssetsFactory(
name=factory_name,
cfg=cfg,
client_debug_local=(client_mode in [substra.BackendType.LOCAL_SUBPROCESS, substra.BackendType.LOCAL_DOCKER]),
) as f:
datasets = []
metrics = []
for index, client in enumerate(network.clients):
Expand Down Expand Up @@ -338,17 +345,17 @@ def channel(cfg, network):


@pytest.fixture(scope="session")
def debug_client(cfg, client):
def hybrid_client(cfg, client):
"""
Client fixture in debug mode (first organization).
Client fixture in hybrid mode (first organization).
Use it with @pytest.mark.remote_only
"""
organization = cfg.organizations[0]
# Debug client and client share the same
# Hybrid client and client share the same
# token, otherwise when one connects the other
# is disconnected.
return sbt.Client(
debug=True,
backend_type=substra.BackendType.LOCAL_DOCKER,
organization_id=organization.msp_id,
address=organization.address,
user=organization.user,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_traintuple_execution_failure(factory, network, default_dataset_1):

spec = factory.create_traintuple(algo=algo, inputs=default_dataset_1.train_data_inputs)

if network.clients[0].backend_mode != substra.BackendType.DEPLOYED:
if network.clients[0].backend_mode != substra.BackendType.REMOTE:
with pytest.raises(substra.sdk.backends.local.compute.spawner.base.ExecutionError):
network.clients[0].add_traintuple(spec)
else:
Expand All @@ -236,7 +236,7 @@ def test_composite_traintuple_execution_failure(factory, client, default_dataset
algo = client.add_algo(spec)

spec = factory.create_composite_traintuple(algo=algo, inputs=default_dataset.train_data_inputs)
if client.backend_mode == substra.BackendType.DEPLOYED:
if client.backend_mode == substra.BackendType.REMOTE:
composite_traintuple = client.add_composite_traintuple(spec)
composite_traintuple = client.wait(composite_traintuple, raises=False)

Expand Down Expand Up @@ -277,7 +277,7 @@ def test_aggregatetuple_execution_failure(factory, client, default_dataset):
worker=client.organization_id,
)

if client.backend_mode == substra.BackendType.DEPLOYED:
if client.backend_mode == substra.BackendType.REMOTE:
aggregatetuple = client.add_aggregatetuple(spec)
aggregatetuple = client.wait(aggregatetuple, raises=False)

Expand Down
28 changes: 14 additions & 14 deletions tests/test_hybrid_debug.py → tests/test_hybrid_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def docker_available() -> bool:

@pytest.mark.remote_only
@pytest.mark.slow
def test_execution_debug(client, debug_client, debug_factory, default_dataset):
def test_execution_debug(client, hybrid_client, debug_factory, default_dataset):

spec = debug_factory.create_algo(AlgoCategory.simple)
simple_algo = client.add_algo(spec)
Expand All @@ -36,7 +36,7 @@ def test_execution_debug(client, debug_client, debug_factory, default_dataset):
spec = debug_factory.create_traintuple(
algo=simple_algo, inputs=default_dataset.opener_input + default_dataset.train_data_sample_inputs[:1]
)
traintuple = debug_client.add_traintuple(spec)
traintuple = hybrid_client.add_traintuple(spec)
assert traintuple.status == models.Status.done
assert len(traintuple.train.models) != 0

Expand All @@ -47,7 +47,7 @@ def test_execution_debug(client, debug_client, debug_factory, default_dataset):
+ default_dataset.train_data_sample_inputs[:1]
+ FLTaskInputGenerator.train_to_predict(traintuple.key),
)
predicttuple = debug_client.add_predicttuple(spec)
predicttuple = hybrid_client.add_predicttuple(spec)
assert predicttuple.status == models.Status.done

spec = debug_factory.create_testtuple(
Expand All @@ -56,19 +56,19 @@ def test_execution_debug(client, debug_client, debug_factory, default_dataset):
+ default_dataset.train_data_sample_inputs[:1]
+ FLTaskInputGenerator.predict_to_test(predicttuple.key),
)
testtuple = debug_client.add_testtuple(spec)
testtuple = hybrid_client.add_testtuple(spec)
assert testtuple.status == models.Status.done
assert list(testtuple.test.perfs.values())[0] == 3


@pytest.mark.remote_only
@pytest.mark.slow
def test_debug_compute_plan_aggregate_composite(network, client, debug_client, debug_factory, default_datasets):
def test_debug_compute_plan_aggregate_composite(network, client, hybrid_client, debug_factory, default_datasets):
"""
Debug / Compute plan version of the
`test_aggregate_composite_traintuples` method from `test_execution.py`
"""
aggregate_worker = debug_client.organization_id
aggregate_worker = hybrid_client.organization_id
number_of_rounds = 2

# register algos on first organization
Expand Down Expand Up @@ -133,8 +133,8 @@ def test_debug_compute_plan_aggregate_composite(network, client, debug_client, d
algo=metric, inputs=dataset.train_data_inputs + FLTaskInputGenerator.predict_to_test(spec.predicttuple_id)
)

cp = debug_client.add_compute_plan(cp_spec)
traintuples = debug_client.list_compute_plan_traintuples(cp.key)
cp = hybrid_client.add_compute_plan(cp_spec)
traintuples = hybrid_client.list_compute_plan_traintuples(cp.key)
composite_traintuples = client.list_compute_plan_composite_traintuples(cp.key)
aggregatetuples = client.list_compute_plan_aggregatetuples(cp.key)
predicttuples = client.list_compute_plan_predicttuples(cp.key)
Expand All @@ -146,13 +146,13 @@ def test_debug_compute_plan_aggregate_composite(network, client, debug_client, d


@pytest.mark.remote_only
def test_debug_download_dataset(debug_client, default_dataset):
debug_client.download_opener(default_dataset.key)
def test_debug_download_dataset(hybrid_client, default_dataset):
hybrid_client.download_opener(default_dataset.key)


@pytest.mark.remote_only
@pytest.mark.slow
def test_test_data_traintuple(client, debug_client, debug_factory, default_dataset):
def test_test_data_traintuple(client, hybrid_client, debug_factory, default_dataset):
"""Check that we can't use test data samples for traintuples"""
spec = debug_factory.create_algo(AlgoCategory.simple)
algo = client.add_algo(spec)
Expand All @@ -165,13 +165,13 @@ def test_test_data_traintuple(client, debug_client, debug_factory, default_datas
)

with pytest.raises(InvalidRequest) as e:
debug_client.add_traintuple(spec)
hybrid_client.add_traintuple(spec)
assert "Cannot create train task with test data" in str(e.value)


@pytest.mark.remote_only
@pytest.mark.slow
def test_fake_data_sample_key(client, debug_client, debug_factory, default_dataset):
def test_fake_data_sample_key(client, hybrid_client, debug_factory, default_dataset):
"""Check that a traintuple can't run with a fake train_data_sample_keys"""
spec = debug_factory.create_algo(AlgoCategory.simple)
algo = client.add_algo(spec)
Expand All @@ -183,5 +183,5 @@ def test_fake_data_sample_key(client, debug_client, debug_factory, default_datas
)

with pytest.raises(InvalidRequest) as e:
debug_client.add_traintuple(spec)
hybrid_client.add_traintuple(spec)
assert "Could not get all the data_samples in the database with the given data_sample_keys" in str(e.value)
2 changes: 1 addition & 1 deletion tests/workflows/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ mnist_workflow:

## run locally with docker or subprocess

`DEBUG_SPAWNER=subprocess pytest tests -v --durations=0 -m "workflows" --local`
`pytest tests -v --durations=0 -m "workflows" --subprocess`
2 changes: 1 addition & 1 deletion tests/workflows/mnist-fedavg/assets/aggregate_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def predict(self, inputs, outputs):
X = inputs["X"]
X = torch.FloatTensor(X)

model = self.load_model(inputs["model"])
model = self.load_model(inputs["models"])
model.eval()
# add the context manager to reduce computation overhead
with torch.no_grad():
Expand Down

0 comments on commit cb254d7

Please sign in to comment.