Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: update the tests, Client(execution_mode) #210

Merged
merged 5 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

- Update the Client, it takes a backend type instead of debug=True + env variable to set the spawner - (#210)

## [0.33.0] - 2022-09-19

### Removed
Expand Down
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ test-remote-sdk: pyclean
pytest tests -rs -v --durations=0 -m "not workflows" -n $(PARALLELISM)

test-remote-workflows: pyclean
pytest tests -v --durations=0 -m "workflows"
pytest tests -v --durations=0 -m "workflows"

test-minimal: pyclean
pytest tests -rs -v --durations=0 -m "not slow and not workflows" -n $(PARALLELISM)

test-local: test-subprocess test-docker test-subprocess-workflows

test-docker: pyclean
DEBUG_SPAWNER=docker pytest tests -rs -v --durations=0 -m "not workflows" --local
pytest tests -rs -v --durations=0 -m "not workflows" --mode=docker

test-subprocess: pyclean
DEBUG_SPAWNER=subprocess pytest tests -rs -v --durations=0 -m "not workflows and not subprocess_skip" --local
pytest tests -rs -v --durations=0 -m "not workflows and not subprocess_skip" --mode=subprocess

test-subprocess-workflows: pyclean
DEBUG_SPAWNER=subprocess pytest tests -v --durations=0 -m "workflows" --local
pytest tests -v --durations=0 -m "workflows" --mode=subprocess

test-all: test-local test-remote

Expand Down
4 changes: 2 additions & 2 deletions substratest/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Client:

def __init__(
self,
debug: bool,
backend_type: substra.BackendType,
organization_id: str,
address: str,
user: str,
Expand All @@ -51,7 +51,7 @@ def __init__(
super().__init__()

self.organization_id = organization_id
self._client = substra.Client(debug=debug, url=address, insecure=False, token=token)
self._client = substra.Client(backend_type=backend_type, url=address, insecure=False, token=token)
if not token:
token = self._client.login(user, password)
self._api_client = _APIClient(address, token)
Expand Down
53 changes: 30 additions & 23 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid

import pytest
import substra
from substra.sdk.schemas import AlgoCategory

import substratest as sbt
Expand Down Expand Up @@ -55,9 +56,10 @@ def pytest_configure(config):
def pytest_addoption(parser):
"""Command line arguments to configure the network to be local or remote"""
parser.addoption(
"--local",
action="store_true",
help="Run the tests on the local backend only. Otherwise run the tests only on the remote backend.",
"--mode",
choices=["subprocess", "docker", "remote"],
default="remote",
help="Choose the mode on which to run the tests",
)
parser.addoption(
"--nb-train-datasamples",
Expand All @@ -77,8 +79,8 @@ def pytest_collection_modifyitems(config, items):
"""Skip the remote tests if local backend and local tests if remote backend.
By default, run only on the remote backend.
"""
local = config.getoption("--local")
if local:
mode = substra.BackendType(config.getoption("--mode"))
if mode != substra.BackendType.REMOTE:
skip_marker = pytest.mark.skip(reason="remove the --local option to run")
keyword = "remote_only"
else:
Expand All @@ -90,11 +92,9 @@ def pytest_collection_modifyitems(config, items):


@pytest.fixture(scope="session")
def client_debug_local(request):
local = request.config.getoption("--local")
if local:
return True
return False
def client_mode(request):
mode = request.config.getoption("--mode")
return substra.BackendType(mode)


class _DataEnv:
Expand Down Expand Up @@ -129,23 +129,26 @@ class Network:


@pytest.fixture
def factory(request, cfg, client_debug_local):
def factory(request, cfg, client_mode):
"""Factory fixture.

Provide class methods to simply create asset specification in order to add them
to the substra framework.
"""
name = f"{TESTS_RUN_UUID}_{request.node.name}"
with sbt.AssetsFactory(name=name, cfg=cfg, client_debug_local=client_debug_local) as f:
with sbt.AssetsFactory(
name=name,
cfg=cfg,
client_debug_local=(client_mode in [substra.BackendType.LOCAL_SUBPROCESS, substra.BackendType.LOCAL_DOCKER]),
) as f:
yield f


@pytest.fixture(scope="session")
def cfg(client_debug_local):
if not client_debug_local:
def cfg(client_mode):
if client_mode == substra.BackendType.REMOTE:
return settings.Settings.load()
else:
# TODO check what enable_intermediate_model_removal does
return settings.Settings.load_local_backend()


Expand All @@ -162,7 +165,7 @@ def debug_factory(request, cfg):


@pytest.fixture(scope="session")
def network(cfg, client_debug_local):
def network(cfg, client_mode):
"""Network fixture.

Network must be started outside of the tests environment and the network is kept
Expand All @@ -174,7 +177,7 @@ def network(cfg, client_debug_local):
"""
clients = [
sbt.Client(
debug=client_debug_local,
backend_type=client_mode,
organization_id=n.msp_id,
address=n.address,
user=n.user,
Expand All @@ -191,7 +194,7 @@ def network(cfg, client_debug_local):


@pytest.fixture(scope="session")
def default_data_env(cfg, network, client_debug_local):
def default_data_env(cfg, network, client_mode):
"""Fixture with pre-existing assets in all organizations.

The following assets will be created for each organization:
Expand All @@ -207,7 +210,11 @@ def default_data_env(cfg, network, client_debug_local):
"""
factory_name = f"{TESTS_RUN_UUID}_global"

with sbt.AssetsFactory(name=factory_name, cfg=cfg, client_debug_local=client_debug_local) as f:
with sbt.AssetsFactory(
name=factory_name,
cfg=cfg,
client_debug_local=(client_mode in [substra.BackendType.LOCAL_SUBPROCESS, substra.BackendType.LOCAL_DOCKER]),
) as f:
datasets = []
metrics = []
for index, client in enumerate(network.clients):
Expand Down Expand Up @@ -338,17 +345,17 @@ def channel(cfg, network):


@pytest.fixture(scope="session")
def debug_client(cfg, client):
def hybrid_client(cfg, client):
"""
Client fixture in debug mode (first organization).
Client fixture in hybrid mode (first organization).
Use it with @pytest.mark.remote_only
"""
organization = cfg.organizations[0]
# Debug client and client share the same
# Hybrid client and client share the same
# token, otherwise when one connects the other
# is disconnected.
return sbt.Client(
debug=True,
backend_type=substra.BackendType.LOCAL_DOCKER,
organization_id=organization.msp_id,
address=organization.address,
user=organization.user,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_traintuple_execution_failure(factory, network, default_dataset_1):

spec = factory.create_traintuple(algo=algo, inputs=default_dataset_1.train_data_inputs)

if network.clients[0].backend_mode != substra.BackendType.DEPLOYED:
if network.clients[0].backend_mode != substra.BackendType.REMOTE:
with pytest.raises(substra.sdk.backends.local.compute.spawner.base.ExecutionError):
network.clients[0].add_traintuple(spec)
else:
Expand All @@ -236,7 +236,7 @@ def test_composite_traintuple_execution_failure(factory, client, default_dataset
algo = client.add_algo(spec)

spec = factory.create_composite_traintuple(algo=algo, inputs=default_dataset.train_data_inputs)
if client.backend_mode == substra.BackendType.DEPLOYED:
if client.backend_mode == substra.BackendType.REMOTE:
composite_traintuple = client.add_composite_traintuple(spec)
composite_traintuple = client.wait(composite_traintuple, raises=False)

Expand Down Expand Up @@ -277,7 +277,7 @@ def test_aggregatetuple_execution_failure(factory, client, default_dataset):
worker=client.organization_id,
)

if client.backend_mode == substra.BackendType.DEPLOYED:
if client.backend_mode == substra.BackendType.REMOTE:
aggregatetuple = client.add_aggregatetuple(spec)
aggregatetuple = client.wait(aggregatetuple, raises=False)

Expand Down
28 changes: 14 additions & 14 deletions tests/test_hybrid_debug.py → tests/test_hybrid_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def docker_available() -> bool:

@pytest.mark.remote_only
@pytest.mark.slow
def test_execution_debug(client, debug_client, debug_factory, default_dataset):
def test_execution_debug(client, hybrid_client, debug_factory, default_dataset):

spec = debug_factory.create_algo(AlgoCategory.simple)
simple_algo = client.add_algo(spec)
Expand All @@ -36,7 +36,7 @@ def test_execution_debug(client, debug_client, debug_factory, default_dataset):
spec = debug_factory.create_traintuple(
algo=simple_algo, inputs=default_dataset.opener_input + default_dataset.train_data_sample_inputs[:1]
)
traintuple = debug_client.add_traintuple(spec)
traintuple = hybrid_client.add_traintuple(spec)
assert traintuple.status == models.Status.done
assert len(traintuple.train.models) != 0

Expand All @@ -47,7 +47,7 @@ def test_execution_debug(client, debug_client, debug_factory, default_dataset):
+ default_dataset.train_data_sample_inputs[:1]
+ FLTaskInputGenerator.train_to_predict(traintuple.key),
)
predicttuple = debug_client.add_predicttuple(spec)
predicttuple = hybrid_client.add_predicttuple(spec)
assert predicttuple.status == models.Status.done

spec = debug_factory.create_testtuple(
Expand All @@ -56,19 +56,19 @@ def test_execution_debug(client, debug_client, debug_factory, default_dataset):
+ default_dataset.train_data_sample_inputs[:1]
+ FLTaskInputGenerator.predict_to_test(predicttuple.key),
)
testtuple = debug_client.add_testtuple(spec)
testtuple = hybrid_client.add_testtuple(spec)
assert testtuple.status == models.Status.done
assert list(testtuple.test.perfs.values())[0] == 3


@pytest.mark.remote_only
@pytest.mark.slow
def test_debug_compute_plan_aggregate_composite(network, client, debug_client, debug_factory, default_datasets):
def test_debug_compute_plan_aggregate_composite(network, client, hybrid_client, debug_factory, default_datasets):
"""
Debug / Compute plan version of the
`test_aggregate_composite_traintuples` method from `test_execution.py`
"""
aggregate_worker = debug_client.organization_id
aggregate_worker = hybrid_client.organization_id
number_of_rounds = 2

# register algos on first organization
Expand Down Expand Up @@ -133,8 +133,8 @@ def test_debug_compute_plan_aggregate_composite(network, client, debug_client, d
algo=metric, inputs=dataset.train_data_inputs + FLTaskInputGenerator.predict_to_test(spec.predicttuple_id)
)

cp = debug_client.add_compute_plan(cp_spec)
traintuples = debug_client.list_compute_plan_traintuples(cp.key)
cp = hybrid_client.add_compute_plan(cp_spec)
traintuples = hybrid_client.list_compute_plan_traintuples(cp.key)
composite_traintuples = client.list_compute_plan_composite_traintuples(cp.key)
aggregatetuples = client.list_compute_plan_aggregatetuples(cp.key)
predicttuples = client.list_compute_plan_predicttuples(cp.key)
Expand All @@ -146,13 +146,13 @@ def test_debug_compute_plan_aggregate_composite(network, client, debug_client, d


@pytest.mark.remote_only
def test_debug_download_dataset(debug_client, default_dataset):
debug_client.download_opener(default_dataset.key)
def test_debug_download_dataset(hybrid_client, default_dataset):
hybrid_client.download_opener(default_dataset.key)


@pytest.mark.remote_only
@pytest.mark.slow
def test_test_data_traintuple(client, debug_client, debug_factory, default_dataset):
def test_test_data_traintuple(client, hybrid_client, debug_factory, default_dataset):
"""Check that we can't use test data samples for traintuples"""
spec = debug_factory.create_algo(AlgoCategory.simple)
algo = client.add_algo(spec)
Expand All @@ -165,13 +165,13 @@ def test_test_data_traintuple(client, debug_client, debug_factory, default_datas
)

with pytest.raises(InvalidRequest) as e:
debug_client.add_traintuple(spec)
hybrid_client.add_traintuple(spec)
assert "Cannot create train task with test data" in str(e.value)


@pytest.mark.remote_only
@pytest.mark.slow
def test_fake_data_sample_key(client, debug_client, debug_factory, default_dataset):
def test_fake_data_sample_key(client, hybrid_client, debug_factory, default_dataset):
"""Check that a traintuple can't run with a fake train_data_sample_keys"""
spec = debug_factory.create_algo(AlgoCategory.simple)
algo = client.add_algo(spec)
Expand All @@ -183,5 +183,5 @@ def test_fake_data_sample_key(client, debug_client, debug_factory, default_datas
)

with pytest.raises(InvalidRequest) as e:
debug_client.add_traintuple(spec)
hybrid_client.add_traintuple(spec)
assert "Could not get all the data_samples in the database with the given data_sample_keys" in str(e.value)
2 changes: 1 addition & 1 deletion tests/workflows/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ mnist_workflow:

## run locally with docker or subprocess

`DEBUG_SPAWNER=subprocess pytest tests -v --durations=0 -m "workflows" --local`
`pytest tests -v --durations=0 -m "workflows" --subprocess`
2 changes: 1 addition & 1 deletion tests/workflows/mnist-fedavg/assets/aggregate_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def predict(self, inputs, outputs):
X = inputs["X"]
X = torch.FloatTensor(X)

model = self.load_model(inputs["model"])
model = self.load_model(inputs["models"])
model.eval()
# add the context manager to reduce computation overhead
with torch.no_grad():
Expand Down